Repository: BrainCog-X/Brain-Cog
Branch: main
Commit: f9b879f75da2
Files: 685
Total size: 4.5 MB
Directory structure:
gitextract_qe2qoke6/
├── .gitignore
├── LICENSE
├── README.md
├── braincog/
│ ├── __init__.py
│ ├── base/
│ │ ├── __init__.py
│ │ ├── brainarea/
│ │ │ ├── BrainArea.py
│ │ │ ├── IPL.py
│ │ │ ├── Insula.py
│ │ │ ├── PFC.py
│ │ │ ├── __init__.py
│ │ │ ├── basalganglia.py
│ │ │ └── dACC.py
│ │ ├── connection/
│ │ │ ├── CustomLinear.py
│ │ │ ├── __init__.py
│ │ │ └── layer.py
│ │ ├── conversion/
│ │ │ ├── __init__.py
│ │ │ ├── convertor.py
│ │ │ ├── merge.py
│ │ │ └── spicalib.py
│ │ ├── encoder/
│ │ │ ├── __init__.py
│ │ │ ├── encoder.py
│ │ │ ├── population_coding.py
│ │ │ └── qs_coding.py
│ │ ├── learningrule/
│ │ │ ├── BCM.py
│ │ │ ├── Hebb.py
│ │ │ ├── RSTDP.py
│ │ │ ├── STDP.py
│ │ │ ├── STP.py
│ │ │ └── __init__.py
│ │ ├── node/
│ │ │ ├── __init__.py
│ │ │ └── node.py
│ │ ├── strategy/
│ │ │ ├── LateralInhibition.py
│ │ │ ├── __init__.py
│ │ │ └── surrogate.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── criterions.py
│ │ └── visualization.py
│ ├── datasets/
│ │ ├── CUB2002011.py
│ │ ├── ESimagenet/
│ │ │ ├── ES_imagenet.py
│ │ │ ├── __init__.py
│ │ │ └── reconstructed_ES_imagenet.py
│ │ ├── NOmniglot/
│ │ │ ├── NOmniglot.py
│ │ │ ├── __init__.py
│ │ │ ├── nomniglot_full.py
│ │ │ ├── nomniglot_nw_ks.py
│ │ │ ├── nomniglot_pair.py
│ │ │ └── utils.py
│ │ ├── StanfordDogs.py
│ │ ├── TinyImageNet.py
│ │ ├── __init__.py
│ │ ├── bullying10k/
│ │ │ ├── __init__.py
│ │ │ └── bullying10k.py
│ │ ├── cut_mix.py
│ │ ├── datasets.py
│ │ ├── gen_input_signal.py
│ │ ├── hmdb_dvs/
│ │ │ ├── __init__.py
│ │ │ └── hmdb_dvs.py
│ │ ├── ncaltech101/
│ │ │ ├── __init__.py
│ │ │ └── ncaltech101.py
│ │ ├── rand_aug.py
│ │ ├── scripts/
│ │ │ ├── testlist01.txt
│ │ │ └── ucf101_dvs_preprocessing.py
│ │ ├── ucf101_dvs/
│ │ │ ├── __init__.py
│ │ │ └── ucf101_dvs.py
│ │ └── utils.py
│ ├── model_zoo/
│ │ ├── NeuEvo/
│ │ │ ├── __init__.py
│ │ │ ├── architect.py
│ │ │ ├── genotypes.py
│ │ │ ├── model.py
│ │ │ ├── model_search.py
│ │ │ ├── operations.py
│ │ │ └── others.py
│ │ ├── __init__.py
│ │ ├── backeinet.py
│ │ ├── base_module.py
│ │ ├── bdmsnn.py
│ │ ├── convnet.py
│ │ ├── fc_snn.py
│ │ ├── glsnn.py
│ │ ├── linearNet.py
│ │ ├── nonlinearNet.py
│ │ ├── qsnn.py
│ │ ├── resnet.py
│ │ ├── resnet19_snn.py
│ │ ├── rsnn.py
│ │ ├── sew_resnet.py
│ │ └── vgg_snn.py
│ └── utils.py
├── docs/
│ ├── Makefile
│ ├── make.bat
│ └── source/
│ ├── conf.py
│ ├── examples/
│ │ ├── Brain_Cognitive_Function_Simulation/
│ │ │ ├── drosophila.md
│ │ │ └── index.rst
│ │ ├── Decision_Making/
│ │ │ ├── BDM_SNN.md
│ │ │ ├── RL.md
│ │ │ └── index.rst
│ │ ├── Knowledge_Representation_and_Reasoning/
│ │ │ ├── CKRGSNN.md
│ │ │ ├── CRSNN.md
│ │ │ ├── SPSNN.md
│ │ │ ├── index.rst
│ │ │ └── musicMemory.md
│ │ ├── Multi-scale_Brain_Structure_Simulation/
│ │ │ ├── Corticothalamic_minicolumn.md
│ │ │ ├── HumanBrain.md
│ │ │ ├── Human_PFC.md
│ │ │ ├── MacaqueBrain.md
│ │ │ ├── index.rst
│ │ │ └── mouse_brain.md
│ │ ├── Perception_and_Learning/
│ │ │ ├── Conversion.md
│ │ │ ├── MultisensoryIntegration.md
│ │ │ ├── QSNN.md
│ │ │ ├── UnsupervisedSTDP.md
│ │ │ ├── img_cls/
│ │ │ │ ├── bp.md
│ │ │ │ ├── glsnn.md
│ │ │ │ └── index.rst
│ │ │ └── index.rst
│ │ ├── Social_Cognition/
│ │ │ ├── Mirror_Test.md
│ │ │ ├── ToM.md
│ │ │ └── index.rst
│ │ └── index.rst
│ ├── index.rst
│ ├── modules.rst
│ └── setup.rst
├── docs.md
├── documents/
│ ├── Data_engine.md
│ ├── Lectures.md
│ ├── Pub_brain_inspired_AI.md
│ ├── Pub_brain_simulation.md
│ ├── Pub_sh_codesign.md
│ ├── Publication.md
│ └── Tutorial.md
├── examples/
│ ├── Brain_Cognitive_Function_Simulation/
│ │ └── drosophila/
│ │ ├── README.md
│ │ └── drosophila.py
│ ├── Embodied_Cognition/
│ │ └── RHI/
│ │ ├── RHI_Test.py
│ │ ├── RHI_Train.py
│ │ └── ReadMe.md
│ ├── Hardware_acceleration/
│ │ ├── README.md
│ │ ├── firefly_v1_schedule_on_pynq.py
│ │ ├── standalone_utils.py
│ │ ├── ultra96_test.py
│ │ └── zcu104_test.py
│ ├── Knowledge_Representation_and_Reasoning/
│ │ ├── CKRGSNN/
│ │ │ ├── README.md
│ │ │ ├── main.py
│ │ │ └── sub_Conceptnet.csv
│ │ ├── CRSNN/
│ │ │ ├── README.md
│ │ │ └── main.py
│ │ ├── SPSNN/
│ │ │ ├── README.md
│ │ │ └── main.py
│ │ └── musicMemory/
│ │ ├── Areas/
│ │ │ ├── apac.py
│ │ │ ├── cortex.py
│ │ │ ├── pac.py
│ │ │ └── pfc.py
│ │ ├── Modal/
│ │ │ ├── PAC.py
│ │ │ ├── cluster.py
│ │ │ ├── composercluster.py
│ │ │ ├── composerlayer.py
│ │ │ ├── composerlifneuron.py
│ │ │ ├── genrecluster.py
│ │ │ ├── genrelayer.py
│ │ │ ├── genrelifneuron.py
│ │ │ ├── izhikevichneuron.py
│ │ │ ├── layer.py
│ │ │ ├── lifneuron.py
│ │ │ ├── note.py
│ │ │ ├── notecluster.py
│ │ │ ├── notelifneuron.py
│ │ │ ├── notesequencelayer.py
│ │ │ ├── pitch.py
│ │ │ ├── sequencelayer.py
│ │ │ ├── sequencememory.py
│ │ │ ├── synapse.py
│ │ │ ├── tempocluster.py
│ │ │ ├── tempolifneuron.py
│ │ │ ├── temposequencelayer.py
│ │ │ ├── titlecluster.py
│ │ │ ├── titlelayer.py
│ │ │ └── titlelifneuron.py
│ │ ├── README.md
│ │ ├── api/
│ │ │ └── music_engine_api.py
│ │ ├── conf/
│ │ │ ├── GenreData.txt
│ │ │ ├── MIDIData.txt
│ │ │ └── conf.py
│ │ ├── inputs/
│ │ │ ├── 1.txt
│ │ │ ├── Data.txt
│ │ │ ├── GenreData.txt
│ │ │ ├── MIDIData.txt
│ │ │ ├── chords.csv
│ │ │ ├── chords.xlsx
│ │ │ ├── information.csv
│ │ │ ├── keyIndex.csv
│ │ │ ├── keys.csv
│ │ │ ├── keys.xlsx
│ │ │ ├── modeindex.csv
│ │ │ ├── modeindex.xlsx
│ │ │ ├── pitch2midi.csv
│ │ │ └── tones2.csv
│ │ ├── result_output/
│ │ │ └── tone learning/
│ │ │ ├── C major_20241121155522.mid
│ │ │ ├── C major_20241122093822.mid
│ │ │ ├── C major_20241122094000.mid
│ │ │ ├── C major_20241122094419.mid
│ │ │ └── C major_20241122094736.mid
│ │ ├── task/
│ │ │ ├── Bach_generated.mid
│ │ │ ├── Classical_generated.mid
│ │ │ ├── Sonate C Major.Mid_recall.mid
│ │ │ ├── melody_generated.mid
│ │ │ ├── mode-conditioned learning.py
│ │ │ ├── musicGeneration.py
│ │ │ └── musicMemory.py
│ │ ├── testData/
│ │ │ ├── Bach/
│ │ │ │ └── prelude C major.mid
│ │ │ ├── JayZhou/
│ │ │ │ └── rainbow.mid
│ │ │ └── Mozart/
│ │ │ └── Sonate C major.mid
│ │ └── tools/
│ │ ├── __init__.py
│ │ ├── generateData.py
│ │ ├── hamonydataset_test.py
│ │ ├── msg.py
│ │ ├── msgq.py
│ │ ├── oscillations.py
│ │ ├── position.txt
│ │ ├── readjson.py
│ │ ├── testSound.py
│ │ ├── testmusic21.py
│ │ ├── testopengl.py
│ │ ├── testwave.py
│ │ └── xmlParser.py
│ ├── MotorControl/
│ │ └── experimental/
│ │ ├── README.md
│ │ ├── brain_area.py
│ │ ├── main.py
│ │ └── model.py
│ ├── Multiscale_Brain_Structure_Simulation/
│ │ ├── CorticothalamicColumn/
│ │ │ ├── README.md
│ │ │ ├── data/
│ │ │ │ ├── __init__.py
│ │ │ │ └── globaldata.py
│ │ │ ├── main.py
│ │ │ ├── model/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── cortex.py
│ │ │ │ ├── cortex_thalamus.py
│ │ │ │ ├── dendrite.py
│ │ │ │ ├── fire.csv
│ │ │ │ ├── layer.py
│ │ │ │ ├── synapse.py
│ │ │ │ └── thalamus.py
│ │ │ └── tools/
│ │ │ ├── __init__.py
│ │ │ ├── cortical.csv
│ │ │ ├── exdata.py
│ │ │ ├── layer.csv
│ │ │ ├── neuron.csv
│ │ │ └── synapse.csv
│ │ ├── Corticothalamic_Brain_Model/
│ │ │ ├── Bioinformatics_propofol_circle.py
│ │ │ ├── Readme.md
│ │ │ └── spectrogram.py
│ │ ├── HumanBrain/
│ │ │ ├── README.md
│ │ │ ├── human_brain.py
│ │ │ └── human_multi.py
│ │ ├── Human_Brain_Model/
│ │ │ ├── NA.py
│ │ │ ├── Readme.md
│ │ │ ├── gc.py
│ │ │ ├── main_246.py
│ │ │ ├── main_84.py
│ │ │ ├── pci.py
│ │ │ ├── pci_246.py
│ │ │ └── spectrogram.py
│ │ ├── Human_PFC_Model/
│ │ │ ├── README.md
│ │ │ └── Six_Layer_PFC.py
│ │ ├── MacaqueBrain/
│ │ │ ├── README.md
│ │ │ └── macaque_brain.py
│ │ └── MouseBrain/
│ │ ├── README.md
│ │ └── mouse_brain.py
│ ├── Perception_and_Learning/
│ │ ├── Conversion/
│ │ │ ├── burst_conversion/
│ │ │ │ ├── CIFAR10_VGG16.py
│ │ │ │ ├── README.md
│ │ │ │ └── converted_CIFAR10.py
│ │ │ └── msat_conversion/
│ │ │ ├── CIFAR10_VGG16.py
│ │ │ ├── README.md
│ │ │ ├── converted_CIFAR10.py
│ │ │ └── convertor.py
│ │ ├── IllusionPerception/
│ │ │ └── AbuttingGratingIllusion/
│ │ │ ├── distortion/
│ │ │ │ ├── __init__.py
│ │ │ │ └── abutting_grating_illusion/
│ │ │ │ ├── __init__.py
│ │ │ │ └── abutting_grating_distortion.py
│ │ │ └── main.py
│ │ ├── MultisensoryIntegration/
│ │ │ ├── README.md
│ │ │ └── code/
│ │ │ ├── MultisensoryIntegrationDEMO_AM.py
│ │ │ ├── MultisensoryIntegrationDEMO_IM.py
│ │ │ └── measure_and_visualization.py
│ │ ├── NeuEvo/
│ │ │ ├── auto_augment.py
│ │ │ ├── main.py
│ │ │ ├── separate_loss.py
│ │ │ ├── train.py
│ │ │ ├── train_search.py
│ │ │ └── utils.py
│ │ ├── QSNN/
│ │ │ ├── README.md
│ │ │ └── main.py
│ │ ├── UnsupervisedSTDP/
│ │ │ ├── Readme.md
│ │ │ └── codef.py
│ │ └── img_cls/
│ │ ├── bp/
│ │ │ ├── README.md
│ │ │ ├── main.py
│ │ │ ├── main_backei.py
│ │ │ └── main_simplified.py
│ │ ├── glsnn/
│ │ │ ├── README.md
│ │ │ └── cls_glsnn.py
│ │ ├── spiking_capsnet/
│ │ │ ├── README.md
│ │ │ └── spikingcaps.py
│ │ └── transfer_for_dvs/
│ │ ├── GradCAM_visualization.py
│ │ ├── README.md
│ │ ├── datasets.py
│ │ ├── main.py
│ │ ├── main_transfer.py
│ │ └── main_visual_losslandscape.py
│ ├── Snn_safety/
│ │ ├── DPSNN/
│ │ │ ├── Readme.txt
│ │ │ ├── load_data.py
│ │ │ ├── main_dpsnn.py
│ │ │ └── model.py
│ │ └── RandHet-SNN/
│ │ ├── README.md
│ │ ├── evaluate.py
│ │ ├── my_node.py
│ │ ├── sew_resnet.py
│ │ ├── train.py
│ │ └── utils.py
│ ├── Social_Cognition/
│ │ ├── FOToM/
│ │ │ ├── algorithms/
│ │ │ │ ├── ToM_class.py
│ │ │ │ ├── __init__.py
│ │ │ │ ├── maddpg.py
│ │ │ │ └── tom11.py
│ │ │ ├── common/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── distributions.py
│ │ │ │ ├── tile_images.py
│ │ │ │ └── vec_env/
│ │ │ │ ├── __init__.py
│ │ │ │ └── vec_env.py
│ │ │ ├── evaluate.py
│ │ │ ├── main.py
│ │ │ ├── multiagent/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── core.py
│ │ │ │ ├── environment.py
│ │ │ │ ├── multi_discrete.py
│ │ │ │ ├── policy.py
│ │ │ │ ├── rendering.py
│ │ │ │ ├── scenario.py
│ │ │ │ └── scenarios/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── hetero_spread.py
│ │ │ │ ├── simple.py
│ │ │ │ ├── simple_adversary.py
│ │ │ │ ├── simple_crypto.py
│ │ │ │ ├── simple_push.py
│ │ │ │ ├── simple_reference.py
│ │ │ │ ├── simple_speaker_listener.py
│ │ │ │ ├── simple_spread.py
│ │ │ │ ├── simple_tag.py
│ │ │ │ └── simple_world_comm.py
│ │ │ ├── readme.md
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── buffer.py
│ │ │ ├── env_wrappers.py
│ │ │ ├── make_env.py
│ │ │ ├── misc.py
│ │ │ ├── multiprocessing.py
│ │ │ ├── networks.py
│ │ │ └── noise.py
│ │ ├── Intention_Prediction/
│ │ │ └── Intention_Prediction.py
│ │ ├── MAToM-SNN/
│ │ │ ├── LICENSE
│ │ │ ├── MPE/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── agents/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── agents.py
│ │ │ │ ├── common/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── distributions.py
│ │ │ │ │ ├── tile_images.py
│ │ │ │ │ └── vec_env/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── vec_env.py
│ │ │ │ ├── main.py
│ │ │ │ ├── multiagent/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── scenarios/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── simple.py
│ │ │ │ │ ├── simple_crypto.py
│ │ │ │ │ ├── simple_push.py
│ │ │ │ │ ├── simple_reference.py
│ │ │ │ │ ├── simple_speaker_listener.py
│ │ │ │ │ ├── simple_spread.py
│ │ │ │ │ └── simple_world_comm.py
│ │ │ │ ├── policy/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── maddpg.py
│ │ │ │ └── utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── buffer.py
│ │ │ │ ├── env_wrappers.py
│ │ │ │ ├── make_env.py
│ │ │ │ ├── misc.py
│ │ │ │ ├── multiprocessing.py
│ │ │ │ ├── networks.py
│ │ │ │ └── noise.py
│ │ │ ├── README.md
│ │ │ └── STAG/
│ │ │ ├── agents/
│ │ │ │ ├── __init__.py
│ │ │ │ └── sagent.py
│ │ │ ├── common_sr/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── arguments.py
│ │ │ │ ├── dummy_vec_env.py
│ │ │ │ ├── multiprocessing_env.py
│ │ │ │ ├── replay_buffer.py
│ │ │ │ └── srollout.py
│ │ │ ├── envs/
│ │ │ │ ├── Stag_Hunt_env.py
│ │ │ │ ├── __init__.py
│ │ │ │ ├── abstract.py
│ │ │ │ └── constants.py
│ │ │ ├── main_spiking.py
│ │ │ ├── network/
│ │ │ │ ├── __init__.py
│ │ │ │ └── spiking_net.py
│ │ │ ├── policy/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── dqn.py
│ │ │ │ ├── stomvdn.py
│ │ │ │ └── svdn.py
│ │ │ ├── preprocessoing/
│ │ │ │ ├── __init__.py
│ │ │ │ └── common.py
│ │ │ └── runner.py
│ │ ├── ReadMe.md
│ │ ├── SmashVat/
│ │ │ ├── dqn.py
│ │ │ ├── environment.py
│ │ │ ├── main.py
│ │ │ ├── manual_control.py
│ │ │ ├── qnets.py
│ │ │ ├── side_effect_eval.py
│ │ │ └── window.py
│ │ ├── ToCM/
│ │ │ ├── README.md
│ │ │ ├── agent/
│ │ │ │ ├── controllers/
│ │ │ │ │ └── ToCMController.py
│ │ │ │ ├── learners/
│ │ │ │ │ └── ToCMLearner.py
│ │ │ │ ├── memory/
│ │ │ │ │ └── ToCMMemory.py
│ │ │ │ ├── models/
│ │ │ │ │ └── ToCMModel.py
│ │ │ │ ├── optim/
│ │ │ │ │ ├── loss.py
│ │ │ │ │ └── utils.py
│ │ │ │ ├── runners/
│ │ │ │ │ └── ToCMRunner.py
│ │ │ │ ├── utils/
│ │ │ │ │ └── params.py
│ │ │ │ └── workers/
│ │ │ │ └── ToCMWorker.py
│ │ │ ├── configs/
│ │ │ │ ├── Config.py
│ │ │ │ ├── EnvConfigs.py
│ │ │ │ ├── Experiment.py
│ │ │ │ ├── ToCM/
│ │ │ │ │ ├── ToCMAgentConfig.py
│ │ │ │ │ ├── ToCMControllerConfig.py
│ │ │ │ │ ├── ToCMLearnerConfig.py
│ │ │ │ │ └── optimal/
│ │ │ │ │ └── starcraft/
│ │ │ │ │ ├── AgentConfig.py
│ │ │ │ │ └── LearnerConfig.py
│ │ │ │ └── __init__.py
│ │ │ ├── env/
│ │ │ │ ├── mpe/
│ │ │ │ │ └── MPE.py
│ │ │ │ └── starcraft/
│ │ │ │ └── StarCraft.py
│ │ │ ├── environments.py
│ │ │ ├── mpe/
│ │ │ │ ├── MPE_Env.py
│ │ │ │ ├── __init__.py
│ │ │ │ ├── core.py
│ │ │ │ ├── environment.py
│ │ │ │ ├── multi_discrete.py
│ │ │ │ ├── rendering.py
│ │ │ │ ├── scenario.py
│ │ │ │ └── scenarios/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── hetero_spread.py
│ │ │ │ ├── simple_adversary.py
│ │ │ │ ├── simple_crypto.py
│ │ │ │ ├── simple_crypto_display.py
│ │ │ │ ├── simple_push.py
│ │ │ │ ├── simple_reference.py
│ │ │ │ ├── simple_speaker_listener.py
│ │ │ │ ├── simple_spread.py
│ │ │ │ ├── simple_tag.py
│ │ │ │ └── simple_world_comm.py
│ │ │ ├── networks/
│ │ │ │ ├── ToCM/
│ │ │ │ │ ├── action.py
│ │ │ │ │ ├── critic.py
│ │ │ │ │ ├── dense.py
│ │ │ │ │ ├── rnns.py
│ │ │ │ │ ├── utils.py
│ │ │ │ │ └── vae.py
│ │ │ │ └── transformer/
│ │ │ │ └── layers.py
│ │ │ ├── requirements.txt
│ │ │ ├── run.sh
│ │ │ ├── smac/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── bin/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── map_list.py
│ │ │ │ ├── env/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── multiagentenv.py
│ │ │ │ │ ├── pettingzoo/
│ │ │ │ │ │ ├── StarCraft2PZEnv.py
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ └── test/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── all_test.py
│ │ │ │ │ │ └── smac_pettingzoo_test.py
│ │ │ │ │ └── starcraft2/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── maps/
│ │ │ │ │ │ ├── SMAC_Maps/
│ │ │ │ │ │ │ └── 2s_vs_1sc.SC2Map
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ └── smac_maps.py
│ │ │ │ │ ├── render.py
│ │ │ │ │ └── starcraft2.py
│ │ │ │ └── examples/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── pettingzoo/
│ │ │ │ │ ├── README.rst
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── pettingzoo_demo.py
│ │ │ │ ├── random_agents.py
│ │ │ │ └── rllib/
│ │ │ │ ├── README.rst
│ │ │ │ ├── __init__.py
│ │ │ │ ├── env.py
│ │ │ │ ├── model.py
│ │ │ │ ├── run_ppo.py
│ │ │ │ └── run_qmix.py
│ │ │ ├── train.py
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── mlp_buffer.py
│ │ │ ├── mlp_nstep_buffer.py
│ │ │ ├── popart.py
│ │ │ ├── rec_buffer.py
│ │ │ ├── segment_tree.py
│ │ │ └── util.py
│ │ ├── ToM/
│ │ │ ├── BrainArea/
│ │ │ │ ├── PFC_ToM.py
│ │ │ │ ├── TPJ.py
│ │ │ │ ├── __init__.py
│ │ │ │ ├── dACC.py
│ │ │ │ ├── one_hot.py
│ │ │ │ └── test.py
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── data/
│ │ │ │ ├── NPC_assessment.csv
│ │ │ │ ├── agent_assessment.csv
│ │ │ │ ├── injury_memory.txt
│ │ │ │ ├── injury_value.txt
│ │ │ │ └── one_hot.py
│ │ │ ├── env/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── env.py
│ │ │ │ ├── env3_train_env00.py
│ │ │ │ └── env3_train_env01.py
│ │ │ ├── main_ToM.py
│ │ │ ├── main_both.py
│ │ │ ├── rulebasedpolicy/
│ │ │ │ ├── Find_a_way.py
│ │ │ │ ├── __init__.py
│ │ │ │ ├── a_star.py
│ │ │ │ ├── load_statedata.py
│ │ │ │ ├── point.py
│ │ │ │ ├── random_map.py
│ │ │ │ ├── statedata_pre.py
│ │ │ │ ├── train.txt
│ │ │ │ └── world_model.py
│ │ │ └── utils/
│ │ │ ├── Encoder.py
│ │ │ └── one_hot.py
│ │ ├── affective_empathy/
│ │ │ ├── BAE-SNN/
│ │ │ │ ├── BAESNN.py
│ │ │ │ ├── README.md
│ │ │ │ ├── env_poly.py
│ │ │ │ └── env_two_poly.py
│ │ │ ├── BEEAD-SNN/
│ │ │ │ ├── BEEAD-SNN.py
│ │ │ │ ├── README.md
│ │ │ │ ├── RL_Brain.py
│ │ │ │ ├── env.py
│ │ │ │ ├── env_poly_SNN.py
│ │ │ │ ├── rsnn.py
│ │ │ │ ├── sd_env.py
│ │ │ │ └── snowdrift_main.py
│ │ │ └── BRP-SNN/
│ │ │ ├── BRP-SNN.py
│ │ │ ├── README.md
│ │ │ ├── env_poly_SNN.py
│ │ │ └── env_two_poly_SNN.py
│ │ └── mirror_test/
│ │ ├── README.md
│ │ └── mirror_test.py
│ ├── Spiking-Transformers/
│ │ ├── LIFNode.py
│ │ ├── README.md
│ │ ├── datasets.py
│ │ ├── main.py
│ │ └── models/
│ │ ├── spike_driven_transformer.py
│ │ ├── spike_driven_transformer_dvs.py
│ │ ├── spike_driven_transformer_v2.py
│ │ ├── spike_driven_transformer_v2_dvs.py
│ │ ├── spikformer.py
│ │ └── spikformer_dvs.py
│ ├── Structural_Development/
│ │ ├── DPAP/
│ │ │ ├── README.md
│ │ │ ├── mask_model.py
│ │ │ ├── prun_main.py
│ │ │ └── utils.py
│ │ ├── DSD-SNN/
│ │ │ ├── README.md
│ │ │ └── cifar100/
│ │ │ ├── available.py
│ │ │ ├── main_simplified.py
│ │ │ ├── manipulate.py
│ │ │ ├── maskcl2.py
│ │ │ └── vgg_snn.py
│ │ ├── ELSM/
│ │ │ ├── evolve.py
│ │ │ ├── lsm.py
│ │ │ ├── model.py
│ │ │ ├── nsganet.py
│ │ │ └── spikes.py
│ │ ├── SCA-SNN/
│ │ │ ├── README.md
│ │ │ ├── configs/
│ │ │ │ └── train.yaml
│ │ │ ├── inclearn/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── convnet/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── classifier.py
│ │ │ │ │ ├── imbalance.py
│ │ │ │ │ ├── maskcl2.py
│ │ │ │ │ ├── network.py
│ │ │ │ │ ├── resnet.py
│ │ │ │ │ ├── sew_resnet.py
│ │ │ │ │ └── utils.py
│ │ │ │ ├── datasets/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── data.py
│ │ │ │ │ └── dataset.py
│ │ │ │ ├── models/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── base.py
│ │ │ │ │ └── incmodel.py
│ │ │ │ └── tools/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── autoaugment_extra.py
│ │ │ │ ├── cutout.py
│ │ │ │ ├── data_utils.py
│ │ │ │ ├── factory.py
│ │ │ │ ├── memory.py
│ │ │ │ ├── metrics.py
│ │ │ │ ├── results_utils.py
│ │ │ │ ├── scheduler.py
│ │ │ │ ├── similar.py
│ │ │ │ └── utils.py
│ │ │ └── main.py
│ │ └── SD-SNN/
│ │ ├── README.md
│ │ ├── main.py
│ │ ├── prun_and_generation.py
│ │ ├── snn_model.py
│ │ └── utils.py
│ ├── Structure_Evolution/
│ │ ├── Adaptive_lsm/
│ │ │ ├── BrainCog-Version/
│ │ │ │ ├── README.md
│ │ │ │ ├── brid.py
│ │ │ │ ├── lsmmodel.py
│ │ │ │ ├── maze.py
│ │ │ │ └── tools/
│ │ │ │ ├── EnuGlobalNetwork.py
│ │ │ │ ├── ExperimentEnvGlobalNetworkSurvival.py
│ │ │ │ ├── MazeTurnEnvVec.py
│ │ │ │ └── nsganet.py
│ │ │ └── raw/
│ │ │ ├── BCM.py
│ │ │ ├── README.md
│ │ │ ├── lstm.py
│ │ │ ├── main.py
│ │ │ ├── pltbcm.py
│ │ │ ├── pltrank.py
│ │ │ ├── q_l.py
│ │ │ └── tools/
│ │ │ ├── EnuGlobalNetwork.py
│ │ │ ├── ExperimentEnvGlobalNetworkSurvival.py
│ │ │ └── MazeTurnEnvVec.py
│ │ ├── EB-NAS/
│ │ │ ├── acc_predictor/
│ │ │ │ ├── adaptive_switching.py
│ │ │ │ ├── carts.py
│ │ │ │ ├── factory.py
│ │ │ │ ├── gp.py
│ │ │ │ ├── mlp.py
│ │ │ │ └── rbf.py
│ │ │ ├── cellmodel.py
│ │ │ ├── ebnas.py
│ │ │ ├── micro_encoding.py
│ │ │ ├── motifs.py
│ │ │ ├── nsganet.py
│ │ │ ├── operations.py
│ │ │ ├── readme.md
│ │ │ ├── single_genome.py
│ │ │ └── tm.py
│ │ ├── ELSM/
│ │ │ ├── README.md
│ │ │ ├── evolve.py
│ │ │ ├── lsm.py
│ │ │ ├── model.py
│ │ │ ├── nsganet.py
│ │ │ └── spikes.py
│ │ └── MSE-NAS/
│ │ ├── auto_augment.py
│ │ ├── cellmodel.py
│ │ ├── evolution.py
│ │ ├── loss_f.py
│ │ ├── micro_encoding.py
│ │ ├── motifs.py
│ │ ├── nsganet.py
│ │ ├── obj.py
│ │ ├── operations.py
│ │ ├── readme.md
│ │ ├── tm.py
│ │ └── utils.py
│ ├── TIM/
│ │ ├── README.md
│ │ ├── main.py
│ │ ├── models/
│ │ │ ├── TIM.py
│ │ │ ├── spikformer_braincog_DVS.py
│ │ │ └── spikformer_braincog_SHD.py
│ │ └── utils/
│ │ ├── MyGrad.py
│ │ ├── MyNode.py
│ │ └── datasets.py
│ └── decision_making/
│ ├── BDM-SNN/
│ │ ├── BDM-SNN-UAV.py
│ │ ├── BDM-SNN-hh.py
│ │ ├── BDM-SNN.py
│ │ ├── README.md
│ │ └── decisionmaking.py
│ ├── RL/
│ │ ├── README.md
│ │ ├── atari/
│ │ │ ├── __init__.py
│ │ │ └── atari_wrapper.py
│ │ ├── mcs-fqf/
│ │ │ ├── discrete.py
│ │ │ ├── main.py
│ │ │ ├── network.py
│ │ │ └── policy.py
│ │ ├── requirements.txt
│ │ ├── sdqn/
│ │ │ ├── main.py
│ │ │ └── network.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ └── normalization.py
│ └── swarm/
│ ├── Collision-Avoidance.py
│ └── README.md
├── requirements.txt
└── setup.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.idea
*.egg-info/
eggs/
.eggs/
*.exe
*.pyc
/.vscode/
*.code-workspace
__pycache__
# Sphinx documentation
docs/_build/
docs/build/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# event data
*.bin
*.dat
*.pt
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# BrainCog
---
BrainCog is an open source spiking neural network based brain-inspired
cognitive intelligence engine for Brain-inspired Artificial Intelligence, Brain-inspired Embodied AI, and brain simulation. More information on BrainCog can be found on its homepage http://www.brain-cog.network/
The current version of BrainCog contains at least 50 functional spiking neural network algorithms (including but not limited to perception and learning, decision making, knowledge representation and reasoning, motor control, social cognition, etc.) built based on BrainCog infrastructures, and BrainCog also provide brain simulations to drosophila, rodent, monkey, and human brains at multiple scales based on spiking neural networks at multiple scales. More detail in http://www.brain-cog.network/docs/
BrainCog is a community based effort for spiking neural network based artificial intelligence, and we welcome any forms of contributions, from contributing to the development of core components, to contributing for applications.
BrainCog provides essential and fundamental components to model biological and artificial intelligence.

Our paper is published in [Patterns](https://www.cell.com/patterns/fulltext/S2666-3899(23)00144-7?_returnURL=https%3A%2F%2Flinkinghub.elsevier.com%2Fretrieve%2Fpii%2FS2666389923001447%3Fshowall%3Dtrue). If you use BrainCog in your research, the following paper can be cited as the source for BrainCog.
```bib
@article{Zeng2023,
doi = {10.1016/j.patter.2023.100789},
url = {https://doi.org/10.1016/j.patter.2023.100789},
year = {2023},
month = jul,
publisher = {Cell Press},
pages = {100789},
author = {Yi Zeng and Dongcheng Zhao and Feifei Zhao and Guobin Shen and Yiting Dong and Enmeng Lu and Qian Zhang and Yinqian Sun and Qian Liang and Yuxuan Zhao and Zhuoya Zhao and Hongjian Fang and Yuwei Wang and Yang Li and Xin Liu and Chengcheng Du and Qingqun Kong and Zizhe Ruan and Weida Bi},
title = {{BrainCog}: A spiking neural network based, brain-inspired cognitive intelligence engine for brain-inspired {AI} and brain simulation},
journal = {Patterns}
}
```
## Brain-Inspired AI
BrainCog currently provides cognitive functions components that can be classified
into five categories:
* Perception and Learning
* Knowledge Representation and Reasoning
* Decision Making
* Motor Control
* Social Cognition
* Development and Evolution
* Safety and Security
## Brain Simulation
BrainCog currently include two parts for brain simulation:
* Brain Cognitive Function Simulation
* Multi-scale Brain Structure Simulation
The anatomical and imaging data is used to support our simulation from various aspects.
## Software-Hardware Codesign (BrainCog Firefly)
BrainCog currently provides `hardware acceleration` for spiking neural network based brain-inspired AI.
The following papers are most recent advancement of BrainCog Firefly series for Software-Hardware Codesign for Brain-inspired AI.
* Tenglong Li, Jindong Li, Guobin Shen, Dongcheng Zhao, Qian Zhang, Yi Zeng. FireFly-S: Exploiting Dual-Side Sparsity for Spiking Neural Networks Acceleration With Reconfigurable Spatial Architecture. IEEE Transactions on Circuits and Systems I (TCAS-I), 2024.(https://doi.org/10.1109/TCSI.2024.3496554)
* Jindong Li, Guobin Shen, Dongcheng Zhao, Qian Zhang, Yi Zeng. Firefly v2: Advancing hardware support for high-performance spiking neural network with a spatiotemporal fpga accelerator. IEEE Transactions on Computer-Aided Design of Integrated Circuits and Systems, 2024. (https://ieeexplore.ieee.org/abstract/document/10478105/)
* Jindong Li, Guobin Shen, Dongcheng Zhao, Qian Zhang, Yi Zeng. FireFly: A High-Throughput Hardware Accelerator for Spiking Neural Networks With Efficient DSP and Memory Optimization. IEEE Transactions on Very Large Scale Integration (VLSI) Systems, 2023. (https://ieeexplore.ieee.org/document/10143752)
## Embodied AI and Robotics (BrainCog Embot)


BrainCog Embot is an Embodied AI platform under the Brain-inspired Cognitive Intelligence Engine (BrainCog) framework, which is an open-source Brain-inspired AI platform based on Spiking Neural Network.
The following papers are most recent advancement of BrainCog Embot:
* Qianhao Wang, Yinqian Sun, Enmeng Lu, Qian Zhang, Yi Zeng. Brain-Inspired Action Generation with Spiking Transformer Diffusion Policy Model. Advances in Brain Inspired Cognitive Systems (BICS), 2024.(https://link.springer.com/chapter/10.1007/978-981-96-2882-7_23)
* Yinqian Sun, Feifei Zhao, Mingyang Lv, Yi Zeng. Implementing Spiking World Model with Multi-Compartment Neurons for Model-based Reinforcement Learning, 2025. (https://arxiv.org/abs/2503.00713)
* Qianhao Wang, Yinqian Sun, Enmeng Lu, Qian Zhang, Yi Zeng. MTDP: Modulated Transformer Diffusion Policy Model, 2025. (https://arxiv.org/abs/2502.09029)
## Resources
### [[Lectures]](https://github.com/BrainCog-X/Brain-Cog/blob/main/documents/Lectures.md) | [[Tutorial]](https://github.com/BrainCog-X/Brain-Cog/blob/main/documents/Tutorial.md)
## Publications using BrainCog
### [[Brain Inspired AI]](https://github.com/BrainCog-X/Brain-Cog/blob/main/documents/Publication.md) | [[Brain Simulation]](https://github.com/BrainCog-X/Brain-Cog/blob/main/documents/Pub_brain_simulation.md) | [[Software-Hardware Co-design]](https://github.com/BrainCog-X/Brain-Cog/blob/main/documents/Pub_sh_codesign.md)
## BrainCog Data Engine
### [BrainCog Data Engine](https://github.com/BrainCog-X/Brain-Cog/blob/main/documents/Data_engine.md)
## Requirements:
* numpy
* scipy
* h5py
* torch
* torchvision
* torchaudio
* timm == 0.6.13
* scikit-learn
* einops
* thop
* pyyaml
* matplotlib
* seaborn
* pygame
* dv
* tensorboard
* tonic
## Install
### Install Online
1. You can install braincog by running:
> `pip install braincog`
2. Also, install from github by running:
> `pip install git+https://github.com/braincog-X/Brain-Cog.git`
### Install locally
1. If you are a developer, it is recommanded to download or clone
braincog from github.
> `git clone https://github.com/braincog-X/Brain-Cog.git`
2. Enter the folder of braincog
> `cd Brain-Cog`
3. Install braincog locally
> `pip install -e .`
## Example
1. Examples for Image Classification
```shell
cd ./examples/Perception_and_Learning/img_cls/bp
python main.py --model cifar_convnet --dataset cifar10 --node-type LIFNode --step 8 --device 0
```
2. Examples for Event Classification
```shell
cd ./examples/Perception_and_Learning/img_cls/bp
python main.py --model dvs_convnet --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGrad --device 0
```
Other BrainCog features and tutorials can be found at http://www.brain-cog.network/docs/
## BrainCog Assistant
Please add our BrainCog Assitant via wechat and we will invite you to our wechat developer group.

## Maintenance
This project is led by
**1.Brain-inspired Cognitive Intelligence Lab, Institute of Automation, Chinese Academy of Sciences http://www.braincog.ai/**
**2.Center for Long-term Artificial Intelligence (CLAI) http://long-term-ai.center/**
================================================
FILE: braincog/__init__.py
================================================
# __all__ = ['base', 'datasets', 'model_zoo', 'utils']
#
# from . import (
# base,
# datasets,
# model_zoo,
# utils
# )
================================================
FILE: braincog/base/__init__.py
================================================
__all__ = ['node', 'connection', 'learningrule', 'brainarea', 'encoder', 'utils', 'conversion']
from . import (
node,
strategy,
connection,
conversion,
learningrule,
brainarea,
utils,
encoder
)
================================================
FILE: braincog/base/brainarea/BrainArea.py
================================================
import numpy as np
import torch, os, sys
from torch import nn
from torch.nn import Parameter
import abc
import math
from abc import ABC
import numpy as np
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from braincog.base.node.node import *
from braincog.base.learningrule.STDP import *
from braincog.base.connection.CustomLinear import *
class BrainArea(nn.Module, abc.ABC):
"""
脑区基类
"""
@abc.abstractmethod
def __init__(self):
"""
"""
super().__init__()
@abc.abstractmethod
def forward(self, x):
"""
计算前向传播过程
:return:x是脉冲
"""
return x
def reset(self):
"""
计算前向传播过程
:return:x是脉冲
"""
pass
class ThreePointForward(BrainArea):
"""
三点前馈脑区
"""
def __init__(self, w1, w2, w3):
"""
"""
super().__init__()
self.node = [IFNode(), IFNode(), IFNode()]
self.connection = [CustomLinear(w1), CustomLinear(w2), CustomLinear(w3)]
self.stdp = []
self.stdp.append(STDP(self.node[0], self.connection[0]))
self.stdp.append(STDP(self.node[1], self.connection[1]))
self.stdp.append(STDP(self.node[2], self.connection[2]))
def forward(self, x):
"""
计算前向传播过程
:return:x是脉冲
"""
x, dw1 = self.stdp[0](x)
x, dw2 = self.stdp[1](x)
x, dw3 = self.stdp[2](x)
return x, (*dw1, *dw2, *dw3)
class Feedback(BrainArea):
"""
反馈网络
"""
def __init__(self, w1, w2, w3):
"""
"""
super().__init__()
self.node = [IFNode(), IFNode()]
self.connection = [CustomLinear(w1), CustomLinear(w2), CustomLinear(w3)]
self.stdp = []
self.stdp.append(MutliInputSTDP(self.node[0], [self.connection[0], self.connection[2]]))
self.stdp.append(STDP(self.node[1], self.connection[1]))
self.x1 = torch.zeros(1, w3.shape[0])
def forward(self, x):
"""
计算前向传播过程
:return:x是脉冲
"""
x, dw1 = self.stdp[0](x, self.x1)
self.x1, dw2 = self.stdp[1](x)
return self.x1, (*dw1, *dw2)
def reset(self):
self.x1 *= 0
class TwoInOneOut(BrainArea):
"""
反馈网络
"""
def __init__(self, w1, w2):
"""
"""
super().__init__()
self.node = [IFNode()]
self.connection = [CustomLinear(w1), CustomLinear(w2)]
self.stdp = []
self.stdp.append(MutliInputSTDP(self.node[0], [self.connection[0], self.connection[1]]))
def forward(self, x1, x2):
"""
计算前向传播过程
:return:x是脉冲
"""
x, dw1 = self.stdp[0](x1, x2)
return x, dw1
class SelfConnectionArea(BrainArea):
"""
反馈网络
"""
def __init__(self, w1, w2 ):
"""
"""
super().__init__()
self.node = [IFNode() ]
self.connection = [CustomLinear(w1), CustomLinear(w2) ]
self.stdp = []
self.stdp.append(MutliInputSTDP(self.node[0], [self.connection[0], self.connection[1]]))
self.x1 = torch.zeros(1, w2.shape[0])
def forward(self, x):
"""
计算前向传播过程
:return:x是脉冲
"""
self.x1, dw1 = self.stdp[0](x, self.x1)
return self.x1, dw1
def reset(self):
self.x1 *= 0
if __name__ == "__main__":
T = 20
w1 = torch.tensor([[1., 1], [1, 1]])
w2 = torch.tensor([[1., 1], [1, 1]])
w3 = torch.tensor([[0.4, 0.4], [0.4, 0.4]])
ba = TwoInOneOut(w1, w2)
for i in range(T):
x = ba(torch.tensor([[0.1, 0.1]]), torch.tensor([[0.1, 0.1]]))
print(x[0])
================================================
FILE: braincog/base/brainarea/IPL.py
================================================
from braincog.base.learningrule.STDP import *
from braincog.base.node.node import *
from braincog.base.connection.CustomLinear import *
import random
import numpy as np
import torch
import os
import sys
from torch import nn
from torch.nn import Parameter
import abc
import math
from abc import ABC
import numpy as np
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
import matplotlib.pyplot as plt
from braincog.base.strategy.surrogate import *
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
class IPLNet(nn.Module):
"""
inferior parietal lobule (IPL)
"""
def __init__(self, connection):
"""
Setting the network structure of IPL
"""
super().__init__()
# IPLM, IPLV
self.num_subMB = 2
self.node = [IzhNodeMU(threshold=30., a=0.02, b=0.2, c=-65., d=6., mem=-70.) for i in range(self.num_subMB)]
self.connection = connection
self.learning_rule = []
self.learning_rule.append(STDP(self.node[0], self.connection[0])) # vPMC_input-IPLM
self.learning_rule.append(MutliInputSTDP(self.node[1], [self.connection[1], self.connection[2]])) # STS_input-IPLV, IPLM-IPLV
self.out_IPLM = torch.zeros((self.connection[0].weight.shape[1]), dtype=torch.float)
self.out_IPLV = torch.zeros((self.connection[1].weight.shape[1]), dtype=torch.float)
def forward(self, input1, input2): # input from vPMC and STS
"""
Calculate the output of IPLv and the weight update between IPLm and IPLv
:param input1: input from vPMC
:param input2: input from STS
:return: output of IPLv, weight update between IPLm and IPLv
"""
self.out_IPLM = self.node[0](self.connection[0](input1))
self.out_IPLV, dw_IPLv = self.learning_rule[1](input2, self.out_IPLM)
if sum(sum(self.out_IPLV)) == 1:
dw_IPLv = dw_IPLv[0][torch.nonzero(dw_IPLv[1])[0][1]][torch.nonzero(dw_IPLv[1])[0][1]] * dw_IPLv[1]
else:
dw_IPLv = dw_IPLv[0]
return self.out_IPLV, dw_IPLv
def UpdateWeight(self, i, dw):
"""
Update the weight
:param i: index of the connection to update
:param dw: weight update
:return: None
"""
self.connection[i].update(dw)
def reset(self):
"""
reset the network
:return: None
"""
for i in range(self.num_subMB):
self.node[i].n_reset()
for i in range(len(self.learning_rule)):
self.learning_rule[i].reset()
def getweight(self):
"""
Get the connection and weight in IPL
:return: connection
"""
return self.connection
================================================
FILE: braincog/base/brainarea/Insula.py
================================================
import numpy as np
import torch,os,sys
from torch import nn
from torch.nn import Parameter
import abc
import math
from abc import ABC
import numpy as np
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
import matplotlib.pyplot as plt
from braincog.base.strategy.surrogate import *
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import random
from braincog.base.connection.CustomLinear import *
from braincog.base.node.node import *
from braincog.base.learningrule.STDP import *
class InsulaNet(nn.Module):
"""
Insula
"""
def __init__(self,connection):
"""
Setting the network structure of Insula
"""
super().__init__()
# Insula
self.num_subMB = 1
self.node = [IzhNodeMU(threshold=30., a=0.02, b=0.2, c=-65., d=6., mem=-70.) for i in range(self.num_subMB)]
self.connection = connection
self.learning_rule = []
self.learning_rule.append(MutliInputSTDP(self.node[0], [self.connection[0],self.connection[1]]))# IPLv-Insula, STS-Insula
self.Insula=torch.zeros((self.connection[1].weight.shape[1]), dtype=torch.float)
def forward(self, input1, input2): # input from IPLv and STS
"""
Calculate the output of Insula
:param input1: input from IPLv
:param input2: input from STS
:return: output of Insula, weight update (unused)
"""
self.out_Insula, dw_Insula = self.learning_rule[0](input1, input2)
return self.out_Insula
def UpdateWeight(self,i,dw):
"""
Update the weight
:param i: index of the connection to update
:param dw: weight update
:return: None
"""
self.connection[i].update(dw)
def reset(self):
"""
reset the network
:return: None
"""
for i in range(self.num_subMB):
self.node[i].n_reset()
for i in range(len(self.learning_rule)):
self.learning_rule[i].reset()
def getweight(self):
"""
Get the connection and weight in Insula
:return: connection
"""
return self.connection
================================================
FILE: braincog/base/brainarea/PFC.py
================================================
import torch
from torch import nn
from braincog.base.brainarea import BrainArea
from braincog.model_zoo.base_module import BaseLinearModule, BaseModule
class PFC:
"""
PFC
"""
def __init__(self):
"""
"""
super().__init__()
def forward(self, x):
"""
:return:x
"""
return x
def reset(self):
"""
:return:x
"""
pass
class dlPFC(BaseModule, PFC):
"""
SNNLinear
"""
def __init__(self,
step,
encode_type,
in_features:int,
out_features:int,
bias,
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.bias = bias
self.in_features = in_features
self.out_features = out_features
self.fc = self._create_fc()
self.c = self._rest_c()
def _rest_c(self):
c = torch.rand((self.out_features, self.in_features)) # eligibility trace
return c
def _create_fc(self):
"""
the connection of the SNN linear
@return: nn.Linear
"""
fc = nn.Linear(in_features=self.in_features,
out_features=self.out_features, bias=self.bias)
return fc
================================================
FILE: braincog/base/brainarea/__init__.py
================================================
from .basalganglia import basalganglia
from .BrainArea import BrainArea, ThreePointForward, Feedback, TwoInOneOut, SelfConnectionArea
from .Insula import InsulaNet
from .IPL import IPLNet
from .PFC import PFC, dlPFC
__all__ = [
'basalganglia',
'BrainArea', 'ThreePointForward', 'Feedback', 'TwoInOneOut', 'SelfConnectionArea',
'InsulaNet',
'IPLNet',
'PFC', 'dlPFC'
]
================================================
FILE: braincog/base/brainarea/basalganglia.py
================================================
import numpy as np
import torch
import os
import sys
from torch import nn
from torch.nn import Parameter
import abc
import math
from abc import ABC
import numpy as np
import torch
import torch.nn.functional as F
from braincog.base.strategy.surrogate import *
from braincog.base.node.node import IFNode, SimHHNode
from braincog.base.learningrule.STDP import STDP, MutliInputSTDP
from braincog.base.connection.CustomLinear import CustomLinear
class basalganglia(nn.Module):
"""
Basal Ganglia
"""
def __init__(self, ns, na, we, wi, node_type):
super().__init__()
"""
:param ns: 状态个数
:param na:动作个数
:param we:兴奋性连接权重
:param wi:抑制性连接权重
"""
num_state = ns
num_action = na
num_STN = 2
weight_exc = we
weight_inh = wi
# connetions: 0DLPFC-StrD1 1DLPFC-StrD2 2DLPFC-STN 3StrD1-GPi 4StrD2-GPe 5Gpe-Gpi 6STN-Gpi 7STN-Gpe 8Gpe-STN
bg_connection = []
bg_con_mask = []
# DLPFC-StrD1
con_matrix1 = torch.zeros((num_state, num_state * num_action), dtype=torch.float)
for i in range(num_state):
for j in range(num_action):
con_matrix1[i, i * num_action + j] = 1
bg_con_mask.append(con_matrix1)
bg_connection.append(CustomLinear(weight_exc * con_matrix1, con_matrix1))
# DLPFC-StrD2
bg_connection.append(CustomLinear(weight_exc * con_matrix1, con_matrix1))
bg_con_mask.append(con_matrix1)
# DLPFC-STN
con_matrix3 = torch.ones((num_state, num_STN), dtype=torch.float)
bg_con_mask.append(con_matrix3)
bg_connection.append(CustomLinear(weight_exc * con_matrix3, con_matrix3))
# StrD1-GPi
con_matrix4 = torch.zeros((num_state * num_action, num_action), dtype=torch.float)
for i in range(num_state):
for j in range(num_action):
con_matrix4[i * num_action + j, j] = 1
bg_con_mask.append(con_matrix4)
bg_connection.append(CustomLinear(weight_inh * con_matrix4, con_matrix4))
# StrD2-GPe
bg_con_mask.append(con_matrix4)
bg_connection.append(CustomLinear(weight_inh * con_matrix4, con_matrix4))
# Gpe-Gpi
con_matrix5 = torch.eye((num_action), dtype=torch.float)
bg_con_mask.append(con_matrix5)
bg_connection.append(CustomLinear(weight_inh * con_matrix5, con_matrix5))
# STN-Gpi
con_matrix6 = torch.ones((num_STN, num_action), dtype=torch.float)
bg_con_mask.append(con_matrix6)
bg_connection.append(CustomLinear(0.5 * weight_exc * con_matrix6, con_matrix6))
# STN-Gpe
bg_con_mask.append(con_matrix6)
bg_connection.append(CustomLinear(0.5 * weight_exc * con_matrix6, con_matrix6))
# Gpe-STN
con_matrix7 = torch.ones((num_action, num_STN), dtype=torch.float)
bg_con_mask.append(con_matrix7)
bg_connection.append(CustomLinear(0.5 * weight_inh * con_matrix7, con_matrix7))
self.num_subBG = 5
self.node_type = node_type
if self.node_type == "hh":
self.node = [SimHHNode() for i in range(self.num_subBG)]
if self.node_type == "lif":
self.node = [IFNode() for i in range(self.num_subBG)]
self.connection = bg_connection
self.mask = bg_con_mask
self.learning_rule = []
trace_stdp = 0.99
self.learning_rule.append(STDP(self.node[0], self.connection[0], trace_stdp)) # DLPFC-StrD1
self.learning_rule.append(STDP(self.node[1], self.connection[1], trace_stdp)) # DLPFC-StrD2
self.learning_rule.append(MutliInputSTDP(self.node[2], [self.connection[2], self.connection[8]])) # DLPFC-STN
self.learning_rule.append(MutliInputSTDP(self.node[3], [self.connection[4], self.connection[7]])) # StrD2-GPe STN-Gpe
self.learning_rule.append(MutliInputSTDP(self.node[4], [self.connection[3], self.connection[5], self.connection[6]])) # StrD1-GPi Gpe-Gpi STN-Gpi
self.out_StrD1 = torch.zeros((self.connection[0].weight.shape[1]), dtype=torch.float)
self.out_StrD2 = torch.zeros((self.connection[1].weight.shape[1]), dtype=torch.float)
self.out_STN = torch.zeros((self.connection[2].weight.shape[1]), dtype=torch.float)
self.out_Gpi = torch.zeros((self.connection[3].weight.shape[1]), dtype=torch.float)
self.out_Gpe = torch.zeros((self.connection[4].weight.shape[1]), dtype=torch.float)
def forward(self, input):
"""
计算由当前输入基底节网络的输出
:param input: 输入电流
:return: 输出脉冲
"""
self.out_StrD1, dw_strd1 = self.learning_rule[0](input)
self.out_StrD2, dw_strd2 = self.learning_rule[1](input)
self.out_STN, dw_stn = self.learning_rule[2](input, self.out_Gpe)
self.out_Gpe, dw_gpe = self.learning_rule[3](self.out_StrD2, self.out_STN)
self.out_Gpi, dw_gpi = self.learning_rule[4](self.out_StrD1, self.out_Gpe, self.out_STN)
return self.out_Gpi
def UpdateWeight(self, i, dw):
"""
更新基底节内第i组连接的权重 根据传入的dw值
:param i: 要更新的连接的索引
:param dw: 更新的量
:return: None
"""
self.connection[i].update(dw)
self.connection[i].weight.data = F.normalize(self.connection[i].weight.data.float(), p=1, dim=1)
def reset(self):
"""
reset神经元或学习法则的中间量
:return: None
"""
for i in range(self.num_subMB):
self.node[i].n_reset()
for i in range(len(self.learning_rule)):
self.learning_rule[i].reset()
def getweight(self):
"""
获取基底节网络的连接(包括权值等)
:return: 基底节网络的连接
"""
return self.connection
def getmask(self):
"""
获取基底节网络的连接(仅连接矩阵)
:return: 基底节网络的连接矩阵
"""
return self.mask
if __name__ == "__main__":
BG = basalganglia(4, 2, 0.2, -4)
con = BG.getweight()
print(con)
================================================
FILE: braincog/base/brainarea/dACC.py
================================================
import torch
import matplotlib.pyplot as plt
import numpy as np
np.set_printoptions(threshold=np.inf)
from utils.one_hot import *
import os
import time
import sys
from tqdm import tqdm
from braincog.base.encoder.population_coding import *
from braincog.model_zoo.base_module import BaseLinearModule, BaseModule
from braincog.base.learningrule.STDP import *
import sys
sys.path.append("..")
class dACC(BaseModule):
"""
SNNLinear
"""
def __init__(self,
step,
encode_type,
in_features:int,
out_features:int,
bias,
node,
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.bias = bias
self.in_features = in_features
self.out_features = out_features
self.node1 = node(threshold=0.5, tau=2.)
self.node_name1 = node
self.node2 = node(threshold=0.1, tau=2.)
self.node_name2 = node
self.fc = self._create_fc()
self.c = self._rest_c()
def _rest_c(self):
c = torch.rand((self.out_features, self.in_features)) # eligibility trace
return c
def _create_fc(self):
"""
the connection of the SNN linear
@return: nn.Linear
"""
fc = nn.Linear(in_features=self.in_features,
out_features=self.out_features, bias=self.bias)
return fc
def update_c(self, c, STDP, tau_c=0.2):
"""
update the trace of eligibility
@param c: a tensor to record eligibility
@param STDP: the results of STDP
@param tau_c: the parameter of trace decay
@return: a update tensor to record eligibility
Equation:
delta_c = (-(c / tau_c) + STDP) * dela_t
c = c + delta_c
reference:
"""
c = c + tau_c * STDP
return c
def forward(self, inputs, epoch):
"""
decision
@param inputs: state
@return: action
"""
output = []
stdp = STDP(self.node2, self.fc, decay=0.80)
self.c = self._rest_c()
# stdp.connection.weight.data = torch.rand((self.out_features, self.in_features))
for i in range(inputs.shape[0]):
for t in range(self.step):
l1_in = torch.tensor(inputs[i, :])
l1_out = self.node1(l1_in).unsqueeze(0) #pre : l1_out
l2_out, dw = stdp(l1_out) #dw -- STDP
self.c = self.update_c(self.c, dw[0])
output.append(torch.min(l2_out))
# output.append((l2_out.any() == 0).cpu().detach().numpy().tolist())
return output
# if __name__ == '__main__':
# np.random.seed(6)
# T = 5
# num_popneurons = 2
# safety = 2
# epoch = 50
# file_name = "/home/zhaozhuoya/braincog/examples/ToM/data/injury_value.txt"
# state = []
# with open(file_name) as f:
# data = []
# data_split = f.readlines() #
# for i in data_split:
# state.append(one_hot(int(i[0])))
#
# output = np.array(state)
# train_y = output
# test_y = output[79:82]#output[12].reshape(1,2)
#
# file_name = "/home/zhaozhuoya/braincog/examples/ToM/data/injury_memory.txt"
# state = []
# with open(file_name) as f:
# data_split = f.readlines()
# for i in data_split:
# data = []
# data.append(int(bool(abs(int(i[2]) - int(i[18]))))*10)
# data.append(int(bool(abs(int(i[5]) - int(i[21]))))*10)
# state.append(data)
# input = np.array(state)
# train_x = input
# test_x = input[79:82]
# dACC_net = dACC(step=T, encode_type='rate', bias=True,
# in_features=num_popneurons, out_features=safety,
# node=node.LIFNode)
# dACC_net.fc.weight.data = torch.rand((safety, num_popneurons))
# dACC_net.load_state_dict(torch.load('./checkpoint/dACC_net.pth')['dacc'])
# output = dACC_net(inputs=train_x, epoch=50)
# for i in range(len(output)):
# print(output[i], train_x[i])
# torch.save({'dacc': dACC_net.state_dict()}, os.path.join('./checkpoint', 'dACC_net.pth'))
# dACC_net.load_state_dict(torch.load('./checkpoint/dACC_net.pth')['dacc'])
# output = dACC_net(inputs=test_x, epoch=50)
# for i in range(len(test_x)):
#
# print(output[i],test_x[i])
================================================
FILE: braincog/base/connection/CustomLinear.py
================================================
import os
import sys
import numpy as np
import torch
from torch import nn
from torch import einsum
import torch.nn.functional as F
class CustomLinear(nn.Module):
"""
用户自定义连接 通常stdp的计算
"""
def __init__(self, weight, mask=None):
super().__init__()
self.weight = nn.Parameter(weight, requires_grad=True)
self.mask = mask
def forward(self, x: torch.Tensor):
"""
:param x:输入 x.shape = [N ]
"""
#
# ret.shape = [C]
return x.matmul(self.weight)
def update(self, dw):
"""
:param dw:权重更新量
"""
with torch.no_grad():
if self.mask is not None:
dw *= self.mask
self.weight.data += dw
================================================
FILE: braincog/base/connection/__init__.py
================================================
from .CustomLinear import CustomLinear
from .layer import VotingLayer, WTALayer, NDropout, ThresholdDependentBatchNorm2d, LayerNorm, SMaxPool, LIPool
__all__ = [
'CustomLinear',
'VotingLayer', 'WTALayer', 'NDropout', 'ThresholdDependentBatchNorm2d', 'LayerNorm', 'SMaxPool', 'LIPool'
]
================================================
FILE: braincog/base/connection/layer.py
================================================
import warnings
import math
import numpy as np
import torch
from torch import nn
from torch import einsum
from torch.nn.modules.batchnorm import _BatchNorm
import torch.nn.functional as F
from torch.nn import Parameter
from einops import rearrange
class VotingLayer(nn.Module):
"""
用于SNNs的输出层, 几个神经元投票选出最终的类
:param voter_num: 投票的神经元的数量, 例如 ``voter_num = 10``, 则表明会对这10个神经元取平均
"""
def __init__(self, voter_num: int):
super().__init__()
self.voting = nn.AvgPool1d(voter_num, voter_num)
def forward(self, x: torch.Tensor):
# x.shape = [N, voter_num * C]
# ret.shape = [N, C]
return self.voting(x.unsqueeze(1)).squeeze(1)
class WTALayer(nn.Module):
"""
winner take all用于SNNs的每层后,将随机选取一个或者多个输出
:param k: X选取的输出数目 k默认等于1
"""
def __init__(self, k=1):
super().__init__()
self.k = k
def forward(self, x: torch.Tensor):
# x.shape = [N, C,W,H]
# ret.shape = [N, C,W,H]
pos = x * torch.rand(x.shape, device=x.device)
if self.k > 1:
x = x * (pos >= pos.topk(self.k, dim=1)[0][:, -1:]).float()
else:
x = x * (pos >= pos.max(1, True)[0]).float()
return x
class NDropout(nn.Module):
"""
与Drop功能相同, 但是会保证同一个样本不同时刻的mask相同.
"""
def __init__(self, p):
super(NDropout, self).__init__()
self.p = p
self.mask = None
def n_reset(self):
"""
重置, 能够生成新的mask
:return:
"""
self.mask = None
def create_mask(self, x):
"""
生成新的mask
:param x: 输入Tensor, 生成与之形状相同的mask
:return:
"""
self.mask = F.dropout(torch.ones_like(x.data), self.p, training=True)
def forward(self, x):
if self.training:
if self.mask is None:
self.create_mask(x)
return self.mask * x
else:
return x
class WSConv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True, gain=True):
super(WSConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
if gain:
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
else:
self.gain = 1.
def forward(self, x):
weight = self.weight
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
keepdim=True).mean(dim=3, keepdim=True)
weight = weight - weight_mean
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
weight = self.gain * weight / std.expand_as(weight)
return F.conv2d(x, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
class ThresholdDependentBatchNorm2d(_BatchNorm):
"""
tdBN
https://ojs.aaai.org/index.php/AAAI/article/view/17320
"""
def __init__(self, num_features, alpha: float, threshold: float = .5, layer_by_layer: bool = True, affine: bool = True,**kwargs):
self.alpha = alpha
self.threshold = threshold
super().__init__(num_features=num_features, affine=affine)
assert layer_by_layer, \
'tdBN may works in step-by-step mode, which will not take temporal dimension into batch norm'
assert self.affine, 'ThresholdDependentBatchNorm needs to set `affine = True`!'
torch.nn.init.constant_(self.weight, alpha * threshold)
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError("expected 4D input (got {}D input)".format(input.dim()))
def forward(self, input):
# input = rearrange(input, '(t b) c w h -> b (t c) w h', t=self.step)
output = super().forward(input)
return output
# return rearrange(output, 'b (t c) w h -> (t b) c w h', t=self.step)
class TEBN(nn.Module):
def __init__(self, num_features,step, eps=1e-5, momentum=0.1,**kwargs):
super(TEBN, self).__init__()
self.bn = nn.BatchNorm3d(num_features)
self.p = nn.Parameter(torch.ones(4, 1, 1, 1, 1))
self.step=step
def forward(self, input):
#y = input.transpose(1, 2).contiguous() # N T C H W , N C T H W
y = rearrange(input,"(t b) c w h -> t c b w h",t=self.step)
y = self.bn(y)
# y = y.contiguous().transpose(1, 2)
# y = y.transpose(0, 1).contiguous() # NTCHW TNCHW
y = rearrange(y,"t c b w h -> t b c w h")
y = y * self.p
#y = y.contiguous().transpose(0, 1) # TNCHW NTCHW
y = rearrange(y, "t b c w h -> (t b) c w h")
return y
class LayerNorm(nn.Module):
""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class SMaxPool(nn.Module):
"""用于转换方法的最大池化层的常规替换
选用具有最大脉冲发放率的神经元的脉冲通过,能够满足一般性最大池化层的需要
Reference:
https://arxiv.org/abs/1612.04052
"""
def __init__(self, child):
super(SMaxPool, self).__init__()
self.opration = child
self.sumspike = 0
def forward(self, x):
self.sumspike += x
single = self.opration(self.sumspike * 1000)
sum_plus_spike = self.opration(x + self.sumspike * 1000)
return sum_plus_spike - single
def reset(self):
self.sumspike = 0
class LIPool(nn.Module):
r"""用于转换方法的最大池化层的精准替换
LIPooling通过引入侧向抑制机制保证在转换后的SNN中输出的最大值与期望值相同。
Reference:
https://arxiv.org/abs/2204.13271
"""
def __init__(self, child=None):
super(LIPool, self).__init__()
if child is None:
raise NotImplementedError("child should be Pooling operation with torch.")
self.opration = child
self.sumspike = 0
def forward(self, x):
self.sumspike += x
out = self.opration(self.sumspike)
self.sumspike -= F.interpolate(out, scale_factor=2, mode='nearest')
return out
def reset(self):
self.sumspike = 0
class CustomLinear(nn.Module):
def __init__(self, in_channels, out_channels, bias=True):
super(CustomLinear, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
# self.weight = Parameter(torch.tensor([
# [1., .5, .25, .125],
# [0., 1., .5, .25],
# [0., 0., 1., .5],
# [0., 0., 0., 1.]
# ]), requires_grad=True)
self.weight = Parameter(torch.diag(torch.ones(self.in_channels)), requires_grad=True)
# self.weight = Parameter(torch.randn(self.in_channels, self.in_channels))
mask = torch.tril(torch.ones(self.in_channels, self.in_channels), diagonal=0)
self.register_buffer('mask', mask)
if bias:
self.bias = Parameter(torch.zeros(out_channels), requires_grad=True)
else:
self.register_parameter('bias', None)
def forward(self, inputs):
weight = self.mask * self.weight
return F.linear(inputs, weight, self.bias)
================================================
FILE: braincog/base/conversion/__init__.py
================================================
from .convertor import HookScale, Hookoutput, Scale, Convertor, SNode
from .merge import mergeConvBN, merge
__all__ = [
'Hookoutput', 'HookScale', 'Scale', 'Convertor', 'SNode',
'merge', 'mergeConvBN'
]
================================================
FILE: braincog/base/conversion/convertor.py
================================================
import torch
import torch.nn as nn
from braincog.base.connection.layer import SMaxPool, LIPool
from .merge import mergeConvBN
from .spicalib import SpiCalib
import types
class HookScale(nn.Module):
""" 在每个ReLU层后记录该层的百分位最大值
For channelnorm: 获取最大值时使用了torch.quantile
For layernorm: 使用sort,然后手动取百分比,因为quantile在计算单个通道时有上限,batch较大时易出错
"""
def __init__(self,
p: float = 0.9995,
channelnorm: bool = False,
gamma: float = 0.999,
):
super().__init__()
if channelnorm:
self.register_buffer('scale', torch.tensor(0.0))
else:
self.register_buffer('scale', torch.tensor(0.0))
self.p = p
self.channelnorm = channelnorm
self.gamma = gamma
def forward(self, x):
x = torch.where(x.detach() < self.gamma, x.detach(),
torch.tensor(self.gamma, dtype=x.dtype, device=x.device))
if len(x.shape) == 4 and self.channelnorm:
num_channel = x.shape[1]
tmp = torch.quantile(x.permute(1, 0, 2, 3).reshape(num_channel, -1), self.p, dim=1,
interpolation='lower') + 1e-10
self.scale = torch.max(tmp, self.scale)
else:
sort, _ = torch.sort(x.view(-1))
self.scale = torch.max(sort[int(sort.shape[0] * self.p) - 1], self.scale)
return x
class Hookoutput(nn.Module):
"""
在伪转换中为ReLU和ClipQuan提供包装,用于监控其输出
"""
def __init__(self, module):
super(Hookoutput, self).__init__()
self.activation = 0.
self.operation = module
def forward(self, x):
output = self.operation(x)
self.activation = output.detach()
return output
class Scale(nn.Module):
"""
对前向过程的值进行缩放
"""
def __init__(self, scale: float = 1.0):
super().__init__()
self.register_buffer('scale', scale)
def forward(self, x):
if len(self.scale.shape) == 1:
return self.scale.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x
else:
return self.scale * x
def reset(self):
"""
转换的网络来自ANN,需要将新附加上的脉冲module进行reset
判断module名称并调用各自节点的reset方法
"""
children = list(self.named_children())
for i, (name, child) in enumerate(children):
if isinstance(child, (SNode, LIPool, SMaxPool)):
child.reset()
else:
reset(child)
class Convertor(nn.Module):
"""ANN2SNN转换器
用于转换完整的pytorch模型,使用dataloader中部分数据进行最大值计算,通过p控制获取第p百分比最大值
channlenorm: https://arxiv.org/abs/1903.06530
channelnorm可以对每个通道获取最大值并进行权重归一化
gamma: https://arxiv.org/abs/2204.13271
gamma可以控制burst spikes的脉冲数,burst spike可以提高神经元的脉冲发放能力,减小信息残留
lipool: https://arxiv.org/abs/2204.13271
lipool用于使用侧向抑制机制进行最大池化,LIPooling能够对SNN中的最大池化进行有效的转换
soft_mode: https://arxiv.org/abs/1612.04052
soft_mode被称为软重置,可以减小重置过程神经元的信息损失,有效提高转换的性能
merge用于是否对网络中相邻的卷积和BN层进行融合
batch_norm控制对dataloader的数据集的用量
"""
def __init__(self,
dataloader,
device=None,
p=0.9995,
channelnorm=False,
lipool=True,
gamma=1,
soft_mode=True,
merge=True,
batch_num=1,
spicalib=0
):
super(Convertor, self).__init__()
self.dataloader = dataloader
self.device = device
self.p = p
self.channelnorm = channelnorm
self.lipool = lipool
self.gamma = gamma
self.soft_mode = soft_mode
self.merge = merge
self.batch_num = batch_num
self.spicalib = spicalib
def forward(self, model):
model.eval()
model = Convertor.register_hook(model, self.p, self.channelnorm, self.gamma)
model = Convertor.get_percentile(model, self.dataloader, self.device, batch_num=self.batch_num)
model = mergeConvBN(model) if self.merge else model
model = Convertor.replace_for_spike(model, self.lipool, self.soft_mode, self.gamma, self.spicalib)
model.reset = types.MethodType(reset, model)
return model
@staticmethod
def register_hook(model, p=0.99, channelnorm=False, gamma=0.999):
""" Reference: https://github.com/fangwei123456/spikingjelly
将网络的每一层后注册一个HookScale类
该方法在仿真上等效于与对权重进行归一化操作,且易扩展到任意结构的网络中
"""
children = list(model.named_children())
for _, (name, child) in enumerate(children):
if isinstance(child, nn.ReLU):
model._modules[name] = nn.Sequential(nn.ReLU(), HookScale(p, channelnorm, gamma))
else:
Convertor.register_hook(child, p, channelnorm, gamma)
return model
@staticmethod
def get_percentile(model, dataloader, device, batch_num=1):
"""
该函数需与具有HookScale层的网络配合使用
"""
for idx, (data, _) in enumerate(dataloader):
data = data.to(device)
if idx >= batch_num:
break
model(data)
return model
@staticmethod
def replace_for_spike(model, lipool=True, soft_mode=True, gamma=1, spicalib=0):
"""
该函数用于将定义好的ANN模型转换为SNN模型
ReLU单元将被替换为脉冲神经元,
如果模型中使用了最大池化,lipool参数将定义使用常规模型还是LIPooling方法
"""
children = list(model.named_children())
for _, (name, child) in enumerate(children):
if isinstance(child, nn.Sequential) and len(child) == 2 and isinstance(child[0], nn.ReLU) and isinstance(child[1], HookScale):
model._modules[name] = nn.Sequential(
Scale(1.0 / child[1].scale),
SNode(soft_mode, gamma),
SpiCalib(spicalib),
Scale(child[1].scale)
)
if isinstance(child, nn.MaxPool2d):
model._modules[name] = LIPool(child) if lipool else SMaxPool(child)
else:
Convertor.replace_for_spike(child, lipool, soft_mode, gamma)
return model
class SNode(nn.Module):
"""
用于转换后的SNN的神经元模型
IF神经元模型由gamma=1确定,当gamma为其他大于1的值时,即为使用burst神经元模型
soft_mode用于定义神经元的重置方法,soft重置能够极大地减少神经元在重置过程的信息损失
"""
def __init__(self, soft_mode=False, gamma=5):
super(SNode, self).__init__()
self.threshold = 1.0
self.soft_mode = soft_mode
self.gamma = gamma
self.mem = 0
self.spike = 0
def forward(self, x):
self.mem = self.mem + x
self.spike = (self.mem / self.threshold).floor().clamp(min=0, max=self.gamma)
self.soft_reset() if self.soft_mode else self.hard_reset
out = self.spike
return out
def hard_reset(self):
"""
硬重置后神经元的膜电势被重置为0
"""
self.mem = self.mem * (1 - self.spike.detach())
def soft_reset(self):
"""
软重置后神经元的膜电势为神经元当前膜电势减去阈值
"""
self.mem = self.mem - self.threshold * self.spike.detach()
def reset(self):
self.mem = 0
self.spike = 0
================================================
FILE: braincog/base/conversion/merge.py
================================================
import torch
import torch.nn as nn
def mergeConvBN(m):
"""
合并网络模块中的卷积与BN层
"""
children = list(m.named_children())
c, cn = None, None
for i, (name, child) in enumerate(children):
if isinstance(child, nn.BatchNorm2d):
bc = merge(c, child)
m._modules[cn] = bc
m._modules[name] = torch.nn.Identity()
c = None
elif isinstance(child, nn.Conv2d):
c = child
cn = name
else:
mergeConvBN(child)
return m
def merge(conv, bn):
"""
conv: 卷积层实例
bn: BN层实例
"""
w = conv.weight
mean, var_sqrt, beta, gamma = bn.running_mean, torch.sqrt(bn.running_var + bn.eps), bn.weight, bn.bias
b = conv.bias if conv.bias is not None else mean.new_zeros(mean.shape)
w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1])
b = (b - mean) / var_sqrt * beta + gamma
fused_conv = nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, bias=True)
fused_conv.weight = nn.Parameter(w)
fused_conv.bias = nn.Parameter(b)
return fused_conv
================================================
FILE: braincog/base/conversion/spicalib.py
================================================
import torch
import torch.nn as nn
class SpiCalib(nn.Module):
def __init__(self, allowance):
super(SpiCalib, self).__init__()
self.allowance = allowance
self.sumspike = 0
self.t = 0
def forward(self, x):
if self.allowance == 0:
return x
if self.t == 0:
self.last_spike = torch.zeros_like(x)
self.avg_time = torch.zeros_like(x)
self.num_spike = torch.zeros_like(x)
SPIKE_MASK = x > 0
self.num_spike[SPIKE_MASK] += 1
self.avg_time[SPIKE_MASK] = (self.t - self.last_spike + self.avg_time * (self.num_spike - 1))[SPIKE_MASK] / \
self.num_spike[SPIKE_MASK]
self.last_spike[SPIKE_MASK] = self.t
SIN_MASK = self.t - self.last_spike > self.avg_time + self.allowance
x[SIN_MASK] -= 1.0
self.sumspike += x
x[self.sumspike <= -1] = 0
self.t += 1
return x
def reset(self):
self.sumspike = 0
self.t = 0
================================================
FILE: braincog/base/encoder/__init__.py
================================================
from .encoder import Encoder
from .population_coding import PEncoder
from.qs_coding import QSEncoder
__all__ = [
'Encoder',
'PEncoder',
'QSEncoder'
]
================================================
FILE: braincog/base/encoder/encoder.py
================================================
import torch
import torch.nn as nn
from einops import rearrange, repeat
from braincog.base.strategy.surrogate import GateGrad
class AutoEncoder(nn.Module):
def __init__(self, step, spike_output=True):
super(AutoEncoder, self).__init__()
self.step = step
self.spike_output = spike_output
# self.gru = nn.GRU(input_size=1, hidden_size=1, num_layers=3)
self.sigmoid = nn.Sigmoid()
self.fc1 = nn.Linear(1, self.step)
self.fc2 = nn.Linear(self.step, self.step)
self.relu = nn.ReLU()
#
self.act_fun = GateGrad()
def forward(self, x):
shape = x.shape
x = self.fc1(x.view(-1, 1))
x = self.relu(x)
x = self.fc2(x).transpose_(1, 0)
# x = x.view(1, -1, 1).repeat(self.step, 1, 1)
# x, _ = self.gru(x)
x = self.sigmoid(x)
if not self.spike_output:
return x.view(self.step, *shape)
else:
return self.act_fun(x).view(self.step, *shape)
# class TransEncoder(nn.Module):
# def __init__(self, step):
# super(TransEncoder, self).__init__()
# self.step = step
# self.trans = Transformer(dim=128, depth=3, heads=8, dim_head=, mlp_dim, dropout=0.)
class Encoder(nn.Module):
'''
将static image编码
:param step: 仿真步长
:param encode_type: 编码方式, 可选 ``direct``, ``ttfs``, ``rate``, ``phase``
:param temporal_flatten: 直接将temporal维度concat到channel维度
:param layer_by_layer: 是否使用计算每一层的所有的输出的方式进行推理
:param
(step, batch_size, )
'''
def __init__(self, step, encode_type='ttfs', *args, **kwargs):
super(Encoder, self).__init__()
self.step = step
self.fun = getattr(self, encode_type)
self.encode_type = encode_type
self.temporal_flatten = kwargs['temporal_flatten'] if 'temporal_flatten' in kwargs else False
self.layer_by_layer = kwargs['layer_by_layer'] if 'layer_by_layer' in kwargs else False
self.no_encode = kwargs['adaptive_node'] if 'adaptive_node' in kwargs else False
self.groups = kwargs['n_groups'] if 'n_groups' in kwargs else 1
# if encode_type == 'auto':
# self.fun = AutoEncoder(self.step, spike_output=False)
def forward(self, inputs, deletion_prob=None, shift_var=None):
if len(inputs.shape) == 5: # DVS data
outputs = inputs.permute(1, 0, 2, 3, 4).contiguous() # t, b, c, w, h
elif len(inputs.shape) == 3: # DAS data
outputs = inputs.permute(1, 0, 2).contiguous() # t, b, c
else:
if self.encode_type == 'auto':
if self.fun.device != inputs.device:
self.fun.to(inputs.device)
outputs = self.fun(inputs)
if deletion_prob:
outputs = self.delete(outputs, deletion_prob)
if shift_var:
outputs = self.shift(outputs, shift_var)
if self.temporal_flatten or self.no_encode:
outputs = rearrange(outputs, 't b c w h -> 1 b (t c) w h')
elif self.groups != 1:
outputs = rearrange(outputs, 't b c w h -> b (c t) w h')
elif self.layer_by_layer:
if len(inputs.shape) == 3:
outputs = rearrange(outputs, 't b c-> (t b) c')
else:
outputs = rearrange(outputs, 't b c w h -> (t b) c w h')
return outputs
@torch.no_grad()
def direct(self, inputs):
"""
直接编码
:param inputs: 形状(b, c, w, h)
:return: (t, b, c, w, h)
"""
outputs = repeat(inputs, 'b c w h -> t b c w h', t=self.step)
# outputs = inputs.unsqueeze(0).repeat(self.step, *([1] * len(shape)))
return outputs
def auto(self, inputs):
# TODO: Calc loss for firing-rate
shape = inputs.shape
outputs = self.fun(inputs)
print(outputs.shape)
return outputs
@torch.no_grad()
def ttfs(self, inputs):
"""
Time-to-First-Spike Encoder
:param inputs: static data
:return: Encoded data
"""
# print("ttfs")
shape = (self.step,) + inputs.shape
outputs = torch.zeros(shape, device=self.device)
for i in range(self.step):
mask = (inputs * self.step <= (self.step - i)
) & (inputs * self.step > (self.step - i - 1))
outputs[i, mask] = 1 / (i + 1)
return outputs
@torch.no_grad()
def rate(self, inputs):
"""
Rate Coding
:param inputs:
:return:
"""
shape = (self.step,) + inputs.shape
return (inputs > torch.rand(shape, device=inputs.device)).float()
@torch.no_grad()
def phase(self, inputs):
"""
Phase Coding
相位编码
:param inputs: static data
:return: encoded data
"""
shape = (self.step,) + inputs.shape
outputs = torch.zeros(shape, device=self.device)
inputs = (inputs * 256).long()
val = 1.
for i in range(self.step):
if i < 8:
mask = (inputs >> (8 - i - 1)) & 1 != 0
outputs[i, mask] = val
val /= 2.
else:
outputs[i] = outputs[i % 8]
return outputs
@torch.no_grad()
def delete(self, inputs, prob):
"""
在Coding 过程中随机删除脉冲
:param inputs: encoded data
:param prob: 删除脉冲的概率
:return: 随机删除脉冲之后的数据
"""
mask = (inputs >= 0) & (torch.randn_like(
inputs, device=self.device) < prob)
inputs[mask] = 0.
return inputs
@torch.no_grad()
def shift(self, inputs, var):
"""
对数据进行随机平移, 添加噪声
:param inputs: encoded data
:param var: 随机平移的方差
:return: shifted data
"""
# TODO: Real-time shift
outputs = torch.zeros_like(inputs)
for step in range(self.step):
shift = (var * torch.randn(1)).round_() + step
shift.clamp_(min=0, max=self.step - 1)
outputs[step] += inputs[int(shift)]
return outputs
================================================
FILE: braincog/base/encoder/population_coding.py
================================================
import torch
import torch.nn as nn
import torchvision.utils
class PEncoder(nn.Module):
"""
Population coding
:param step: time steps
:param encode_type: encoder type (str)
"""
def __init__(self, step, encode_type):
super().__init__()
self.step = step
self.fun = getattr(self, encode_type)
def forward(self, inputs, num_popneurons, *args, **kwargs):
outputs = self.fun(inputs, num_popneurons, *args, **kwargs)
return outputs
@torch.no_grad()
def population_time(self, inputs, m):
"""
one feature will be encoded into gauss_neurons
the center of i-th neuron is: gauss --
.. math::
\\mu u_i = I_min + (2i-3)/2(I_max-I_min)/(m -2)
the width of i-th neuron is : gauss --
.. math::
\\sigma sigma_i = \\frac{1}{1.5}\\frac{(I_max-I_min)}{m - 2}
:param inputs: (N_num, N_feature) array
:param m: the number of the gaussian neurons
i : the i_th gauss_neuron
1.5: experience value
popneurons_spike_t: gauss -- function
I_min = min(inputs)
I_max = max(inputs)
:return: (step, num_gauss_neuron)
"""
# m = self.step
I_min, I_max = torch.min(inputs), torch.max(inputs)
mu = [i for i in range(0, m)]
mu = torch.ones((1, m)) * I_min + ((2 * torch.tensor(mu) - 3) / 2) * ((I_max-I_min) / (m -2))
sigma = (1 / 1.5) * ((I_max-I_min) / (m -2))
# shape = (self.step,) + inputs.shape
shape = (self.step,m)
popneurons_spike_t = torch.zeros(((m,) + inputs.shape))
for i in range(m):
popneurons_spike_t[i, :] = torch.exp(-(inputs - mu[0, i]) ** 2 / (2 * sigma * sigma))
spike_time = (self.step * popneurons_spike_t).type(torch.int)
spikes = torch.zeros(shape)
for spike_time_k in range(self.step):
if torch.where(spike_time == spike_time_k)[1].numel() != 0:
spikes[spike_time_k][torch.where(spike_time == spike_time_k)[0]] = 1
return spikes
@torch.no_grad()
def population_voltage(self, inputs, m, VTH):
'''
The more similar the input is to the mean,
the more sensitive the neuron corresponding to the mean is to the input.
You can change the maen.
:param inputs: (N_num, N_feature) array
:param m : the number of the gaussian neurons
:param VTH : threshold voltage
i : the i_th gauss_neuron
one feature will be encoded into gauss_neurons
the center of i-th neuron is: gauss -- \mu u_i = I_min + (2i-3)/2(I_max-I_min)/(m -2)
the width of i-th neuron is : gauss -- \sigma sigma_i = 1/1.5(I_max-I_min)/(m -2) 1.5: experience value
popneuron_v: gauss -- function
I_min = min(inputs)
I_max = max(inputs)
:return: (step, num_gauss_neuron, dim_inputs)
'''
ENCODER_REGULAR_VTH = VTH
I_min, I_max = torch.min(inputs), torch.max(inputs)
mu = [i for i in range(0, m)]
mu = torch.ones((1, m)) * I_min + ((2 * torch.tensor(mu) - 3) / 2) * ((I_max-I_min) / (m -2))
sigma = (1 / 1.5) * ((I_max-I_min) / (m -2))
popneuron_v = torch.zeros(((m,) + inputs.shape))
delta_v = torch.zeros(((m,) + inputs.shape))
for i in range(m):
delta_v[i] = torch.exp(-(inputs - mu[0, i]) ** 2 / (2 * sigma * sigma))
spikes = torch.zeros((self.step,) + ((m,) + inputs.shape))
for spike_time_k in range(self.step):
popneuron_v = popneuron_v + delta_v
spikes[spike_time_k][torch.where(popneuron_v.ge(ENCODER_REGULAR_VTH))] = 1
popneuron_v = popneuron_v - spikes[spike_time_k] * ENCODER_REGULAR_VTH
popneuron_rate = torch.sum(spikes, dim=0)/self.step
return spikes, popneuron_rate
## test
# if __name__ == '__main__':
# a = (torch.rand((2,4))*10).type(torch.int)
# print(a)
# pencoder = PEncoder(10, 'population_time')
# spikes=pencoder(inputs=a, num_popneurons=3)
# print(spikes, spikes.shape)
# pencoder = PEncoder(10, 'population_voltage')
# spikes, popneuron_rate = pencoder(inputs=a, num_popneurons=5, VTH=0.99)
# print(spikes, spikes.shape)
================================================
FILE: braincog/base/encoder/qs_coding.py
================================================
from signal import signal
from subprocess import call
import numpy as np
import random
import copy
class QSEncoder:
"""
QS Encoding.
:param lambda_max: 最大发放率
:param steps: 脉冲发放周期长度 T
:param sig_len: 脉冲发放窗口
:param shift: 是否反转背景
:param noise: 是否增加噪声
:param noise_rate: 噪声比例
:param eps: 防止溢出参数
"""
def __init__(self,
lambda_max,
steps,
sig_len,
shift=False,
noise=None,
noise_rate=None,
eps=1e-6
) -> None:
self._lambda_max = lambda_max
self._steps = steps
self._sig_len = sig_len
self._shift = shift
self._noise = noise
self._noise_rate = noise_rate
self._eps = eps
def __call__(self, image, image_delta, image_ori, image_ori_delta):
"""
将图片转换为脉冲。
:param image: 背景反转图片
:param image_delta: 扰动图片,用于计算相位
:param image_ori: 原始图片
:param image_ori_delta: 原始扰动图片
"""
if self._noise:
signals = self.noise_trans(image, image_ori, image_ori_delta)
elif self._shift:
signals = self.shift_trans(image, image_delta, image_ori, image_ori_delta)
else:
signals = np.zeros((self.steps, image.shape[0]))
signal_possion = np.random.poisson(image, (self._sig_len, image.shape[0]))
signals[:self._sig_len] = signal_possion[:]
return signal.T
def shift_trans(self, image, image_delta, image_ori, image_ori_delta):
"""
背景翻转图片转脉冲序列。
:param image: 背景反转图片
:param image_delta: 扰动图片,用于计算相位
:param image_ori: 原始图片
:param image_ori_delta: 原始扰动图片
"""
signal = np.zeros((self._steps, image.shape[0]))
assert image_ori is not None
assert self.noise is False
assert image_delta is not None
assert image_ori_delta is not None
image_ori_reverse = self._lambda_max - image_ori
image_ori_delta_reverse = self._lambda_max - image_ori_delta
zeta = image / (image_ori**2 + image_ori_reverse**2) ** 0.5
zeta_delta = image_delta / (image_ori_delta**2 + image_ori_delta_reverse**2)**0.5
idx_left = zeta < zeta_delta
phi = np.arctan(image_ori / (image_ori_reverse + self._eps))
zeta = np.clip(zeta, -1, 1)
zeta = np.arcsin(zeta)
theta1 = zeta - phi
theta2 = np.pi - zeta - phi
theta = np.zeros(theta1.shape)
theta[idx_left] = theta1[idx_left]
theta[~idx_left] = theta2[~idx_left]
theta = np.mean(theta)
cos_theta = np.cos(theta)
sin_theta = np.sin(theta)
spike_rate = np.abs((self._lambda_max * sin_theta - image) / (sin_theta - cos_theta + self._eps))
signal_possion = np.random.poisson(spike_rate, (self._sig_len, spike_rate.shape[0]))
shift_step = np.rint(np.clip(2 * theta / np.pi, a_min=0, a_max=1.0) * (self._steps - self._sig_len))
shift_step = shift_step.astype(np.int)
signal[shift_step:shift_step + self._sig_len] = signal_possion[:]
def noise_trans(self, image, image_ori, image_ori_delta):
"""
噪声图片转脉冲序列
:param image: 背景反转图片
:param image_ori: 原始图片
:param image_ori_delta: 原始扰动图片
"""
signal = np.zeros((self._steps, image.shape[0]))
assert image_ori is not None
assert self._shift is False
assert self._noise_rate is not None
image_ori_delta = copy.deepcopy(image_ori)
idx = image_ori_delta < (self._lambda_max - 0.001)
image_ori_delta[idx] += 0.001
image_ori_reverse = self._lambda_max - image_ori
image_ori_delta_reverse = self._lambda_max - image_ori_delta
image_noise, image_delta_noise = self.reverse_pixels(image_ori, image_ori_delta, noise_rate=self._noise_rate)
zeta = image_noise / (image_ori**2 + image_ori_reverse**2)**0.5
zeta_delta = image_delta_noise / (image_ori_delta**2 + image_ori_delta_reverse**2)**0.5
idx_left = zeta < zeta_delta
phi = np.arctan(image_ori / (image_ori_reverse + self._eps))
zeta = np.clip(zeta, -1, 1)
zeta = np.arcsin(zeta)
theta1 = zeta - phi
theta2 = np.pi - zeta - phi
theta = np.zeros(theta1.shape)
theta[idx_left] = theta1[idx_left]
theta[~idx_left] = theta2[~idx_left]
theta = np.mean(theta)
cos_theta = np.cos(theta)
sin_theta = np.sin(theta)
spike_rate = np.abs((self._lambda_max * sin_theta - image_noise) / (sin_theta - cos_theta + self._eps))
signal_possion = np.random.poisson(spike_rate, (self._sig_len, spike_rate.shape[0]))
shift_step = np.rint(np.clip(2 * theta / np.pi, a_min=0, a_max=1.0) * (self._steps - self._sig_len))
shift_step = shift_step.astype(np.int)
signal[shift_step:shift_step + self._sig_len] = signal_possion[:]
return signal
def reverse_pixels(self, image, image_delta, noise_rate, flip_bits=None):
"""
反转图片像素
"""
if flip_bits is None:
N = int(noise_rate * image.shape[0])
flip_bits = random.sample(range(image.shape[0]), N)
img = copy.copy(image)
img_delta = copy.copy(image_delta)
img[flip_bits] = self._lambda_max - img[flip_bits]
img_delta[flip_bits] = self._lambda_max - img_delta[flip_bits]
return img, img_delta
================================================
FILE: braincog/base/learningrule/BCM.py
================================================
import numpy as np
import torch
import os
import sys
from torch import nn
from torch.nn import Parameter
import abc
import math
from abc import ABC
import numpy as np
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from braincog.base.node import *
class BCM(nn.Module):
"""
BCM learning rule 多组神经元输入到该节点
"""
def __init__(self, node, connection, cfunc=None, weightdecay=0.99, tau=10):
"""
:param node:node神经元类型实例如IFNode LIFNode
:param connection:连接 类的实例列表 里面只能有一个操作
:param cfunc:BCM的频率函数 默认y(y-th)
:param weightdecay:权重衰减系数 默认0.99
:param tau: 频率更新时间常数
"""
super().__init__()
self.node = node
self.connection = connection
if not isinstance(connection, list):
self.connection = [self.connection]
self.weightdecay = weightdecay
self.tau = tau
self.threshold = 0
def forward(self, *x):
"""
计算前向传播过程
:return:s是脉冲 dw更新量
"""
i = 0
x = [xi.clone().detach() for xi in x]
for xi, coni in zip(x, self.connection):
i += coni(xi)
with torch.no_grad():
s = self.node(i)
i.data += self.cfunc(s) - i.data
dw = torch.autograd.grad(outputs=i, inputs=[i.weight for i in self.connection], grad_outputs=i)
for dwi, i in zip(dw, self.connection):
dwi -= (1 - self.weightdecay) * i.weight
return s, dw
def cfunc(self, s):
self.threshold = ((self.tau - 1) * self.threshold + s) / self.tau
return (s * (s - self.threshold)).detach()
def reset(self):
"""
重置
"""
self.threshold = 0
pass
================================================
FILE: braincog/base/learningrule/Hebb.py
================================================
import numpy as np
import torch
import os
import sys
from torch import nn
from torch.nn import Parameter
import abc
import math
from abc import ABC
import numpy as np
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from braincog.base.node.node import *
class Hebb(nn.Module):
"""
Hebb learning rule 多组神经元输入到该节点
"""
def __init__(self, node, connection):
"""
:param node:node神经元类型实例如IFNode LIFNode
:param connection:连接 类的实例列表 里面只能有一个操作
"""
super().__init__()
self.node = node
self.connection = connection
self.trace = [None for i in self.connection]
def forward(self, *x):
"""
计算前向传播过程
:return:s是脉冲 dw更新量
"""
i = 0
x = [xi.clone().detach() for xi in x]
for xi, coni in zip(x, self.connection):
i += coni(xi)
with torch.no_grad():
s = self.node(i)
i.data += s - i.data
dw = torch.autograd.grad(outputs=i, inputs=[i.weight for i in self.connection], grad_outputs=i)
return s, dw
def reset(self):
"""
重置
"""
self.trace = [None for i in self.connection]
if __name__ == "__main__":
node = IFNode()
linear1 = nn.Linear(2, 2, bias=False)
linear2 = nn.Linear(2, 2, bias=False)
linear1.weight.data = torch.tensor([[1., 1], [1, 1]], requires_grad=True)
linear2.weight.data = torch.tensor([[1., 1], [1, 1]], requires_grad=True)
hebb = Hebb(node, [linear1, linear2])
for i in range(10):
x, dw1 = hebb(torch.tensor([1.1, 1.1]), torch.tensor([1.1, 1.1]))
print(dw1)
================================================
FILE: braincog/base/learningrule/RSTDP.py
================================================
import numpy as np
import torch
import os
import sys
from torch import nn
from torch.nn import Parameter
import abc
import math
from abc import ABC
import numpy as np
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from braincog.base.node import *
class RSTDP(nn.Module):
"""
RSTDP算法
"""
def __init__(self, node, connection, decay=0.99, reward_decay=0.5):
"""
:param node:node神经元类型实例如IFNode LIFNode
:param connection:连接 类的实例列表 里面只能有一个操作
"""
super().__init__()
self.node = node
self.connection = connection
if not isinstance(connection, list):
self.connection = [self.connection]
self.trace = [None for i in self.connection]
self.decay = decay
self.reward_decay = reward_decay
self.stdp = STDP(self.node, self.node, self.decay)
def forward(self, *x, r):
"""
计算前向传播过程
:return:s是脉冲 dw更新量
"""
s, dw = self.stdp(x)
trace = self.cal_trace(r)
return s, dw * trace
def cal_trace(self, x):
"""
计算trace
"""
for i in range(len(x)):
if self.trace[i] is None:
self.trace[i] = Parameter(x[i].clone().detach(), requires_grad=False)
else:
self.trace[i] *= self.decay
self.trace[i] += x[i].detach()
return self.trace
def reset(self):
self.trace = [None for i in self.connection]
================================================
FILE: braincog/base/learningrule/STDP.py
================================================
import numpy as np
import torch
import os
import sys
from torch import nn
from torch.nn import Parameter
import abc
import math
from abc import ABC
import numpy as np
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from braincog.base.node.node import *
class STDP(nn.Module):
"""
STDP learning rule
"""
def __init__(self, node, connection, decay=0.99):
"""
:param node:node神经元类型实例如IFNode LIFNode
:param connection:连接 类的实例 里面只能有一个操作
"""
super().__init__()
self.node = node
self.connection = connection
self.trace = None
self.decay = decay
def forward(self, x):
"""
计算前向传播过程
:return:s是脉冲 dw更新量
"""
x = x.clone().detach()
i = self.connection(x)
with torch.no_grad():
s = self.node(i)
i.data += s - i.data
trace = self.cal_trace(x)
x.data += trace - x.data
dw = torch.autograd.grad(outputs=i, inputs=self.connection.weight, grad_outputs=i)
return s, dw
def cal_trace(self, x):
"""
计算trace
"""
if self.trace is None:
self.trace = Parameter(x.clone().detach(), requires_grad=False)
else:
self.trace *= self.decay
self.trace += x
return self.trace.detach()
def reset(self):
"""
重置
"""
self.trace = None
class MutliInputSTDP(nn.Module):
"""
STDP learning rule 多组神经元输入到该节点
"""
def __init__(self, node, connection, decay=0.99):
"""
:param node:node神经元类型实例如IFNode LIFNode
:param connection:连接 类的实例列表 里面只能有一个操作
"""
super().__init__()
self.node = node
self.connection = connection
self.trace = [None for i in self.connection]
self.decay = decay
def forward(self, *x):
"""
计算前向传播过程
:return:s是脉冲 dw更新量
"""
i = 0
x = [xi.clone().detach() for xi in x]
for xi, coni in zip(x, self.connection):
i += coni(xi)
with torch.no_grad():
s = self.node(i)
i.data += s - i.data
trace = self.cal_trace(x)
for xi, ti in zip(x, trace):
xi.data += ti - xi.data
dw = torch.autograd.grad(outputs=i, inputs=[i.weight for i in self.connection], grad_outputs=i)
return s, dw
def cal_trace(self, x):
"""
计算trace
"""
for i in range(len(x)):
if self.trace[i] is None:
self.trace[i] = Parameter(x[i].clone().detach(), requires_grad=False)
else:
self.trace[i] *= self.decay
self.trace[i] += x[i].detach()
return self.trace
def reset(self):
"""
重置
"""
self.trace = [None for i in self.connection]
class LTP(MutliInputSTDP):
"""
STDP learning rule 多组神经元输入到该节点
"""
pass
class LTD(nn.Module):
"""
STDP learning rule 多组神经元输入到该节点
"""
def __init__(self, node, connection, decay=0.99):
"""
:param node:node神经元类型实例如IFNode LIFNode
:param connection:连接 类的实例列表 里面只能有一个操作
"""
super().__init__()
self.node = node
self.connection = connection
self.trace = None
self.decay = decay
def forward(self, *x):
"""
计算前向传播过程
:return:s是脉冲 dw更新量
"""
i = 0
x = [xi.clone().detach() for xi in x]
for xi, coni in zip(x, self.connection):
i += coni(xi)
with torch.no_grad():
s = self.node(i)
trace = self.cal_trace(s)
i.data += trace - i.data
dw = torch.autograd.grad(outputs=i, inputs=[i.weight for i in self.connection], grad_outputs=i)
return s, dw
def cal_trace(self, x):
"""
计算trace
"""
if self.trace is None:
self.trace = Parameter(torch.zeros_like(x), requires_grad=False)
else:
self.trace *= self.decay
trace = self.trace.clone().detach()
self.trace += x
return trace
def reset(self):
"""
重置
"""
self.trace = None
class FullSTDP(nn.Module):
"""
STDP learning rule 多组神经元输入到该节点
"""
def __init__(self, node, connection, decay=0.99, decay2=0.99):
"""
:param node:node神经元类型实例如IFNode LIFNode
:param connection:连接 类的实例列表 里面只能有一个操作
"""
super().__init__()
self.node = node
self.connection = connection
self.tracein = [None for i in self.connection]
self.traceout = None
self.decay = decay
self.decay2 = decay2
def forward(self, *x):
"""
计算前向传播过程
:return:s是脉冲 dw更新量
"""
i = 0
x = [xi.clone().detach() for xi in x]
for xi, coni in zip(x, self.connection):
i += coni(xi)
with torch.no_grad():
s = self.node(i)
traceout = self.cal_traceout(s)
i.data += traceout - i.data
dw1 = torch.autograd.grad(outputs=i, inputs=[i.weight for i in self.connection], retain_graph=True,
grad_outputs=i)
with torch.no_grad():
i.data += s - i.data
tracein = self.cal_tracein(x)
for xi, ti in zip(x, tracein):
xi.data += ti - xi.data
dw2 = torch.autograd.grad(outputs=i, inputs=[i.weight for i in self.connection], grad_outputs=i)
return s, dw2, dw1
def cal_tracein(self, x):
"""
计算trace
"""
for i in range(len(x)):
if self.tracein[i] is None:
self.tracein[i] = Parameter(x[i].clone().detach(), requires_grad=False)
else:
self.tracein[i] *= self.decay
self.tracein[i] += x[i].detach()
return self.tracein
def cal_traceout(self, x):
"""
计算trace
"""
if self.traceout is None:
self.traceout = Parameter(torch.zeros_like(x), requires_grad=False)
else:
self.traceout *= self.decay2
trace = self.traceout.clone().detach()
self.traceout += x
return trace
def reset(self):
"""
重置
"""
self.traceout = [None for i in self.connection]
self.tracein = None
if __name__ == "__main__":
node = IFNode()
linear1 = nn.Linear(2, 2, bias=False)
linear2 = nn.Linear(2, 2, bias=False)
linear1.weight.data = torch.tensor([[1., 1], [1, 1]], requires_grad=True)
linear2.weight.data = torch.tensor([[1., 1], [1, 1]], requires_grad=True)
stdp = LTD(node, [linear1, linear2])
for i in range(10):
x, dw1 = stdp(torch.tensor([1.1, 1.1]), torch.tensor([1.1, 1.1]))
print(dw1)
================================================
FILE: braincog/base/learningrule/STP.py
================================================
import math
class short_time():
"""
计算短期突触可塑性的变量详见Tsodyks和Markram 1997
:param Syn:突出可塑性结构体
:param ISI:棘突间期
:param Nsp:突触前棘波
"""
def __init__(self, SizeHistOutput):
super().__init__()
self.SizeHistOutput = SizeHistOutput
def syndepr(self, Syn=None, ISI=None, Nsp=None):
"""
短期突触可塑性计算
"""
SizeHistOutput = self.SizeHistOutput
qu = Syn.uprev[Nsp] * math.exp(-ISI / Syn.tc_fac)
qR = math.exp(-ISI / Syn.tc_rec)
u = qu + Syn.use * (1.0 - qu)
R = Syn.Rprev[Nsp] * (1.0 - Syn.uprev[Nsp]) * qR + 1.0 - qR
Syn.uprev[(Nsp + 1) % SizeHistOutput] = u
Syn.Rprev[(Nsp + 1) % SizeHistOutput] = R
return R * u
def set_gsyn(self, np=None, dt=None, v=None, NoiseSyn=None):
"""
突触电流参数计算
"""
Isyn = 0
gsyn_AN = 0
gsyn_G = 0
for j in range(np.NumSynType):
syn = np.STList[j]
sgate = 1.0
if (syn.Mg_gate > 0.0):
sgate = syn.Mg_gate / (1.0 + syn.Mg_fac * math.exp(syn.Mg_slope * (syn.Mg_half - v[0])))
Isyn += sgate * (
np.gfOFFsyn[j] * math.exp(-dt / syn.tc_off) - np.gfONsyn[j] * math.exp(-dt / syn.tc_on)) * (
syn.Erev - v[0])
if (syn.Erev == 0.0):
gsyn_AN = gsyn_AN + sgate * (
np.gfOFFsyn[j] * math.exp(-dt / syn.tc_off) - np.gfONsyn[j] * math.exp(-dt / syn.tc_on))
else:
gsyn_G = gsyn_G + sgate * (
np.gfOFFsyn[j] * math.exp(-dt / syn.tc_off) - np.gfONsyn[j] * math.exp(-dt / syn.tc_on))
for j in range(NoiseSyn.NumSyn):
syn = NoiseSyn.Syn[j].STPtr
sgate = 1.0
if (syn.Mg_gate > 0.0):
sgate = syn.Mg_gate / (1.0 + syn.Mg_fac * math.exp(syn.Mg_slope * (syn.Mg_half - v)))
Isyn += sgate * (
np.gfOFFnoise[j] * math.exp(-dt / syn.tc_off) - np.gfONnoise[j] * math.exp(-dt / syn.tc_on)) * (
syn.Erev - v)
if (syn.Erev == 0.0):
gsyn_AN = gsyn_AN + sgate * (
np.gfOFFnoise[j] * math.exp(-dt / syn.tc_off) - np.gfONnoise[j] * math.exp(-dt / syn.tc_on))
else:
gsyn_G = gsyn_G + sgate * (
np.gfOFFnoise[j] * math.exp(-dt / syn.tc_off) - np.gfONnoise[j] * math.exp(-dt / syn.tc_on))
I_tot = Isyn + np.Iinj
return gsyn_AN, I_tot, gsyn_G
def IDderiv(self, np=None, v=None, dt=None, dv=None, NoiseSyn=None, flag_dv=None):
"""
定义模型的常微分方程计算单个神经元常微分方程
:param np:神经元参数
:param v:当前变量
:param dt:时间步长
"""
Isyn = 0
gsyn_G = 0
gsyn_AN = 0
for j in range(np.NumSynType):
syn = np.STList[j]
sgate = 1.0
if (syn.Mg_gate > 0.0):
sgate = syn.Mg_gate / (1.0 + syn.Mg_fac * math.exp(syn.Mg_slope * (syn.Mg_half - v[0])))
Isyn += sgate * (
np.gfOFFsyn[j] * math.exp(-dt / syn.tc_off) - np.gfONsyn[j] * math.exp(-dt / syn.tc_on)) * (
syn.Erev - v[0])
if (syn.Erev == 0.0):
gsyn_AN = gsyn_AN + sgate * (
np.gfOFFsyn[j] * math.exp(-dt / syn.tc_off) - np.gfONsyn[j] * math.exp(-dt / syn.tc_on))
else:
gsyn_G = gsyn_G + sgate * (
np.gfOFFsyn[j] * math.exp(-dt / syn.tc_off) - np.gfONsyn[j] * math.exp(-dt / syn.tc_on))
for j in range(NoiseSyn.NumSyn):
syn = NoiseSyn.Syn[j].STPtr
sgate = 1.0
if (syn.Mg_gate > 0.0):
sgate = syn.Mg_gate / (1.0 + syn.Mg_fac * math.exp(syn.Mg_slope * (syn.Mg_half - v[0])))
Isyn += sgate * (
np.gfOFFnoise[j] * math.exp(-dt / syn.tc_off) - np.gfONnoise[j] * math.exp(-dt / syn.tc_on)) * (
syn.Erev - v[0])
if (syn.Erev == 0.0):
gsyn_AN = gsyn_AN + sgate * (
np.gfOFFnoise[j] * math.exp(-dt / syn.tc_off) - np.gfONnoise[j] * math.exp(-dt / syn.tc_on))
else:
gsyn_G = gsyn_G + sgate * (
np.gfOFFnoise[j] * math.exp(-dt / syn.tc_off) - np.gfONnoise[j] * math.exp(-dt / syn.tc_on))
I_ex = np.gL * np.sf * math.exp((v[0] - np.Vth) / np.sf)
wV = np.Iinj + Isyn - np.gL * (v[0] - np.EL) + I_ex
D0 = (np.Cm / np.gL) * wV
if ((
np.Iinj + Isyn) >= np.I_ref and flag_dv == 0):
dv[0] = -(np.gL / np.Cm) * (v[0] - np.v_dep)
flag_regime_osc = 0
else:
dv[0] = (np.Iinj - np.gL * (v[0] - np.EL) - v[1] + I_ex + Isyn) / np.Cm
flag_regime_osc = 1
dD0 = np.Cm * (math.exp((v[0] - np.Vth) / np.sf) - 1)
if ((v[1] > wV - D0 / np.tcw) and (v[1] < wV + D0 / np.tcw) and v[0] <= np.Vth and (
np.Iinj + Isyn) < np.I_ref):
dv[1] = -(np.gL * (1 - math.exp((v[0] - np.Vth) / np.sf)) + dD0 / np.tcw) * dv[0]
else:
dv[1] = 0
I_tot = Isyn + np.Iinj
return wV, D0, gsyn_AN, gsyn_G, I_tot, dv
def update(self, np=None, dt=None, NoiseSyn=None, flag_dv=None):
"""
用二阶显式龙格-库塔法积分常微分方程
:param np:神经元参数
:param dt:时间步长
"""
nvar = 2
v = [0] * 2
dv1 = [0] * 2
dv2 = [0] * 2
for i in range(nvar):
v[i] = np.v[i]
wV, D0, gsyn_AN, gsyn_G, I_tot, dv1 = short_time(self.SizeHistOutput).IDderiv(np, v, 0.0, dv1, NoiseSyn, flag_dv)
for i in range(nvar):
v[i] += dt * dv1[i]
wV, D0, gsyn_AN, gsyn_G, I_tot, dv2 = short_time(self.SizeHistOutput).IDderiv(np, v, 0.0, dv2, NoiseSyn, flag_dv)
for i in range(nvar):
np.v[i] += dt / 2.0 * (dv1[i] + dv2[i])
np.dv[i] = dt / 2.0 * (dv1[i] + dv2[i])
if ((np.v[1] > wV - D0 / np.tcw) and (np.v[1] < wV + D0 / np.tcw) and np.v[0] <= np.Vth):
np.v[1] = wV - (D0 / np.tcw)
return np, gsyn_AN, gsyn_G, I_tot
================================================
FILE: braincog/base/learningrule/__init__.py
================================================
from .BCM import BCM
from .Hebb import Hebb
from .RSTDP import RSTDP
from .STDP import STDP, MutliInputSTDP, LTP, LTD, FullSTDP
from .STP import short_time
__all__ = [
'BCM',
"Hebb",
'RSTDP',
'STDP', 'MutliInputSTDP', 'LTP', 'LTD', 'FullSTDP',
'short_time'
]
================================================
FILE: braincog/base/node/__init__.py
================================================
from .node import *
================================================
FILE: braincog/base/node/node.py
================================================
# encoding: utf-8
# Author : Floyed
# Datetime : 2022/4/10 18:46
# User : Floyed
# Product : PyCharm
# Project : braincog
# File : node.py
# explain : 神经元节点类型
import abc
import math
from abc import ABC
import numpy as np
import random
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from einops import rearrange, repeat
from braincog.base.connection.layer import CustomLinear
from braincog.base.strategy.surrogate import *
class BaseNode(nn.Module, abc.ABC):
"""
神经元模型的基类
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param dt: 时间步长
:param step: 仿真步
:param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``
:param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``
:param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``
:param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``
:param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组
:param mem_detach: 是否将上一时刻的膜电位在计算图中截断
:param args: 其他的参数
:param kwargs: 其他的参数
"""
def __init__(self,
threshold=.5,
v_reset=0.,
dt=1.,
step=8,
requires_thres_grad=False,
sigmoid_thres=False,
requires_fp=False,
layer_by_layer=False,
n_groups=1,
*args,
**kwargs):
super(BaseNode, self).__init__()
self.threshold = Parameter(torch.tensor(threshold), requires_grad=requires_thres_grad)
self.sigmoid_thres = sigmoid_thres
self.mem = 0.
self.spike = 0.
self.dt = dt
self.feature_map = []
self.mem_collect = []
self.requires_fp = requires_fp
self.v_reset = v_reset
self.step = step
self.layer_by_layer = layer_by_layer
self.groups = n_groups
self.mem_detach = kwargs['mem_detach'] if 'mem_detach' in kwargs else False
self.requires_mem = kwargs['requires_mem'] if 'requires_mem' in kwargs else False
@abc.abstractmethod
def calc_spike(self):
"""
通过当前的mem计算是否发放脉冲,并reset
:return: None
"""
pass
def integral(self, inputs):
"""
计算由当前inputs对于膜电势的累积
:param inputs: 当前突触输入电流
:type inputs: torch.tensor
:return: None
"""
pass
def get_thres(self):
return self.threshold if not self.sigmoid_thres else self.threshold.sigmoid()
def rearrange2node(self, inputs):
if self.groups != 1:
if len(inputs.shape) == 4:
outputs = rearrange(inputs, 'b (c t) w h -> t b c w h', t=self.step)
elif len(inputs.shape) == 2:
outputs = rearrange(inputs, 'b (c t) -> t b c', t=self.step)
else:
raise NotImplementedError
elif self.layer_by_layer:
if len(inputs.shape) == 4:
outputs = rearrange(inputs, '(t b) c w h -> t b c w h', t=self.step)
elif len(inputs.shape) == 3:
outputs = rearrange(inputs, '(t b) n c -> t b n c', t=self.step)
elif len(inputs.shape) == 2:
outputs = rearrange(inputs, '(t b) c -> t b c', t=self.step)
else:
raise NotImplementedError
else:
outputs = inputs
return outputs
def rearrange2op(self, inputs):
if self.groups != 1:
if len(inputs.shape) == 5:
outputs = rearrange(inputs, 't b c w h -> b (c t) w h')
elif len(inputs.shape) == 3:
outputs = rearrange(inputs, ' t b c -> b (c t)')
else:
raise NotImplementedError
elif self.layer_by_layer:
if len(inputs.shape) == 5:
outputs = rearrange(inputs, 't b c w h -> (t b) c w h')
elif len(inputs.shape) == 4:
outputs = rearrange(inputs, ' t b n c -> (t b) n c')
elif len(inputs.shape) == 3:
outputs = rearrange(inputs, ' t b c -> (t b) c')
else:
raise NotImplementedError
else:
outputs = inputs
return outputs
def forward(self, inputs):
"""
torch.nn.Module 默认调用的函数,用于计算膜电位的输入和脉冲的输出
在```self.requires_fp is True``` 的情况下,可以使得```self.feature_map```用于记录trace
:param inputs: 当前输入的膜电位
:return: 输出的脉冲
"""
if hasattr(self, 'parallel') and self.parallel is True:
inputs = self.rearrange2node(inputs)
if self.mem_detach and hasattr(self.mem, 'detach'):
self.mem = self.mem.detach()
self.spike = self.spike.detach()
self.integral(inputs)
self.calc_spike()
if self.requires_fp is True:
self.feature_map.append(self.spike)
if self.requires_mem is True:
self.mem_collect.append(self.mem)
return self.rearrange2op(self.spike)
elif self.layer_by_layer or self.groups != 1:
inputs = self.rearrange2node(inputs)
outputs = []
for i in range(self.step):
if self.mem_detach and hasattr(self.mem, 'detach'):
self.mem = self.mem.detach()
self.spike = self.spike.detach()
self.integral(inputs[i])
self.calc_spike()
if self.requires_fp is True:
self.feature_map.append(self.spike)
if self.requires_mem is True:
self.mem_collect.append(self.mem)
outputs.append(self.spike)
outputs = torch.stack(outputs)
outputs = self.rearrange2op(outputs)
return outputs
else:
if self.mem_detach and hasattr(self.mem, 'detach'):
self.mem = self.mem.detach()
self.spike = self.spike.detach()
self.integral(inputs)
self.calc_spike()
if self.requires_fp is True:
self.feature_map.append(self.spike)
if self.requires_mem is True:
self.mem_collect.append(self.mem)
return self.spike
def n_reset(self):
"""
神经元重置,用于模型接受两个不相关输入之间,重置神经元所有的状态
:return: None
"""
self.mem = self.v_reset
self.spike = 0.
self.feature_map = []
self.mem_collect = []
def get_n_attr(self, attr):
if hasattr(self, attr):
return getattr(self, attr)
else:
return None
def set_n_warm_up(self, flag):
"""
一些训练策略会在初始的一些epoch,将神经元视作ANN的激活函数训练,此为设置是否使用该方法训练
:param flag: True:神经元变为激活函数, False:不变
:return: None
"""
self.warm_up = flag
def set_n_threshold(self, thresh):
"""
动态设置神经元的阈值
:param thresh: 阈值
:return:
"""
self.threshold = Parameter(torch.tensor(thresh, dtype=torch.float), requires_grad=False)
def set_n_tau(self, tau):
"""
动态设置神经元的衰减系数,用于带Leaky的神经元
:param tau: 衰减系数
:return:
"""
if hasattr(self, 'tau'):
self.tau = Parameter(torch.tensor(tau, dtype=torch.float), requires_grad=False)
else:
raise NotImplementedError
#============================================================================
# node的基类
class BaseMCNode(nn.Module, abc.ABC):
"""
多房室神经元模型的基类
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param comps: 神经元不同房室, 例如["apical", "basal", "soma"]
"""
def __init__(self,
threshold=1.0,
v_reset=0.,
comps=[]):
super().__init__()
self.threshold = Parameter(torch.tensor(threshold), requires_grad=False)
# self.decay = Parameter(torch.tensor(decay), requires_grad=False)
self.v_reset = v_reset
assert len(comps) != 0
self.mems = dict()
for c in comps:
self.mems[c] = None
self.spike = None
self.warm_up = False
@abc.abstractmethod
def calc_spike(self):
pass
@abc.abstractmethod
def integral(self, inputs):
pass
def forward(self, inputs: dict):
'''
Params:
inputs dict: Inputs for every compartments of neuron
'''
if self.warm_up:
return inputs
else:
self.integral(**inputs)
self.calc_spike()
return self.spike
def n_reset(self):
for c in self.mems.keys():
self.mems[c] = self.v_reset
self.spike = 0.0
def get_n_fire_rate(self):
if self.spike is None:
return 0.
return float((self.spike.detach() >= self.threshold).sum()) / float(np.product(self.spike.shape))
def set_n_warm_up(self, flag):
self.warm_up = flag
def set_n_threshold(self, thresh):
self.threshold = Parameter(torch.tensor(thresh, dtype=torch.float), requires_grad=False)
class ThreeCompNode(BaseMCNode):
"""
三房室神经元模型
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param tau: 胞体膜电位时间常数, 用于控制胞体膜电位衰减
:param tau_basal: 基底树突膜电位时间常数, 用于控制基地树突胞体膜电位衰减
:param tau_apical: 远端树突膜电位时间常数, 用于控制远端树突胞体膜电位衰减
:param comps: 神经元不同房室, 例如["apical", "basal", "soma"]
:param act_fun: 脉冲梯度代理函数
"""
def __init__(self,
threshold=1.0,
tau=2.0,
tau_basal=2.0,
tau_apical=2.0,
v_reset=0.0,
comps=['basal', 'apical', 'soma'],
act_fun=AtanGrad):
g_B = 0.6
g_L = 0.05
super().__init__(threshold, v_reset, comps)
self.tau = tau
self.tau_basal = tau_basal
self.tau_apical = tau_apical
self.act_fun = act_fun(alpha=tau, requires_grad=False)
def integral(self, basal_inputs, apical_inputs):
'''
Params:
inputs torch.Tensor: Inputs for basal dendrite
'''
self.mems['basal'] = (self.mems['basal'] + basal_inputs) / self.tau_basal
self.mems['apical'] = (self.mems['apical'] + apical_inputs) / self.tau_apical
self.mems['soma'] = self.mems['soma'] + (self.mems['apical'] + self.mems['basal'] - self.mems['soma']) / self.tau
def calc_spike(self):
self.spike = self.act_fun(self.mems['soma'] - self.threshold)
self.mems['soma'] = self.mems['soma'] * (1. - self.spike.detach())
self.mems['basal'] = self.mems['basal'] * (1. - self.spike.detach())
self.mems['apical'] = self.mems['apical'] * (1. - self.spike.detach())
#============================================================================
# 用于静态测试 使用ANN的情况 不累积电位
class ReLUNode(BaseNode):
"""
用于相同连接的ANN的测试
"""
def __init__(self,
*args,
**kwargs):
super().__init__(requires_fp=False, *args, **kwargs)
self.act_fun = nn.ReLU()
def forward(self, x):
"""
参考```BaseNode```
:param x:
:return:
"""
self.spike = self.act_fun(x)
if self.requires_fp is True:
self.feature_map.append(self.spike)
if self.requires_mem is True:
self.mem_collect.append(self.mem)
return self.spike
def calc_spike(self):
pass
class BiasReLUNode(BaseNode):
"""
用于相同连接的ANN的测试, 会在每个时刻注入恒定电流, 使得神经元更容易激发
"""
def __init__(self,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.act_fun = nn.ReLU()
def forward(self, x):
self.spike = self.act_fun(x + 0.1)
if self.requires_fp is True:
self.feature_map += self.spike
return self.spike
def calc_spike(self):
pass
# ============================================================================
# 用于SNN的node
class IFNode(BaseNode):
"""
Integrate and Fire Neuron
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param dt: 时间步长
:param step: 仿真步
:param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``
:param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``
:param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``
:param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``
:param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``
:param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组
:param args: 其他的参数
:param kwargs: 其他的参数
"""
def __init__(self, threshold=.5, act_fun=AtanGrad, *args, **kwargs):
"""
:param threshold:
:param act_fun:
:param args:
:param kwargs:
"""
super().__init__(threshold, *args, **kwargs)
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.act_fun = act_fun(alpha=2., requires_grad=False)
def integral(self, inputs):
self.mem = self.mem + inputs * self.dt
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.get_thres())
self.mem = self.mem * (1 - self.spike.detach())
class LIFNode(BaseNode):
"""
Leaky Integrate and Fire
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param dt: 时间步长
:param step: 仿真步
:param tau: 膜电位时间常数, 用于控制膜电位衰减
:param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``
:param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``
:param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``
:param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``
:param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``
:param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组
:param args: 其他的参数
:param kwargs: 其他的参数
"""
def __init__(self, threshold=0.5, tau=2., act_fun=QGateGrad, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
self.tau = tau
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.act_fun = act_fun(alpha=2., requires_grad=False)
# self.threshold = threshold
# print(threshold)
# print(tau)
def integral(self, inputs):
self.mem = self.mem + (inputs - self.mem) / self.tau
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.threshold)
self.mem = self.mem * (1 - self.spike.detach())
class BurstLIFNode(LIFNode):
def __init__(self, threshold=.5, tau=2., act_fun=RoundGrad, *args, **kwargs):
super().__init__(threshold=threshold, tau=tau, act_fun=act_fun, *args, **kwargs)
self.burst_factor = 1.5
def calc_spike(self):
LIFNode.calc_spike(self)
self.spike = torch.where(self.spike > 1., self.burst_factor * self.spike, self.spike)
class BackEINode(BaseNode):
"""
BackEINode with self feedback connection and excitatory and inhibitory neurons
Reference:https://www.sciencedirect.com/science/article/pii/S0893608022002520
:param threshold: 神经元发放脉冲需要达到的阈值
:param if_back whether to use self feedback
:param if_ei whether to use excitotory and inhibitory neurons
:param args: 其他的参数
:param kwargs: 其他的参数
"""
def __init__(self, threshold=0.5, decay=0.2, act_fun=BackEIGateGrad, th_fun=EIGrad, channel=40, if_back=True,
if_ei=True, cfg_backei=2, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
self.decay = decay
if isinstance(act_fun, str):
act_fun = eval(act_fun)
if isinstance(th_fun, str):
th_fun = eval(th_fun)
self.act_fun = act_fun()
self.th_fun = th_fun()
self.channel = channel
self.if_back = if_back
if self.if_back:
self.back = nn.Conv2d(channel, channel, kernel_size=2 * cfg_backei+1, stride=1, padding=cfg_backei)
self.if_ei = if_ei
if self.if_ei:
self.ei = nn.Conv2d(channel, channel, kernel_size=2 * cfg_backei+1, stride=1, padding=cfg_backei)
def integral(self, inputs):
if self.mem is None:
self.mem = torch.zeros_like(inputs)
self.spike = torch.zeros_like(inputs)
self.mem = self.decay * self.mem
if self.if_back:
self.mem += F.sigmoid(self.back(self.spike)) * inputs
else:
self.mem += inputs
def calc_spike(self):
if self.if_ei:
ei_gate = self.th_fun(self.ei(self.mem))
self.spike = self.act_fun(self.mem-self.threshold)
self.mem = self.mem * (1 - self.spike)
self.spike = ei_gate * self.spike
else:
self.spike = self.act_fun(self.mem-self.threshold)
self.mem = self.mem * (1 - self.spike)
def n_reset(self):
self.mem = None
self.spike = None
self.feature_map = []
self.mem_collect = []
class NoiseLIFNode(LIFNode):
"""
Noisy Leaky Integrate and Fire
在神经元中注入噪声, 默认的噪声分布为 ``Beta(log(2), log(6))``
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param dt: 时间步长
:param step: 仿真步
:param tau: 膜电位时间常数, 用于控制膜电位衰减
:param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``
:param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``
:param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``
:param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``
:param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``
:param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组
:param log_alpha: 控制 beta 分布的参数 ``a``
:param log_beta: 控制 beta 分布的参数 ``b``
:param args: 其他的参数
:param kwargs: 其他的参数
"""
def __init__(self,
threshold=1,
tau=2.,
act_fun=GateGrad,
log_alpha=np.log(2),
log_beta=np.log(6),
*args,
**kwargs):
super().__init__(threshold=threshold, tau=tau, act_fun=act_fun, *args, **kwargs)
self.log_alpha = Parameter(torch.as_tensor(log_alpha), requires_grad=True)
self.log_beta = Parameter(torch.as_tensor(log_beta), requires_grad=True)
# self.fc = nn.Sequential(
# nn.Linear(1, 5),
# nn.ReLU(),
# nn.Linear(5, 5),
# nn.ReLU(),
# nn.Linear(5, 2)
# )
def integral(self, inputs): # b, c, w, h / b, c
# self.mu, self.log_var = self.fc(inputs.mean().unsqueeze(0)).split(1)
alpha, beta = torch.exp(self.log_alpha), torch.exp(self.log_beta)
mu = alpha / (alpha + beta)
var = ((alpha + 1) * alpha) / ((alpha + beta + 1) * (alpha + beta))
noise = torch.distributions.beta.Beta(alpha, beta).sample(inputs.shape) * self.get_thres()
noise = noise * var / var.detach() + mu - mu.detach()
self.mem = self.mem + ((inputs - self.mem) / self.tau + noise) * self.dt
class BiasLIFNode(BaseNode):
"""
带有恒定电流输入Bias的LIF神经元,用于带有抑制性/反馈链接的网络的测试
Noisy Leaky Integrate and Fire
在神经元中注入噪声, 默认的噪声分布为 ``Beta(log(2), log(6))``
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param dt: 时间步长
:param step: 仿真步
:param tau: 膜电位时间常数, 用于控制膜电位衰减
:param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``
:param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``
:param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``
:param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``
:param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``
:param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组
:param args: 其他的参数
:param kwargs: 其他的参数
"""
def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
self.tau = tau
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.act_fun = act_fun(alpha=2., requires_grad=False)
def integral(self, inputs):
self.mem = self.mem + ((inputs - self.mem) / self.tau) * self.dt + 0.1
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.get_thres())
self.mem = self.mem * (1 - self.spike.detach())
class LIFSTDPNode(BaseNode):
"""
用于执行STDP运算时使用的节点 decay的方式是膜电位乘以decay并直接加上输入电流
"""
def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
self.tau = tau
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.act_fun = act_fun(alpha=2., requires_grad=False)
def integral(self, inputs):
self.mem = self.mem * self.tau + inputs
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.threshold)
# print(( self.threshold).max())
self.mem = self.mem * (1 - self.spike.detach())
def requires_activation(self):
return False
class PLIFNode(BaseNode):
"""
Parametric LIF, 其中的 ```tau``` 会被backward过程影响
Reference:https://arxiv.org/abs/2007.05785
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param dt: 时间步长
:param step: 仿真步
:param tau: 膜电位时间常数, 用于控制膜电位衰减
:param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``
:param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``
:param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``
:param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``
:param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``
:param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组
:param args: 其他的参数
:param kwargs: 其他的参数
"""
def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
init_w = -math.log(tau - 1.)
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.act_fun = act_fun(alpha=2., requires_grad=True)
self.w = nn.Parameter(torch.as_tensor(init_w))
def integral(self, inputs):
self.mem = self.mem + ((inputs - self.mem) * self.w.sigmoid()) * self.dt
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.get_thres())
self.mem = self.mem * (1 - self.spike.detach())
class PSU(BaseNode):
def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
init_w = -math.log(tau - 1.)
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.parallel = True
self.act_fun = act_fun(alpha=2., requires_grad=True)
T = self.step
m1, m2 = generate_matrix(T, tau)
self.register_buffer('m1', m1)
self.register_buffer('m2', m2)
self.m2 *= self.threshold
def integral(self, inputs):
d1 = self.m1 @ inputs.flatten(1)
self.mem = (d1 + self.m2 @ d1.sigmoid()).view(inputs.shape)
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.threshold)
class IPSU(BaseNode):
def masked_weight(self):
return self.fc.weight * self.mask0
def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
init_w = -math.log(tau - 1.)
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.parallel = True
self.act_fun = act_fun(alpha=2., requires_grad=True)
T = self.step
matrix, matrix2 = generate_matrix(T, tau)
self.register_buffer('m1', matrix)
self.register_buffer('m2', matrix2)
# self.m2 *= self.threshold
self.fc = nn.Linear(T, T)
nn.init.constant_(self.fc.bias, 0.)
nn.init.kaiming_normal_(self.fc.weight, mode='fan_out', nonlinearity='relu')
mask0 = torch.tril(torch.ones([T, T]))
self.register_buffer('mask0', mask0)
def integral(self, inputs):
d1 = torch.addmm(self.fc.bias.unsqueeze(1), self.masked_weight(), inputs.flatten((1)))
self.mem = (d1 + self.m2 @ inputs.flatten(1)).view(inputs.shape)
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.threshold)
class RPSU(BaseNode):
def masked_weight(self):
return self.fc.weight * self.mask0
def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
init_w = -math.log(tau - 1.)
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.parallel = True
self.act_fun = act_fun(alpha=2., requires_grad=True)
T = self.step
matrix, matrix2 = generate_matrix(T, tau)
self.register_buffer('m1', matrix)
self.register_buffer('m2', matrix2)
# self.m2 *= self.threshold
self.fc = nn.Linear(T, T)
nn.init.constant_(self.fc.bias, 0.)
nn.init.kaiming_normal_(self.fc.weight, mode='fan_out', nonlinearity='relu')
mask0 = torch.tril(torch.ones([T, T]))
self.register_buffer('mask0', mask0)
def integral(self, inputs):
d1 = self.m1 @ inputs.flatten(1)
d2 = torch.addmm(self.fc.bias.unsqueeze(1), self.masked_weight(), inputs.flatten((1)))
self.mem = (d1 + self.m2 @ d2.sigmoid()).view(inputs.shape)
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.threshold)
class SPSN(BaseNode):
def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
init_w = -math.log(tau - 1.)
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.parallel = True
self.act_fun = act_fun(alpha=2., requires_grad=True)
m1, m2 = generate_matrix(self.step, tau)
self.register_buffer('m1', m1)
def integral(self, inputs):
self.mem = (self.m1 @ inputs.flatten(1)).sigmoid().view(inputs.shape)
def calc_spike(self):
self.spike = torch.bernoulli(self.mem)
class NoisePLIFNode(PLIFNode):
"""
Noisy Parametric Leaky Integrate and Fire
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param dt: 时间步长
:param step: 仿真步
:param tau: 膜电位时间常数, 用于控制膜电位衰减
:param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``
:param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``
:param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``
:param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``
:param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``
:param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组
:param args: 其他的参数
:param kwargs: 其他的参数
"""
def __init__(self,
threshold=1,
tau=2.,
act_fun=GateGrad,
*args,
**kwargs):
super().__init__(threshold=threshold, tau=tau, act_fun=act_fun, *args, **kwargs)
log_alpha = kwargs['log_alpha'] if 'log_alpha' in kwargs else np.log(2)
log_beta = kwargs['log_beta'] if 'log_beta' in kwargs else np.log(6)
self.log_alpha = Parameter(torch.as_tensor(log_alpha), requires_grad=True)
self.log_beta = Parameter(torch.as_tensor(log_beta), requires_grad=True)
# self.fc = nn.Sequential(
# nn.Linear(1, 5),
# nn.ReLU(),
# nn.Linear(5, 5),
# nn.ReLU(),
# nn.Linear(5, 2)
# )
def integral(self, inputs): # b, c, w, h / b, c
# self.mu, self.log_var = self.fc(inputs.mean().unsqueeze(0)).split(1)
alpha, beta = torch.exp(self.log_alpha), torch.exp(self.log_beta)
mu = alpha / (alpha + beta)
var = ((alpha + 1) * alpha) / ((alpha + beta + 1) * (alpha + beta))
noise = torch.distributions.beta.Beta(alpha, beta).sample(inputs.shape) * self.get_thres()
noise = noise * var / var.detach() + mu - mu.detach()
self.mem = self.mem + ((inputs - self.mem) * self.w.sigmoid() + noise) * self.dt
class BiasPLIFNode(BaseNode):
"""
Parametric LIF with bias
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param dt: 时间步长
:param step: 仿真步
:param tau: 膜电位时间常数, 用于控制膜电位衰减
:param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``
:param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``
:param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``
:param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``
:param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``
:param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组
:param args: 其他的参数
:param kwargs: 其他的参数
"""
def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
init_w = -math.log(tau - 1.)
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.act_fun = act_fun(alpha=2., requires_grad=True)
self.w = nn.Parameter(torch.as_tensor(init_w))
def integral(self, inputs):
self.mem = self.mem + ((inputs - self.mem) * self.w.sigmoid() + 0.1) * self.dt
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.get_thres())
self.mem = self.mem * (1 - self.spike.detach())
class DoubleSidePLIFNode(LIFNode):
"""
能够输入正负脉冲的 PLIF
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param dt: 时间步长
:param step: 仿真步
:param tau: 膜电位时间常数, 用于控制膜电位衰减
:param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``
:param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``
:param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``
:param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``
:param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``
:param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组
:param args: 其他的参数
:param kwargs: 其他的参数
"""
def __init__(self,
threshold=.5,
tau=2.,
act_fun=AtanGrad,
*args,
**kwargs):
super().__init__(threshold, tau, act_fun, *args, **kwargs)
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.act_fun = act_fun(alpha=2., requires_grad=True)
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.get_thres()) - self.act_fun(self.get_thres - self.mem)
self.mem = self.mem * (1. - torch.abs(self.spike.detach()))
class IzhNode(BaseNode):
"""
Izhikevich 脉冲神经元
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param dt: 时间步长
:param step: 仿真步
:param tau: 膜电位时间常数, 用于控制膜电位衰减
:param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``
:param args: 其他的参数
:param kwargs: 其他的参数
"""
def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
self.tau = tau
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.act_fun = act_fun(alpha=2., requires_grad=False)
self.a = kwargs['a'] if 'a' in kwargs else 0.02
self.b = kwargs['b'] if 'b' in kwargs else 0.2
self.c = kwargs['c'] if 'c' in kwargs else -55.
self.d = kwargs['d'] if 'd' in kwargs else -2.
'''
v' = 0.04v^2 + 5v + 140 -u + I
u' = a(bv-u)
下面是将Izh离散化的写法
if v>= thresh:
v = c
u = u + d
'''
# 初始化膜电势 以及 对应的U
self.mem = 0.
self.u = 0.
self.dt = kwargs['dt'] if 'dt' in kwargs else 1.
def integral(self, inputs):
self.mem = self.mem + self.dt * (0.04 * self.mem * self.mem + 5 * self.mem - self.u + 140 + inputs)
self.u = self.u + self.dt * (self.a * self.b * self.mem - self.a * self.u)
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.get_thres()) # 大于阈值释放脉冲
self.mem = self.mem * (1 - self.spike.detach()) + self.spike.detach() * self.c
self.u = self.u + self.spike.detach() * self.d
def n_reset(self):
self.mem = 0.
self.u = 0.
self.spike = 0.
class IzhNodeMU(BaseNode):
"""
Izhikevich 脉冲神经元多参数版
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param dt: 时间步长
:param step: 仿真步
:param tau: 膜电位时间常数, 用于控制膜电位衰减
:param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``
:param args: 其他的参数
:param kwargs: 其他的参数
"""
def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
self.tau = tau
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.act_fun = act_fun(alpha=2., requires_grad=False)
self.a = kwargs['a'] if 'a' in kwargs else 0.02
self.b = kwargs['b'] if 'b' in kwargs else 0.2
self.c = kwargs['c'] if 'c' in kwargs else -55.
self.d = kwargs['d'] if 'd' in kwargs else -2.
self.mem = kwargs['mem'] if 'mem' in kwargs else 0.
self.u = kwargs['u'] if 'u' in kwargs else 0.
self.dt = kwargs['dt'] if 'dt' in kwargs else 1.
def integral(self, inputs):
self.mem = self.mem + self.dt * (0.04 * self.mem * self.mem + 5 * self.mem - self.u + 140 + inputs)
self.u = self.u + self.dt * (self.a * self.b * self.mem - self.a * self.u)
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.threshold)
self.mem = self.mem * (1 - self.spike.detach()) + self.spike.detach() * self.c
self.u = self.u + self.spike.detach() * self.d
def n_reset(self):
self.mem = -70.
self.u = 0.
self.spike = 0.
def requires_activation(self):
return False
class DGLIFNode(BaseNode):
"""
Reference: https://arxiv.org/abs/2110.08858
:param threshold: 神经元的脉冲发放阈值
:param tau: 神经元的膜常数, 控制膜电位衰减
"""
def __init__(self, threshold=.5, tau=2., *args, **kwargs):
super().__init__(threshold, tau, *args, **kwargs)
self.act = nn.ReLU()
self.tau = tau
def integral(self, inputs):
inputs = self.act(inputs)
self.mem = self.mem + ((inputs - self.mem) / self.tau) * self.dt
def calc_spike(self):
spike = self.mem.clone()
spike[(spike < self.get_thres())] = 0.
# self.spike = spike / (self.mem.detach().clone() + 1e-12)
self.spike = spike - spike.detach() + \
torch.where(spike.detach() > self.get_thres(), torch.ones_like(spike), torch.zeros_like(spike))
self.spike = spike
self.mem = torch.where(self.mem >= self.get_thres(), torch.zeros_like(self.mem), self.mem)
class HTDGLIFNode(IFNode):
"""
Reference: https://arxiv.org/abs/2110.08858
:param threshold: 神经元的脉冲发放阈值
:param tau: 神经元的膜常数, 控制膜电位衰减
"""
def __init__(self, threshold=.5, tau=2., *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
self.warm_up = False
def calc_spike(self):
spike = self.mem.clone()
spike[(spike < self.get_thres())] = 0.
# self.spike = spike / (self.mem.detach().clone() + 1e-12)
self.spike = spike - spike.detach() + \
torch.where(spike.detach() > self.get_thres(), torch.ones_like(spike), torch.zeros_like(spike))
self.spike = spike
self.mem = torch.where(self.mem >= self.get_thres(), torch.zeros_like(self.mem), self.mem)
# self.mem[[(spike > self.get_thres())]] = self.mem[[(spike > self.get_thres())]] - self.get_thres()
self.mem = (self.mem + 0.2 * self.spike - 0.2 * self.spike.detach()) * self.dt
def forward(self, inputs):
if self.warm_up:
return F.relu(inputs)
else:
return super(IFNode, self).forward(F.relu(inputs))
class SimHHNode(BaseNode):
"""
简单版本的HH模型
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param dt: 时间步长
:param step: 仿真步
:param tau: 膜电位时间常数, 用于控制膜电位衰减
:param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``
:param args: 其他的参数
:param kwargs: 其他的参数
"""
def __init__(self, threshold=50., tau=2., act_fun=AtanGrad, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
self.tau = tau
if isinstance(act_fun, str):
act_fun = eval(act_fun)
'''
I = Cm dV/dt + g_k*n^4*(V_m-V_k) + g_Na*m^3*h*(V_m-V_Na) + g_l*(V_m - V_L)
'''
self.act_fun = act_fun(alpha=2., requires_grad=False)
self.g_Na, self.g_K, self.g_l = torch.tensor(120.), torch.tensor(120), torch.tensor(0.3) # k 36
self.V_Na, self.V_K, self.V_l = torch.tensor(120.), torch.tensor(-120.), torch.tensor(10.6) # k -12
self.m, self.n, self.h = torch.tensor(0), torch.tensor(0), torch.tensor(0)
self.mem = 0
self.dt = 0.01
def integral(self, inputs):
self.I_Na = torch.pow(self.m, 3) * self.g_Na * self.h * (self.mem - self.V_Na)
self.I_K = torch.pow(self.n, 4) * self.g_K * (self.mem - self.V_K)
self.I_L = self.g_l * (self.mem - self.V_l)
self.mem = self.mem + self.dt * (inputs - self.I_Na - self.I_K - self.I_L) / 0.02
# non Na
# self.mem = self.mem + 0.01 * (inputs - self.I_K - self.I_L) / 0.02 #decayed
# NON k
# self.mem = self.mem + 0.01 * (inputs - self.I_Na - self.I_L) / 0.02 #increase
self.alpha_n = 0.01 * (self.mem + 10.0) / (1 - torch.exp(-(self.mem + 10.0) / 10))
self.beta_n = 0.125 * torch.exp(-(self.mem) / 80)
self.alpha_m = 0.1 * (self.mem + 25) / (1 - torch.exp(-(self.mem + 25) / 10))
self.beta_m = 4 * torch.exp(-(self.mem) / 18)
self.alpha_h = 0.07 * torch.exp(-(self.mem) / 20)
self.beta_h = 1 / (1 + torch.exp(-(self.mem + 30) / 10))
self.n = self.n + self.dt * (self.alpha_n * (1 - self.n) - self.beta_n * self.n)
self.m = self.m + self.dt * (self.alpha_m * (1 - self.m) - self.beta_m * self.m)
self.h = self.h + self.dt * (self.alpha_h * (1 - self.h) - self.beta_h * self.h)
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.threshold)
self.mem = self.mem * (1 - self.spike.detach())
def forward(self, inputs):
self.integral(inputs)
self.calc_spike()
return self.spike
def n_reset(self):
self.mem = 0.
self.spike = 0.
self.m, self.n, self.h = torch.tensor(0), torch.tensor(0), torch.tensor(0)
def requires_activation(self):
return False
class CTIzhNode(IzhNode):
def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):
super().__init__(threshold, tau, act_fun, *args, **kwargs)
self.name = kwargs['name'] if 'name' in kwargs else ''
self.excitability = kwargs['excitability'] if 'excitability' in kwargs else 'TRUE'
self.spikepattern = kwargs['spikepattern'] if 'spikepattern' in kwargs else 'RS'
self.synnum = kwargs['synnum'] if 'synnum' in kwargs else 0
self.locationlayer = kwargs['locationlayer'] if 'locationlayer' in kwargs else ''
self.adjneuronlist = {}
self.proximal_dendrites = []
self.distal_dendrites = []
self.totalindex = kwargs['totalindex'] if 'totalindex' in kwargs else 0
self.colindex = 0
self.state = 'inactive'
self.Gup = kwargs['Gup'] if 'Gup' in kwargs else 0.0
self.Gdown = kwargs['Gdown'] if 'Gdown' in kwargs else 0.0
self.Vr = kwargs['Vr'] if 'Vr' in kwargs else 0.0
self.Vt = kwargs['Vt'] if 'Vt' in kwargs else 0.0
self.Vpeak = kwargs['Vpeak'] if 'Vpeak' in kwargs else 0.0
self.capicitance = kwargs['capacitance'] if 'capacitance' in kwargs else 0.0
self.k = kwargs['k'] if 'k' in kwargs else 0.0
self.mem = -65
self.vtmp = -65
self.u = -13.0
self.spike = 0
self.dc = 0
def integral(self, inputs):
self.mem += self.dt * (
self.k * (self.mem - self.Vr) * (self.mem - self.Vt) - self.u + inputs) / self.capicitance
self.u += self.dt * (self.a * (self.b * (self.mem - self.Vr) - self.u))
def calc_spike(self):
if self.mem >= self.Vpeak:
self.mem = self.c
self.u = self.u + self.d
self.spike = 1
self.spreadMarkPostNeurons()
def spreadMarkPostNeurons(self):
for post, list in self.adjneuronlist.items():
if self.excitability == "TRUE":
post.dc = random.randint(140, 160)
else:
post.dc = random.randint(-160, -140)
class adth(BaseNode):
"""
The adaptive Exponential Integrate-and-Fire model (aEIF)
:param args: Other parameters
:param kwargs: Other parameters
"""
def __init__(self, *args, **kwargs):
super().__init__(requires_fp=False, *args, **kwargs)
def adthNode(self, v, dt, c_m, g_m, alpha_w, ad, Ieff, Ichem, Igap, tau_ad, beta_ad, vt, vm1):
"""
Calculate the neurons that discharge after the current threshold is reached
:param v: Current neuron voltage
:param dt: time step
:param ad:Adaptive variable
:param vv:Spike, if the voltage exceeds the threshold from below
"""
v = v + dt / c_m * (-g_m * v + alpha_w * ad + Ieff + Ichem + Igap)
ad = ad + dt / tau_ad * (-ad + beta_ad * v)
vv = (v >= vt).astype(int) * (vm1 < vt).astype(int)
vm1 = v
return v, ad, vv, vm1
def calc_spike(self):
pass
class HHNode(BaseNode):
"""
用于脑模拟的HH模型
p: [threshold, g_Na, g_K, g_l, V_Na, V_K, V_l, C]
"""
def __init__(self, p, dt, device, act_fun=AtanGrad, *args, **kwargs):
super().__init__(threshold=p[0], *args, **kwargs)
if isinstance(act_fun, str):
act_fun = eval(act_fun)
'''
I = Cm dV/dt + g_k*n^4*(V_m-V_k) + g_Na*m^3*h*(V_m-V_Na) + g_l*(V_m - V_L)
'''
self.neuron_num = len(p[0])
self.act_fun = act_fun(alpha=2., requires_grad=False)
self.tau_I = 3
self.g_Na = torch.tensor(p[1])
self.g_K = torch.tensor(p[2])
self.g_l = torch.tensor(p[3])
self.V_Na = torch.tensor(p[4])
self.V_K = torch.tensor(p[5])
self.V_l = torch.tensor(p[6])
self.C = torch.tensor(p[7])
self.m = 0.05 * torch.ones(self.neuron_num, device=device, requires_grad=False)
self.n = 0.31 * torch.ones(self.neuron_num, device=device, requires_grad=False)
self.h = 0.59 * torch.ones(self.neuron_num, device=device, requires_grad=False)
self.v_reset = 0
self.dt = dt
self.dt_over_tau = self.dt / self.tau_I
self.sqrt_coeff = math.sqrt(1 / (2 * (1 / self.dt_over_tau)))
self.mu = 10
self.sig = 12
self.mem = torch.tensor(self.v_reset, device=device, requires_grad=False)
self.mem_p = self.mem
self.spike = torch.zeros(self.neuron_num, device=device, requires_grad=False)
self.Iback = torch.zeros(self.neuron_num, device=device, requires_grad=False)
self.Ieff = torch.zeros(self.neuron_num, device=device, requires_grad=False)
def integral(self, inputs):
self.alpha_n = (0.1 - 0.01 * self.mem) / (torch.exp(1 - 0.1 * self.mem) - 1)
self.alpha_m = (2.5 - 0.1 * self.mem) / (torch.exp(2.5 - 0.1 * self.mem) - 1)
self.alpha_h = 0.07 * torch.exp(-self.mem / 20.0)
self.beta_n = 0.125 * torch.exp(-self.mem / 80.0)
self.beta_m = 4.0 * torch.exp(-self.mem / 18.0)
self.beta_h = 1 / (torch.exp(3 - 0.1 * self.mem) + 1)
self.n = self.n + self.dt * (self.alpha_n * (1 - self.n) - self.beta_n * self.n)
self.m = self.m + self.dt * (self.alpha_m * (1 - self.m) - self.beta_m * self.m)
self.h = self.h + self.dt * (self.alpha_h * (1 - self.h) - self.beta_h * self.h)
self.I_Na = torch.pow(self.m, 3) * self.g_Na * self.h * (self.mem - self.V_Na)
self.I_K = torch.pow(self.n, 4) * self.g_K * (self.mem - self.V_K)
self.I_L = self.g_l * (self.mem - self.V_l)
self.mem_p = self.mem
self.mem = self.mem + self.dt * (inputs - self.I_Na - self.I_K - self.I_L) / self.C
def calc_spike(self):
self.spike = (self.threshold > self.mem_p).float() * (self.mem > self.threshold).float()
def forward(self, inputs):
self.integral(inputs)
self.calc_spike()
return self.spike, self.mem
def requires_activation(self):
return False
class aEIF(BaseNode):
"""
The adaptive Exponential Integrate-and-Fire model (aEIF)
This class define the membrane, spike, current and parameters of a neuron group of a specific type
:param args: Other parameters
:param kwargs: Other parameters
"""
def __init__(self, p, dt, device, *args, **kwargs):
"""
p:[threshold, v_reset, c_m, tao_w, alpha_ad, beta_ad]
"""
super().__init__(threshold=p[0], requires_fp=False, *args, **kwargs)
self.neuron_num = len(p[0])
self.g_m = 0.1 # neuron conduction
self.dt = dt
self.tau_I = 3 # Time constant to filter the synaptic inputs
self.Delta_T = 0.5 # parameter
self.v_reset = p[1] # membrane potential reset to v_reset after fire spike
self.c_m = p[2]
self.tau_w = p[3] # Time constant of adaption coupling
self.alpha_ad = p[4]
self.beta_ad = p[5]
self.refrac = 5 / self.dt # refractory period
self.dt_over_tau = self.dt / self.tau_I
self.sqrt_coeff = math.sqrt(1 / (2 * (1 / self.dt_over_tau)))
self.mem = self.v_reset
self.spike = torch.zeros(self.neuron_num, device=device, requires_grad=False)
self.ad = torch.zeros(self.neuron_num, device=device, requires_grad=False)
self.ref = torch.randint(0, int(self.refrac + 1), (1, self.neuron_num), device=device, requires_grad=False).squeeze(
0) # refractory counter
self.ref = self.ref.float()
self.mu = 10
self.sig = 12
self.Iback = torch.zeros(self.neuron_num, device=device, requires_grad=False)
self.Ieff = torch.zeros(self.neuron_num, device=device, requires_grad=False)
def integral(self, inputs):
self.mem = self.mem + (self.ref > self.refrac) * self.dt / self.c_m * \
(-self.g_m * (self.mem - self.v_reset) + self.g_m * self.Delta_T *
torch.exp((self.mem - self.threshold) / self.Delta_T) +
self.alpha_ad * self.ad + inputs)
self.ad = self.ad + (self.ref > self.refrac) * self.dt / self.tau_w * \
(-self.ad + self.beta_ad * (self.mem - self.v_reset))
def calc_spike(self):
self.spike = (self.mem > self.threshold).float()
self.ref = self.ref * (1 - self.spike) + 1
self.ad = self.ad + self.spike * 30
self.mem = self.spike * self.v_reset + (1 - self.spike.detach()) * self.mem
def forward(self, inputs):
# aeifnode_cuda.forward(self.threshold, self.c_m, self.alpha_w, self.beta_ad, inputs, self.ref, self.ad, self.mem, self.spike)
self.integral(inputs)
self.calc_spike()
return self.spike, self.mem
class LIAFNode(BaseNode):
"""
Leaky Integrate and Analog Fire (LIAF), Reference: https://ieeexplore.ieee.org/abstract/document/9429228
与LIF相同, 但前传的是膜电势, 更新沿用阈值和膜电势
:param act_fun: 前传使用的激活函数 [ReLU, SeLU, LeakyReLU]
:param threshold_related: 阈值依赖模式,若为"True"则 self.spike = act_fun(mem-threshold)
:note that BaseNode return self.spike, and here self.spike is analog value.
"""
def __init__(self, spike_act=BackEIGateGrad(), act_fun="SELU", threshold=0.5, tau=2., threshold_related=True, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
if isinstance(act_fun, str):
act_fun = eval("nn." + act_fun + "()")
self.tau = tau
self.act_fun = act_fun
self.spike_act = spike_act
self.threshold_related = threshold_related
def integral(self, inputs):
self.mem = self.mem + (inputs - self.mem) / self.tau
def calc_spike(self):
if self.threshold_related:
spike_tmp = self.act_fun(self.mem - self.threshold)
else:
spike_tmp = self.act_fun(self.mem)
self.spike = self.spike_act(self.mem - self.threshold)
self.mem = self.mem * (1 - self.spike)
self.spike = spike_tmp
class OnlineLIFNode(BaseNode):
"""
Online-update Leaky Integrate and Fire
与LIF模型相同,但是时序信息在反传时从计算图剥离,因此可以实现在线的更新;模型占用显存固定,不随仿真步step线性提升。
使用此神经元需要修改: 1. 将模型中t次forward从model_zoo写到main.py中
2. 在Conv层与OnelineLIFNode层中加入Replace函数,即时序前传都是detach的,但仍计算该层空间梯度信息。
3. 网络结构不适用BN层,使用weight standardization
注意该神经元不同于OTTT,而是将时序信息全部扔弃。对应这篇文章:https://arxiv.org/abs/2302.14311
若需保留时序,需要对self.rate_tracking进行计算。实现可参考https://github.com/pkuxmq/OTTT-SNN
"""
def __init__(self, threshold=0.5, tau=2., act_fun=QGateGrad, init=False, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
self.tau = tau
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.act_fun = act_fun(alpha=2., requires_grad=False)
self.rate_tracking = None
self.init = True
def integral(self, inputs):
if self.init is True:
self.mem = torch.zeros_like(inputs)
self.init = False
self.mem = self.mem.detach() + (inputs - self.mem.detach()) / self.tau
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.threshold)
self.mem = self.mem * (1 - self.spike.detach())
with torch.no_grad():
if self.rate_tracking == None:
self.rate_tracking = self.spike.clone().detach()
self.spike = torch.cat((self.spike, self.rate_tracking), dim=0)
class AdaptiveNode(LIFNode):
def __init__(self, threshold=1., act_fun=QGateGrad, step=10, spike_output=True, *args, **kwargs):
super().__init__(threshold=threshold, step=step, **kwargs)
self.n_encode_type = kwargs['n_encode_type'] if 'n_encode_type' in kwargs else 'linear'
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.act_fun = act_fun(alpha=2., requires_grad=False)
# self.act_fun = BinaryActivation()
print(self.n_encode_type)
if self.n_encode_type == 'linear':
self.encoder = nn.Sequential(
CustomLinear(self.step, self.step)
)
elif self.n_encode_type == 'mlp':
# Direct
self.encoder = nn.Sequential(
CustomLinear(self.step, self.step),
nn.ReLU(),
CustomLinear(self.step, self.step),
nn.ReLU(),
CustomLinear(self.step, self.step),
nn.ReLU(),
CustomLinear(self.step, self.step),
)
elif self.n_encode_type == 'att':
# -> SE block
self.encoder = nn.Sequential(
nn.Linear(self.step, self.step),
nn.ReLU(),
nn.Linear(self.step, self.step),
nn.ReLU(),
nn.Linear(self.step, self.step),
nn.Sigmoid()
)
elif self.n_encode_type == 'conv':
self.encoder = nn.Sequential(
nn.Linear(self.step, self.step),
nn.ReLU(),
nn.Linear(self.step, self.step),
)
# self.init_weight()
else:
raise NotImplementedError('Unrecognizable categories {}.'.format(self.n_encode_type))
self.saved_mem = 0.
def init_weight(self):
for mod in self.encoder.modules():
if isinstance(mod, nn.Conv1d):
mod.weight.data[:, :, 4] = 1. / mod.weight.shape[0]
mod.weight.data[:, :, [0, 1, 2, 3, 5, 6, 7, 8]] = 0.
mod.bias.data[:] = 0.
def forward(self, inputs): # (t b) c w h
if self.n_encode_type != 'conv':
x = rearrange(inputs, '(t b) ... -> b ... t', t=self.step)
else:
c, w, h = inputs.shape[1:]
x = rearrange(inputs, '(t b) c w h -> (b c w h) 1 t', t=self.step)
if self.n_encode_type != 'att':
x = self.encoder(x) # Direct
else:
x = x * self.encoder(x) # SE Block
if self.n_encode_type != 'conv':
x = rearrange(x, 'b ... t -> (t b) ...')
else:
x = rearrange(x, '(b c w h) 1 t -> (t b) c w h', c=c, w=w, h=h)
# self.spike = self.act_fun(x - 0.5)
# # print(self.spike.mean())
# # print(self.requires_fp)
# if self.requires_fp:
# spike = rearrange(self.spike, '(t b) c w h -> t b c w h', t=self.step)
# for t in range(self.step):
# # print(t, float(spike[t].mean()), float(spike[t].std()))
# self.feature_map.append(spike[t])
# self.saved_mem = x
# return self.spike
return super().forward(x)
# def get_thres(self):
# mem_relu = F.relu(self.mem.detach())
# return mem_relu[mem_relu > 0.].median()
def n_reset(self):
super().n_reset()
self.saved_mem = 0.
================================================
FILE: braincog/base/strategy/LateralInhibition.py
================================================
import warnings
import torch
from torch import nn
import torch.nn.functional as F
class LateralInhibition(nn.Module):
"""
侧抑制 用于发放脉冲的神经元抑制其他同层神经元 在膜电位上作用
"""
def __init__(self, node, inh, mode="constant"):
super().__init__()
self.inh = inh
self.node = node
self.mode = mode
def forward(self, x: torch.Tensor, xori=None):
# x.shape = [N, C,W,H]
# ret.shape = [N, C,W,H]
if self.mode == "constant":
self.node.mem = self.node.mem - self.inh * (x.max(1, True)[0] - x)
elif self.mode == "max":
self.node.mem = self.node.mem - self.inh * xori.max(1, True)[0] .detach() * (x.max(1, True)[0] - x)
elif self.mode == "threshold":
self.node.mem = self.node.mem - self.inh * self.node.threshold * (x.max(1, True)[0] - x)
else:
pass
return x
================================================
FILE: braincog/base/strategy/__init__.py
================================================
__all__ = ['surrogate', 'LateralInhibition']
from . import (
surrogate,
LateralInhibition
)
================================================
FILE: braincog/base/strategy/surrogate.py
================================================
import math
import torch
from torch import nn
from torch.nn import functional as F
def heaviside(x):
return (x >= 0.).to(x.dtype)
class SurrogateFunctionBase(nn.Module):
"""
Surrogate Function 的基类
:param alpha: 为一些能够调控函数形状的代理函数提供参数.
:param requires_grad: 参数 ``alpha`` 是否需要计算梯度, 默认为 ``False``
"""
def __init__(self, alpha, requires_grad=True):
super().__init__()
self.alpha = nn.Parameter(
torch.tensor(alpha, dtype=torch.float),
requires_grad=requires_grad)
@staticmethod
def act_fun(x, alpha):
"""
:param x: 膜电位的输入
:param alpha: 控制代理梯度形状的变量, 可以为 ``NoneType``
:return: 激发之后的spike, 取值为 ``[0, 1]``
"""
raise NotImplementedError
def forward(self, x):
"""
:param x: 膜电位输入
:return: 激发之后的spike
"""
return self.act_fun(x, self.alpha)
'''
sigmoid surrogate function.
'''
class sigmoid(torch.autograd.Function):
"""
使用 sigmoid 作为代理梯度函数
对应的原函数为:
.. math::
g(x) = \\mathrm{sigmoid}(\\alpha x) = \\frac{1}{1+e^{-\\alpha x}}
反向传播的函数为:
.. math::
g'(x) = \\alpha * (1 - \\mathrm{sigmoid} (\\alpha x)) \\mathrm{sigmoid} (\\alpha x)
"""
@staticmethod
def forward(ctx, x, alpha):
if x.requires_grad:
ctx.save_for_backward(x)
ctx.alpha = alpha
return heaviside(x)
@staticmethod
def backward(ctx, grad_output):
grad_x = None
if ctx.needs_input_grad[0]:
s_x = torch.sigmoid(ctx.alpha * ctx.saved_tensors[0])
grad_x = grad_output * s_x * (1 - s_x) * ctx.alpha
return grad_x, None
class SigmoidGrad(SurrogateFunctionBase):
def __init__(self, alpha=1., requires_grad=False):
super().__init__(alpha, requires_grad)
@staticmethod
def act_fun(x, alpha):
return sigmoid.apply(x, alpha)
'''
atan surrogate function.
'''
class atan(torch.autograd.Function):
"""
使用 Atan 作为代理梯度函数
对应的原函数为:
.. math::
g(x) = \\frac{1}{\\pi} \\arctan(\\frac{\\pi}{2}\\alpha x) + \\frac{1}{2}
反向传播的函数为:
.. math::
g'(x) = \\frac{\\alpha}{2(1 + (\\frac{\\pi}{2}\\alpha x)^2)}
"""
@staticmethod
def forward(ctx, inputs, alpha):
ctx.save_for_backward(inputs, alpha)
return inputs.gt(0.).float()
@staticmethod
def backward(ctx, grad_output):
grad_x = None
grad_alpha = None
shared_c = grad_output / \
(1 + (ctx.saved_tensors[1] * math.pi /
2 * ctx.saved_tensors[0]).square())
if ctx.needs_input_grad[0]:
grad_x = ctx.saved_tensors[1] / 2 * shared_c
if ctx.needs_input_grad[1]:
grad_alpha = (ctx.saved_tensors[0] / 2 * shared_c).sum()
return grad_x, grad_alpha
class AtanGrad(SurrogateFunctionBase):
def __init__(self, alpha=2., requires_grad=True):
super().__init__(alpha, requires_grad)
@staticmethod
def act_fun(x, alpha):
return atan.apply(x, alpha)
'''
gate surrogate fucntion.
'''
class gate(torch.autograd.Function):
"""
使用 gate 作为代理梯度函数
对应的原函数为:
.. math::
g(x) = \\mathrm{NonzeroSign}(x) \\log (|\\alpha x| + 1)
反向传播的函数为:
.. math::
g'(x) = \\frac{\\alpha}{1 + |\\alpha x|} = \\frac{1}{\\frac{1}{\\alpha} + |x|}
"""
@staticmethod
def forward(ctx, x, alpha):
if x.requires_grad:
grad_x = torch.where(x.abs() < 1. / alpha, torch.ones_like(x), torch.zeros_like(x))
ctx.save_for_backward(grad_x)
return x.gt(0).float()
@staticmethod
def backward(ctx, grad_output):
grad_x = None
if ctx.needs_input_grad[0]:
grad_x = grad_output * ctx.saved_tensors[0]
return grad_x, None
class GateGrad(SurrogateFunctionBase):
def __init__(self, alpha=2., requires_grad=False):
super().__init__(alpha, requires_grad)
@staticmethod
def act_fun(x, alpha):
return gate.apply(x, alpha)
'''
gatquadratic_gate surrogate function.
'''
class quadratic_gate(torch.autograd.Function):
"""
使用 quadratic_gate 作为代理梯度函数
对应的原函数为:
.. math::
g(x) =
\\begin{cases}
0, & x < -\\frac{1}{\\alpha} \\\\
-\\frac{1}{2}\\alpha^2|x|x + \\alpha x + \\frac{1}{2}, & |x| \\leq \\frac{1}{\\alpha} \\\\
1, & x > \\frac{1}{\\alpha} \\\\
\\end{cases}
反向传播的函数为:
.. math::
g'(x) =
\\begin{cases}
0, & |x| > \\frac{1}{\\alpha} \\\\
-\\alpha^2|x|+\\alpha, & |x| \\leq \\frac{1}{\\alpha}
\\end{cases}
"""
@staticmethod
def forward(ctx, x, alpha):
if x.requires_grad:
mask_zero = (x.abs() > 1 / alpha)
grad_x = -alpha * alpha * x.abs() + alpha
grad_x.masked_fill_(mask_zero, 0)
ctx.save_for_backward(grad_x)
return x.gt(0.).float()
@staticmethod
def backward(ctx, grad_output):
grad_x = None
if ctx.needs_input_grad[0]:
grad_x = grad_output * ctx.saved_tensors[0]
return grad_x, None
class QGateGrad(SurrogateFunctionBase):
def __init__(self, alpha=2., requires_grad=False):
super().__init__(alpha, requires_grad)
@staticmethod
def act_fun(x, alpha):
return quadratic_gate.apply(x, alpha)
class relu_like(torch.autograd.Function):
@staticmethod
def forward(ctx, x, alpha):
if x.requires_grad:
ctx.save_for_backward(x, alpha)
return heaviside(x)
@staticmethod
def backward(ctx, grad_output):
grad_x, grad_alpha = None, None
x, alpha = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_x = grad_output * x.gt(0.).float() * alpha
if ctx.needs_input_grad[1]:
grad_alpha = (grad_output * F.relu(x)).sum()
return grad_x, grad_alpha
class RoundGrad(nn.Module):
def __init__(self, **kwargs):
super(RoundGrad, self).__init__()
self.act = nn.Hardtanh(-.5, 4.5)
def forward(self, x):
x = self.act(x)
return x.ceil() + x - x.detach()
class ReLUGrad(SurrogateFunctionBase):
"""
使用ReLU作为代替梯度函数, 主要用为相同结构的ANN的测试
"""
def __init__(self, alpha=2., requires_grad=False):
super().__init__(alpha, requires_grad)
@staticmethod
def act_fun(x, alpha):
return relu_like.apply(x, alpha)
'''
Straight-Through (ST) Estimator
'''
class straight_through_estimator(torch.autograd.Function):
"""
使用直通估计器作为代理梯度函数
http://arxiv.org/abs/1308.3432
"""
@staticmethod
def forward(ctx, inputs):
outputs = heaviside(inputs)
ctx.save_for_backward(outputs)
return outputs
@staticmethod
def backward(ctx, grad_output):
grad_x = None
if ctx.needs_input_grad[0]:
grad_x = grad_output
return grad_x
class stdp(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs):
outputs = inputs.gt(0.).float()
ctx.save_for_backward(outputs)
return outputs
@staticmethod
def backward(ctx, grad_output):
inputs, = ctx.saved_tensors
return inputs * grad_output
class STDPGrad(SurrogateFunctionBase):
def __init__(self, alpha=2., requires_grad=False):
super().__init__(alpha, requires_grad)
@staticmethod
def act_fun(x, alpha):
return stdp.apply(x)
class backeigate(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.gt(0.).float()
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
temp = abs(input) < 0.5
return grad_input * temp.float()
class BackEIGateGrad(SurrogateFunctionBase):
def __init__(self, alpha=2., requires_grad=False):
super().__init__(alpha, requires_grad)
@staticmethod
def act_fun(x, alpha):
return backeigate.apply(x)
class ei(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return torch.sign(input).float()
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
temp = abs(input) < 0.5
return grad_input * temp.float()
class EIGrad(SurrogateFunctionBase):
def __init__(self, alpha=2., requires_grad=False):
super().__init__(alpha, requires_grad)
@staticmethod
def act_fun(x, alpha):
return ei.apply(x)
================================================
FILE: braincog/base/utils/__init__.py
================================================
from .criterions import UnilateralMse, MixLoss
from .visualization import plot_tsne, plot_tsne_3d, plot_confusion_matrix
from torch.autograd import Variable
import torch
__all__ = [
'UnilateralMse', 'MixLoss',
'plot_tsne', 'plot_tsne_3d', 'plot_confusion_matrix', 'drop_path'
]
def drop_path(x, drop_prob):
if drop_prob > 0.:
keep_prob = 1. - drop_prob
mask = Variable(torch.cuda.FloatTensor(
x.size(0), 1, 1, 1).bernoulli_(keep_prob))
x.div_(keep_prob)
x.mul_(mask)
return x
================================================
FILE: braincog/base/utils/criterions.py
================================================
import numpy as np
import torch
import torch.nn.functional as F
class UnilateralMse(torch.nn.Module):
"""
扩展单边的MSE损失, 用于控制输出层的期望fire-rate 高于 thresh
:param thresh: 输出层的期望输出频率
"""
def __init__(self, thresh=1.):
super(UnilateralMse, self).__init__()
self.thresh = thresh
self.loss = torch.nn.MSELoss()
def forward(self, x, target):
# x = nn.functional.softmax(x, dim=1)
torch.clip(x, max=self.thresh)
if x.shape == target.shape:
return self.loss(x, target)
return self.loss(x, torch.zeros_like(x).scatter_(1, target.view(-1, 1), self.thresh))
class MixLoss(torch.nn.Module):
"""
混合损失函数, 可以将任意的损失函数与UnilateralMse损失混合
:param ce_loss: 任意的损失函数
"""
def __init__(self, ce_loss):
super(MixLoss, self).__init__()
self.ce = ce_loss
self.mse = UnilateralMse(1.)
def forward(self, x, target):
return 0.1 * self.ce(x, target) + self.mse(x, target)
class TetLoss(torch.nn.Module):
def __init__(self, loss_fn):
super(TetLoss, self).__init__()
self.loss_fn = loss_fn
def forward(self, x, target):
loss = 0.
for logit in x:
loss += self.loss_fn(logit, target)
return loss / x.shape[0]
class OnehotMse(torch.nn.Module):
"""
将类别转换为onehot进行mse损失计算, 用于带vote的SNN中
"""
def __init__(self, num_class):
super(OnehotMse, self).__init__()
self.num_class = num_class
self.loss_fn = torch.nn.MSELoss()
def forward(self, x, target):
target = F.one_hot(target.to(torch.int64), self.num_class).float()
loss = self.loss_fn(x, target)
return loss
================================================
FILE: braincog/base/utils/visualization.py
================================================
# encoding: utf-8
# Author : Floyed
# Datetime : 2022/7/1 11:10
# User : Floyed
# Product : PyCharm
# Project : braincog
# File : visualization.py
# explain : add t-SNE
import os
import numpy as np
import sklearn
from sklearn.manifold import TSNE
from sklearn.metrics import confusion_matrix
import torch
import torch.nn.functional as F
from einops import rearrange
import matplotlib.pyplot as plt
import matplotlib.patheffects as PathEffects
import matplotlib
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
import seaborn as sns
# Random state.
RS = 20150101
def spike_rate_vis_1d(data, output_dir=''):
assert len(data.shape) == 2, 'Shape should be (t, c).'
data = rearrange(data, 'i j -> j i')
if isinstance(data, torch.Tensor):
data = data.to('cpu').numpy()
plt.figure(figsize=(8, 8))
sns.heatmap(data, annot=None, cmap='YlGnBu')
# plt.ylim(0, _max + 1)
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()
def spike_rate_vis(data, output_dir=''):
assert len(data.shape) == 3, 'Shape should be (t, r, c).'
data = data.mean(axis=0)
if isinstance(data, torch.Tensor):
data = data.to('cpu').numpy()
plt.figure(figsize=(8, 8))
sns.heatmap(data, annot=None, cmap='YlGnBu')
# plt.ylim(0, _max + 1)
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()
def plot_mem_distribution(data,
output_dir='',
legend='',
xlabel='Membrane Potential',
ylabel='Density',
**kwargs):
# print(type(data), len(data))
if isinstance(data, torch.Tensor):
data = data.reshape(-1).to('cpu').numpy()
mean = data.mean()
std = data.std()
idx = np.argwhere(data < mean - 3 * std)
data = np.delete(data, idx)
idx = np.argwhere(data > mean + 3 * std)
data = np.delete(data, idx)
sns.set_style('darkgrid')
# sns.set_palette('deep', desat=.6)
sns.set_context("notebook", font_scale=1.5,
rc={"lines.linewidth": 2.5})
# fig = plt.figure(figsize=(8, 8))
# ax = fig.add_subplot(111, aspect='equal')
# sns.distplot(data, bins=int(np.sqrt(data.shape[0])),
# hist=True, kde=False, hist_kws={'histtype': 'stepfilled'}, **kwargs)
# print('hist begin')
print(len(data))
n, bins, patches = plt.hist(data,
density=True,
histtype='stepfilled',
alpha=0.618,
bins=int(np.sqrt(data.shape[0])),
**kwargs)
# print('hist finished')
# sns.kdeplot(data, color='#5294c3')
# print('kde finished')
plt.xlabel(xlabel)
plt.ylabel(ylabel)
# if legend != '':
# plt.legend(legend)
# ax.axis('tight')
if output_dir != '':
plt.savefig(output_dir, bbox_inches='tight')
print('{} saved'.format(output_dir))
# plt.show()
def plot_tsne(x, colors,output_dir="", num_classes=None):
if isinstance(x, torch.Tensor):
x = x.to('cpu').numpy()
if isinstance(colors, torch.Tensor):
colors = colors.to('cpu').numpy()
if num_classes is None:
num_classes=colors.max()+1
x = TSNE(random_state=RS, n_components=2).fit_transform(x)
sns.set_style('darkgrid')
sns.set_palette('muted')
sns.set_context("notebook", font_scale=1.5,
rc={"lines.linewidth": 2.5})
palette = np.array(sns.color_palette("hls", num_classes))
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, aspect='equal')
sc = ax.scatter(x[:, 0], x[:, 1], lw=0, s=25,
c=palette[colors.astype(np.int)])
# plt.xlim(-25, 25)
# plt.ylim(-25, 25)
# ax.axis('off')
ax.axis('tight')
# plt.grid('off')
plt.savefig(output_dir, facecolor=fig.get_facecolor(), bbox_inches='tight')
#plt.show()
def plot_tsne_3d(x, colors,output_dir="", num_classes=None):
"""
绘制3D t-SNE聚类图, 直接将图片保存到输出路径
:param x: 输入的feature map / spike
:param colors: predicted labels 作为不同类别的颜色
:param output_dir: 图片输出的路径(包括图片名及后缀)
:return: None
"""
if isinstance(x, torch.Tensor):
x = x.to('cpu').numpy()
if isinstance(colors, torch.Tensor):
colors = colors.to('cpu').numpy()
if num_classes is None:
num_classes=colors.max()+1
x = TSNE(random_state=RS, n_components=3, perplexity=30).fit_transform(x)
# sns.set_style('darkgrid')
sns.set_palette('muted')
sns.set_context("notebook", font_scale=1.5,
rc={"lines.linewidth": 2.5})
fig = plt.figure(figsize=(8, 8))
palette = np.array(sns.color_palette("hls", num_classes))
ax = fig.add_subplot(111, projection='3d')
sc = ax.scatter(x[:, 0], x[:, 1], x[:, 2], lw=0, s=20, alpha=0.8,
c=palette[colors.astype(np.int)])
# ax.set_xlabel('X')
# ax.set_ylabel('Y')
# ax.set_zlabel('Z')
# ax.view_init(20, -120)
ax.axis('tight')
plt.savefig(output_dir, facecolor=fig.get_facecolor(), bbox_inches='tight')
#plt.show()
def plot_confusion_matrix(logits, labels, output_dir):
"""
绘制混淆矩阵图
:param logits: predicted labels
:param labels: true labels
:param output_dir: 输出路径, 需要包括文件名以及后缀
:return: None
"""
sns.set_style('darkgrid')
sns.set_palette('Blues_r')
sns.set_context("notebook", font_scale=1.,
rc={"lines.linewidth": 2.})
logits = logits.argmax(dim=1).cpu()
labels = labels.cpu()
_max = labels.max()
if _max > 10:
annot = False
else:
annot = True
# print(labels.shape, logits.shape)
conf_matrix = confusion_matrix(labels, logits)
con_mat_norm = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis] # 归一化
con_mat_norm = np.around(con_mat_norm, decimals=2)
plt.figure(figsize=(8, 8))
sns.heatmap(con_mat_norm, annot=annot, cmap='Blues')
plt.ylim(0, _max + 1)
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.savefig(output_dir, bbox_inches='tight')
#plt.show()
if __name__ == '__main__':
# Test for T-SNE
# x = torch.randn((100, 100))
# y = torch.randint(low=0, high=10, size=[100])
# plot_tsne_3d(x, y, output_dir='./t-sne.eps')
# Test for confusion matrix
# x = torch.rand(5012, 100)
# y = torch.randint(0, 100, (5012,))
# plot_confusion_matrix(x, y, '')
# Test for Mem Distribution
x = torch.randn(100000)
plot_mem_distribution(x, legend=['test'])
================================================
FILE: braincog/datasets/CUB2002011.py
================================================
import os
import pandas as pd
from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_file_from_google_drive
class CUB2002011(VisionDataset):
"""`CUB-200-2011 `_ Dataset.
Args:
root (string): Root directory of the dataset.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
base_folder = 'CUB_200_2011/images'
# url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
file_id = '1hbzc_P1FuxMkcabkgn9ZKinBwW683j45'
filename = 'CUB_200_2011.tgz'
tgz_md5 = '97eceeb196236b17998738112f37df78'
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
super(CUB2002011, self).__init__(root, transform=transform, target_transform=target_transform)
self.loader = default_loader
self.train = train
if download:
self._download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted. You can use download=True to download it')
def _load_metadata(self):
images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
names=['img_id', 'filepath'])
image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
sep=' ', names=['img_id', 'target'])
train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
sep=' ', names=['img_id', 'is_training_img'])
data = images.merge(image_class_labels, on='img_id')
self.data = data.merge(train_test_split, on='img_id')
class_names = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'classes.txt'),
sep=' ', names=['class_name'], usecols=[1])
self.class_names = class_names['class_name'].to_list()
if self.train:
self.data = self.data[self.data.is_training_img == 1]
else:
self.data = self.data[self.data.is_training_img == 0]
def _check_integrity(self):
try:
self._load_metadata()
except Exception:
return False
for index, row in self.data.iterrows():
filepath = os.path.join(self.root, self.base_folder, row.filepath)
if not os.path.isfile(filepath):
print(filepath)
return False
return True
def _download(self):
import tarfile
if self._check_integrity():
print('Files already downloaded and verified')
return
download_file_from_google_drive(self.file_id, self.root, self.filename, self.tgz_md5)
with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
tar.extractall(path=self.root)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data.iloc[idx]
path = os.path.join(self.root, self.base_folder, sample.filepath)
target = sample.target - 1 # Targets start at 1 by default, so shift to 0
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
if __name__ == '__main__':
train_dataset = CUB2002011('./cub2011', train=True, download=False)
test_dataset = CUB2002011('./cub2011', train=False, download=False)
================================================
FILE: braincog/datasets/ESimagenet/ES_imagenet.py
================================================
# -*- coding: utf-8 -*-
# Time : 2022/11/1 11:06
# Author : Regulus
# FileName: ES_imagenet.py
# Explain:
# Software: PyCharm
import numpy as np
import torch
import linecache
import torch.utils.data as data
class ESImagenet_Dataset(data.Dataset):
def __init__(self, mode, data_set_path='/data/dvsimagenet/', transform=None):
super().__init__()
self.mode = mode
self.filenames = []
self.trainpath = data_set_path + 'train'
self.testpath = data_set_path + 'val'
self.traininfotxt = data_set_path + 'trainlabel.txt'
self.testinfotxt = data_set_path + 'vallabel.txt'
self.formats = '.npz'
self.transform = transform
if mode == 'train':
self.path = self.trainpath
trainfile = open(self.traininfotxt, 'r')
for line in trainfile:
filename, classnum, a, b = line.split()
realname, sub = filename.split('.')
self.filenames.append(realname + self.formats)
else:
self.path = self.testpath
testfile = open(self.testinfotxt, 'r')
for line in testfile:
filename, classnum, a, b = line.split()
realname, sub = filename.split('.')
self.filenames.append(realname + self.formats)
def __getitem__(self, index):
if self.mode == 'train':
info = linecache.getline(self.traininfotxt, index + 1)
else:
info = linecache.getline(self.testinfotxt, index + 1)
filename, classnum, a, b = info.split()
realname, sub = filename.split('.')
filename = realname + self.formats
filename = self.path + r'/' + filename
classnum = int(classnum)
a = int(a)
b = int(b)
datapos = np.load(filename)['pos'].astype(np.float64)
dataneg = np.load(filename)['neg'].astype(np.float64)
dy = (254 - b) // 2
dx = (254 - a) // 2
input = torch.zeros([2, 8, 256, 256])
x = datapos[:, 0] + dx
y = datapos[:, 1] + dy
t = datapos[:, 2] - 1
input[0, t, x, y] = 1
x = dataneg[:, 0] + dx
y = dataneg[:, 1] + dy
t = dataneg[:, 2] - 1
input[1, t, x, y] = 1
reshape = input[:, :, 16:240, 16:240].permute(0, 1, 2, 3).contiguous()
if self.transform is not None:
reshape = self.transform(reshape)
label = torch.tensor([classnum])
return reshape, label
def __len__(self):
return len(self.filenames)
================================================
FILE: braincog/datasets/ESimagenet/__init__.py
================================================
# -*- coding: utf-8 -*-
# Time : 2022/11/1 11:05
# Author : Regulus
# FileName: __init__.py.py
# Explain:
# Software: PyCharm
"""
from: https://github.com/lyh983012/ES-imagenet-master
"""
__all__ = ['ES_imagenet', 'reconstructed_ES_imagenet']
from . import (
ES_imagenet,
reconstructed_ES_imagenet
)
================================================
FILE: braincog/datasets/ESimagenet/reconstructed_ES_imagenet.py
================================================
# -*- coding: utf-8 -*-
# Time : 2022/11/1 11:06
# Author : Regulus
# FileName: reconstructed_ES_imagenet.py
# Explain:
# Software: PyCharm
import numpy as np
import torch
import linecache
import torch.utils.data as data
from tqdm import tqdm
class ESImagenet2D_Dataset(data.Dataset):
def __init__(self, mode, data_set_path='/data/ESimagenet-0.18/', transform=None):
super().__init__()
self.mode = mode
self.filenames = []
self.trainpath = data_set_path + 'train'
self.testpath = data_set_path + 'val'
self.traininfotxt = data_set_path + 'trainlabel.txt'
self.testinfotxt = data_set_path + 'vallabel.txt'
self.formats = '.npz'
self.transform = transform
if mode == 'train':
self.path = self.trainpath
trainfile = open(self.traininfotxt, 'r')
for line in trainfile:
filename, classnum, a, b = line.split()
realname, sub = filename.split('.')
self.filenames.append(realname + self.formats)
trainfile = open(self.traininfotxt, 'r')
self.infolist = trainfile.readlines()
else:
self.path = self.testpath
testfile = open(self.testinfotxt, 'r')
for line in testfile:
filename, classnum, a, b = line.split()
realname, sub = filename.split('.')
self.filenames.append(realname + self.formats)
testfile = open(self.testinfotxt, 'r')
self.infolist = testfile.readlines()
def __getitem__(self, index):
info = self.infolist[index]
filename, classnum, a, b = info.split()
realname, sub = filename.split('.')
filename = realname + self.formats
filename = self.path + r'/' + filename
classnum = int(classnum)
a = int(a)
b = int(b)
with open(filename, "rb") as f:
data = np.load(f)
datapos = data['pos'].astype(np.float64)
dataneg = data['neg'].astype(np.float64)
tracex = [0, 2, 1, 0, 2, 1, 1, 2]
tracey = [2, 1, 0, 1, 2, 0, 1, 1]
dy = (254 - b) // 2
dx = (254 - a) // 2
input = torch.zeros([2, 8, 256, 256])
x = datapos[:, 0] + dx
y = datapos[:, 1] + dy
t = datapos[:, 2] - 1
input[0, t, x, y] += 1
x = dataneg[:, 0] + dx
y = dataneg[:, 1] + dy
t = dataneg[:, 2] - 1
input[1, t, x, y] += 1
sum_gary_data = torch.zeros([1, 1, 256, 256])
reshape = input[:, :, 16:240, 16:240]
H = 224
W = 224
for t in range(8):
dx = tracex[t]
dy = tracey[t]
sum_gary_data[0, 0, 2 - dx:2 - dx + H, 2 - dy:2 - dy + W] += reshape[0, t, :, :]
sum_gary_data[0, 0, 2 - dx:2 - dx + H, 2 - dy:2 - dy + W] -= reshape[1, t, :, :]
sum_gary_data = sum_gary_data[:, :, 1:225, 1:225]
# if self.transform is not None:
# sum_gary_data = self.transform(sum_gary_data)
label = classnum
return sum_gary_data, label
def __len__(self):
return len(self.filenames)
================================================
FILE: braincog/datasets/NOmniglot/NOmniglot.py
================================================
from torch.utils.data import Dataset
from braincog.datasets.NOmniglot.utils import *
class NOmniglot(Dataset):
def __init__(self, root='data/', frames_num=12, train=True, data_type='event',
transform=None, target_transform=None, use_npz=False, crop=True, create=True, thread_num=16):
super().__init__()
self.crop = crop
self.data_type = data_type
self.use_npz = use_npz
self.transform = transform
self.target_transform = target_transform
events_npy_root = os.path.join(root, 'events_npy', 'background' if train else "evaluation")
frames_root = os.path.join(root, f'fnum_{frames_num}_dtype_{data_type}_npz_{use_npz}',
'background' if train else "evaluation")
if not os.path.exists(frames_root) and create:
if not os.path.exists(events_npy_root) and create:
os.makedirs(events_npy_root)
print('creating event data..')
convert_aedat4_dir_to_events_dir(root, train)
else:
print(f'npy format events data root {events_npy_root}, already exists')
os.makedirs(frames_root)
print('creating frames data..')
convert_events_dir_to_frames_dir(events_npy_root, frames_root, '.npy', frames_num, data_type,
thread_num=thread_num, compress=use_npz)
else:
print(f'frames data root {frames_root} already exists.')
self.datadict, self.num_classes = list_class_files(events_npy_root, frames_root, True, use_npz=use_npz)
self.datalist = []
for i in self.datadict:
self.datalist.extend([(j, i) for j in self.datadict[i]])
def __len__(self):
return len(self.datalist)
def __getitem__(self, index):
image, label = self.datalist[index]
image, label = self.readimage(image, label)
return image, label
def readimage(self, image, label):
if self.use_npz:
image = torch.tensor(np.load(image)['arr_0']).float()
else:
image = torch.tensor(np.load(image)).float()
if self.crop:
image = image[:, :, 4:254, 54:304]
if self.transform is not None: image = self.transform(image)
if self.target_transform is not None: label = self.target_transform(label)
return image, label
================================================
FILE: braincog/datasets/NOmniglot/__init__.py
================================================
__all__ = ['NOmniglot', 'nomniglot_full', 'nomniglot_nw_ks','nomniglot_pair','utils']
from . import (
NOmniglot,
nomniglot_full,
nomniglot_nw_ks,
nomniglot_pair,
utils
)
================================================
FILE: braincog/datasets/NOmniglot/nomniglot_full.py
================================================
import torch
from torch.utils.data import Dataset, DataLoader
from braincog.datasets.NOmniglot.NOmniglot import NOmniglot
class NOmniglotfull(Dataset):
'''
solve few-shot learning as general classification problem,
We combine the original training set with the test set and take 3/4 as the training set
'''
def __init__(self, root='data/', train=True, frames_num=4, data_type='event',
transform=None, target_transform=None, use_npz=False, crop=True, create=True):
super().__init__()
trainSet = NOmniglot(root=root, train=True, frames_num=frames_num, data_type=data_type,
transform=transform, target_transform=target_transform,
use_npz=use_npz, crop=crop, create=create)
testSet = NOmniglot(root=root, train=False, frames_num=frames_num, data_type=data_type,
transform=transform, target_transform=lambda x: x + 964,
use_npz=use_npz, crop=crop, create=create)
self.data = torch.utils.data.ConcatDataset([trainSet, testSet])
if train:
self.id = [j for j in range(len(self.data)) if j % 20 in [i for i in range(15)]]
else:
self.id = [j for j in range(len(self.data)) if j % 20 in [i for i in range(15, 20)]]
def __len__(self):
return len(self.id)
def __getitem__(self, index):
image, label = self.data[self.id[index]]
return image, label
if __name__ == '__main__':
db_train = NOmniglotfull('../../data/', train=True, frames_num=4, data_type='event')
dataloadertrain = DataLoader(db_train, batch_size=16, shuffle=True, num_workers=16, pin_memory=True)
for x_spt, y_spt, x_qry, y_qry in dataloadertrain:
print(x_spt.shape)
================================================
FILE: braincog/datasets/NOmniglot/nomniglot_nw_ks.py
================================================
import torch
import torchvision
import numpy as np
from torch.utils.data import Dataset, DataLoader
from braincog.datasets.NOmniglot.NOmniglot import NOmniglot
class NOmniglotNWayKShot(Dataset):
'''
get n-wway k-shot data as meta learning
We set the sampling times of each epoch as "len(self.dataSet) // (self.n_way * (self.k_shot + self.k_query))"
you can increase or decrease the number of epochs to determine the total training times
'''
def __init__(self, root, n_way, k_shot, k_query, train=True, frames_num=12, data_type='event',
transform=torchvision.transforms.Resize((28, 28))):
self.dataSet = NOmniglot(root=root, train=train,
frames_num=frames_num, data_type=data_type, transform=transform)
self.n_way = n_way # n way
self.k_shot = k_shot # k shot
self.k_query = k_query # k query
assert (k_shot + k_query) <= 20
self.length = 256
self.data_cache = self.load_data_cache(self.dataSet.datadict, self.length)
def load_data_cache(self, data_dict, length):
'''
The dataset is sampled randomly length times, and the address is saved to obtain
'''
data_cache = []
for i in range(length):
selected_cls = np.random.choice(len(data_dict), self.n_way, False)
x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
for j, cur_class in enumerate(selected_cls):
selected_img = np.random.choice(20, self.k_shot + self.k_query, False)
x_spts.append(np.array(data_dict[cur_class])[selected_img[:self.k_shot]])
x_qrys.append(np.array(data_dict[cur_class])[selected_img[self.k_shot:]])
y_spts.append([j for _ in range(self.k_shot)])
y_qrys.append([j for _ in range(self.k_query)])
shufflespt = np.random.choice(self.n_way * self.k_shot, self.n_way * self.k_shot, False)
shuffleqry = np.random.choice(self.n_way * self.k_query, self.n_way * self.k_query, False)
temp = [np.array(x_spts).reshape(-1)[shufflespt], np.array(y_spts).reshape(-1)[shufflespt],
np.array(x_qrys).reshape(-1)[shuffleqry], np.array(y_qrys).reshape(-1)[shuffleqry]]
data_cache.append(temp)
return data_cache
def __getitem__(self, index):
x_spts, y_spts, x_qrys, y_qrys = self.data_cache[index]
x_sptst, y_sptst, x_qryst, y_qryst = [], [], [], []
for i, j in zip(x_spts, y_spts):
i, j = self.dataSet.readimage(i, j)
x_sptst.append(i.unsqueeze(0))
y_sptst.append(j)
for i, j in zip(x_qrys, y_qrys):
i, j = self.dataSet.readimage(i, j)
x_qryst.append(i.unsqueeze(0))
y_qryst.append(j)
return torch.cat(x_sptst, dim=0), np.array(y_sptst), torch.cat(x_qryst, dim=0), np.array(y_qryst)
def reset(self):
self.data_cache = self.load_data_cache(self.dataSet.datadict, self.length)
def __len__(self):
return len(self.data_cache)
if __name__ == "__main__":
db_train = NOmniglotNWayKShot('./data/', n_way=5, k_shot=1, k_query=15,
frames_num=4, data_type='frequency', train=True)
dataloadertrain = DataLoader(db_train, batch_size=16, shuffle=True, num_workers=16, pin_memory=True)
for x_spt, y_spt, x_qry, y_qry in dataloadertrain:
print(x_spt.shape)
db_train.resampling()
================================================
FILE: braincog/datasets/NOmniglot/nomniglot_pair.py
================================================
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from numpy.random import choice as npc
import random
import torch.nn.functional as F
from braincog.datasets.NOmniglot import NOmniglot
class NOmniglotTrainSet(Dataset):
'''
Dataloader for Siamese Net
The pairs of similar samples are labeled as 1, and those of different samples are labeled as 0
'''
def __init__(self, root='data/', use_frame=True, frames_num=10, data_type='event', use_npz=False, resize=None):
super(NOmniglotTrainSet, self).__init__()
self.resize = resize
self.data_type = data_type
self.use_frame = use_frame
self.dataSet = NOmniglot(root=root, train=True, frames_num=frames_num, data_type=data_type, use_npz=use_npz)
self.datas, self.num_classes = self.dataSet.datadict, self.dataSet.num_classes
np.random.seed(0)
def __len__(self):
'''
Sampling upper limit, you can set the maximum sampling times when using to terminate
'''
return 21000000
def __getitem__(self, index):
# get image from same class
if index % 2 == 1:
label = 1.0
idx1 = random.randint(0, self.num_classes - 1)
image1 = random.choice(self.datas[idx1])
image2 = random.choice(self.datas[idx1])
# get image from different class
else:
label = 0.0
idx1 = random.randint(0, self.num_classes - 1)
idx2 = random.randint(0, self.num_classes - 1)
while idx1 == idx2:
idx2 = random.randint(0, self.num_classes - 1)
image1 = random.choice(self.datas[idx1])
image2 = random.choice(self.datas[idx2])
if self.use_frame:
if self.data_type == 'event':
image1 = torch.tensor(np.load(image1)['arr_0']).float()
image2 = torch.tensor(np.load(image2)['arr_0']).float()
elif self.data_type == 'frequency':
image1 = torch.tensor(np.load(image1)['arr_0']).float()
image2 = torch.tensor(np.load(image2)['arr_0']).float()
else:
raise NotImplementedError
if self.resize is not None:
image1 = image1[:, :, 4:254, 54:304]
image1 = F.interpolate(image1, size=(self.resize, self.resize))
image2 = image2[:, :, 4:254, 54:304]
image2 = F.interpolate(image2, size=(self.resize, self.resize))
return image1, image2, torch.from_numpy(np.array([label], dtype=np.float32))
class NOmniglotTestSet(Dataset):
'''
Dataloader for Siamese Net
'''
def __init__(self, root='data/', time=1000, way=20, shot=1, query=1, use_frame=True, frames_num=10, data_type='event', use_npz=True, resize=None):
super(NOmniglotTestSet, self).__init__()
self.resize = resize
self.use_frame = use_frame
self.time = time # Sampling times
self.way = way
self.shot = shot
self.query = query
self.img1 = None # Fix test sample while sampling support set
self.c1 = None # Fixed categories when sampling multiple samples
self.c2 = None
self.select_class = [] # selected classes
self.select_sample = [] # selected samples
self.data_type = data_type
np.random.seed(0)
self.dataSet = NOmniglot(root=root, train=False, frames_num=frames_num, data_type=data_type, use_npz=use_npz)
self.datas, self.num_classes = self.dataSet.datadict, self.dataSet.num_classes
def __len__(self):
'''
In general, the total number of test tasks is 1000.
Since one test sample is collected at a time, way * shot support samples are used for each test
'''
return self.time * self.way * self.shot
def __getitem__(self, index):
'''
The 0th sample of each way*shot is used for query and recorded in the selected sample
to achieve the effect of selecting K +1
'''
idx = index % (self.way * self.shot)
# generate image pair from same class
if idx == 0: #
self.select_class = []
self.c1 = random.randint(0, self.num_classes - 1)
self.c2 = self.c1
sind = random.randint(0, len(self.datas[self.c1]) - 1)
self.select_sample.append(sind)
self.img1 = self.datas[self.c1][sind]
sind = random.randint(0, len(self.datas[self.c2]) - 1)
while sind in self.select_sample:
sind = random.randint(0, len(self.datas[self.c2]) - 1)
img2 = self.datas[self.c1][sind]
self.select_sample.append(sind)
self.select_class.append(self.c1)
# generate image pair from different class
else:
if index % self.shot == 0:
self.c2 = random.randint(0, self.num_classes - 1)
while self.c2 in self.select_class: # self.c1 == c2:
self.c2 = random.randint(0, self.num_classes - 1)
self.select_class.append(self.c2)
self.select_sample = []
sind = random.randint(0, len(self.datas[self.c2]) - 1)
while sind in self.select_sample:
sind = random.randint(0, len(self.datas[self.c2]) - 1)
img2 = self.datas[self.c2][sind]
self.select_sample.append(sind)
if self.use_frame:
if self.data_type == 'event':
img1 = torch.tensor(np.load(self.img1)['arr_0']).float()
img2 = torch.tensor(np.load(img2)['arr_0']).float()
elif self.data_type == 'frequency':
img1 = torch.tensor(np.load(self.img1)['arr_0']).float()
img2 = torch.tensor(np.load(img2)['arr_0']).float()
else:
raise NotImplementedError
if self.resize is not None:
img1 = img1[:, :, 4:254, 54:304]
img1 = F.interpolate(img1, size=(self.resize, self.resize))
img2 = img2[:, :, 4:254, 54:304]
img2 = F.interpolate(img2, size=(self.resize, self.resize))
return img1, img2
if __name__ == '__main__':
data_type = 'frequency'
T = 4
trainSet = NOmniglotTrainSet(root='data/', use_frame=True, frames_num=T, data_type=data_type, use_npz=True, resize=105)
testSet = NOmniglotTestSet(root='data/', time=1000, way=5, shot=1, use_frame=True, frames_num=T,
data_type=data_type, use_npz=True, resize=105)
trainLoader = DataLoader(trainSet, batch_size=48, shuffle=False, num_workers=4)
testLoader = DataLoader(testSet, batch_size=5 * 1, shuffle=False, num_workers=4)
for batch_id, (img1, img2) in enumerate(testLoader, 1):
# img1.shape [batch, T, 2, H, W]
print(batch_id)
break
for batch_id, (img1, img2, label) in enumerate(trainLoader, 1):
# img1.shape [batch, T, 2, H, W]
print(batch_id)
break
================================================
FILE: braincog/datasets/NOmniglot/utils.py
================================================
import torch
import threading
import numpy as np
import pandas
import os
from dv import AedatFile
class FunctionThread(threading.Thread):
def __init__(self, f, *args, **kwargs):
super().__init__()
self.f = f
self.args = args
self.kwargs = kwargs
def run(self):
self.f(*self.args, **self.kwargs)
def integrate_events_to_frames(events, height, width, frames_num=10, data_type='event'):
frames = np.zeros(shape=[frames_num, 2, height * width])
# create j_{l}和j_{r}
j_l = np.zeros(shape=[frames_num], dtype=int)
j_r = np.zeros(shape=[frames_num], dtype=int)
# split by time
events['t'] -= events['t'][0] # start with 0 timestamp
assert events['t'][-1] > frames_num
dt = events['t'][-1] // frames_num # get length of each frame
idx = np.arange(events['t'].size)
for i in range(frames_num):
t_l = dt * i
t_r = t_l + dt
mask = np.logical_and(events['t'] >= t_l, events['t'] < t_r)
idx_masked = idx[mask]
if len(idx_masked) == 0:
j_l[i] = -1
j_r[i] = -1
else:
j_l[i] = idx_masked[0]
j_r[i] = idx_masked[-1] + 1 if i < frames_num - 1 else events['t'].size
for i in range(frames_num):
if j_l[i] >= 0:
x = events['x'][j_l[i]:j_r[i]]
y = events['y'][j_l[i]:j_r[i]]
p = events['p'][j_l[i]:j_r[i]]
mask = []
mask.append(p == 0)
mask.append(np.logical_not(mask[0]))
for j in range(2):
position = y[mask[j]] * width + x[mask[j]]
events_number_per_pos = np.bincount(position)
frames[i][j][np.arange(events_number_per_pos.size)] += events_number_per_pos
if data_type == 'frequency':
if i < frames_num - 1:
frames[i] /= dt
else:
frames[i] /= (dt + events['t'][-1] % frames_num)
frames = frames.astype(np.float16)
if data_type == 'event':
frames = (frames > 0).astype(np.bool)
else:
frames = normalize_frame(frames, 'max')
return frames.reshape((frames_num, 2, height, width))
def normalize_frame(frames: np.ndarray or torch.Tensor, normalization: str):
eps = 1e-5
for i in range(frames.shape[0]):
if normalization == 'max':
frames[i][0] = frames[i][0] / max(frames[i][0].max(), eps)
frames[i][1] = frames[i][1] / max(frames[i][1].max(), eps)
elif normalization == 'norm':
frames[i][0] = (frames[i][0] - frames[i][0].mean()) / np.sqrt(max(frames[i][0].var(), eps))
frames[i][1] = (frames[i][1] - frames[i][1].mean()) / np.sqrt(max(frames[i][1].var(), eps))
elif normalization == 'sum':
frames[i][0] = frames[i][0] / max(frames[i][0].sum(), eps)
frames[i][1] = frames[i][1] / max(frames[i][1].sum(), eps)
else:
raise NotImplementedError
return frames
def convert_events_dir_to_frames_dir(events_data_dir, frames_data_dir, suffix,
frames_num=12, result_type='event', thread_num=1,
compress=True):
"""
Iterate through all event data in eventS_date_DIR and generate frame data files in frames_data_DIR
"""
def read_function(file_name):
return np.load(file_name, allow_pickle=True).item()
def cvt_fun(events_file_list):
for events_file in events_file_list:
print(events_file)
frames = integrate_events_to_frames(read_function(events_file), 260, 346, frames_num, result_type )
if compress:
frames_file = os.path.join(frames_data_dir,
os.path.basename(events_file)[0: -suffix.__len__()] + '.npz')
np.savez_compressed(frames_file, frames)
else:
frames_file = os.path.join(frames_data_dir,
os.path.basename(events_file)[0: -suffix.__len__()] + '.npy')
np.save(frames_file, frames)
# Obtain the path of the all files
events_file_list = list_all_files(events_data_dir, '.npy')
if thread_num == 1:
cvt_fun(events_file_list)
else:
# Multithreading acceleration
thread_list = []
block = events_file_list.__len__() // thread_num
for i in range(thread_num - 1):
thread_list.append(FunctionThread(cvt_fun, events_file_list[i * block: (i + 1) * block]))
thread_list[-1].start()
print(f'thread {i} start, processing files index: {i * block} : {(i + 1) * block}.')
thread_list.append(FunctionThread(cvt_fun, events_file_list[(thread_num - 1) * block:]))
thread_list[-1].start()
print(
f'thread {thread_num} start, processing files index: {(thread_num - 1) * block} : {events_file_list.__len__()}.')
for i in range(thread_num):
thread_list[i].join()
print(f'thread {i} finished.')
def convert_aedat4_dir_to_events_dir(root, train):
kind = 'background' if train else "evaluation"
originroot = root
root = root + '/dvs_' + kind + '/'
alphabet_names = [a for a in os.listdir(root) if a[0] != '.'] # get folder names
for a in range(len(alphabet_names)):
alpha_name = alphabet_names[a]
for b in range(len(os.listdir(os.path.join(root, alpha_name)))):
character_id = b + 1
character_path = alpha_name + '/character' + num2str(character_id)
print('Parsing %s \\ character%s ...' % (alpha_name, num2str(character_id)))
file_path = os.path.join(root, character_path)
aedat4_name = [a for a in os.listdir(file_path) if a[-4:] == 'dat4' and len(a) == 11][0]
csv_name = [a for a in os.listdir(file_path) if a[-4:] == '.csv' and len(a) == 8][0]
number = csv_name[:4]
new_path = originroot + '/events_npy/' + kind + '/' + alpha_name + '/character' + num2str(character_id)
if not os.path.exists(new_path):
os.makedirs(new_path)
start_end_timestamp = pandas.read_csv(os.path.join(file_path, csv_name)).values
a_timestamp, a_polarity, a_x, a_y = [], [], [], []
with AedatFile(os.path.join(file_path, aedat4_name)) as f: # read aedat4
for e in f['events']:
a_timestamp.append(e.timestamp)
a_polarity.append(e.polarity)
a_x.append(e.x)
a_y.append(e.y)
for ii in range(20): # each file has 20 samples
name = str(number) + '_' + num2str(ii + 1) + '.npy'
start_index = a_timestamp.index(start_end_timestamp[ii][1])
end_index = a_timestamp.index(start_end_timestamp[ii][2])
tmp = {'t': np.array(a_timestamp[start_index:end_index]),
'x': np.array(a_x[start_index:end_index]),
'y': np.array(a_y[start_index:end_index]),
'p': np.array(a_polarity[start_index:end_index])}
np.save(os.path.join(new_path, name), tmp)
def num2str(idx):
if idx < 10:
return '0' + str(idx)
return str(idx)
def list_all_files(root, suffix, getlen=False):
'''
List the path of all files under root, output a list
'''
file_list = []
alphabet_names = [a for a in os.listdir(root) if a[0] != '.'] # get folder names
idx = 0
for a in range(len(alphabet_names)):
alpha_name = alphabet_names[a]
for b in range(len(os.listdir(os.path.join(root, alpha_name)))):
character_id = b + 1
character_path = os.path.join(root, alpha_name, 'character' + num2str(character_id))
idx += 1
for c in range(len(os.listdir(character_path))):
fn_example = os.listdir(character_path)[c]
if fn_example[-4:] == suffix:
file_list.append(os.path.join(character_path, fn_example))
if getlen:
return file_list, idx
else:
return file_list
def list_class_files(root, frames_kind_root, getlen=False, use_npz=False):
'''
index the generated samples,
get dictionaries according to categories, each corresponding to a list,
the list contain the address of the new file in fnum_x_dtype_x_npz_True
'''
file_list = {}
alphabet_names = [a for a in os.listdir(root) if a[0] != '.'] # get folder names
idx = 0
for a in range(len(alphabet_names)):
alpha_name = alphabet_names[a]
for b in range(len(os.listdir(os.path.join(root, alpha_name)))):
character_id = b + 1
character_path = os.path.join(root, alpha_name, 'character' + num2str(character_id))
file_list[idx] = []
for c in range(len(os.listdir(character_path))):
fn_example = os.listdir(character_path)[c]
if use_npz:
fn_example = fn_example[:-1] + 'z'
file_list[idx].append(os.path.join(frames_kind_root, fn_example))
idx += 1
if getlen:
return file_list, idx
else:
return file_list
================================================
FILE: braincog/datasets/StanfordDogs.py
================================================
import os
import scipy.io
from os.path import join
from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url, list_dir
class StanfordDogs(VisionDataset):
"""`Stanford Dogs `_ Dataset.
Args:
root (string): Root directory of the dataset.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs'
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
super(StanfordDogs, self).__init__(root, transform=transform, target_transform=target_transform)
self.loader = default_loader
self.train = train
if download:
self.download()
split = self.load_split()
self.images_folder = join(self.root, 'Images')
self.annotations_folder = join(self.root, 'Annotation')
self._breeds = list_dir(self.images_folder)
self._breed_images = [(annotation + '.jpg', idx) for annotation, idx in split]
self._flat_breed_images = self._breed_images
def __len__(self):
return len(self._flat_breed_images)
def __getitem__(self, index):
image_name, target = self._flat_breed_images[index]
image_path = join(self.images_folder, image_name)
image = self.loader(image_path)
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
target = self.target_transform(target)
return image, target
def download(self):
import tarfile
if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')):
if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120:
print('Files already downloaded and verified')
return
for filename in ['images', 'annotation', 'lists']:
tar_filename = filename + '.tar'
url = self.download_url_prefix + '/' + tar_filename
download_url(url, self.root, tar_filename, None)
print('Extracting downloaded file: ' + join(self.root, tar_filename))
with tarfile.open(join(self.root, tar_filename), 'r') as tar_file:
tar_file.extractall(self.root)
os.remove(join(self.root, tar_filename))
def load_split(self):
if self.train:
split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list']
labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels']
else:
split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list']
labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels']
split = [item[0][0] for item in split]
labels = [item[0] - 1 for item in labels]
return list(zip(split, labels))
def stats(self):
counts = {}
for index in range(len(self._flat_breed_images)):
image_name, target_class = self._flat_breed_images[index]
if target_class not in counts.keys():
counts[target_class] = 1
else:
counts[target_class] += 1
print("%d samples spanning %d classes (avg %f per class)" % (len(self._flat_breed_images), len(counts.keys()),
float(len(self._flat_breed_images)) / float(
len(counts.keys()))))
return counts
if __name__ == '__main__':
train_dataset = Dogs('./dogs', train=True, download=False)
test_dataset = Dogs('./dogs', train=False, download=False)
================================================
FILE: braincog/datasets/TinyImageNet.py
================================================
import os
import os
import pandas as pd
import warnings
from torchvision.datasets import ImageFolder
from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import extract_archive, check_integrity, download_url, verify_str_arg
class TinyImageNet(VisionDataset):
"""`tiny-imageNet `_ Dataset.
Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``train``, or ``val``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
base_folder = 'tiny-imagenet-200/'
url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
filename = 'tiny-imagenet-200.zip'
md5 = '90528d7ca1a48142e341f4ef8d21d0de'
def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
super(TinyImageNet, self).__init__(root, transform=transform, target_transform=target_transform)
self.dataset_path = os.path.join(root, self.base_folder)
self.loader = default_loader
self.split = verify_str_arg(split, "split", ("train", "val",))
if self._check_integrity():
print('Files already downloaded and verified.')
elif download:
self._download()
else:
raise RuntimeError(
'Dataset not found. You can use download=True to download it.')
if not os.path.isdir(self.dataset_path):
print('Extracting...')
extract_archive(os.path.join(root, self.filename))
_, class_to_idx = find_classes(os.path.join(self.dataset_path, 'wnids.txt'))
self.data = make_dataset(self.root, self.base_folder, self.split, class_to_idx)
def _download(self):
print('Downloading...')
download_url(self.url, root=self.root, filename=self.filename)
print('Extracting...')
extract_archive(os.path.join(self.root, self.filename))
def _check_integrity(self):
return check_integrity(os.path.join(self.root, self.filename), self.md5)
def __getitem__(self, index):
img_path, target = self.data[index]
image = self.loader(img_path)
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
target = self.target_transform(target)
return image, target
def __len__(self):
return len(self.data)
def find_classes(class_file):
with open(class_file) as r:
classes = list(map(lambda s: s.strip(), r.readlines()))
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def make_dataset(root, base_folder, dirname, class_to_idx):
images = []
dir_path = os.path.join(root, base_folder, dirname)
if dirname == 'train':
for fname in sorted(os.listdir(dir_path)):
cls_fpath = os.path.join(dir_path, fname)
if os.path.isdir(cls_fpath):
cls_imgs_path = os.path.join(cls_fpath, 'images')
for imgname in sorted(os.listdir(cls_imgs_path)):
path = os.path.join(cls_imgs_path, imgname)
item = (path, class_to_idx[fname])
images.append(item)
else:
imgs_path = os.path.join(dir_path, 'images')
imgs_annotations = os.path.join(dir_path, 'val_annotations.txt')
with open(imgs_annotations) as r:
data_info = map(lambda s: s.split('\t'), r.readlines())
cls_map = {line_data[0]: line_data[1] for line_data in data_info}
for imgname in sorted(os.listdir(imgs_path)):
path = os.path.join(imgs_path, imgname)
item = (path, class_to_idx[cls_map[imgname]])
images.append(item)
return images
if __name__ == '__main__':
train_dataset = TinyImageNet('./tiny-imagenet', split='train', download=False)
test_dataset = TinyImageNet('./tiny-imagenet', split='val', download=False)
================================================
FILE: braincog/datasets/__init__.py
================================================
from .datasets import build_transform, build_dataset, get_mnist_data, get_fashion_data, \
get_cifar10_data, get_cifar100_data, get_imnet_data, get_dvsg_data, get_dvsc10_data, \
get_NCALTECH101_data, get_NCARS_data, get_nomni_data, get_bullyingdvs_data
from .utils import rescale, dvs_channel_check_expend
from .hmdb_dvs import HMDBDVS
from .ucf101_dvs import ucf101_dvs
from .ncaltech101 import NCALTECH101
from .bullying10k import BULLYINGDVS
__all__ = [
'build_transform', 'build_dataset',
'get_mnist_data', 'get_fashion_data', 'get_cifar10_data', 'get_cifar100_data', 'get_imnet_data',
'get_dvsg_data', 'get_dvsc10_data', 'get_NCALTECH101_data', 'get_NCARS_data', 'get_nomni_data',
'rescale', 'dvs_channel_check_expend', 'get_bullyingdvs_data'
]
dvs_data = [
'dvsg',
'dvsc10',
'ncaltech101',
'ncars',
'dvsg',
'ucf101dvs',
'hmdbdvs',
'shd',
'ntidigits',
'nmnist'
]
def is_dvs_data(dataset):
if dataset.lower() in dvs_data:
return True
else:
return False
================================================
FILE: braincog/datasets/bullying10k/__init__.py
================================================
from .bullying10k import BULLYINGDVS
================================================
FILE: braincog/datasets/bullying10k/bullying10k.py
================================================
import os
import numpy as np
from numpy.lib import recfunctions
import scipy.io as scio
from typing import Tuple, Any, Optional
from tonic.dataset import Dataset
from tonic.download_utils import extract_archive
import dv
class BULLYINGDVS(Dataset):
classes = ["fingerguess", "greeting", "hairgrabs", "handshake", "kicking",
"punching", "pushing", "slapping", "strangling", "walking"]
class_dict = {cls: idx for idx, cls in enumerate(classes)}
sensor_size = (346, 260, 2)
dtype = np.dtype([("t", int), ("x", int), ("y", int), ("p", int)])
ordering = dtype.names
def __init__(self, save_to, transform=None, target_transform=None):
super(BULLYINGDVS, self).__init__(
save_to, transform=transform, target_transform=target_transform
)
self.aedat4 = True
for path, dirs, files in os.walk(self.location_on_system):
dirs.sort()
files.sort()
for file in files:
if file.endswith("aedat4"):
self.data.append(path + "/" + file)
self.targets.append(self.class_dict[path.split('/')[-2]])
if file.endswith("npy"):
self.aedat4 = False
self.data.append(path + "/" + file)
self.targets.append(self.class_dict[path.split('/')[-2]])
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Returns:
(events, target) where target is index of the target class.
"""
if self.aedat4:
events, target = dv.AedatFile(self.data[index])['events'], self.targets[index]
events = np.concatenate([event for event in events.numpy()])
else:
events = np.concatenate(np.load(self.data[index], allow_pickle=True))
events = np.column_stack(
[
events['timestamp'] - events['timestamp'][0],
events['x'],
events['y'],
events['polarity']
]
)
events = np.lib.recfunctions.unstructured_to_structured(events, self.dtype)
if self.transform is not None:
events = self.transform(events)
if self.target_transform is not None:
target = self.target_transform(target)
return events, target
def __len__(self):
return len(self.data)
def _check_exists(self):
return True
================================================
FILE: braincog/datasets/cut_mix.py
================================================
import math
import numpy as np
import random
from torch.utils.data.dataset import Dataset
from braincog.datasets.rand_aug import SaltAndPepperNoise
import numpy as np
import torch
from torch.nn import functional as F
def event_difference(x1, x2, kernel_size=3):
padding = kernel_size // 2
x1 = F.avg_pool2d(x1, kernel_size=kernel_size, stride=1, padding=padding)
x2 = F.avg_pool2d(x2, kernel_size=kernel_size, stride=1, padding=padding)
return F.mse_loss(x1, x2)
def onehot(size, target):
vec = torch.zeros(size, dtype=torch.float32)
vec[target] = 1.
return vec
def rand_bbox_time(size, rat):
if len(size) == 4: # step, channel, height, width
step = size[0]
else:
raise Exception
cut_t = np.int(step * rat)
ct = np.random.randint(step)
bbt1 = np.clip(ct - cut_t // 2, 0, step)
bbt2 = np.clip(ct + cut_t // 2, 0, step)
return bbt1, bbt2
def rand_bbox(size, rat):
if len(size) == 4:
W = size[2]
H = size[3]
else:
raise Exception
cut_rat = np.sqrt(rat)
cut_w = np.int(W * cut_rat)
cut_h = np.int(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
def calc_lam(x1, x2, bbt1, bbt2, bbx1, bbx2, bby1, bby2):
tot_x1 = x1.sum()
tot_x2 = x2.sum()
tot_bb1 = x1[bbt1:bbt2, :, bbx1:bbx2, bby1:bby2].sum()
tot_bb2 = x2[bbt1:bbt2, :, bbx1:bbx2, bby1:bby2].sum()
x1_rat = tot_bb1 / tot_x1
x2_rat = tot_bb2 / tot_x2
lam = 1. - (x2_rat / (1. - x1_rat + x2_rat))
return lam
def rand_bbox_st(size, rat):
temporal_rat = np.random.uniform(rat, 1.)
wh_rat = rat / temporal_rat
bbt1, bbt2 = rand_bbox_time(size, temporal_rat)
bbx1, bby1, bbx2, bby2 = rand_bbox(size, wh_rat)
return bbt1, bbt2, bbx1, bby1, bbx2, bby2
def spatio_mask(size, rat):
t = size[0]
x = torch.rand(2, 2)
y = torch.rand(2, 2)
f = torch.zeros(*size[-2:], dtype=torch.complex64)
# f[0:2, 0:2] = x + y * 0.j
f[[[0, -1], [-1, -1]], [[0, -1], [0, -1]]] = x + y * 1.j
mask = torch.fft.ifftn(f).real
idx = int(np.prod(size[-2:]) * rat)
val = mask.flatten().sort()[0][idx]
return (mask < val).unsqueeze(0).unsqueeze(0).repeat(t, 2, 1, 1)
def temporal_mask(size, rat):
bbt1, bbt2 = rand_bbox_time(size, rat)
mask = torch.zeros(*size, dtype=torch.bool)
mask[bbt1:bbt2] = True
return mask
def st_mask(size, rat):
t = size[0]
temporal_rat = np.random.uniform(rat, 1.)
wh_rat = rat / temporal_rat
bbt1, bbt2 = rand_bbox_time(size, temporal_rat)
mask = spatio_mask(size, wh_rat)
mask[0:bbt1] = False
mask[bbt2:t] = False
return mask
def GMM_mask_clip(size, rat):
t = size[0]
temporal_rat = np.random.uniform(rat, 1.)
wh_rat = rat / temporal_rat
bbt1, bbt2 = rand_bbox_time(size, temporal_rat)
mask = GMM_mask(size, wh_rat)
mask[0:bbt1] = False
mask[bbt2:t] = False
return mask
def GMM_mask(size, rat, n=None):
if n is None:
n = np.random.randint(2, 5)
pi = torch.tensor(np.random.rand(n))
# pi = torch.ones(n) / n
mask = torch.zeros((size[0], size[2], size[3]))
t = torch.tensor(list(range(size[0])))
x = torch.tensor(list(range(size[2])))
y = torch.tensor(list(range(size[3])))
t, x, y = torch.meshgrid(t, x, y, indexing='ij')
for p in pi:
mt = np.random.randint(0, size[0])
mx = np.random.randint(0, size[2])
my = np.random.randint(0, size[3])
# print(mt, mx, my)
st = max(np.random.rand(), 0.1) * size[0] * 0.5
sx = max(np.random.rand(), 0.1) * size[2] * .5
sy = max(np.random.rand(), 0.1) * size[3] * .5
# st, sx, sy = size[0], 0000.5 * size[2], 0000.5 * size[3]
# print(st, sx, sy)
tt = t - mt
xx = x - mx
yy = y - my
tmp = -((tt ** 2) / (st ** 2) + (xx ** 2) / (sx ** 2) + (yy ** 2) / (sy ** 2)) / 2
mask += p * tmp.exp()
idx = int(np.prod(mask.shape) * rat)
val = mask.flatten().sort()[0][idx - 1]
return (mask > val).unsqueeze(1).repeat(1, 2, 1, 1)
# return mask.unsqueeze(1).repeat(1, 2, 1, 1)
# FOR EVENT VIS
# def spatio_mask(size, rat):
# t = size[0]
# x = torch.rand(2, 2)
# y = torch.rand(2, 2)
# f = torch.zeros(*size[-2:], dtype=torch.complex64)
# # f[0:2, 0:2] = x + y * 0.j
# f[[[0, -1], [-1, -1]], [[0, -1], [0, -1]]] = x + y * 1.j
#
# f = f.unsqueeze(0).repeat(t, 1, 1)
# f[1:-2, :, :] = 0
#
# mask = torch.fft.ifftn(f).real
# # print(mask.shape)
# idx = int(np.prod(mask.shape) * 0.6)
# # print(idx)
# val = mask.flatten().sort()[0][idx]
# print(mask.unsqueeze(1).repeat(1, 2, 1, 1).shape)
# return (mask < val).unsqueeze(1).repeat(1, 2, 1, 1)
#
# def st_mask(size, rat):
# # t = size[0]
# # temporal_rat = np.random.uniform(rat, 1.)
# # wh_rat = rat / temporal_rat
# wh_rat = rat
# # bbt1, bbt2 = rand_bbox_time(size, temporal_rat)
# mask = spatio_mask(size, wh_rat)
# # mask[0:bbt1] = False
# # mask[bbt2:t] = False
# return mask
def calc_masked_lam(x1, x2, mask):
tot_x1 = x1.sum()
tot_x2 = x2.sum()
tot_mask1 = x1[mask].sum()
tot_mask2 = x2[mask].sum()
x1_rat = tot_mask1 / tot_x1
x2_rat = tot_mask2 / tot_x2
lam = 1. - (x2_rat / (1. - x1_rat + x2_rat))
# print(tot_x1, tot_x2, tot_mask1, tot_mask2)
return lam
def calc_masked_lam_with_difference(x1, x2, mix, kernel_size=3):
s1 = event_difference(x1, mix, kernel_size=kernel_size)
s2 = event_difference(x2, mix, kernel_size=kernel_size)
return (s2 * s2) / (s1 * s1 + s2 * s2)
class MixUp(Dataset):
def __init__(self, dataset, num_class, num_mix=1, beta=1., prob=1.0, indices=None, noise=0.0, vis=False, **kwargs):
self.dataset = dataset
self.num_class = num_class
self.num_mix = num_mix
self.beta = beta
self.prob = prob
self.indices = indices
self.noise = noise
self.vis = vis
def __getitem__(self, index):
img, lb = self.dataset[index]
lb_onehot = onehot(self.num_class, lb)
if self.vis:
origin = img.clone()
for _ in range(self.num_mix):
r = np.random.rand(1)
if self.beta <= 0 or r > self.prob:
continue
# generate mixed sample
lam = np.random.beta(self.beta, self.beta)
if self.indices is None:
rand_index = random.choice(range(len(self)))
else:
rand_index = random.choice(self.indices)
img2, lb2 = self.dataset[rand_index]
lb2_onehot = onehot(self.num_class, lb2)
img = img * lam + img2 * (1. - lam)
lb_onehot = lb_onehot * lam + lb2_onehot * (1. - lam)
if self.noise != 0.:
img = SaltAndPepperNoise(img, self.noise)
if self.vis:
return origin, img, img2
else:
return img, lb_onehot
def __len__(self):
return len(self.dataset)
class CutMix(Dataset):
# 81.45161290322581 (epoch 584) /data/floyed/BrainCog/train/20220413-050658-resnet34-dvsc10-10-cut_mix before lam
def __init__(self, dataset, num_class, num_mix=1, beta=1., prob=1.0, indices=None, noise=0.0, vis=False, **kwargs):
self.dataset = dataset
self.num_class = num_class
self.num_mix = num_mix
self.beta = beta
self.prob = prob
self.indices = indices
self.noise = noise
self.vis = vis
def __getitem__(self, index):
img, lb = self.dataset[index]
lb_onehot = onehot(self.num_class, lb)
if self.vis:
origin = img.clone()
for _ in range(self.num_mix):
r = np.random.rand(1)
if self.beta <= 0 or r > self.prob:
continue
# generate mixed sample
lam = np.random.beta(self.beta, self.beta)
if self.indices is None:
rand_index = random.choice(range(len(self)))
else:
rand_index = random.choice(self.indices)
img2, lb2 = self.dataset[rand_index]
lb2_onehot = onehot(self.num_class, lb2)
# shape: step, channel, height, width
# alpha = np.random.rand()
# if alpha < 0.333:
bbx1, bby1, bbx2, bby2 = rand_bbox(img.shape, 1. - lam)
# bbx1, bby1, bbx2, bby2 = 32, 0, 48, 16
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (img.shape[-1] * img.shape[-2])) # area
# lam = calc_lam(img, img2, 0, shape[0], bbx1, bbx2, bby1, bby2) # count
# distance
# mask = torch.zeros_like(img, dtype=torch.bool)
# mask[:, :, bbx1:bbx2, bby1:bby2] = True
# mix = img.clone()
# mix[mask] = img2[mask]
# lam = calc_masked_lam_with_difference(img, img2, mix, kernel_size=3)
if self.vis:
img[:, :, bbx1:bbx2, bby1:bby2] = -img2[:, :, bbx1:bbx2, bby1:bby2]
img2 = -img2
else:
img[:, :, bbx1:bbx2, bby1:bby2] = img2[:, :, bbx1:bbx2, bby1:bby2]
# elif alpha > 0.667:
# bbt1, bbt2 = rand_bbox_time(img.shape, 1. - lam)
# lam = calc_lam(img, img2, bbt1, bbt2, 0, shape[2], 0, shape[3])
# img[:, bbt1:bbt2, :, :] = img2[:, bbt1:bbt2, :, :]
# # lam = 1 - (bbt2 - bbt1) / (img.shape[-4])
# else:
# bbt1, bbt2, bbx1, bby1, bbx2, bby2 = rand_bbox_st(img.shape, 1. - lam)
# lam = calc_lam(img, img2, bbt1, bbt2, bbx1, bbx2, bby1, bby2)
# img[:, bbt1:bbt2, bbx1:bbx2, bby1:bby2] = img2[:, bbt1:bbt2, bbx1:bbx2, bby1:bby2]
# # lam = 1 - ((bbt2 - bbt1) * (bbx2 - bbx1) * (bby2 - bby1) / (img.shape[-1] * img.shape[-2] * img.shape[-4]))
if self.noise != 0.:
img = SaltAndPepperNoise(img, self.noise)
lb_onehot = lb_onehot * lam + lb2_onehot * (1. - lam)
if self.vis:
mask = torch.zeros_like(img)
mask[:, :, bbx1:bbx2, bby1:bby2] = 1.
return origin, img, img2, mask
else:
return img, lb_onehot
def __len__(self):
return len(self.dataset)
class EventMix(Dataset):
# 82.15725806451613 (epoch 554) /data/floyed/BrainCog/train/20220413-014843-resnet34-dvsc10-10-masked
def __init__(self,
dataset,
num_class,
num_mix=1,
beta=1.,
prob=1.0,
indices=None,
noise=0.1,
vis=False,
gaussian_n=None,
**kwargs):
self.dataset = dataset
self.num_class = num_class
self.num_mix = num_mix
self.beta = beta
self.prob = prob
self.indices = indices
self.noise = noise
self.vis = vis
self.gaussian_n = gaussian_n
print(self.prob, self.gaussian_n, self.beta)
def __getitem__(self, index):
img, lb = self.dataset[index]
lb_onehot = onehot(self.num_class, lb)
shape = img.shape
if self.vis:
origin = img.clone()
for _ in range(self.num_mix):
r = np.random.rand(1)
if self.beta <= 0 or r > self.prob:
continue
# generate mixed sample
lam = np.random.beta(self.beta, self.beta) # lam -> remain ratio
if self.indices is None:
rand_index = random.choice(range(len(self)))
else:
rand_index = random.choice(self.indices)
img2, lb2 = self.dataset[rand_index]
lb2_onehot = onehot(self.num_class, lb2)
# shape: step, channel, height, width
# alpha = np.random.rand()
# if alpha < 0.333:
# mask = spatio_mask(shape, 1. - lam)
# elif alpha > 0.667:
# mask = temporal_mask(shape, 1. - lam)
# else:
# mask = st_mask(shape, 1. - lam)
mask = GMM_mask(shape, 1. - lam, self.gaussian_n)
# mask = GMM_mask_clip(shape, 1. - lam)
# mask = torch.logical_not(mask)
# lam = 1 - (mask.sum() / np.prod(img.shape)) # area
lam = calc_masked_lam(img, img2, mask) # count
img[mask] = img2[mask] # count && mask required
# distance
# mix = torch.clone(img)
# if self.vis:
# mix[mask] = -img2[mask]
# img2 = -img2
# else:
# mix[mask] = img2[mask]
# lam = calc_masked_lam_with_difference(img, img2, mix, kernel_size=3)
# img = mix
if self.noise != 0.:
img = SaltAndPepperNoise(img, self.noise)
lb_onehot = lb_onehot * lam + lb2_onehot * (1. - lam)
if self.vis:
return origin, img, img2, mask
else:
return img, lb_onehot
def __len__(self):
return len(self.dataset)
if __name__ == '__main__':
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
def get_proj(self):
"""
Create the projection matrix from the current viewing position.
elev stores the elevation angle in the z plane
azim stores the azimuth angle in the (x, y) plane
dist is the distance of the eye viewing point from the object point.
"""
# chosen for similarity with the initial view before gh-8896
relev, razim = np.pi * self.elev / 180, np.pi * self.azim / 180
# EDITED TO HAVE SCALED AXIS
xmin, xmax = np.divide(self.get_xlim3d(), self.pbaspect[0])
ymin, ymax = np.divide(self.get_ylim3d(), self.pbaspect[1])
zmin, zmax = np.divide(self.get_zlim3d(), self.pbaspect[2])
# transform to uniform world coordinates 0-1, 0-1, 0-1
worldM = proj3d.world_transformation(xmin, xmax,
ymin, ymax,
zmin, zmax)
# look into the middle of the new coordinates
R = self.pbaspect / 2
xp = R[0] + np.cos(razim) * np.cos(relev) * self.dist
yp = R[1] + np.sin(razim) * np.cos(relev) * self.dist
zp = R[2] + np.sin(relev) * self.dist
E = np.array((xp, yp, zp))
self.eye = E
self.vvec = R - E
self.vvec = self.vvec / np.linalg.norm(self.vvec)
if abs(relev) > np.pi / 2:
# upside down
V = np.array((0, 0, -1))
else:
V = np.array((0, 0, 1))
zfront, zback = -self.dist, self.dist
viewM = proj3d.view_transformation(E, R, V)
projM = self._projection(zfront, zback)
M0 = np.dot(viewM, worldM)
M = np.dot(projM, M0)
return M
Axes3D.get_proj = get_proj
size = (100, 2, 48, 48)
mask = GMM_mask(size, 0.3)
print(mask.shape)
# for i in range(100):
# plt.figure()
# plt.imshow(mask[i, 0])
# plt.show()
pos_idx1 = []
neg_idx1 = []
for t in range(100):
for r in range(48):
for c in range(48):
if mask[t, 0, r, c] > 0:
pos_idx1.append((t, r, c))
if mask[t, 1, r, c] > 0:
neg_idx1.append((t, r, c))
pos_t1, pos_x1, pos_y1 = np.split(np.array(pos_idx1), 3, axis=1)
neg_t1, neg_x1, neg_y1 = np.split(np.array(neg_idx1), 3, axis=1)
fig = plt.figure(figsize=plt.figaspect(0.5) * 1.5)
ax = Axes3D(fig)
ax.pbaspect = np.array([1, 1, 1]) # np.array([2.0, 1.0, 0.5])
ax.view_init(elev=10, azim=-75)
# ax.axis('off')
ax.scatter(pos_t1[:, 0], pos_y1[:, 0], 48 - pos_x1[:, 0], color='red', alpha=0.1, s=2.)
ax.scatter(neg_t1[:, 0], neg_y1[:, 0], 48 - neg_x1[:, 0], color='blue', alpha=0.1, s=2.)
plt.show()
================================================
FILE: braincog/datasets/datasets.py
================================================
import os, warnings
import tonic
from tonic import DiskCachedDataset
import torch
import torch.nn.functional as F
import torch.utils
import torchvision.datasets as datasets
from timm.data import ImageDataset, create_loader, Mixup, FastCollateMixup, AugMixDataset
from timm.data import create_transform
from torchvision import transforms
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import braincog
from braincog.datasets.NOmniglot.nomniglot_full import NOmniglotfull
from braincog.datasets.NOmniglot.nomniglot_nw_ks import NOmniglotNWayKShot
from braincog.datasets.NOmniglot.nomniglot_pair import NOmniglotTrainSet, NOmniglotTestSet
from braincog.datasets.ESimagenet.ES_imagenet import ESImagenet_Dataset
from braincog.datasets.ESimagenet.reconstructed_ES_imagenet import ESImagenet2D_Dataset
from braincog.datasets.CUB2002011 import CUB2002011
from braincog.datasets.TinyImageNet import TinyImageNet
from braincog.datasets.StanfordDogs import StanfordDogs
from braincog.datasets.bullying10k import BULLYINGDVS
from .cut_mix import CutMix, EventMix, MixUp
from .rand_aug import *
from .utils import dvs_channel_check_expend, rescale
DVSCIFAR10_MEAN_16 = [0.3290, 0.4507]
DVSCIFAR10_STD_16 = [1.8398, 1.6549]
DATA_DIR = '/data/datasets'
DEFAULT_CROP_PCT = 0.875
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
CIFAR10_DEFAULT_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_DEFAULT_STD = (0.2023, 0.1994, 0.2010)
def unpack_mix_param(args):
mix_up = args['mix_up'] if 'mix_up' in args else False
cut_mix = args['cut_mix'] if 'cut_mix' in args else False
event_mix = args['event_mix'] if 'event_mix' in args else False
beta = args['beta'] if 'beta' in args else 1.
prob = args['prob'] if 'prob' in args else .5
num = args['num'] if 'num' in args else 1
num_classes = args['num_classes'] if 'num_classes' in args else 10
noise = args['noise'] if 'noise' in args else 0.
gaussian_n = args['gaussian_n'] if 'gaussian_n' in args else None
return mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n
def build_transform(is_train, img_size):
"""
构建数据增强, 适用于static data
:param is_train: 是否训练集
:param img_size: 输出的图像尺寸
:return: 数据增强策略
"""
resize_im = img_size > 32
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=img_size,
is_training=True,
color_jitter=0.4,
auto_augment='rand-m9-mstd0.5-inc1',
interpolation='bicubic',
re_prob=0.25,
re_mode='pixel',
re_count=1,
)
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(
img_size, padding=4)
return transform
t = []
if resize_im:
size = int((256 / 224) * img_size)
t.append(
# to maintain same ratio w.r.t. 224 images
transforms.Resize(size, interpolation=3),
)
t.append(transforms.CenterCrop(img_size))
t.append(transforms.ToTensor())
if img_size > 32:
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
else:
t.append(transforms.Normalize(CIFAR10_DEFAULT_MEAN, CIFAR10_DEFAULT_STD))
return transforms.Compose(t)
def build_dataset(is_train, img_size, dataset, path, same_da=False):
"""
构建带有增强策略的数据集
:param is_train: 是否训练集
:param img_size: 输出图像尺寸
:param dataset: 数据集名称
:param path: 数据集路径
:param same_da: 为训练集使用测试集的增广方法
:return: 增强后的数据集
"""
transform = build_transform(False, img_size) if same_da else build_transform(is_train, img_size)
if dataset == 'CIFAR10':
dataset = datasets.CIFAR10(
path, train=is_train, transform=transform, download=True)
nb_classes = 10
elif dataset == 'CIFAR100':
dataset = datasets.CIFAR100(
path, train=is_train, transform=transform, download=True)
nb_classes = 100
else:
raise NotImplementedError
return dataset, nb_classes
class MNISTData(object):
"""
Load MNIST datesets.
"""
def __init__(self,
data_path: str,
batch_size: int,
train_trans: Sequence[torch.nn.Module] = None,
test_trans: Sequence[torch.nn.Module] = None,
pin_memory: bool = True,
drop_last: bool = True,
shuffle: bool = True,
) -> None:
self._data_path = data_path
self._batch_size = batch_size
self._pin_memory = pin_memory
self._drop_last = drop_last
self._shuffle = shuffle
self._train_transform = transforms.Compose(train_trans) if train_trans else None
self._test_transform = transforms.Compose(test_trans) if test_trans else None
def get_data_loaders(self):
print('Batch size: ', self._batch_size)
train_datasets = datasets.MNIST(root=self._data_path, train=True, transform=self._train_transform, download=True)
test_datasets = datasets.MNIST(root=self._data_path, train=False, transform=self._test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=self._batch_size,
pin_memory=self._pin_memory, drop_last=self._drop_last, shuffle=self._shuffle
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=self._batch_size,
pin_memory=self._pin_memory, drop_last=False
)
return train_loader, test_loader
def get_standard_data(self):
MNIST_MEAN = 0.1307
MNIST_STD = 0.3081
self._train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),
transforms.ToTensor(),
transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])
self._test_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])
return self.get_data_loaders()
def get_mnist_data(batch_size, num_workers=8, same_da=False,root=DATA_DIR, **kwargs):
"""
获取MNIST数据
http://data.pymvpa.org/datasets/mnist/
:param batch_size: batch size
:param same_da: 为训练集使用测试集的增广方法
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
MNIST_MEAN = 0.1307
MNIST_STD = 0.3081
if 'root' in kwargs:root=kwargs["root"]
if 'skip_norm' in kwargs and kwargs['skip_norm'] is True:
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(rescale)
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(rescale)
])
else:
train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),
# transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])
test_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])
train_datasets = datasets.MNIST(
root=root, train=True, transform=test_transform if same_da else train_transform, download=True)
test_datasets = datasets.MNIST(
root=root, train=False, transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_fashion_data(batch_size, num_workers=8, same_da=False,root=DATA_DIR, **kwargs):
"""
获取fashion MNIST数据
http://arxiv.org/abs/1708.07747
:param batch_size: batch size
:param same_da: 为训练集使用测试集的增广方法
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor()])
test_transform = transforms.Compose([transforms.ToTensor()])
train_datasets = datasets.FashionMNIST(
root=root, train=True, transform=test_transform if same_da else train_transform, download=True)
test_datasets = datasets.FashionMNIST(
root=root, train=False, transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_cifar10_data(batch_size, num_workers=8, same_da=False,root=DATA_DIR, **kwargs):
"""
获取CIFAR10数据
https://www.cs.toronto.edu/~kriz/cifar.html
:param batch_size: batch size
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
train_datasets, _ = build_dataset(True, 32, 'CIFAR10', root, same_da)
test_datasets, _ = build_dataset(False, 32, 'CIFAR10', root, same_da)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True,
num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False,
num_workers=num_workers
)
return train_loader, test_loader, None, None
def get_cifar100_data(batch_size, num_workers=8, same_data=False,root=DATA_DIR, *args, **kwargs):
"""
获取CIFAR100数据
https://www.cs.toronto.edu/~kriz/cifar.html
:param batch_size: batch size
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
train_datasets, _ = build_dataset(True, 32, 'CIFAR100', root, same_data)
test_datasets, _ = build_dataset(False, 32, 'CIFAR100', root, same_data)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_TinyImageNet_data(batch_size, num_workers=8, same_da=False,root=DATA_DIR, *args, **kwargs):
size=kwargs["size"] if "size" in kwargs else 224
train_transform = transforms.Compose([
transforms.RandomResizedCrop(size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
test_transform = transforms.Compose([
transforms.Resize(size*8//7),
transforms.CenterCrop(size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
root=os.path.join(root, 'TinyImageNet')
train_datasets = TinyImageNet(
root=root, split="train", transform=test_transform if same_da else train_transform, download=True)
test_datasets = TinyImageNet(
root=root, split="val", transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_imnet_data(args, _logger, data_config, num_aug_splits,root=DATA_DIR, **kwargs):
"""
获取ImageNet数据集
http://arxiv.org/abs/1409.0575
:param args: 其他的参数
:param _logger: 日志路径
:param data_config: 增强策略
:param num_aug_splits: 不同增强策略的数量
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
train_dir = os.path.join(root, 'ILSVRC2012/train')
if not os.path.exists(train_dir):
_logger.error(
'Training folder does not exist at: {}'.format(train_dir))
exit(1)
dataset_train = ImageDataset(train_dir)
# collate_fn = None
# mixup_fn = None
# mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
# if mixup_active:
# mixup_args = dict(
# mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
# prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
# label_smoothing=args.smoothing, num_classes=args.num_classes)
# if args.prefetcher:
# # collate conflict (need to support deinterleaving in collate mixup)
# assert not num_aug_splits
# collate_fn = FastCollateMixup(**mixup_args)
# else:
# mixup_fn = Mixup(**mixup_args)
# if num_aug_splits > 1:
# dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation']
loader_train = create_loader(
dataset_train,
input_size=data_config['input_size'],
batch_size=args.batch_size,
is_training=True,
use_prefetcher=args.prefetcher,
no_aug=args.no_aug,
# re_prob=args.reprob,
# re_mode=args.remode,
# re_count=args.recount,
# re_split=args.resplit,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
# vflip=arg,
color_jitter=args.color_jitter,
#auto_augment=args.aa,
num_aug_splits=num_aug_splits,
interpolation=train_interpolation,
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
#collate_fn=collate_fn,
pin_memory=args.pin_mem,
use_multi_epochs_loader=args.use_multi_epochs_loader)
eval_dir = os.path.join(root, 'ILSVRC2012/val')
if not os.path.isdir(eval_dir):
eval_dir = os.path.join(root, 'ILSVRC2012/validation')
if not os.path.isdir(eval_dir):
_logger.error(
'Validation folder does not exist at: {}'.format(eval_dir))
exit(1)
dataset_eval = ImageDataset(eval_dir)
loader_eval = create_loader(
dataset_eval,
input_size=data_config['input_size'],
batch_size=args.validation_batch_size_multiplier * args.batch_size,
is_training=False,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
)
return loader_train, loader_eval, False, None
def get_dvsg_data(batch_size, step,root=DATA_DIR, **kwargs):
"""
获取DVS Gesture数据
DOI: 10.1109/CVPR.2017.781
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
sensor_size = tonic.datasets.DVSGesture.sensor_size
size = kwargs['size'] if 'size' in kwargs else 48
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),
])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),
])
train_dataset = tonic.datasets.DVSGesture(os.path.join(root, 'DVS/DVSGesture'),
transform=train_transform, train=True)
test_dataset = tonic.datasets.DVSGesture(os.path.join(root, 'DVS/DVSGesture'),
transform=test_transform, train=False)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
lambda x: dvs_channel_check_expend(x),
transforms.RandomCrop(size, padding=size // 12),
# transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
lambda x: dvs_channel_check_expend(x),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(root, 'DVS/DVSGesture/train_cache_{}'.format(step)),
transform=train_transform, num_copies=3)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(root, 'DVS/DVSGesture/test_cache_{}'.format(step)),
transform=test_transform, num_copies=3)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
pin_memory=True, drop_last=True, num_workers=8,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=2,
shuffle=False,
)
return train_loader, test_loader, mixup_active, None
def get_bullyingdvs_data(batch_size, step, root=DATA_DIR, **kwargs):
"""
获取Bullying10K数据
NeurIPS 2023
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return:
"""
size = kwargs['size'] if 'size' in kwargs else 48
sensor_size = BULLYINGDVS.sensor_size
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = BULLYINGDVS('/data/datasets/Bullying10k_processed', transform=train_transform)
# train_dataset = BULLYINGDVS(os.path.join(root, 'DVS/BULLYINGDVS'), transform=train_transform)
test_dataset = BULLYINGDVS(os.path.join(root, 'DVS/BULLYINGDVS'), transform=test_transform)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
transforms.RandomCrop(size, padding=size // 12),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
# print('randaug', m, n)
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(root, 'DVS/BULLYINGDVS/train_cache_{}'.format(step)),
transform=train_transform)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(root, 'DVS/BULLYINGDVS/test_cache_{}'.format(step)),
transform=test_transform)
num_train = len(train_dataset)
num_per_cls = num_train // 10
indices_train, indices_test = [], []
portion = kwargs['portion'] if 'portion' in kwargs else .9
for i in range(10):
indices_train.extend(
list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))))
indices_test.extend(
list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
# print('cut_mix', beta, prob, num, num_classes)
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_train),
pin_memory=True, drop_last=True, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_test),
pin_memory=True, drop_last=False, num_workers=2
)
return train_loader, test_loader, mixup_active, None
def get_dvsc10_data(batch_size, step, root=DATA_DIR, **kwargs):
"""
获取DVS CIFAR10数据
http://journal.frontiersin.org/article/10.3389/fnins.2017.00309/full
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
size = kwargs['size'] if 'size' in kwargs else 48
sensor_size = tonic.datasets.CIFAR10DVS.sensor_size
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = tonic.datasets.CIFAR10DVS(os.path.join(root, 'DVS/DVS_Cifar10'), transform=train_transform)
test_dataset = tonic.datasets.CIFAR10DVS(os.path.join(root, 'DVS/DVS_Cifar10'), transform=test_transform)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# lambda x: TemporalShift(x, .01),
# lambda x: drop(x, 0.15),
# lambda x: ShearX(x, 15),
# lambda x: ShearY(x, 15),
# lambda x: TranslateX(x, 0.225),
# lambda x: TranslateY(x, 0.225),
# lambda x: Rotate(x, 15),
# lambda x: CutoutAbs(x, 0.25),
# lambda x: CutoutTemporal(x, 0.25),
# lambda x: GaussianBlur(x, 0.5),
# lambda x: SaltAndPepperNoise(x, 0.1),
# transforms.Normalize(DVSCIFAR10_MEAN_16, DVSCIFAR10_STD_16),
transforms.RandomCrop(size, padding=size // 12),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
# print('randaug', m, n)
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(root, 'DVS/DVS_Cifar10/train_cache_{}'.format(step)),
transform=train_transform)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(root, 'DVS/DVS_Cifar10/test_cache_{}'.format(step)),
transform=test_transform)
num_train = len(train_dataset)
num_per_cls = num_train // 10
indices_train, indices_test = [], []
portion = kwargs['portion'] if 'portion' in kwargs else .9
for i in range(10):
indices_train.extend(
list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))))
indices_test.extend(
list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
# print('cut_mix', beta, prob, num, num_classes)
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_train),
pin_memory=True, drop_last=True, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_test),
pin_memory=True, drop_last=False, num_workers=2
)
return train_loader, test_loader, mixup_active, None
def get_NCALTECH101_data(batch_size, step,root=DATA_DIR, **kwargs):
"""
获取NCaltech101数据
http://journal.frontiersin.org/Article/10.3389/fnins.2015.00437/abstract
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
sensor_size = tonic.datasets.NCALTECH101.sensor_size
cls_count = tonic.datasets.NCALTECH101.cls_count
dataset_length = tonic.datasets.NCALTECH101.length
portion = kwargs['portion'] if 'portion' in kwargs else .9
size = kwargs['size'] if 'size' in kwargs else 48
# print('portion', portion)
train_sample_weight = []
train_sample_index = []
train_count = 0
test_sample_index = []
idx_begin = 0
for count in cls_count:
sample_weight = dataset_length / count
train_sample = round(portion * count)
test_sample = count - train_sample
train_count += train_sample
train_sample_weight.extend(
[sample_weight] * train_sample
)
train_sample_weight.extend(
[0.] * test_sample
)
train_sample_index.extend(
list((range(idx_begin, idx_begin + train_sample)))
)
test_sample_index.extend(
list(range(idx_begin + train_sample, idx_begin + train_sample + test_sample))
)
idx_begin += count
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weight, train_count)
test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_sample_index)
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = tonic.datasets.NCALTECH101(os.path.join(root, 'DVS/NCALTECH101'), transform=train_transform)
test_dataset = tonic.datasets.NCALTECH101(os.path.join(root, 'DVS/NCALTECH101'), transform=test_transform)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: print(x.shape),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# transforms.RandomCrop(size, padding=size // 12),
# transforms.RandomHorizontalFlip(),
#transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# lambda x: temporal_flatten(x),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(root, 'DVS/NCALTECH101/train_cache_{}'.format(step)),
transform=train_transform, num_copies=3)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(root, 'DVS/NCALTECH101/test_cache_{}'.format(step)),
transform=test_transform, num_copies=3)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
sampler=train_sampler,
pin_memory=True, drop_last=True, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
sampler=test_sampler,
pin_memory=True, drop_last=False, num_workers=2
)
return train_loader, test_loader, mixup_active, None
def get_NCARS_data(batch_size, step,root=DATA_DIR, **kwargs):
"""
获取N-Cars数据
https://ieeexplore.ieee.org/document/8578284/
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
sensor_size = tonic.datasets.NCARS.sensor_size
size = kwargs['size'] if 'size' in kwargs else 48
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=None, n_time_bins=step),
])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=None, n_time_bins=step),
])
train_dataset = tonic.datasets.NCARS(os.path.join(root, 'DVS/NCARS'), transform=train_transform, train=True)
test_dataset = tonic.datasets.NCARS(os.path.join(root, 'DVS/NCARS'), transform=test_transform, train=False)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
lambda x: dvs_channel_check_expend(x),
transforms.RandomCrop(size, padding=size // 12),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
lambda x: dvs_channel_check_expend(x),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(root, 'DVS/NCARS/train_cache_{}'.format(step)),
transform=train_transform, num_copies=3)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(root, 'DVS/NCARS/test_cache_{}'.format(step)),
transform=test_transform, num_copies=3)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
pin_memory=True, drop_last=True, num_workers=8,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=2,
shuffle=False,
)
return train_loader, test_loader, mixup_active, None
def get_nomni_data(batch_size, train_portion=1.,root=DATA_DIR, **kwargs):
"""
获取N-Omniglot数据
:param batch_size:batch的大小
:param data_mode:一共full nkks pair三种模式
:param frames_num:一个样本帧的个数
:param data_type:event frequency两种模式
"""
data_mode = kwargs["data_mode"] if "data_mode" in kwargs else "full"
frames_num = kwargs["frames_num"] if "frames_num" in kwargs else 4
data_type = kwargs["data_type"] if "data_type" in kwargs else "event"
train_transform = transforms.Compose([
transforms.Resize((28, 28))])
test_transform = transforms.Compose([
transforms.Resize((28, 28))])
if data_mode == "full":
train_datasets = NOmniglotfull(root=os.path.join(root, 'DVS/NOmniglot'), train=True, frames_num=frames_num,
data_type=data_type,
transform=train_transform)
test_datasets = NOmniglotfull(root=os.path.join(root, 'DVS/NOmniglot'), train=False, frames_num=frames_num,
data_type=data_type,
transform=test_transform)
elif data_mode == "nkks":
train_datasets = NOmniglotNWayKShot(os.path.join(root, 'DVS/NOmniglot'),
n_way=kwargs["n_way"],
k_shot=kwargs["k_shot"],
k_query=kwargs["k_query"],
train=True,
frames_num=frames_num,
data_type=data_type,
transform=train_transform)
test_datasets = NOmniglotNWayKShot(os.path.join(root, 'DVS/NOmniglot'),
n_way=kwargs["n_way"],
k_shot=kwargs["k_shot"],
k_query=kwargs["k_query"],
train=False,
frames_num=frames_num,
data_type=data_type,
transform=test_transform)
elif data_mode == "pair":
train_datasets = NOmniglotTrainSet(root=os.path.join(root, 'DVS/NOmniglot'), use_frame=True,
frames_num=frames_num, data_type=data_type,
use_npz=False, resize=105)
test_datasets = NOmniglotTestSet(root=os.path.join(root, 'DVS/NOmniglot'), time=2000, way=kwargs["n_way"],
shot=kwargs["k_shot"], use_frame=True,
frames_num=frames_num, data_type=data_type, use_npz=False, resize=105)
else:
pass
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size, num_workers=12,
pin_memory=True, drop_last=True, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size, num_workers=12,
pin_memory=True, drop_last=False
)
return train_loader, test_loader, None, None
def get_esimnet_data(batch_size, step,root=DATA_DIR, **kwargs):
"""
获取ES imagenet数据
DOI: 10.3389/fnins.2021.726582
:param batch_size: batch size
:param step: 仿真步长,固定为8
:param reconstruct: 重构则时间步为1, 否则为8
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
:note: 没有自动下载, 下载及md5请参考spikingjelly, sampler默认为DistributedSampler
"""
reconstruct = kwargs["reconstruct"] if "reconstruct" in kwargs else False
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: dvs_channel_check_expend(x),
])
if reconstruct:
assert step == 1
train_dataset = ESImagenet2D_Dataset(mode='train',
data_set_path=os.path.join(root, 'DVS/ES-imagenet-0.18/extract/ES-imagenet-0.18/'),
transform=train_transform)
test_dataset = ESImagenet2D_Dataset(mode='test',
data_set_path=os.path.join(root, 'DVS/ES-imagenet-0.18/extract/ES-imagenet-0.18/'),
transform=test_transform)
else:
assert step == 8
train_dataset = ESImagenet_Dataset(mode='train',
data_set_path=os.path.join(root,
'DVS/ES-imagenet-0.18/extract/ES-imagenet-0.18/'),
transform=train_transform)
test_dataset = ESImagenet_Dataset(mode='test',
data_set_path=os.path.join(root,
'DVS/ES-imagenet-0.18/extract/ES-imagenet-0.18/'),
transform=test_transform)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
pin_memory=True, drop_last=True, num_workers=8,
shuffle=False, sampler=train_sampler
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=1,
shuffle=False, sampler=test_sampler
)
return train_loader, test_loader, mixup_active, None
def get_nmnist_data(batch_size, step, **kwargs):
"""
获取N-MNIST数据
http://journal.frontiersin.org/Article/10.3389/fnins.2015.00437/abstract
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
sensor_size = tonic.datasets.NMNIST.sensor_size
size = kwargs['size'] if 'size' in kwargs else 34
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),
])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),
])
train_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/N-MNIST'),
transform=train_transform, train=True)
test_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/N-MNIST'),
transform=test_transform, train=False)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
lambda x: dvs_channel_check_expend(x),
# transforms.RandomCrop(size, padding=size // 12),
# transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
lambda x: dvs_channel_check_expend(x),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/N-MNIST/train_cache_{}'.format(step)),
transform=train_transform, num_copies=3)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/N-MNIST/test_cache_{}'.format(step)),
transform=test_transform, num_copies=3)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
pin_memory=True, drop_last=True, num_workers=8,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=2,
shuffle=False,
)
return train_loader, test_loader, mixup_active, None
def get_ntidigits_data(batch_size, step, **kwargs):
"""
获取N-TIDIGITS数据 (tonic 新版本中的下载链接可能挂了,可以参考0.4.0的版本)
https://www.frontiersin.org/articles/10.3389/fnins.2018.00023/full
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
:format: (b,t,c,len) 不同于vision, audio中c为1, 并且没有h,w; 只有len=64
"""
sensor_size = tonic.datasets.NTIDIGITS.sensor_size
train_transform = transforms.Compose([
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: x.squeeze(1)
])
test_transform = transforms.Compose([
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: x.squeeze(1)
])
train_dataset = tonic.datasets.NTIDIGITS(os.path.join(DATA_DIR, 'DVS/NTIDIGITS'),
transform=train_transform, train=True)
test_dataset = tonic.datasets.NTIDIGITS(os.path.join(DATA_DIR, 'DVS/NTIDIGITS'),
transform=test_transform, train=False)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
pin_memory=True, drop_last=True, num_workers=8,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=2,
shuffle=False,
)
return train_loader, test_loader, None, None
def get_shd_data(batch_size, step, **kwargs):
"""
获取SHD数据
https://ieeexplore.ieee.org/abstract/document/9311226
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
:format: (b,t,c,len) 不同于vision, audio中c为1, 并且没有h,w; 只有len=700. Transform后变为(b, t, len)
"""
sensor_size = tonic.datasets.SHD.sensor_size
train_transform = transforms.Compose([
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step)
])
test_transform = transforms.Compose([
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step)
])
train_dataset = tonic.datasets.SHD(os.path.join(DATA_DIR, 'DVS/SHD'),
transform=train_transform, train=True)
test_dataset = tonic.datasets.SHD(os.path.join(DATA_DIR, 'DVS/SHD'),
transform=test_transform, train=False)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: x.squeeze(1)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: x.squeeze(1)
])
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/SHD/train_cache_{}'.format(step)),
transform=train_transform, num_copies=3)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/SHD/test_cache_{}'.format(step)),
transform=test_transform, num_copies=3)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=8,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=2,
shuffle=False,
)
return train_loader, test_loader, None, None
def get_CUB2002011_data(batch_size, num_workers=8, same_da=False,root=DATA_DIR, *args, **kwargs):
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
root=os.path.join(root, 'CUB2002011')
train_datasets = CUB2002011(
root=root, train=True, transform=test_transform if same_da else train_transform, download=True)
test_datasets = CUB2002011(
root=root, train=False, transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_StanfordCars_data(batch_size, num_workers=8, same_da=False,root=DATA_DIR, *args, **kwargs):
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
root=os.path.join(root, 'StanfordCars')
train_datasets = datasets.StanfordCars(
root=root, split ="train", transform=test_transform if same_da else train_transform, download=True)
test_datasets = datasets.StanfordCars(
root=root, split ="test", transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_StanfordDogs_data(batch_size, num_workers=8, same_da=False,root=DATA_DIR, *args, **kwargs):
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
root=os.path.join(root, 'StanfordDogs')
train_datasets = StanfordDogs(
root=root, train=True, transform=test_transform if same_da else train_transform, download=True)
test_datasets = StanfordDogs(
root=root, train=False, transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_FGVCAircraft_data(batch_size, num_workers=8, same_da=False,root=DATA_DIR, *args, **kwargs):
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
root=os.path.join(root, 'FGVCAircraft')
train_datasets = datasets.FGVCAircraft(
root=root, split="train", transform=test_transform if same_da else train_transform, download=True)
test_datasets = datasets.FGVCAircraft(
root=root, split="test", transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_Flowers102_data(batch_size, num_workers=8, same_da=False,root=DATA_DIR, *args, **kwargs):
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
root=os.path.join(root, 'Flowers102')
train_datasets = datasets.Flowers102(
root=root, split="train", transform=test_transform if same_da else train_transform, download=True)
test_datasets = datasets.Flowers102(
root=root, split="test", transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_UCF101DVS_data(batch_size, step, **kwargs):
"""
获取DVS CIFAR10数据
http://journal.frontiersin.org/article/10.3389/fnins.2017.00309/full
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
size = kwargs['size'] if 'size' in kwargs else 48
sensor_size = braincog.datasets.ucf101_dvs.UCF101DVS.sensor_size
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = braincog.datasets.ucf101_dvs.UCF101DVS(os.path.join(DATA_DIR, 'UCF101DVS'), train=True, transform=train_transform)
test_dataset = braincog.datasets.ucf101_dvs.UCF101DVS(os.path.join(DATA_DIR, 'UCF101DVS'), train=False, transform=test_transform)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# lambda x: TemporalShift(x, .01),
# lambda x: drop(x, 0.15),
# lambda x: ShearX(x, 15),
# lambda x: ShearY(x, 15),
# lambda x: TranslateX(x, 0.225),
# lambda x: TranslateY(x, 0.225),
# lambda x: Rotate(x, 15),
# lambda x: CutoutAbs(x, 0.25),
# lambda x: CutoutTemporal(x, 0.25),
# lambda x: GaussianBlur(x, 0.5),
# lambda x: SaltAndPepperNoise(x, 0.1),
# transforms.Normalize(DVSCIFAR10_MEAN_16, DVSCIFAR10_STD_16),
# transforms.RandomCrop(size, padding=size // 12),
transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
# print('randaug', m, n)
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'UCF101DVS/train_cache_{}'.format(step)),
transform=train_transform)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'UCF101DVS/test_cache_{}'.format(step)),
transform=test_transform)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
# print('cut_mix', beta, prob, num, num_classes)
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
pin_memory=True, drop_last=True, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False,
pin_memory=True, drop_last=False, num_workers=2
)
return train_loader, test_loader, mixup_active, None
def get_HMDBDVS_data(batch_size, step, **kwargs):
sensor_size = braincog.datasets.hmdb_dvs.HMDBDVS.sensor_size
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = braincog.datasets.hmdb_dvs.HMDBDVS(os.path.join(DATA_DIR, 'HMDBDVS'), transform=train_transform)
test_dataset = braincog.datasets.hmdb_dvs.HMDBDVS(os.path.join(DATA_DIR, 'HMDBDVS'), transform=test_transform)
cls_count = train_dataset.cls_count
dataset_length = train_dataset.length
portion = .5
# portion = kwargs['portion'] if 'portion' in kwargs else .9
size = kwargs['size'] if 'size' in kwargs else 48
# print('portion', portion)
train_sample_weight = []
train_sample_index = []
train_count = 0
test_sample_index = []
idx_begin = 0
for count in cls_count:
sample_weight = dataset_length / count
train_sample = round(portion * count)
test_sample = count - train_sample
train_count += train_sample
train_sample_weight.extend(
[sample_weight] * train_sample
)
train_sample_weight.extend(
[0.] * test_sample
)
lst = list(range(idx_begin, idx_begin + train_sample + test_sample))
random.seed(0)
random.shuffle(lst)
train_sample_index.extend(
lst[:train_sample]
# list((range(idx_begin, idx_begin + train_sample)))
)
test_sample_index.extend(
lst[train_sample:train_sample + test_sample]
# list(range(idx_begin + train_sample, idx_begin + train_sample + test_sample))
)
idx_begin += count
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weight, train_count)
test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_sample_index)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: print(x.shape),
# lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# transforms.RandomCrop(size, padding=size // 12),
# transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# lambda x: temporal_flatten(x),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'HMDBDVS/train_cache_{}'.format(step)),
transform=train_transform, num_copies=3)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'HMDBDVS/test_cache_{}'.format(step)),
transform=test_transform, num_copies=3)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
sampler=train_sampler,
pin_memory=True, drop_last=True, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
sampler=test_sampler,
pin_memory=True, drop_last=False, num_workers=2
)
return train_loader, test_loader, mixup_active, None
================================================
FILE: braincog/datasets/gen_input_signal.py
================================================
import numpy as np
import random
import copy
dt = 1.0 # ms
lambda_max = 0.25 * dt # maximum spike rate (spikes per time step)
================================================
FILE: braincog/datasets/hmdb_dvs/__init__.py
================================================
# encoding: utf-8
# Author : Floyed
# Datetime : 2023/1/30 20:54
# User : yu
# Product : PyCharm
# Project : BrainCog
# File : __init__.py
# explain :
from .hmdb_dvs import HMDBDVS
__all__ = [
'HMDBDVS'
]
================================================
FILE: braincog/datasets/hmdb_dvs/hmdb_dvs.py
================================================
# encoding: utf-8
# Author : Floyed
# Datetime : 2023/1/30 20:54
# User : yu
# Product : PyCharm
# Project : BrainCog
# File : hmdb_dvs.py
# explain :
import os
import numpy as np
from numpy.lib import recfunctions
import scipy.io as scio
from typing import Tuple, Any, Optional
from tonic.dataset import Dataset
from tonic.download_utils import extract_archive
class HMDBDVS(Dataset):
"""ASL-DVS dataset . Events have (txyp) ordering.
::
@inproceedings{bi2019graph,
title={Graph-based Object Classification for Neuromorphic Vision Sensing},
author={Bi, Y and Chadha, A and Abbas, A and and Bourtsoulatze, E and Andreopoulos, Y},
booktitle={2019 IEEE International Conference on Computer Vision (ICCV)},
year={2019},
organization={IEEE}
}
Parameters:
save_to (string): Location to save files to on disk.
transform (callable, optional): A callable of transforms to apply to the data.
target_transform (callable, optional): A callable of transforms to apply to the targets/labels.
"""
sensor_size = (240, 180, 2)
dtype = np.dtype([("t", int), ("x", int), ("y", int), ("p", int)])
ordering = dtype.names
def __init__(self, save_to, transform=None, target_transform=None):
super(HMDBDVS, self).__init__(
save_to, transform=transform, target_transform=target_transform
)
if not self._check_exists():
raise NotImplementedError(
'Please manually download the dataset from'
' https://www.dropbox.com/sh/ie75dn246cacf6n/AACoU-_zkGOAwj51lSCM0JhGa?dl=0 '
'and extract it to {}'.format(self.location_on_system))
classes = os.listdir(self.location_on_system)
self.int_classes = dict(zip(classes, range(len(classes))))
for path, dirs, files in os.walk(self.location_on_system):
dirs.sort()
files.sort()
for file in files:
if file.endswith("mat"):
fsize = os.path.getsize(path + '/' + file) / float(1024)
if fsize < 1:
# print('{} size {} K'.format(file, fsize))
continue
self.data.append(path + "/" + file)
self.targets.append(self.int_classes[path.split('/')[-1]])
self.length = self.__len__()
self.cls_count = np.bincount(self.targets)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Returns:
(events, target) where target is index of the target class.
"""
events, target = scio.loadmat(self.data[index]), self.targets[index]
events = np.column_stack(
[
events["ts"],
events["x"],
self.sensor_size[1] - 1 - events["y"],
events["pol"],
]
)
events = np.lib.recfunctions.unstructured_to_structured(events, self.dtype)
if self.transform is not None:
events = self.transform(events)
if self.target_transform is not None:
target = self.target_transform(target)
return events, target
def __len__(self):
return len(self.data)
def _check_exists(self):
return self._folder_contains_at_least_n_files_of_type(
6765, ".mat"
)
================================================
FILE: braincog/datasets/ncaltech101/__init__.py
================================================
# encoding: utf-8
# Author : Floyed
# Datetime : 2023/1/30 21:26
# User : yu
# Product : PyCharm
# Project : BrainCog
# File : __init__.py.py
# explain :
from .ncaltech101 import NCALTECH101
__all__ = [
'NCALTECH101'
]
================================================
FILE: braincog/datasets/ncaltech101/ncaltech101.py
================================================
# encoding: utf-8
# Author : Floyed
# Datetime : 2023/1/30 21:28
# User : yu
# Product : PyCharm
# Project : BrainCog
# File : ncaltech101.py
# explain :
import os
import numpy as np
from tonic.io import read_mnist_file
from tonic.dataset import Dataset
from tonic.download_utils import extract_archive
class NCALTECH101(Dataset):
"""N-CALTECH101 dataset . Events have (xytp) ordering.
::
@article{orchard2015converting,
title={Converting static image datasets to spiking neuromorphic datasets using saccades},
author={Orchard, Garrick and Jayawant, Ajinkya and Cohen, Gregory K and Thakor, Nitish},
journal={Frontiers in neuroscience},
volume={9},
pages={437},
year={2015},
publisher={Frontiers}
}
Parameters:
save_to (string): Location to save files to on disk.
transform (callable, optional): A callable of transforms to apply to the data.
target_transform (callable, optional): A callable of transforms to apply to the targets/labels.
"""
url = "https://data.mendeley.com/public-files/datasets/cy6cvx3ryv/files/36b5c52a-b49d-4853-addb-a836a8883e49/file_downloaded"
filename = "N-Caltech101-archive.zip"
file_md5 = "66201824eabb0239c7ab992480b50ba3"
data_filename = "N-Caltech101-archive.zip"
folder_name = "Caltech101"
cls_count = [467,
435, 200, 798, 55, 800, 42, 42, 47, 54, 46,
33, 128, 98, 43, 85, 91, 50, 43, 123, 47,
59, 62, 107, 47, 69, 73, 70, 50, 51, 57,
67, 52, 65, 68, 75, 64, 53, 64, 85, 67,
67, 45, 34, 34, 51, 99, 100, 42, 54, 88,
80, 31, 64, 86, 114, 61, 81, 78, 41, 66,
43, 40, 87, 32, 76, 55, 35, 39, 47, 38,
45, 53, 34, 57, 82, 59, 49, 40, 63, 39,
84, 57, 35, 64, 45, 86, 59, 64, 35, 85,
49, 86, 75, 239, 37, 59, 34, 56, 39, 60]
# length = 8242
length = 8709
sensor_size = None # all recordings are of different size
dtype = np.dtype([("x", int), ("y", int), ("t", int), ("p", int)])
ordering = dtype.names
def __init__(self, save_to, transform=None, target_transform=None):
super(NCALTECH101, self).__init__(
save_to, transform=transform, target_transform=target_transform
)
classes = {
'BACKGROUND_Google': 0,
'Faces_easy': 1,
'Leopards': 2,
'Motorbikes': 3,
'accordion': 4,
'airplanes': 5,
'anchor': 6,
'ant': 7,
'barrel': 8,
'bass': 9,
'beaver': 10,
'binocular': 11,
'bonsai': 12,
'brain': 13,
'brontosaurus': 14,
'buddha': 15,
'butterfly': 16,
'camera': 17,
'cannon': 18,
'car_side': 19,
'ceiling_fan': 20,
'cellphone': 21,
'chair': 22,
'chandelier': 23,
'cougar_body': 24,
'cougar_face': 25,
'crab': 26,
'crayfish': 27,
'crocodile': 28,
'crocodile_head': 29,
'cup': 30,
'dalmatian': 31,
'dollar_bill': 32,
'dolphin': 33,
'dragonfly': 34,
'electric_guitar': 35,
'elephant': 36,
'emu': 37,
'euphonium': 38,
'ewer': 39,
'ferry': 40,
'flamingo': 41,
'flamingo_head': 42,
'garfield': 43,
'gerenuk': 44,
'gramophone': 45,
'grand_piano': 46,
'hawksbill': 47,
'headphone': 48,
'hedgehog': 49,
'helicopter': 50,
'ibis': 51,
'inline_skate': 52,
'joshua_tree': 53,
'kangaroo': 54,
'ketch': 55,
'lamp': 56,
'laptop': 57,
'llama': 58,
'lobster': 59,
'lotus': 60,
'mandolin': 61,
'mayfly': 62,
'menorah': 63,
'metronome': 64,
'minaret': 65,
'nautilus': 66,
'octopus': 67,
'okapi': 68,
'pagoda': 69,
'panda': 70,
'pigeon': 71,
'pizza': 72,
'platypus': 73,
'pyramid': 74,
'revolver': 75,
'rhino': 76,
'rooster': 77,
'saxophone': 78,
'schooner': 79,
'scissors': 80,
'scorpion': 81,
'sea_horse': 82,
'snoopy': 83,
'soccer_ball': 84,
'stapler': 85,
'starfish': 86,
'stegosaurus': 87,
'stop_sign': 88,
'strawberry': 89,
'sunflower': 90,
'tick': 91,
'trilobite': 92,
'umbrella': 93,
'watch': 94,
'water_lilly': 95,
'wheelchair': 96,
'wild_cat': 97,
'windsor_chair': 98,
'wrench': 99,
'yin_yang': 100,
}
# if not self._check_exists():
# self.download()
# extract_archive(os.path.join(self.location_on_system, self.data_filename))
file_path = os.path.join(self.location_on_system, self.folder_name)
for path, dirs, files in os.walk(file_path):
dirs.sort()
# if 'BACKGROUND_Google' in path:
# continue
for file in files:
if file.endswith("bin"):
self.data.append(path + "/" + file)
label_name = os.path.basename(path)
if isinstance(label_name, bytes):
label_name = label_name.decode()
self.targets.append(classes[label_name])
def __getitem__(self, index):
"""
Returns:
a tuple of (events, target) where target is the index of the target class.
"""
events = read_mnist_file(self.data[index], dtype=self.dtype)
target = self.targets[index]
events["x"] -= events["x"].min()
events["y"] -= events["y"].min()
if self.transform is not None:
events = self.transform(events)
if self.target_transform is not None:
target = self.target_transform(target)
return events, target
def __len__(self):
return len(self.data)
def _check_exists(self):
return self._is_file_present() and self._folder_contains_at_least_n_files_of_type(
8709, ".bin"
)
================================================
FILE: braincog/datasets/rand_aug.py
================================================
import random
import numpy as np
import torch
from torchvision import transforms
from torchvision.transforms import functional
from torchvision.transforms import InterpolationMode
def ShearX(x, v): # [-0.3, 0.3]
assert 0 <= v <= 30
v = np.random.uniform(0, v)
if random.random() > 0.5:
v = -v
return functional.affine(x, angle=0, translate=[0, 0], scale=1., shear=[v, 0])
def ShearY(x, v): # [-0.3, 0.3]
assert 0 <= v <= 30
v = np.random.uniform(0, v)
if random.random() > 0.5:
v = -v
return functional.affine(x, angle=0, translate=[0, 0], scale=1., shear=[0, v])
def TranslateX(x, v):
assert 0 <= v <= 0.45
v = np.random.uniform(0, v)
w, h = x.shape[-2::]
v = round(w * v)
if random.random() > 0.5:
v = -v
return functional.affine(x, angle=0, translate=[0, v], scale=1., shear=[0, 0])
def TranslateY(x, v):
assert 0 <= v <= 0.45
v = np.random.uniform(0, v)
w, h = x.shape[-2::]
v = round(w * v)
if random.random() > 0.5:
v = -v
return functional.affine(x, angle=0, translate=[v, 0], scale=1., shear=[0, 0])
def Rotate(x, v): # [-30, 30]
assert 0 <= v <= 30
v = np.random.uniform(0, v)
if random.random() > 0.5:
v = -v
return functional.affine(x, angle=v, translate=[0, 0], scale=1., shear=[0, 0])
def CutoutAbs(x, v): # [0, 60] => percentage: [0, 0.2]
assert 0 <= v <= 0.5
w, h = x.shape[-2::]
v = round(v * w)
x0 = np.random.uniform(w)
y0 = np.random.uniform(h)
x0 = round(max(0, x0 - v / 2.))
y0 = round(max(0, y0 - v / 2.))
x1 = min(w, x0 + v)
y1 = min(h, y0 + v)
x[:, :, y0:y1, x0:x1] = 0.
return x
def CutoutTemporal(x, v):
assert 0 <= v <= 0.5
v = np.random.uniform(0, v)
step = x.shape[0]
v = round(v * step)
t0 = np.random.randint(step)
t1 = min(step, t0 + v)
x[t0:t1, :, :, :] = 0.
return x
def TemporalShift(x, v):
# TODO: Maybe shift too mach than origin has
assert 0 <= v <= 0.2
v = v / 2.
shape = x.shape
# p = torch.zeros(2 * (shape[0] - 1), *shape[-3:], device=x.device)
shift = []
for i in range(x.shape[0] - 1):
spike = x[i].clone()
_max = int(spike.max())
sft = torch.zeros(shape[-3:], device=x.device)
for j in range(_max):
p = torch.rand_like(sft)
sft[torch.logical_and(p < v, spike > 0.)] += 1.
spike -= 1
shift.append(sft)
spike = x[i + 1].clone()
_max = int(spike.max())
sft = torch.zeros(shape[-3:], device=x.device)
for j in range(_max):
p = torch.rand_like(sft)
sft[torch.logical_and(p < v, spike > 0.)] += 1.
shift.append(sft)
for i in range(shape[0] - 1):
sft_next = shift[i * 2]
sft_pre = shift[i * 2 + 1]
x[i + 1] = torch.clip(x[i + 1] + sft_next - sft_pre, 0.)
x[i] = torch.clip(x[i] - sft_next + sft_pre, 0.)
return x
def SpatioShift(x, v):
# assert 0 <= v <= 0.1
w, h = x.shape[-2::]
shift_x = round(random.uniform(-v, v) * w)
shift_y = round(random.uniform(-v, v) * h)
output = []
step = x.shape[0]
for t in range(step):
output.append(functional.affine(x[t],
angle=0,
translate=[
round(shift_x * t / step),
round(shift_y * t / step)],
scale=1.,
shear=[0, 0]))
return torch.stack(output, dim=0)
def drop(x, v):
assert 0 <= v <= 0.5
v = np.random.uniform(0, v)
_max = int(torch.max(x))
p = torch.rand((_max, *x.shape), device=x.device)
for i in range(_max):
p[i, x > 0] += 1.
x -= 1.
p = torch.where(p > 1. + v, 1., 0.)
return torch.sum(p, dim=0)
def GaussianBlur(x, v):
assert 0.1 <= v <= 1.
v = np.random.uniform(0.1, v)
return functional.gaussian_blur(x, kernel_size=[5, 5], sigma=v)
def SaltAndPepperNoise(x, v):
assert 0 <= v <= 0.3
v = np.random.uniform(0, v)
p = torch.rand_like(x)
p = torch.where(p > v, 0., 1.)
return x + p
def Identity(x, v):
return x
# DVSC10 NCAL
augment_list = [ # normal: 79.44 77.57
# (ShearX, 0, 20), # 75.71
# (ShearY, 0, 20),
# (TranslateX, 0, 0.25), # 77.52
# (TranslateY, 0, 0.25),
# (Rotate, 0, 30), # 77.02
(CutoutAbs, 0, 0.5), # 79.13
(CutoutTemporal, 0, 0.5), # 80.65
# (TemporalShift, 0, 0.2), # 75.30
# (SpatioShift, 0, 0.1), # 78.43
(GaussianBlur, 0, 1.), # 79.83
# (drop, 0, 0.5), # 74.00
(SaltAndPepperNoise, 0, 0.3), # 79.64
# cutmix_normal_aug: 90.02 86.52
]
class RandAugment:
def __init__(self, n, m):
self.n = n
self.m = m # [0, 30]
self.augment_list = augment_list
def __call__(self, x):
ops = random.choices(self.augment_list, k=self.n)
for op, minvalue, maxvalue in ops:
val = (float(self.m) / 30) * float(maxvalue - minvalue) + minvalue
x = op(x, val)
return x
================================================
FILE: braincog/datasets/scripts/testlist01.txt
================================================
ApplyEyeMakeup/v_ApplyEyeMakeup_g01_c01.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g01_c02.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g01_c03.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g01_c04.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g01_c05.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g01_c06.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g02_c01.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g02_c02.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g02_c03.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g02_c04.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g03_c01.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g03_c02.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g03_c03.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g03_c04.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g03_c05.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g03_c06.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g04_c01.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g04_c02.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g04_c03.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g04_c04.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g04_c05.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g04_c06.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g04_c07.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g05_c01.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g05_c02.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g05_c03.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g05_c04.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g05_c05.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g05_c06.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g05_c07.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g06_c01.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g06_c02.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g06_c03.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g06_c04.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g06_c05.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g06_c06.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g06_c07.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g07_c01.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g07_c02.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g07_c03.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g07_c04.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g07_c05.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g07_c06.avi
ApplyEyeMakeup/v_ApplyEyeMakeup_g07_c07.avi
ApplyLipstick/v_ApplyLipstick_g01_c01.avi
ApplyLipstick/v_ApplyLipstick_g01_c02.avi
ApplyLipstick/v_ApplyLipstick_g01_c03.avi
ApplyLipstick/v_ApplyLipstick_g01_c04.avi
ApplyLipstick/v_ApplyLipstick_g01_c05.avi
ApplyLipstick/v_ApplyLipstick_g02_c01.avi
ApplyLipstick/v_ApplyLipstick_g02_c02.avi
ApplyLipstick/v_ApplyLipstick_g02_c03.avi
ApplyLipstick/v_ApplyLipstick_g02_c04.avi
ApplyLipstick/v_ApplyLipstick_g03_c01.avi
ApplyLipstick/v_ApplyLipstick_g03_c02.avi
ApplyLipstick/v_ApplyLipstick_g03_c03.avi
ApplyLipstick/v_ApplyLipstick_g03_c04.avi
ApplyLipstick/v_ApplyLipstick_g04_c01.avi
ApplyLipstick/v_ApplyLipstick_g04_c02.avi
ApplyLipstick/v_ApplyLipstick_g04_c03.avi
ApplyLipstick/v_ApplyLipstick_g04_c04.avi
ApplyLipstick/v_ApplyLipstick_g04_c05.avi
ApplyLipstick/v_ApplyLipstick_g05_c01.avi
ApplyLipstick/v_ApplyLipstick_g05_c02.avi
ApplyLipstick/v_ApplyLipstick_g05_c03.avi
ApplyLipstick/v_ApplyLipstick_g05_c04.avi
ApplyLipstick/v_ApplyLipstick_g05_c05.avi
ApplyLipstick/v_ApplyLipstick_g06_c01.avi
ApplyLipstick/v_ApplyLipstick_g06_c02.avi
ApplyLipstick/v_ApplyLipstick_g06_c03.avi
ApplyLipstick/v_ApplyLipstick_g06_c04.avi
ApplyLipstick/v_ApplyLipstick_g06_c05.avi
ApplyLipstick/v_ApplyLipstick_g07_c01.avi
ApplyLipstick/v_ApplyLipstick_g07_c02.avi
ApplyLipstick/v_ApplyLipstick_g07_c03.avi
ApplyLipstick/v_ApplyLipstick_g07_c04.avi
Archery/v_Archery_g01_c01.avi
Archery/v_Archery_g01_c02.avi
Archery/v_Archery_g01_c03.avi
Archery/v_Archery_g01_c04.avi
Archery/v_Archery_g01_c05.avi
Archery/v_Archery_g01_c06.avi
Archery/v_Archery_g01_c07.avi
Archery/v_Archery_g02_c01.avi
Archery/v_Archery_g02_c02.avi
Archery/v_Archery_g02_c03.avi
Archery/v_Archery_g02_c04.avi
Archery/v_Archery_g02_c05.avi
Archery/v_Archery_g02_c06.avi
Archery/v_Archery_g02_c07.avi
Archery/v_Archery_g03_c01.avi
Archery/v_Archery_g03_c02.avi
Archery/v_Archery_g03_c03.avi
Archery/v_Archery_g03_c04.avi
Archery/v_Archery_g03_c05.avi
Archery/v_Archery_g04_c01.avi
Archery/v_Archery_g04_c02.avi
Archery/v_Archery_g04_c03.avi
Archery/v_Archery_g04_c04.avi
Archery/v_Archery_g04_c05.avi
Archery/v_Archery_g05_c01.avi
Archery/v_Archery_g05_c02.avi
Archery/v_Archery_g05_c03.avi
Archery/v_Archery_g05_c04.avi
Archery/v_Archery_g05_c05.avi
Archery/v_Archery_g06_c01.avi
Archery/v_Archery_g06_c02.avi
Archery/v_Archery_g06_c03.avi
Archery/v_Archery_g06_c04.avi
Archery/v_Archery_g06_c05.avi
Archery/v_Archery_g06_c06.avi
Archery/v_Archery_g07_c01.avi
Archery/v_Archery_g07_c02.avi
Archery/v_Archery_g07_c03.avi
Archery/v_Archery_g07_c04.avi
Archery/v_Archery_g07_c05.avi
Archery/v_Archery_g07_c06.avi
BabyCrawling/v_BabyCrawling_g01_c01.avi
BabyCrawling/v_BabyCrawling_g01_c02.avi
BabyCrawling/v_BabyCrawling_g01_c03.avi
BabyCrawling/v_BabyCrawling_g01_c04.avi
BabyCrawling/v_BabyCrawling_g02_c01.avi
BabyCrawling/v_BabyCrawling_g02_c02.avi
BabyCrawling/v_BabyCrawling_g02_c03.avi
BabyCrawling/v_BabyCrawling_g02_c04.avi
BabyCrawling/v_BabyCrawling_g02_c05.avi
BabyCrawling/v_BabyCrawling_g02_c06.avi
BabyCrawling/v_BabyCrawling_g03_c01.avi
BabyCrawling/v_BabyCrawling_g03_c02.avi
BabyCrawling/v_BabyCrawling_g03_c03.avi
BabyCrawling/v_BabyCrawling_g03_c04.avi
BabyCrawling/v_BabyCrawling_g04_c01.avi
BabyCrawling/v_BabyCrawling_g04_c02.avi
BabyCrawling/v_BabyCrawling_g04_c03.avi
BabyCrawling/v_BabyCrawling_g04_c04.avi
BabyCrawling/v_BabyCrawling_g05_c01.avi
BabyCrawling/v_BabyCrawling_g05_c02.avi
BabyCrawling/v_BabyCrawling_g05_c03.avi
BabyCrawling/v_BabyCrawling_g05_c04.avi
BabyCrawling/v_BabyCrawling_g05_c05.avi
BabyCrawling/v_BabyCrawling_g06_c01.avi
BabyCrawling/v_BabyCrawling_g06_c02.avi
BabyCrawling/v_BabyCrawling_g06_c03.avi
BabyCrawling/v_BabyCrawling_g06_c04.avi
BabyCrawling/v_BabyCrawling_g06_c05.avi
BabyCrawling/v_BabyCrawling_g06_c06.avi
BabyCrawling/v_BabyCrawling_g07_c01.avi
BabyCrawling/v_BabyCrawling_g07_c02.avi
BabyCrawling/v_BabyCrawling_g07_c03.avi
BabyCrawling/v_BabyCrawling_g07_c04.avi
BabyCrawling/v_BabyCrawling_g07_c05.avi
BabyCrawling/v_BabyCrawling_g07_c06.avi
BalanceBeam/v_BalanceBeam_g01_c01.avi
BalanceBeam/v_BalanceBeam_g01_c02.avi
BalanceBeam/v_BalanceBeam_g01_c03.avi
BalanceBeam/v_BalanceBeam_g01_c04.avi
BalanceBeam/v_BalanceBeam_g02_c01.avi
BalanceBeam/v_BalanceBeam_g02_c02.avi
BalanceBeam/v_BalanceBeam_g02_c03.avi
BalanceBeam/v_BalanceBeam_g02_c04.avi
BalanceBeam/v_BalanceBeam_g03_c01.avi
BalanceBeam/v_BalanceBeam_g03_c02.avi
BalanceBeam/v_BalanceBeam_g03_c03.avi
BalanceBeam/v_BalanceBeam_g03_c04.avi
BalanceBeam/v_BalanceBeam_g04_c01.avi
BalanceBeam/v_BalanceBeam_g04_c02.avi
BalanceBeam/v_BalanceBeam_g04_c03.avi
BalanceBeam/v_BalanceBeam_g04_c04.avi
BalanceBeam/v_BalanceBeam_g05_c01.avi
BalanceBeam/v_BalanceBeam_g05_c02.avi
BalanceBeam/v_BalanceBeam_g05_c03.avi
BalanceBeam/v_BalanceBeam_g05_c04.avi
BalanceBeam/v_BalanceBeam_g06_c01.avi
BalanceBeam/v_BalanceBeam_g06_c02.avi
BalanceBeam/v_BalanceBeam_g06_c03.avi
BalanceBeam/v_BalanceBeam_g06_c04.avi
BalanceBeam/v_BalanceBeam_g06_c05.avi
BalanceBeam/v_BalanceBeam_g06_c06.avi
BalanceBeam/v_BalanceBeam_g06_c07.avi
BalanceBeam/v_BalanceBeam_g07_c01.avi
BalanceBeam/v_BalanceBeam_g07_c02.avi
BalanceBeam/v_BalanceBeam_g07_c03.avi
BalanceBeam/v_BalanceBeam_g07_c04.avi
BandMarching/v_BandMarching_g01_c01.avi
BandMarching/v_BandMarching_g01_c02.avi
BandMarching/v_BandMarching_g01_c03.avi
BandMarching/v_BandMarching_g01_c04.avi
BandMarching/v_BandMarching_g01_c05.avi
BandMarching/v_BandMarching_g01_c06.avi
BandMarching/v_BandMarching_g01_c07.avi
BandMarching/v_BandMarching_g02_c01.avi
BandMarching/v_BandMarching_g02_c02.avi
BandMarching/v_BandMarching_g02_c03.avi
BandMarching/v_BandMarching_g02_c04.avi
BandMarching/v_BandMarching_g02_c05.avi
BandMarching/v_BandMarching_g02_c06.avi
BandMarching/v_BandMarching_g02_c07.avi
BandMarching/v_BandMarching_g03_c01.avi
BandMarching/v_BandMarching_g03_c02.avi
BandMarching/v_BandMarching_g03_c03.avi
BandMarching/v_BandMarching_g03_c04.avi
BandMarching/v_BandMarching_g03_c05.avi
BandMarching/v_BandMarching_g03_c06.avi
BandMarching/v_BandMarching_g03_c07.avi
BandMarching/v_BandMarching_g04_c01.avi
BandMarching/v_BandMarching_g04_c02.avi
BandMarching/v_BandMarching_g04_c03.avi
BandMarching/v_BandMarching_g04_c04.avi
BandMarching/v_BandMarching_g05_c01.avi
BandMarching/v_BandMarching_g05_c02.avi
BandMarching/v_BandMarching_g05_c03.avi
BandMarching/v_BandMarching_g05_c04.avi
BandMarching/v_BandMarching_g05_c05.avi
BandMarching/v_BandMarching_g05_c06.avi
BandMarching/v_BandMarching_g05_c07.avi
BandMarching/v_BandMarching_g06_c01.avi
BandMarching/v_BandMarching_g06_c02.avi
BandMarching/v_BandMarching_g06_c03.avi
BandMarching/v_BandMarching_g06_c04.avi
BandMarching/v_BandMarching_g07_c01.avi
BandMarching/v_BandMarching_g07_c02.avi
BandMarching/v_BandMarching_g07_c03.avi
BandMarching/v_BandMarching_g07_c04.avi
BandMarching/v_BandMarching_g07_c05.avi
BandMarching/v_BandMarching_g07_c06.avi
BandMarching/v_BandMarching_g07_c07.avi
BaseballPitch/v_BaseballPitch_g01_c01.avi
BaseballPitch/v_BaseballPitch_g01_c02.avi
BaseballPitch/v_BaseballPitch_g01_c03.avi
BaseballPitch/v_BaseballPitch_g01_c04.avi
BaseballPitch/v_BaseballPitch_g01_c05.avi
BaseballPitch/v_BaseballPitch_g01_c06.avi
BaseballPitch/v_BaseballPitch_g02_c01.avi
BaseballPitch/v_BaseballPitch_g02_c02.avi
BaseballPitch/v_BaseballPitch_g02_c03.avi
BaseballPitch/v_BaseballPitch_g02_c04.avi
BaseballPitch/v_BaseballPitch_g03_c01.avi
BaseballPitch/v_BaseballPitch_g03_c02.avi
BaseballPitch/v_BaseballPitch_g03_c03.avi
BaseballPitch/v_BaseballPitch_g03_c04.avi
BaseballPitch/v_BaseballPitch_g03_c05.avi
BaseballPitch/v_BaseballPitch_g03_c06.avi
BaseballPitch/v_BaseballPitch_g03_c07.avi
BaseballPitch/v_BaseballPitch_g04_c01.avi
BaseballPitch/v_BaseballPitch_g04_c02.avi
BaseballPitch/v_BaseballPitch_g04_c03.avi
BaseballPitch/v_BaseballPitch_g04_c04.avi
BaseballPitch/v_BaseballPitch_g04_c05.avi
BaseballPitch/v_BaseballPitch_g05_c01.avi
BaseballPitch/v_BaseballPitch_g05_c02.avi
BaseballPitch/v_BaseballPitch_g05_c03.avi
BaseballPitch/v_BaseballPitch_g05_c04.avi
BaseballPitch/v_BaseballPitch_g05_c05.avi
BaseballPitch/v_BaseballPitch_g05_c06.avi
BaseballPitch/v_BaseballPitch_g05_c07.avi
BaseballPitch/v_BaseballPitch_g06_c01.avi
BaseballPitch/v_BaseballPitch_g06_c02.avi
BaseballPitch/v_BaseballPitch_g06_c03.avi
BaseballPitch/v_BaseballPitch_g06_c04.avi
BaseballPitch/v_BaseballPitch_g06_c05.avi
BaseballPitch/v_BaseballPitch_g06_c06.avi
BaseballPitch/v_BaseballPitch_g06_c07.avi
BaseballPitch/v_BaseballPitch_g07_c01.avi
BaseballPitch/v_BaseballPitch_g07_c02.avi
BaseballPitch/v_BaseballPitch_g07_c03.avi
BaseballPitch/v_BaseballPitch_g07_c04.avi
BaseballPitch/v_BaseballPitch_g07_c05.avi
BaseballPitch/v_BaseballPitch_g07_c06.avi
BaseballPitch/v_BaseballPitch_g07_c07.avi
Basketball/v_Basketball_g01_c01.avi
Basketball/v_Basketball_g01_c02.avi
Basketball/v_Basketball_g01_c03.avi
Basketball/v_Basketball_g01_c04.avi
Basketball/v_Basketball_g01_c05.avi
Basketball/v_Basketball_g01_c06.avi
Basketball/v_Basketball_g01_c07.avi
Basketball/v_Basketball_g02_c01.avi
Basketball/v_Basketball_g02_c02.avi
Basketball/v_Basketball_g02_c03.avi
Basketball/v_Basketball_g02_c04.avi
Basketball/v_Basketball_g02_c05.avi
Basketball/v_Basketball_g02_c06.avi
Basketball/v_Basketball_g03_c01.avi
Basketball/v_Basketball_g03_c02.avi
Basketball/v_Basketball_g03_c03.avi
Basketball/v_Basketball_g03_c04.avi
Basketball/v_Basketball_g03_c05.avi
Basketball/v_Basketball_g03_c06.avi
Basketball/v_Basketball_g04_c01.avi
Basketball/v_Basketball_g04_c02.avi
Basketball/v_Basketball_g04_c03.avi
Basketball/v_Basketball_g04_c04.avi
Basketball/v_Basketball_g05_c01.avi
Basketball/v_Basketball_g05_c02.avi
Basketball/v_Basketball_g05_c03.avi
Basketball/v_Basketball_g05_c04.avi
Basketball/v_Basketball_g06_c01.avi
Basketball/v_Basketball_g06_c02.avi
Basketball/v_Basketball_g06_c03.avi
Basketball/v_Basketball_g06_c04.avi
Basketball/v_Basketball_g07_c01.avi
Basketball/v_Basketball_g07_c02.avi
Basketball/v_Basketball_g07_c03.avi
Basketball/v_Basketball_g07_c04.avi
BasketballDunk/v_BasketballDunk_g01_c01.avi
BasketballDunk/v_BasketballDunk_g01_c02.avi
BasketballDunk/v_BasketballDunk_g01_c03.avi
BasketballDunk/v_BasketballDunk_g01_c04.avi
BasketballDunk/v_BasketballDunk_g01_c05.avi
BasketballDunk/v_BasketballDunk_g01_c06.avi
BasketballDunk/v_BasketballDunk_g01_c07.avi
BasketballDunk/v_BasketballDunk_g02_c01.avi
BasketballDunk/v_BasketballDunk_g02_c02.avi
BasketballDunk/v_BasketballDunk_g02_c03.avi
BasketballDunk/v_BasketballDunk_g02_c04.avi
BasketballDunk/v_BasketballDunk_g03_c01.avi
BasketballDunk/v_BasketballDunk_g03_c02.avi
BasketballDunk/v_BasketballDunk_g03_c03.avi
BasketballDunk/v_BasketballDunk_g03_c04.avi
BasketballDunk/v_BasketballDunk_g03_c05.avi
BasketballDunk/v_BasketballDunk_g03_c06.avi
BasketballDunk/v_BasketballDunk_g04_c01.avi
BasketballDunk/v_BasketballDunk_g04_c02.avi
BasketballDunk/v_BasketballDunk_g04_c03.avi
BasketballDunk/v_BasketballDunk_g04_c04.avi
BasketballDunk/v_BasketballDunk_g05_c01.avi
BasketballDunk/v_BasketballDunk_g05_c02.avi
BasketballDunk/v_BasketballDunk_g05_c03.avi
BasketballDunk/v_BasketballDunk_g05_c04.avi
BasketballDunk/v_BasketballDunk_g05_c05.avi
BasketballDunk/v_BasketballDunk_g05_c06.avi
BasketballDunk/v_BasketballDunk_g06_c01.avi
BasketballDunk/v_BasketballDunk_g06_c02.avi
BasketballDunk/v_BasketballDunk_g06_c03.avi
BasketballDunk/v_BasketballDunk_g06_c04.avi
BasketballDunk/v_BasketballDunk_g07_c01.avi
BasketballDunk/v_BasketballDunk_g07_c02.avi
BasketballDunk/v_BasketballDunk_g07_c03.avi
BasketballDunk/v_BasketballDunk_g07_c04.avi
BasketballDunk/v_BasketballDunk_g07_c05.avi
BasketballDunk/v_BasketballDunk_g07_c06.avi
BenchPress/v_BenchPress_g01_c01.avi
BenchPress/v_BenchPress_g01_c02.avi
BenchPress/v_BenchPress_g01_c03.avi
BenchPress/v_BenchPress_g01_c04.avi
BenchPress/v_BenchPress_g01_c05.avi
BenchPress/v_BenchPress_g01_c06.avi
BenchPress/v_BenchPress_g02_c01.avi
BenchPress/v_BenchPress_g02_c02.avi
BenchPress/v_BenchPress_g02_c03.avi
BenchPress/v_BenchPress_g02_c04.avi
BenchPress/v_BenchPress_g02_c05.avi
BenchPress/v_BenchPress_g02_c06.avi
BenchPress/v_BenchPress_g02_c07.avi
BenchPress/v_BenchPress_g03_c01.avi
BenchPress/v_BenchPress_g03_c02.avi
BenchPress/v_BenchPress_g03_c03.avi
BenchPress/v_BenchPress_g03_c04.avi
BenchPress/v_BenchPress_g03_c05.avi
BenchPress/v_BenchPress_g03_c06.avi
BenchPress/v_BenchPress_g03_c07.avi
BenchPress/v_BenchPress_g04_c01.avi
BenchPress/v_BenchPress_g04_c02.avi
BenchPress/v_BenchPress_g04_c03.avi
BenchPress/v_BenchPress_g04_c04.avi
BenchPress/v_BenchPress_g04_c05.avi
BenchPress/v_BenchPress_g04_c06.avi
BenchPress/v_BenchPress_g04_c07.avi
BenchPress/v_BenchPress_g05_c01.avi
BenchPress/v_BenchPress_g05_c02.avi
BenchPress/v_BenchPress_g05_c03.avi
BenchPress/v_BenchPress_g05_c04.avi
BenchPress/v_BenchPress_g05_c05.avi
BenchPress/v_BenchPress_g05_c06.avi
BenchPress/v_BenchPress_g05_c07.avi
BenchPress/v_BenchPress_g06_c01.avi
BenchPress/v_BenchPress_g06_c02.avi
BenchPress/v_BenchPress_g06_c03.avi
BenchPress/v_BenchPress_g06_c04.avi
BenchPress/v_BenchPress_g06_c05.avi
BenchPress/v_BenchPress_g06_c06.avi
BenchPress/v_BenchPress_g06_c07.avi
BenchPress/v_BenchPress_g07_c01.avi
BenchPress/v_BenchPress_g07_c02.avi
BenchPress/v_BenchPress_g07_c03.avi
BenchPress/v_BenchPress_g07_c04.avi
BenchPress/v_BenchPress_g07_c05.avi
BenchPress/v_BenchPress_g07_c06.avi
BenchPress/v_BenchPress_g07_c07.avi
Biking/v_Biking_g01_c01.avi
Biking/v_Biking_g01_c02.avi
Biking/v_Biking_g01_c03.avi
Biking/v_Biking_g01_c04.avi
Biking/v_Biking_g02_c01.avi
Biking/v_Biking_g02_c02.avi
Biking/v_Biking_g02_c03.avi
Biking/v_Biking_g02_c04.avi
Biking/v_Biking_g02_c05.avi
Biking/v_Biking_g02_c06.avi
Biking/v_Biking_g02_c07.avi
Biking/v_Biking_g03_c01.avi
Biking/v_Biking_g03_c02.avi
Biking/v_Biking_g03_c03.avi
Biking/v_Biking_g03_c04.avi
Biking/v_Biking_g04_c01.avi
Biking/v_Biking_g04_c02.avi
Biking/v_Biking_g04_c03.avi
Biking/v_Biking_g04_c04.avi
Biking/v_Biking_g04_c05.avi
Biking/v_Biking_g05_c01.avi
Biking/v_Biking_g05_c02.avi
Biking/v_Biking_g05_c03.avi
Biking/v_Biking_g05_c04.avi
Biking/v_Biking_g05_c05.avi
Biking/v_Biking_g05_c06.avi
Biking/v_Biking_g05_c07.avi
Biking/v_Biking_g06_c01.avi
Biking/v_Biking_g06_c02.avi
Biking/v_Biking_g06_c03.avi
Biking/v_Biking_g06_c04.avi
Biking/v_Biking_g06_c05.avi
Biking/v_Biking_g07_c01.avi
Biking/v_Biking_g07_c02.avi
Biking/v_Biking_g07_c03.avi
Biking/v_Biking_g07_c04.avi
Biking/v_Biking_g07_c05.avi
Biking/v_Biking_g07_c06.avi
Billiards/v_Billiards_g01_c01.avi
Billiards/v_Billiards_g01_c02.avi
Billiards/v_Billiards_g01_c03.avi
Billiards/v_Billiards_g01_c04.avi
Billiards/v_Billiards_g01_c05.avi
Billiards/v_Billiards_g01_c06.avi
Billiards/v_Billiards_g02_c01.avi
Billiards/v_Billiards_g02_c02.avi
Billiards/v_Billiards_g02_c03.avi
Billiards/v_Billiards_g02_c04.avi
Billiards/v_Billiards_g02_c05.avi
Billiards/v_Billiards_g02_c06.avi
Billiards/v_Billiards_g02_c07.avi
Billiards/v_Billiards_g03_c01.avi
Billiards/v_Billiards_g03_c02.avi
Billiards/v_Billiards_g03_c03.avi
Billiards/v_Billiards_g03_c04.avi
Billiards/v_Billiards_g03_c05.avi
Billiards/v_Billiards_g04_c01.avi
Billiards/v_Billiards_g04_c02.avi
Billiards/v_Billiards_g04_c03.avi
Billiards/v_Billiards_g04_c04.avi
Billiards/v_Billiards_g04_c05.avi
Billiards/v_Billiards_g04_c06.avi
Billiards/v_Billiards_g04_c07.avi
Billiards/v_Billiards_g05_c01.avi
Billiards/v_Billiards_g05_c02.avi
Billiards/v_Billiards_g05_c03.avi
Billiards/v_Billiards_g05_c04.avi
Billiards/v_Billiards_g05_c05.avi
Billiards/v_Billiards_g05_c06.avi
Billiards/v_Billiards_g06_c01.avi
Billiards/v_Billiards_g06_c02.avi
Billiards/v_Billiards_g06_c03.avi
Billiards/v_Billiards_g06_c04.avi
Billiards/v_Billiards_g06_c05.avi
Billiards/v_Billiards_g07_c01.avi
Billiards/v_Billiards_g07_c02.avi
Billiards/v_Billiards_g07_c03.avi
Billiards/v_Billiards_g07_c04.avi
BlowDryHair/v_BlowDryHair_g01_c01.avi
BlowDryHair/v_BlowDryHair_g01_c02.avi
BlowDryHair/v_BlowDryHair_g01_c03.avi
BlowDryHair/v_BlowDryHair_g01_c04.avi
BlowDryHair/v_BlowDryHair_g02_c01.avi
BlowDryHair/v_BlowDryHair_g02_c02.avi
BlowDryHair/v_BlowDryHair_g02_c03.avi
BlowDryHair/v_BlowDryHair_g02_c04.avi
BlowDryHair/v_BlowDryHair_g02_c05.avi
BlowDryHair/v_BlowDryHair_g03_c01.avi
BlowDryHair/v_BlowDryHair_g03_c02.avi
BlowDryHair/v_BlowDryHair_g03_c03.avi
BlowDryHair/v_BlowDryHair_g03_c04.avi
BlowDryHair/v_BlowDryHair_g03_c05.avi
BlowDryHair/v_BlowDryHair_g04_c01.avi
BlowDryHair/v_BlowDryHair_g04_c02.avi
BlowDryHair/v_BlowDryHair_g04_c03.avi
BlowDryHair/v_BlowDryHair_g04_c04.avi
BlowDryHair/v_BlowDryHair_g04_c05.avi
BlowDryHair/v_BlowDryHair_g05_c01.avi
BlowDryHair/v_BlowDryHair_g05_c02.avi
BlowDryHair/v_BlowDryHair_g05_c03.avi
BlowDryHair/v_BlowDryHair_g05_c04.avi
BlowDryHair/v_BlowDryHair_g05_c05.avi
BlowDryHair/v_BlowDryHair_g06_c01.avi
BlowDryHair/v_BlowDryHair_g06_c02.avi
BlowDryHair/v_BlowDryHair_g06_c03.avi
BlowDryHair/v_BlowDryHair_g06_c04.avi
BlowDryHair/v_BlowDryHair_g06_c05.avi
BlowDryHair/v_BlowDryHair_g06_c06.avi
BlowDryHair/v_BlowDryHair_g06_c07.avi
BlowDryHair/v_BlowDryHair_g07_c01.avi
BlowDryHair/v_BlowDryHair_g07_c02.avi
BlowDryHair/v_BlowDryHair_g07_c03.avi
BlowDryHair/v_BlowDryHair_g07_c04.avi
BlowDryHair/v_BlowDryHair_g07_c05.avi
BlowDryHair/v_BlowDryHair_g07_c06.avi
BlowDryHair/v_BlowDryHair_g07_c07.avi
BlowingCandles/v_BlowingCandles_g01_c01.avi
BlowingCandles/v_BlowingCandles_g01_c02.avi
BlowingCandles/v_BlowingCandles_g01_c03.avi
BlowingCandles/v_BlowingCandles_g01_c04.avi
BlowingCandles/v_BlowingCandles_g02_c01.avi
BlowingCandles/v_BlowingCandles_g02_c02.avi
BlowingCandles/v_BlowingCandles_g02_c03.avi
BlowingCandles/v_BlowingCandles_g02_c04.avi
BlowingCandles/v_BlowingCandles_g03_c01.avi
BlowingCandles/v_BlowingCandles_g03_c02.avi
BlowingCandles/v_BlowingCandles_g03_c03.avi
BlowingCandles/v_BlowingCandles_g03_c04.avi
BlowingCandles/v_BlowingCandles_g04_c01.avi
BlowingCandles/v_BlowingCandles_g04_c02.avi
BlowingCandles/v_BlowingCandles_g04_c03.avi
BlowingCandles/v_BlowingCandles_g04_c04.avi
BlowingCandles/v_BlowingCandles_g04_c05.avi
BlowingCandles/v_BlowingCandles_g05_c01.avi
BlowingCandles/v_BlowingCandles_g05_c02.avi
BlowingCandles/v_BlowingCandles_g05_c03.avi
BlowingCandles/v_BlowingCandles_g05_c04.avi
BlowingCandles/v_BlowingCandles_g05_c05.avi
BlowingCandles/v_BlowingCandles_g06_c01.avi
BlowingCandles/v_BlowingCandles_g06_c02.avi
BlowingCandles/v_BlowingCandles_g06_c03.avi
BlowingCandles/v_BlowingCandles_g06_c04.avi
BlowingCandles/v_BlowingCandles_g06_c05.avi
BlowingCandles/v_BlowingCandles_g06_c06.avi
BlowingCandles/v_BlowingCandles_g06_c07.avi
BlowingCandles/v_BlowingCandles_g07_c01.avi
BlowingCandles/v_BlowingCandles_g07_c02.avi
BlowingCandles/v_BlowingCandles_g07_c03.avi
BlowingCandles/v_BlowingCandles_g07_c04.avi
BodyWeightSquats/v_BodyWeightSquats_g01_c01.avi
BodyWeightSquats/v_BodyWeightSquats_g01_c02.avi
BodyWeightSquats/v_BodyWeightSquats_g01_c03.avi
BodyWeightSquats/v_BodyWeightSquats_g01_c04.avi
BodyWeightSquats/v_BodyWeightSquats_g02_c01.avi
BodyWeightSquats/v_BodyWeightSquats_g02_c02.avi
BodyWeightSquats/v_BodyWeightSquats_g02_c03.avi
BodyWeightSquats/v_BodyWeightSquats_g02_c04.avi
BodyWeightSquats/v_BodyWeightSquats_g03_c01.avi
BodyWeightSquats/v_BodyWeightSquats_g03_c02.avi
BodyWeightSquats/v_BodyWeightSquats_g03_c03.avi
BodyWeightSquats/v_BodyWeightSquats_g03_c04.avi
BodyWeightSquats/v_BodyWeightSquats_g03_c05.avi
BodyWeightSquats/v_BodyWeightSquats_g04_c01.avi
BodyWeightSquats/v_BodyWeightSquats_g04_c02.avi
BodyWeightSquats/v_BodyWeightSquats_g04_c03.avi
BodyWeightSquats/v_BodyWeightSquats_g04_c04.avi
BodyWeightSquats/v_BodyWeightSquats_g05_c01.avi
BodyWeightSquats/v_BodyWeightSquats_g05_c02.avi
BodyWeightSquats/v_BodyWeightSquats_g05_c03.avi
BodyWeightSquats/v_BodyWeightSquats_g05_c04.avi
BodyWeightSquats/v_BodyWeightSquats_g06_c01.avi
BodyWeightSquats/v_BodyWeightSquats_g06_c02.avi
BodyWeightSquats/v_BodyWeightSquats_g06_c03.avi
BodyWeightSquats/v_BodyWeightSquats_g06_c04.avi
BodyWeightSquats/v_BodyWeightSquats_g06_c05.avi
BodyWeightSquats/v_BodyWeightSquats_g07_c01.avi
BodyWeightSquats/v_BodyWeightSquats_g07_c02.avi
BodyWeightSquats/v_BodyWeightSquats_g07_c03.avi
BodyWeightSquats/v_BodyWeightSquats_g07_c04.avi
Bowling/v_Bowling_g01_c01.avi
Bowling/v_Bowling_g01_c02.avi
Bowling/v_Bowling_g01_c03.avi
Bowling/v_Bowling_g01_c04.avi
Bowling/v_Bowling_g01_c05.avi
Bowling/v_Bowling_g01_c06.avi
Bowling/v_Bowling_g01_c07.avi
Bowling/v_Bowling_g02_c01.avi
Bowling/v_Bowling_g02_c02.avi
Bowling/v_Bowling_g02_c03.avi
Bowling/v_Bowling_g02_c04.avi
Bowling/v_Bowling_g03_c01.avi
Bowling/v_Bowling_g03_c02.avi
Bowling/v_Bowling_g03_c03.avi
Bowling/v_Bowling_g03_c04.avi
Bowling/v_Bowling_g03_c05.avi
Bowling/v_Bowling_g03_c06.avi
Bowling/v_Bowling_g03_c07.avi
Bowling/v_Bowling_g04_c01.avi
Bowling/v_Bowling_g04_c02.avi
Bowling/v_Bowling_g04_c03.avi
Bowling/v_Bowling_g04_c04.avi
Bowling/v_Bowling_g05_c01.avi
Bowling/v_Bowling_g05_c02.avi
Bowling/v_Bowling_g05_c03.avi
Bowling/v_Bowling_g05_c04.avi
Bowling/v_Bowling_g05_c05.avi
Bowling/v_Bowling_g05_c06.avi
Bowling/v_Bowling_g05_c07.avi
Bowling/v_Bowling_g06_c01.avi
Bowling/v_Bowling_g06_c02.avi
Bowling/v_Bowling_g06_c03.avi
Bowling/v_Bowling_g06_c04.avi
Bowling/v_Bowling_g06_c05.avi
Bowling/v_Bowling_g06_c06.avi
Bowling/v_Bowling_g06_c07.avi
Bowling/v_Bowling_g07_c01.avi
Bowling/v_Bowling_g07_c02.avi
Bowling/v_Bowling_g07_c03.avi
Bowling/v_Bowling_g07_c04.avi
Bowling/v_Bowling_g07_c05.avi
Bowling/v_Bowling_g07_c06.avi
Bowling/v_Bowling_g07_c07.avi
BoxingPunchingBag/v_BoxingPunchingBag_g01_c01.avi
BoxingPunchingBag/v_BoxingPunchingBag_g01_c02.avi
BoxingPunchingBag/v_BoxingPunchingBag_g01_c03.avi
BoxingPunchingBag/v_BoxingPunchingBag_g01_c04.avi
BoxingPunchingBag/v_BoxingPunchingBag_g01_c05.avi
BoxingPunchingBag/v_BoxingPunchingBag_g01_c06.avi
BoxingPunchingBag/v_BoxingPunchingBag_g01_c07.avi
BoxingPunchingBag/v_BoxingPunchingBag_g02_c01.avi
BoxingPunchingBag/v_BoxingPunchingBag_g02_c02.avi
BoxingPunchingBag/v_BoxingPunchingBag_g02_c03.avi
BoxingPunchingBag/v_BoxingPunchingBag_g02_c04.avi
BoxingPunchingBag/v_BoxingPunchingBag_g02_c05.avi
BoxingPunchingBag/v_BoxingPunchingBag_g02_c06.avi
BoxingPunchingBag/v_BoxingPunchingBag_g02_c07.avi
BoxingPunchingBag/v_BoxingPunchingBag_g03_c01.avi
BoxingPunchingBag/v_BoxingPunchingBag_g03_c02.avi
BoxingPunchingBag/v_BoxingPunchingBag_g03_c03.avi
BoxingPunchingBag/v_BoxingPunchingBag_g03_c04.avi
BoxingPunchingBag/v_BoxingPunchingBag_g03_c05.avi
BoxingPunchingBag/v_BoxingPunchingBag_g03_c06.avi
BoxingPunchingBag/v_BoxingPunchingBag_g03_c07.avi
BoxingPunchingBag/v_BoxingPunchingBag_g04_c01.avi
BoxingPunchingBag/v_BoxingPunchingBag_g04_c02.avi
BoxingPunchingBag/v_BoxingPunchingBag_g04_c03.avi
BoxingPunchingBag/v_BoxingPunchingBag_g04_c04.avi
BoxingPunchingBag/v_BoxingPunchingBag_g04_c05.avi
BoxingPunchingBag/v_BoxingPunchingBag_g04_c06.avi
BoxingPunchingBag/v_BoxingPunchingBag_g04_c07.avi
BoxingPunchingBag/v_BoxingPunchingBag_g05_c01.avi
BoxingPunchingBag/v_BoxingPunchingBag_g05_c02.avi
BoxingPunchingBag/v_BoxingPunchingBag_g05_c03.avi
BoxingPunchingBag/v_BoxingPunchingBag_g05_c04.avi
BoxingPunchingBag/v_BoxingPunchingBag_g05_c05.avi
BoxingPunchingBag/v_BoxingPunchingBag_g05_c06.avi
BoxingPunchingBag/v_BoxingPunchingBag_g05_c07.avi
BoxingPunchingBag/v_BoxingPunchingBag_g06_c01.avi
BoxingPunchingBag/v_BoxingPunchingBag_g06_c02.avi
BoxingPunchingBag/v_BoxingPunchingBag_g06_c03.avi
BoxingPunchingBag/v_BoxingPunchingBag_g06_c04.avi
BoxingPunchingBag/v_BoxingPunchingBag_g06_c05.avi
BoxingPunchingBag/v_BoxingPunchingBag_g06_c06.avi
BoxingPunchingBag/v_BoxingPunchingBag_g06_c07.avi
BoxingPunchingBag/v_BoxingPunchingBag_g07_c01.avi
BoxingPunchingBag/v_BoxingPunchingBag_g07_c02.avi
BoxingPunchingBag/v_BoxingPunchingBag_g07_c03.avi
BoxingPunchingBag/v_BoxingPunchingBag_g07_c04.avi
BoxingPunchingBag/v_BoxingPunchingBag_g07_c05.avi
BoxingPunchingBag/v_BoxingPunchingBag_g07_c06.avi
BoxingPunchingBag/v_BoxingPunchingBag_g07_c07.avi
BoxingSpeedBag/v_BoxingSpeedBag_g01_c01.avi
BoxingSpeedBag/v_BoxingSpeedBag_g01_c02.avi
BoxingSpeedBag/v_BoxingSpeedBag_g01_c03.avi
BoxingSpeedBag/v_BoxingSpeedBag_g01_c04.avi
BoxingSpeedBag/v_BoxingSpeedBag_g02_c01.avi
BoxingSpeedBag/v_BoxingSpeedBag_g02_c02.avi
BoxingSpeedBag/v_BoxingSpeedBag_g02_c03.avi
BoxingSpeedBag/v_BoxingSpeedBag_g02_c04.avi
BoxingSpeedBag/v_BoxingSpeedBag_g03_c01.avi
BoxingSpeedBag/v_BoxingSpeedBag_g03_c02.avi
BoxingSpeedBag/v_BoxingSpeedBag_g03_c03.avi
BoxingSpeedBag/v_BoxingSpeedBag_g03_c04.avi
BoxingSpeedBag/v_BoxingSpeedBag_g03_c05.avi
BoxingSpeedBag/v_BoxingSpeedBag_g04_c01.avi
BoxingSpeedBag/v_BoxingSpeedBag_g04_c02.avi
BoxingSpeedBag/v_BoxingSpeedBag_g04_c03.avi
BoxingSpeedBag/v_BoxingSpeedBag_g04_c04.avi
BoxingSpeedBag/v_BoxingSpeedBag_g04_c05.avi
BoxingSpeedBag/v_BoxingSpeedBag_g04_c06.avi
BoxingSpeedBag/v_BoxingSpeedBag_g04_c07.avi
BoxingSpeedBag/v_BoxingSpeedBag_g05_c01.avi
BoxingSpeedBag/v_BoxingSpeedBag_g05_c02.avi
BoxingSpeedBag/v_BoxingSpeedBag_g05_c03.avi
BoxingSpeedBag/v_BoxingSpeedBag_g05_c04.avi
BoxingSpeedBag/v_BoxingSpeedBag_g05_c05.avi
BoxingSpeedBag/v_BoxingSpeedBag_g06_c01.avi
BoxingSpeedBag/v_BoxingSpeedBag_g06_c02.avi
BoxingSpeedBag/v_BoxingSpeedBag_g06_c03.avi
BoxingSpeedBag/v_BoxingSpeedBag_g06_c04.avi
BoxingSpeedBag/v_BoxingSpeedBag_g06_c05.avi
BoxingSpeedBag/v_BoxingSpeedBag_g07_c01.avi
BoxingSpeedBag/v_BoxingSpeedBag_g07_c02.avi
BoxingSpeedBag/v_BoxingSpeedBag_g07_c03.avi
BoxingSpeedBag/v_BoxingSpeedBag_g07_c04.avi
BoxingSpeedBag/v_BoxingSpeedBag_g07_c05.avi
BoxingSpeedBag/v_BoxingSpeedBag_g07_c06.avi
BoxingSpeedBag/v_BoxingSpeedBag_g07_c07.avi
BreastStroke/v_BreastStroke_g01_c01.avi
BreastStroke/v_BreastStroke_g01_c02.avi
BreastStroke/v_BreastStroke_g01_c03.avi
BreastStroke/v_BreastStroke_g01_c04.avi
BreastStroke/v_BreastStroke_g02_c01.avi
BreastStroke/v_BreastStroke_g02_c02.avi
BreastStroke/v_BreastStroke_g02_c03.avi
BreastStroke/v_BreastStroke_g02_c04.avi
BreastStroke/v_BreastStroke_g03_c01.avi
BreastStroke/v_BreastStroke_g03_c02.avi
BreastStroke/v_BreastStroke_g03_c03.avi
BreastStroke/v_BreastStroke_g03_c04.avi
BreastStroke/v_BreastStroke_g04_c01.avi
BreastStroke/v_BreastStroke_g04_c02.avi
BreastStroke/v_BreastStroke_g04_c03.avi
BreastStroke/v_BreastStroke_g04_c04.avi
BreastStroke/v_BreastStroke_g05_c01.avi
BreastStroke/v_BreastStroke_g05_c02.avi
BreastStroke/v_BreastStroke_g05_c03.avi
BreastStroke/v_BreastStroke_g05_c04.avi
BreastStroke/v_BreastStroke_g06_c01.avi
BreastStroke/v_BreastStroke_g06_c02.avi
BreastStroke/v_BreastStroke_g06_c03.avi
BreastStroke/v_BreastStroke_g06_c04.avi
BreastStroke/v_BreastStroke_g07_c01.avi
BreastStroke/v_BreastStroke_g07_c02.avi
BreastStroke/v_BreastStroke_g07_c03.avi
BreastStroke/v_BreastStroke_g07_c04.avi
BrushingTeeth/v_BrushingTeeth_g01_c01.avi
BrushingTeeth/v_BrushingTeeth_g01_c02.avi
BrushingTeeth/v_BrushingTeeth_g01_c03.avi
BrushingTeeth/v_BrushingTeeth_g01_c04.avi
BrushingTeeth/v_BrushingTeeth_g02_c01.avi
BrushingTeeth/v_BrushingTeeth_g02_c02.avi
BrushingTeeth/v_BrushingTeeth_g02_c03.avi
BrushingTeeth/v_BrushingTeeth_g02_c04.avi
BrushingTeeth/v_BrushingTeeth_g02_c05.avi
BrushingTeeth/v_BrushingTeeth_g02_c06.avi
BrushingTeeth/v_BrushingTeeth_g02_c07.avi
BrushingTeeth/v_BrushingTeeth_g03_c01.avi
BrushingTeeth/v_BrushingTeeth_g03_c02.avi
BrushingTeeth/v_BrushingTeeth_g03_c03.avi
BrushingTeeth/v_BrushingTeeth_g03_c04.avi
BrushingTeeth/v_BrushingTeeth_g03_c05.avi
BrushingTeeth/v_BrushingTeeth_g04_c01.avi
BrushingTeeth/v_BrushingTeeth_g04_c02.avi
BrushingTeeth/v_BrushingTeeth_g04_c03.avi
BrushingTeeth/v_BrushingTeeth_g04_c04.avi
BrushingTeeth/v_BrushingTeeth_g05_c01.avi
BrushingTeeth/v_BrushingTeeth_g05_c02.avi
BrushingTeeth/v_BrushingTeeth_g05_c03.avi
BrushingTeeth/v_BrushingTeeth_g05_c04.avi
BrushingTeeth/v_BrushingTeeth_g05_c05.avi
BrushingTeeth/v_BrushingTeeth_g06_c01.avi
BrushingTeeth/v_BrushingTeeth_g06_c02.avi
BrushingTeeth/v_BrushingTeeth_g06_c03.avi
BrushingTeeth/v_BrushingTeeth_g06_c04.avi
BrushingTeeth/v_BrushingTeeth_g06_c05.avi
BrushingTeeth/v_BrushingTeeth_g07_c01.avi
BrushingTeeth/v_BrushingTeeth_g07_c02.avi
BrushingTeeth/v_BrushingTeeth_g07_c03.avi
BrushingTeeth/v_BrushingTeeth_g07_c04.avi
BrushingTeeth/v_BrushingTeeth_g07_c05.avi
BrushingTeeth/v_BrushingTeeth_g07_c06.avi
CleanAndJerk/v_CleanAndJerk_g01_c01.avi
CleanAndJerk/v_CleanAndJerk_g01_c02.avi
CleanAndJerk/v_CleanAndJerk_g01_c03.avi
CleanAndJerk/v_CleanAndJerk_g01_c04.avi
CleanAndJerk/v_CleanAndJerk_g01_c05.avi
CleanAndJerk/v_CleanAndJerk_g02_c01.avi
CleanAndJerk/v_CleanAndJerk_g02_c02.avi
CleanAndJerk/v_CleanAndJerk_g02_c03.avi
CleanAndJerk/v_CleanAndJerk_g02_c04.avi
CleanAndJerk/v_CleanAndJerk_g03_c01.avi
CleanAndJerk/v_CleanAndJerk_g03_c02.avi
CleanAndJerk/v_CleanAndJerk_g03_c03.avi
CleanAndJerk/v_CleanAndJerk_g03_c04.avi
CleanAndJerk/v_CleanAndJerk_g03_c05.avi
CleanAndJerk/v_CleanAndJerk_g03_c06.avi
CleanAndJerk/v_CleanAndJerk_g04_c01.avi
CleanAndJerk/v_CleanAndJerk_g04_c02.avi
CleanAndJerk/v_CleanAndJerk_g04_c03.avi
CleanAndJerk/v_CleanAndJerk_g04_c04.avi
CleanAndJerk/v_CleanAndJerk_g04_c05.avi
CleanAndJerk/v_CleanAndJerk_g05_c01.avi
CleanAndJerk/v_CleanAndJerk_g05_c02.avi
CleanAndJerk/v_CleanAndJerk_g05_c03.avi
CleanAndJerk/v_CleanAndJerk_g05_c04.avi
CleanAndJerk/v_CleanAndJerk_g06_c01.avi
CleanAndJerk/v_CleanAndJerk_g06_c02.avi
CleanAndJerk/v_CleanAndJerk_g06_c03.avi
CleanAndJerk/v_CleanAndJerk_g06_c04.avi
CleanAndJerk/v_CleanAndJerk_g07_c01.avi
CleanAndJerk/v_CleanAndJerk_g07_c02.avi
CleanAndJerk/v_CleanAndJerk_g07_c03.avi
CleanAndJerk/v_CleanAndJerk_g07_c04.avi
CleanAndJerk/v_CleanAndJerk_g07_c05.avi
CliffDiving/v_CliffDiving_g01_c01.avi
CliffDiving/v_CliffDiving_g01_c02.avi
CliffDiving/v_CliffDiving_g01_c03.avi
CliffDiving/v_CliffDiving_g01_c04.avi
CliffDiving/v_CliffDiving_g01_c05.avi
CliffDiving/v_CliffDiving_g01_c06.avi
CliffDiving/v_CliffDiving_g02_c01.avi
CliffDiving/v_CliffDiving_g02_c02.avi
CliffDiving/v_CliffDiving_g02_c03.avi
CliffDiving/v_CliffDiving_g02_c04.avi
CliffDiving/v_CliffDiving_g03_c01.avi
CliffDiving/v_CliffDiving_g03_c02.avi
CliffDiving/v_CliffDiving_g03_c03.avi
CliffDiving/v_CliffDiving_g03_c04.avi
CliffDiving/v_CliffDiving_g03_c05.avi
CliffDiving/v_CliffDiving_g04_c01.avi
CliffDiving/v_CliffDiving_g04_c02.avi
CliffDiving/v_CliffDiving_g04_c03.avi
CliffDiving/v_CliffDiving_g04_c04.avi
CliffDiving/v_CliffDiving_g05_c01.avi
CliffDiving/v_CliffDiving_g05_c02.avi
CliffDiving/v_CliffDiving_g05_c03.avi
CliffDiving/v_CliffDiving_g05_c04.avi
CliffDiving/v_CliffDiving_g05_c05.avi
CliffDiving/v_CliffDiving_g05_c06.avi
CliffDiving/v_CliffDiving_g05_c07.avi
CliffDiving/v_CliffDiving_g06_c01.avi
CliffDiving/v_CliffDiving_g06_c02.avi
CliffDiving/v_CliffDiving_g06_c03.avi
CliffDiving/v_CliffDiving_g06_c04.avi
CliffDiving/v_CliffDiving_g06_c05.avi
CliffDiving/v_CliffDiving_g06_c06.avi
CliffDiving/v_CliffDiving_g06_c07.avi
CliffDiving/v_CliffDiving_g07_c01.avi
CliffDiving/v_CliffDiving_g07_c02.avi
CliffDiving/v_CliffDiving_g07_c03.avi
CliffDiving/v_CliffDiving_g07_c04.avi
CliffDiving/v_CliffDiving_g07_c05.avi
CliffDiving/v_CliffDiving_g07_c06.avi
CricketBowling/v_CricketBowling_g01_c01.avi
CricketBowling/v_CricketBowling_g01_c02.avi
CricketBowling/v_CricketBowling_g01_c03.avi
CricketBowling/v_CricketBowling_g01_c04.avi
CricketBowling/v_CricketBowling_g01_c05.avi
CricketBowling/v_CricketBowling_g01_c06.avi
CricketBowling/v_CricketBowling_g01_c07.avi
CricketBowling/v_CricketBowling_g02_c01.avi
CricketBowling/v_CricketBowling_g02_c02.avi
CricketBowling/v_CricketBowling_g02_c03.avi
CricketBowling/v_CricketBowling_g02_c04.avi
CricketBowling/v_CricketBowling_g02_c05.avi
CricketBowling/v_CricketBowling_g02_c06.avi
CricketBowling/v_CricketBowling_g02_c07.avi
CricketBowling/v_CricketBowling_g03_c01.avi
CricketBowling/v_CricketBowling_g03_c02.avi
CricketBowling/v_CricketBowling_g03_c03.avi
CricketBowling/v_CricketBowling_g03_c04.avi
CricketBowling/v_CricketBowling_g04_c01.avi
CricketBowling/v_CricketBowling_g04_c02.avi
CricketBowling/v_CricketBowling_g04_c03.avi
CricketBowling/v_CricketBowling_g04_c04.avi
CricketBowling/v_CricketBowling_g04_c05.avi
CricketBowling/v_CricketBowling_g05_c01.avi
CricketBowling/v_CricketBowling_g05_c02.avi
CricketBowling/v_CricketBowling_g05_c03.avi
CricketBowling/v_CricketBowling_g05_c04.avi
CricketBowling/v_CricketBowling_g06_c01.avi
CricketBowling/v_CricketBowling_g06_c02.avi
CricketBowling/v_CricketBowling_g06_c03.avi
CricketBowling/v_CricketBowling_g06_c04.avi
CricketBowling/v_CricketBowling_g06_c05.avi
CricketBowling/v_CricketBowling_g07_c01.avi
CricketBowling/v_CricketBowling_g07_c02.avi
CricketBowling/v_CricketBowling_g07_c03.avi
CricketBowling/v_CricketBowling_g07_c04.avi
CricketShot/v_CricketShot_g01_c01.avi
CricketShot/v_CricketShot_g01_c02.avi
CricketShot/v_CricketShot_g01_c03.avi
CricketShot/v_CricketShot_g01_c04.avi
CricketShot/v_CricketShot_g01_c05.avi
CricketShot/v_CricketShot_g01_c06.avi
CricketShot/v_CricketShot_g01_c07.avi
CricketShot/v_CricketShot_g02_c01.avi
CricketShot/v_CricketShot_g02_c02.avi
CricketShot/v_CricketShot_g02_c03.avi
CricketShot/v_CricketShot_g02_c04.avi
CricketShot/v_CricketShot_g02_c05.avi
CricketShot/v_CricketShot_g02_c06.avi
CricketShot/v_CricketShot_g02_c07.avi
CricketShot/v_CricketShot_g03_c01.avi
CricketShot/v_CricketShot_g03_c02.avi
CricketShot/v_CricketShot_g03_c03.avi
CricketShot/v_CricketShot_g03_c04.avi
CricketShot/v_CricketShot_g03_c05.avi
CricketShot/v_CricketShot_g03_c06.avi
CricketShot/v_CricketShot_g03_c07.avi
CricketShot/v_CricketShot_g04_c01.avi
CricketShot/v_CricketShot_g04_c02.avi
CricketShot/v_CricketShot_g04_c03.avi
CricketShot/v_CricketShot_g04_c04.avi
CricketShot/v_CricketShot_g04_c05.avi
CricketShot/v_CricketShot_g04_c06.avi
CricketShot/v_CricketShot_g04_c07.avi
CricketShot/v_CricketShot_g05_c01.avi
CricketShot/v_CricketShot_g05_c02.avi
CricketShot/v_CricketShot_g05_c03.avi
CricketShot/v_CricketShot_g05_c04.avi
CricketShot/v_CricketShot_g05_c05.avi
CricketShot/v_CricketShot_g05_c06.avi
CricketShot/v_CricketShot_g05_c07.avi
CricketShot/v_CricketShot_g06_c01.avi
CricketShot/v_CricketShot_g06_c02.avi
CricketShot/v_CricketShot_g06_c03.avi
CricketShot/v_CricketShot_g06_c04.avi
CricketShot/v_CricketShot_g06_c05.avi
CricketShot/v_CricketShot_g06_c06.avi
CricketShot/v_CricketShot_g06_c07.avi
CricketShot/v_CricketShot_g07_c01.avi
CricketShot/v_CricketShot_g07_c02.avi
CricketShot/v_CricketShot_g07_c03.avi
CricketShot/v_CricketShot_g07_c04.avi
CricketShot/v_CricketShot_g07_c05.avi
CricketShot/v_CricketShot_g07_c06.avi
CricketShot/v_CricketShot_g07_c07.avi
CuttingInKitchen/v_CuttingInKitchen_g01_c01.avi
CuttingInKitchen/v_CuttingInKitchen_g01_c02.avi
CuttingInKitchen/v_CuttingInKitchen_g01_c03.avi
CuttingInKitchen/v_CuttingInKitchen_g01_c04.avi
CuttingInKitchen/v_CuttingInKitchen_g01_c05.avi
CuttingInKitchen/v_CuttingInKitchen_g02_c01.avi
CuttingInKitchen/v_CuttingInKitchen_g02_c02.avi
CuttingInKitchen/v_CuttingInKitchen_g02_c03.avi
CuttingInKitchen/v_CuttingInKitchen_g02_c04.avi
CuttingInKitchen/v_CuttingInKitchen_g03_c01.avi
CuttingInKitchen/v_CuttingInKitchen_g03_c02.avi
CuttingInKitchen/v_CuttingInKitchen_g03_c03.avi
CuttingInKitchen/v_CuttingInKitchen_g03_c04.avi
CuttingInKitchen/v_CuttingInKitchen_g04_c01.avi
CuttingInKitchen/v_CuttingInKitchen_g04_c02.avi
CuttingInKitchen/v_CuttingInKitchen_g04_c03.avi
CuttingInKitchen/v_CuttingInKitchen_g04_c04.avi
CuttingInKitchen/v_CuttingInKitchen_g04_c05.avi
CuttingInKitchen/v_CuttingInKitchen_g05_c01.avi
CuttingInKitchen/v_CuttingInKitchen_g05_c02.avi
CuttingInKitchen/v_CuttingInKitchen_g05_c03.avi
CuttingInKitchen/v_CuttingInKitchen_g05_c04.avi
CuttingInKitchen/v_CuttingInKitchen_g05_c05.avi
CuttingInKitchen/v_CuttingInKitchen_g05_c06.avi
CuttingInKitchen/v_CuttingInKitchen_g06_c01.avi
CuttingInKitchen/v_CuttingInKitchen_g06_c02.avi
CuttingInKitchen/v_CuttingInKitchen_g06_c03.avi
CuttingInKitchen/v_CuttingInKitchen_g06_c04.avi
CuttingInKitchen/v_CuttingInKitchen_g06_c05.avi
CuttingInKitchen/v_CuttingInKitchen_g07_c01.avi
CuttingInKitchen/v_CuttingInKitchen_g07_c02.avi
CuttingInKitchen/v_CuttingInKitchen_g07_c03.avi
CuttingInKitchen/v_CuttingInKitchen_g07_c04.avi
Diving/v_Diving_g01_c01.avi
Diving/v_Diving_g01_c02.avi
Diving/v_Diving_g01_c03.avi
Diving/v_Diving_g01_c04.avi
Diving/v_Diving_g01_c05.avi
Diving/v_Diving_g01_c06.avi
Diving/v_Diving_g01_c07.avi
Diving/v_Diving_g02_c01.avi
Diving/v_Diving_g02_c02.avi
Diving/v_Diving_g02_c03.avi
Diving/v_Diving_g02_c04.avi
Diving/v_Diving_g02_c05.avi
Diving/v_Diving_g02_c06.avi
Diving/v_Diving_g02_c07.avi
Diving/v_Diving_g03_c01.avi
Diving/v_Diving_g03_c02.avi
Diving/v_Diving_g03_c03.avi
Diving/v_Diving_g03_c04.avi
Diving/v_Diving_g03_c05.avi
Diving/v_Diving_g03_c06.avi
Diving/v_Diving_g03_c07.avi
Diving/v_Diving_g04_c01.avi
Diving/v_Diving_g04_c02.avi
Diving/v_Diving_g04_c03.avi
Diving/v_Diving_g04_c04.avi
Diving/v_Diving_g04_c05.avi
Diving/v_Diving_g04_c06.avi
Diving/v_Diving_g04_c07.avi
Diving/v_Diving_g05_c01.avi
Diving/v_Diving_g05_c02.avi
Diving/v_Diving_g05_c03.avi
Diving/v_Diving_g05_c04.avi
Diving/v_Diving_g05_c05.avi
Diving/v_Diving_g05_c06.avi
Diving/v_Diving_g06_c01.avi
Diving/v_Diving_g06_c02.avi
Diving/v_Diving_g06_c03.avi
Diving/v_Diving_g06_c04.avi
Diving/v_Diving_g06_c05.avi
Diving/v_Diving_g06_c06.avi
Diving/v_Diving_g06_c07.avi
Diving/v_Diving_g07_c01.avi
Diving/v_Diving_g07_c02.avi
Diving/v_Diving_g07_c03.avi
Diving/v_Diving_g07_c04.avi
Drumming/v_Drumming_g01_c01.avi
Drumming/v_Drumming_g01_c02.avi
Drumming/v_Drumming_g01_c03.avi
Drumming/v_Drumming_g01_c04.avi
Drumming/v_Drumming_g01_c05.avi
Drumming/v_Drumming_g01_c06.avi
Drumming/v_Drumming_g01_c07.avi
Drumming/v_Drumming_g02_c01.avi
Drumming/v_Drumming_g02_c02.avi
Drumming/v_Drumming_g02_c03.avi
Drumming/v_Drumming_g02_c04.avi
Drumming/v_Drumming_g02_c05.avi
Drumming/v_Drumming_g02_c06.avi
Drumming/v_Drumming_g02_c07.avi
Drumming/v_Drumming_g03_c01.avi
Drumming/v_Drumming_g03_c02.avi
Drumming/v_Drumming_g03_c03.avi
Drumming/v_Drumming_g03_c04.avi
Drumming/v_Drumming_g03_c05.avi
Drumming/v_Drumming_g04_c01.avi
Drumming/v_Drumming_g04_c02.avi
Drumming/v_Drumming_g04_c03.avi
Drumming/v_Drumming_g04_c04.avi
Drumming/v_Drumming_g04_c05.avi
Drumming/v_Drumming_g04_c06.avi
Drumming/v_Drumming_g04_c07.avi
Drumming/v_Drumming_g05_c01.avi
Drumming/v_Drumming_g05_c02.avi
Drumming/v_Drumming_g05_c03.avi
Drumming/v_Drumming_g05_c04.avi
Drumming/v_Drumming_g05_c05.avi
Drumming/v_Drumming_g05_c06.avi
Drumming/v_Drumming_g06_c01.avi
Drumming/v_Drumming_g06_c02.avi
Drumming/v_Drumming_g06_c03.avi
Drumming/v_Drumming_g06_c04.avi
Drumming/v_Drumming_g06_c05.avi
Drumming/v_Drumming_g06_c06.avi
Drumming/v_Drumming_g07_c01.avi
Drumming/v_Drumming_g07_c02.avi
Drumming/v_Drumming_g07_c03.avi
Drumming/v_Drumming_g07_c04.avi
Drumming/v_Drumming_g07_c05.avi
Drumming/v_Drumming_g07_c06.avi
Drumming/v_Drumming_g07_c07.avi
Fencing/v_Fencing_g01_c01.avi
Fencing/v_Fencing_g01_c02.avi
Fencing/v_Fencing_g01_c03.avi
Fencing/v_Fencing_g01_c04.avi
Fencing/v_Fencing_g01_c05.avi
Fencing/v_Fencing_g01_c06.avi
Fencing/v_Fencing_g02_c01.avi
Fencing/v_Fencing_g02_c02.avi
Fencing/v_Fencing_g02_c03.avi
Fencing/v_Fencing_g02_c04.avi
Fencing/v_Fencing_g02_c05.avi
Fencing/v_Fencing_g03_c01.avi
Fencing/v_Fencing_g03_c02.avi
Fencing/v_Fencing_g03_c03.avi
Fencing/v_Fencing_g03_c04.avi
Fencing/v_Fencing_g03_c05.avi
Fencing/v_Fencing_g04_c01.avi
Fencing/v_Fencing_g04_c02.avi
Fencing/v_Fencing_g04_c03.avi
Fencing/v_Fencing_g04_c04.avi
Fencing/v_Fencing_g04_c05.avi
Fencing/v_Fencing_g05_c01.avi
Fencing/v_Fencing_g05_c02.avi
Fencing/v_Fencing_g05_c03.avi
Fencing/v_Fencing_g05_c04.avi
Fencing/v_Fencing_g05_c05.avi
Fencing/v_Fencing_g06_c01.avi
Fencing/v_Fencing_g06_c02.avi
Fencing/v_Fencing_g06_c03.avi
Fencing/v_Fencing_g06_c04.avi
Fencing/v_Fencing_g07_c01.avi
Fencing/v_Fencing_g07_c02.avi
Fencing/v_Fencing_g07_c03.avi
Fencing/v_Fencing_g07_c04.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g01_c01.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g01_c02.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g01_c03.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g01_c04.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g01_c05.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g02_c01.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g02_c02.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g02_c03.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g02_c04.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g02_c05.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g02_c06.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g03_c01.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g03_c02.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g03_c03.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g03_c04.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g04_c01.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g04_c02.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g04_c03.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g04_c04.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g04_c05.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g04_c06.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g04_c07.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g05_c01.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g05_c02.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g05_c03.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g05_c04.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g05_c05.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g05_c06.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g05_c07.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g06_c01.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g06_c02.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g06_c03.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g06_c04.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g06_c05.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g06_c06.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g06_c07.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g07_c01.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g07_c02.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g07_c03.avi
FieldHockeyPenalty/v_FieldHockeyPenalty_g07_c04.avi
FloorGymnastics/v_FloorGymnastics_g01_c01.avi
FloorGymnastics/v_FloorGymnastics_g01_c02.avi
FloorGymnastics/v_FloorGymnastics_g01_c03.avi
FloorGymnastics/v_FloorGymnastics_g01_c04.avi
FloorGymnastics/v_FloorGymnastics_g01_c05.avi
FloorGymnastics/v_FloorGymnastics_g02_c01.avi
FloorGymnastics/v_FloorGymnastics_g02_c02.avi
FloorGymnastics/v_FloorGymnastics_g02_c03.avi
FloorGymnastics/v_FloorGymnastics_g02_c04.avi
FloorGymnastics/v_FloorGymnastics_g03_c01.avi
FloorGymnastics/v_FloorGymnastics_g03_c02.avi
FloorGymnastics/v_FloorGymnastics_g03_c03.avi
FloorGymnastics/v_FloorGymnastics_g03_c04.avi
FloorGymnastics/v_FloorGymnastics_g04_c01.avi
FloorGymnastics/v_FloorGymnastics_g04_c02.avi
FloorGymnastics/v_FloorGymnastics_g04_c03.avi
FloorGymnastics/v_FloorGymnastics_g04_c04.avi
FloorGymnastics/v_FloorGymnastics_g04_c05.avi
FloorGymnastics/v_FloorGymnastics_g05_c01.avi
FloorGymnastics/v_FloorGymnastics_g05_c02.avi
FloorGymnastics/v_FloorGymnastics_g05_c03.avi
FloorGymnastics/v_FloorGymnastics_g05_c04.avi
FloorGymnastics/v_FloorGymnastics_g06_c01.avi
FloorGymnastics/v_FloorGymnastics_g06_c02.avi
FloorGymnastics/v_FloorGymnastics_g06_c03.avi
FloorGymnastics/v_FloorGymnastics_g06_c04.avi
FloorGymnastics/v_FloorGymnastics_g06_c05.avi
FloorGymnastics/v_FloorGymnastics_g06_c06.avi
FloorGymnastics/v_FloorGymnastics_g06_c07.avi
FloorGymnastics/v_FloorGymnastics_g07_c01.avi
FloorGymnastics/v_FloorGymnastics_g07_c02.avi
FloorGymnastics/v_FloorGymnastics_g07_c03.avi
FloorGymnastics/v_FloorGymnastics_g07_c04.avi
FloorGymnastics/v_FloorGymnastics_g07_c05.avi
FloorGymnastics/v_FloorGymnastics_g07_c06.avi
FloorGymnastics/v_FloorGymnastics_g07_c07.avi
FrisbeeCatch/v_FrisbeeCatch_g01_c01.avi
FrisbeeCatch/v_FrisbeeCatch_g01_c02.avi
FrisbeeCatch/v_FrisbeeCatch_g01_c03.avi
FrisbeeCatch/v_FrisbeeCatch_g01_c04.avi
FrisbeeCatch/v_FrisbeeCatch_g01_c05.avi
FrisbeeCatch/v_FrisbeeCatch_g01_c06.avi
FrisbeeCatch/v_FrisbeeCatch_g02_c01.avi
FrisbeeCatch/v_FrisbeeCatch_g02_c02.avi
FrisbeeCatch/v_FrisbeeCatch_g02_c03.avi
FrisbeeCatch/v_FrisbeeCatch_g02_c04.avi
FrisbeeCatch/v_FrisbeeCatch_g02_c05.avi
FrisbeeCatch/v_FrisbeeCatch_g03_c01.avi
FrisbeeCatch/v_FrisbeeCatch_g03_c02.avi
FrisbeeCatch/v_FrisbeeCatch_g03_c03.avi
FrisbeeCatch/v_FrisbeeCatch_g03_c04.avi
FrisbeeCatch/v_FrisbeeCatch_g03_c05.avi
FrisbeeCatch/v_FrisbeeCatch_g04_c01.avi
FrisbeeCatch/v_FrisbeeCatch_g04_c02.avi
FrisbeeCatch/v_FrisbeeCatch_g04_c03.avi
FrisbeeCatch/v_FrisbeeCatch_g04_c04.avi
FrisbeeCatch/v_FrisbeeCatch_g04_c05.avi
FrisbeeCatch/v_FrisbeeCatch_g05_c01.avi
FrisbeeCatch/v_FrisbeeCatch_g05_c02.avi
FrisbeeCatch/v_FrisbeeCatch_g05_c03.avi
FrisbeeCatch/v_FrisbeeCatch_g05_c04.avi
FrisbeeCatch/v_FrisbeeCatch_g05_c05.avi
FrisbeeCatch/v_FrisbeeCatch_g06_c01.avi
FrisbeeCatch/v_FrisbeeCatch_g06_c02.avi
FrisbeeCatch/v_FrisbeeCatch_g06_c03.avi
FrisbeeCatch/v_FrisbeeCatch_g06_c04.avi
FrisbeeCatch/v_FrisbeeCatch_g06_c05.avi
FrisbeeCatch/v_FrisbeeCatch_g07_c01.avi
FrisbeeCatch/v_FrisbeeCatch_g07_c02.avi
FrisbeeCatch/v_FrisbeeCatch_g07_c03.avi
FrisbeeCatch/v_FrisbeeCatch_g07_c04.avi
FrisbeeCatch/v_FrisbeeCatch_g07_c05.avi
FrisbeeCatch/v_FrisbeeCatch_g07_c06.avi
FrontCrawl/v_FrontCrawl_g01_c01.avi
FrontCrawl/v_FrontCrawl_g01_c02.avi
FrontCrawl/v_FrontCrawl_g01_c03.avi
FrontCrawl/v_FrontCrawl_g01_c04.avi
FrontCrawl/v_FrontCrawl_g02_c01.avi
FrontCrawl/v_FrontCrawl_g02_c02.avi
FrontCrawl/v_FrontCrawl_g02_c03.avi
FrontCrawl/v_FrontCrawl_g02_c04.avi
FrontCrawl/v_FrontCrawl_g03_c01.avi
FrontCrawl/v_FrontCrawl_g03_c02.avi
FrontCrawl/v_FrontCrawl_g03_c03.avi
FrontCrawl/v_FrontCrawl_g03_c04.avi
FrontCrawl/v_FrontCrawl_g03_c05.avi
FrontCrawl/v_FrontCrawl_g03_c06.avi
FrontCrawl/v_FrontCrawl_g04_c01.avi
FrontCrawl/v_FrontCrawl_g04_c02.avi
FrontCrawl/v_FrontCrawl_g04_c03.avi
FrontCrawl/v_FrontCrawl_g04_c04.avi
FrontCrawl/v_FrontCrawl_g04_c05.avi
FrontCrawl/v_FrontCrawl_g04_c06.avi
FrontCrawl/v_FrontCrawl_g04_c07.avi
FrontCrawl/v_FrontCrawl_g05_c01.avi
FrontCrawl/v_FrontCrawl_g05_c02.avi
FrontCrawl/v_FrontCrawl_g05_c03.avi
FrontCrawl/v_FrontCrawl_g05_c04.avi
FrontCrawl/v_FrontCrawl_g06_c01.avi
FrontCrawl/v_FrontCrawl_g06_c02.avi
FrontCrawl/v_FrontCrawl_g06_c03.avi
FrontCrawl/v_FrontCrawl_g06_c04.avi
FrontCrawl/v_FrontCrawl_g06_c05.avi
FrontCrawl/v_FrontCrawl_g07_c01.avi
FrontCrawl/v_FrontCrawl_g07_c02.avi
FrontCrawl/v_FrontCrawl_g07_c03.avi
FrontCrawl/v_FrontCrawl_g07_c04.avi
FrontCrawl/v_FrontCrawl_g07_c05.avi
FrontCrawl/v_FrontCrawl_g07_c06.avi
FrontCrawl/v_FrontCrawl_g07_c07.avi
GolfSwing/v_GolfSwing_g01_c01.avi
GolfSwing/v_GolfSwing_g01_c02.avi
GolfSwing/v_GolfSwing_g01_c03.avi
GolfSwing/v_GolfSwing_g01_c04.avi
GolfSwing/v_GolfSwing_g01_c05.avi
GolfSwing/v_GolfSwing_g01_c06.avi
GolfSwing/v_GolfSwing_g02_c01.avi
GolfSwing/v_GolfSwing_g02_c02.avi
GolfSwing/v_GolfSwing_g02_c03.avi
GolfSwing/v_GolfSwing_g02_c04.avi
GolfSwing/v_GolfSwing_g03_c01.avi
GolfSwing/v_GolfSwing_g03_c02.avi
GolfSwing/v_GolfSwing_g03_c03.avi
GolfSwing/v_GolfSwing_g03_c04.avi
GolfSwing/v_GolfSwing_g03_c05.avi
GolfSwing/v_GolfSwing_g03_c06.avi
GolfSwing/v_GolfSwing_g03_c07.avi
GolfSwing/v_GolfSwing_g04_c01.avi
GolfSwing/v_GolfSwing_g04_c02.avi
GolfSwing/v_GolfSwing_g04_c03.avi
GolfSwing/v_GolfSwing_g04_c04.avi
GolfSwing/v_GolfSwing_g04_c05.avi
GolfSwing/v_GolfSwing_g04_c06.avi
GolfSwing/v_GolfSwing_g05_c01.avi
GolfSwing/v_GolfSwing_g05_c02.avi
GolfSwing/v_GolfSwing_g05_c03.avi
GolfSwing/v_GolfSwing_g05_c04.avi
GolfSwing/v_GolfSwing_g05_c05.avi
GolfSwing/v_GolfSwing_g05_c06.avi
GolfSwing/v_GolfSwing_g05_c07.avi
GolfSwing/v_GolfSwing_g06_c01.avi
GolfSwing/v_GolfSwing_g06_c02.avi
GolfSwing/v_GolfSwing_g06_c03.avi
GolfSwing/v_GolfSwing_g06_c04.avi
GolfSwing/v_GolfSwing_g07_c01.avi
GolfSwing/v_GolfSwing_g07_c02.avi
GolfSwing/v_GolfSwing_g07_c03.avi
GolfSwing/v_GolfSwing_g07_c04.avi
GolfSwing/v_GolfSwing_g07_c05.avi
Haircut/v_Haircut_g01_c01.avi
Haircut/v_Haircut_g01_c02.avi
Haircut/v_Haircut_g01_c03.avi
Haircut/v_Haircut_g01_c04.avi
Haircut/v_Haircut_g02_c01.avi
Haircut/v_Haircut_g02_c02.avi
Haircut/v_Haircut_g02_c03.avi
Haircut/v_Haircut_g02_c04.avi
Haircut/v_Haircut_g03_c01.avi
Haircut/v_Haircut_g03_c02.avi
Haircut/v_Haircut_g03_c03.avi
Haircut/v_Haircut_g03_c04.avi
Haircut/v_Haircut_g03_c05.avi
Haircut/v_Haircut_g03_c06.avi
Haircut/v_Haircut_g04_c01.avi
Haircut/v_Haircut_g04_c02.avi
Haircut/v_Haircut_g04_c03.avi
Haircut/v_Haircut_g04_c04.avi
Haircut/v_Haircut_g04_c05.avi
Haircut/v_Haircut_g05_c01.avi
Haircut/v_Haircut_g05_c02.avi
Haircut/v_Haircut_g05_c03.avi
Haircut/v_Haircut_g05_c04.avi
Haircut/v_Haircut_g06_c01.avi
Haircut/v_Haircut_g06_c02.avi
Haircut/v_Haircut_g06_c03.avi
Haircut/v_Haircut_g06_c04.avi
Haircut/v_Haircut_g07_c01.avi
Haircut/v_Haircut_g07_c02.avi
Haircut/v_Haircut_g07_c03.avi
Haircut/v_Haircut_g07_c04.avi
Haircut/v_Haircut_g07_c05.avi
Haircut/v_Haircut_g07_c06.avi
Hammering/v_Hammering_g01_c01.avi
Hammering/v_Hammering_g01_c02.avi
Hammering/v_Hammering_g01_c03.avi
Hammering/v_Hammering_g01_c04.avi
Hammering/v_Hammering_g02_c01.avi
Hammering/v_Hammering_g02_c02.avi
Hammering/v_Hammering_g02_c03.avi
Hammering/v_Hammering_g02_c04.avi
Hammering/v_Hammering_g03_c01.avi
Hammering/v_Hammering_g03_c02.avi
Hammering/v_Hammering_g03_c03.avi
Hammering/v_Hammering_g03_c04.avi
Hammering/v_Hammering_g03_c05.avi
Hammering/v_Hammering_g04_c01.avi
Hammering/v_Hammering_g04_c02.avi
Hammering/v_Hammering_g04_c03.avi
Hammering/v_Hammering_g04_c04.avi
Hammering/v_Hammering_g04_c05.avi
Hammering/v_Hammering_g05_c01.avi
Hammering/v_Hammering_g05_c02.avi
Hammering/v_Hammering_g05_c03.avi
Hammering/v_Hammering_g05_c04.avi
Hammering/v_Hammering_g06_c01.avi
Hammering/v_Hammering_g06_c02.avi
Hammering/v_Hammering_g06_c03.avi
Hammering/v_Hammering_g06_c04.avi
Hammering/v_Hammering_g06_c05.avi
Hammering/v_Hammering_g06_c06.avi
Hammering/v_Hammering_g07_c01.avi
Hammering/v_Hammering_g07_c02.avi
Hammering/v_Hammering_g07_c03.avi
Hammering/v_Hammering_g07_c04.avi
Hammering/v_Hammering_g07_c05.avi
HammerThrow/v_HammerThrow_g01_c01.avi
HammerThrow/v_HammerThrow_g01_c02.avi
HammerThrow/v_HammerThrow_g01_c03.avi
HammerThrow/v_HammerThrow_g01_c04.avi
HammerThrow/v_HammerThrow_g01_c05.avi
HammerThrow/v_HammerThrow_g01_c06.avi
HammerThrow/v_HammerThrow_g02_c01.avi
HammerThrow/v_HammerThrow_g02_c02.avi
HammerThrow/v_HammerThrow_g02_c03.avi
HammerThrow/v_HammerThrow_g02_c04.avi
HammerThrow/v_HammerThrow_g02_c05.avi
HammerThrow/v_HammerThrow_g02_c06.avi
HammerThrow/v_HammerThrow_g02_c07.avi
HammerThrow/v_HammerThrow_g03_c01.avi
HammerThrow/v_HammerThrow_g03_c02.avi
HammerThrow/v_HammerThrow_g03_c03.avi
HammerThrow/v_HammerThrow_g03_c04.avi
HammerThrow/v_HammerThrow_g03_c05.avi
HammerThrow/v_HammerThrow_g03_c06.avi
HammerThrow/v_HammerThrow_g03_c07.avi
HammerThrow/v_HammerThrow_g04_c01.avi
HammerThrow/v_HammerThrow_g04_c02.avi
HammerThrow/v_HammerThrow_g04_c03.avi
HammerThrow/v_HammerThrow_g04_c04.avi
HammerThrow/v_HammerThrow_g04_c05.avi
HammerThrow/v_HammerThrow_g04_c06.avi
HammerThrow/v_HammerThrow_g04_c07.avi
HammerThrow/v_HammerThrow_g05_c01.avi
HammerThrow/v_HammerThrow_g05_c02.avi
HammerThrow/v_HammerThrow_g05_c03.avi
HammerThrow/v_HammerThrow_g05_c04.avi
HammerThrow/v_HammerThrow_g05_c05.avi
HammerThrow/v_HammerThrow_g05_c06.avi
HammerThrow/v_HammerThrow_g06_c01.avi
HammerThrow/v_HammerThrow_g06_c02.avi
HammerThrow/v_HammerThrow_g06_c03.avi
HammerThrow/v_HammerThrow_g06_c04.avi
HammerThrow/v_HammerThrow_g06_c05.avi
HammerThrow/v_HammerThrow_g06_c06.avi
HammerThrow/v_HammerThrow_g06_c07.avi
HammerThrow/v_HammerThrow_g07_c01.avi
HammerThrow/v_HammerThrow_g07_c02.avi
HammerThrow/v_HammerThrow_g07_c03.avi
HammerThrow/v_HammerThrow_g07_c04.avi
HammerThrow/v_HammerThrow_g07_c05.avi
HandstandPushups/v_HandStandPushups_g01_c01.avi
HandstandPushups/v_HandStandPushups_g01_c02.avi
HandstandPushups/v_HandStandPushups_g01_c03.avi
HandstandPushups/v_HandStandPushups_g01_c04.avi
HandstandPushups/v_HandStandPushups_g02_c01.avi
HandstandPushups/v_HandStandPushups_g02_c02.avi
HandstandPushups/v_HandStandPushups_g02_c03.avi
HandstandPushups/v_HandStandPushups_g02_c04.avi
HandstandPushups/v_HandStandPushups_g03_c01.avi
HandstandPushups/v_HandStandPushups_g03_c02.avi
HandstandPushups/v_HandStandPushups_g03_c03.avi
HandstandPushups/v_HandStandPushups_g03_c04.avi
HandstandPushups/v_HandStandPushups_g04_c01.avi
HandstandPushups/v_HandStandPushups_g04_c02.avi
HandstandPushups/v_HandStandPushups_g04_c03.avi
HandstandPushups/v_HandStandPushups_g04_c04.avi
HandstandPushups/v_HandStandPushups_g05_c01.avi
HandstandPushups/v_HandStandPushups_g05_c02.avi
HandstandPushups/v_HandStandPushups_g05_c03.avi
HandstandPushups/v_HandStandPushups_g05_c04.avi
HandstandPushups/v_HandStandPushups_g06_c01.avi
HandstandPushups/v_HandStandPushups_g06_c02.avi
HandstandPushups/v_HandStandPushups_g06_c03.avi
HandstandPushups/v_HandStandPushups_g06_c04.avi
HandstandPushups/v_HandStandPushups_g07_c01.avi
HandstandPushups/v_HandStandPushups_g07_c02.avi
HandstandPushups/v_HandStandPushups_g07_c03.avi
HandstandPushups/v_HandStandPushups_g07_c04.avi
HandstandWalking/v_HandstandWalking_g01_c01.avi
HandstandWalking/v_HandstandWalking_g01_c02.avi
HandstandWalking/v_HandstandWalking_g01_c03.avi
HandstandWalking/v_HandstandWalking_g01_c04.avi
HandstandWalking/v_HandstandWalking_g02_c01.avi
HandstandWalking/v_HandstandWalking_g02_c02.avi
HandstandWalking/v_HandstandWalking_g02_c03.avi
HandstandWalking/v_HandstandWalking_g02_c04.avi
HandstandWalking/v_HandstandWalking_g03_c01.avi
HandstandWalking/v_HandstandWalking_g03_c02.avi
HandstandWalking/v_HandstandWalking_g03_c03.avi
HandstandWalking/v_HandstandWalking_g03_c04.avi
HandstandWalking/v_HandstandWalking_g04_c01.avi
HandstandWalking/v_HandstandWalking_g04_c02.avi
HandstandWalking/v_HandstandWalking_g04_c03.avi
HandstandWalking/v_HandstandWalking_g04_c04.avi
HandstandWalking/v_HandstandWalking_g04_c05.avi
HandstandWalking/v_HandstandWalking_g05_c01.avi
HandstandWalking/v_HandstandWalking_g05_c02.avi
HandstandWalking/v_HandstandWalking_g05_c03.avi
HandstandWalking/v_HandstandWalking_g05_c04.avi
HandstandWalking/v_HandstandWalking_g05_c05.avi
HandstandWalking/v_HandstandWalking_g05_c06.avi
HandstandWalking/v_HandstandWalking_g05_c07.avi
HandstandWalking/v_HandstandWalking_g06_c01.avi
HandstandWalking/v_HandstandWalking_g06_c02.avi
HandstandWalking/v_HandstandWalking_g06_c03.avi
HandstandWalking/v_HandstandWalking_g06_c04.avi
HandstandWalking/v_HandstandWalking_g07_c01.avi
HandstandWalking/v_HandstandWalking_g07_c02.avi
HandstandWalking/v_HandstandWalking_g07_c03.avi
HandstandWalking/v_HandstandWalking_g07_c04.avi
HandstandWalking/v_HandstandWalking_g07_c05.avi
HandstandWalking/v_HandstandWalking_g07_c06.avi
HeadMassage/v_HeadMassage_g01_c01.avi
HeadMassage/v_HeadMassage_g01_c02.avi
HeadMassage/v_HeadMassage_g01_c03.avi
HeadMassage/v_HeadMassage_g01_c04.avi
HeadMassage/v_HeadMassage_g01_c05.avi
HeadMassage/v_HeadMassage_g02_c01.avi
HeadMassage/v_HeadMassage_g02_c02.avi
HeadMassage/v_HeadMassage_g02_c03.avi
HeadMassage/v_HeadMassage_g02_c04.avi
HeadMassage/v_HeadMassage_g02_c05.avi
HeadMassage/v_HeadMassage_g02_c06.avi
HeadMassage/v_HeadMassage_g02_c07.avi
HeadMassage/v_HeadMassage_g03_c01.avi
HeadMassage/v_HeadMassage_g03_c02.avi
HeadMassage/v_HeadMassage_g03_c03.avi
HeadMassage/v_HeadMassage_g03_c04.avi
HeadMassage/v_HeadMassage_g03_c05.avi
HeadMassage/v_HeadMassage_g03_c06.avi
HeadMassage/v_HeadMassage_g03_c07.avi
HeadMassage/v_HeadMassage_g04_c01.avi
HeadMassage/v_HeadMassage_g04_c02.avi
HeadMassage/v_HeadMassage_g04_c03.avi
HeadMassage/v_HeadMassage_g04_c04.avi
HeadMassage/v_HeadMassage_g05_c01.avi
HeadMassage/v_HeadMassage_g05_c02.avi
HeadMassage/v_HeadMassage_g05_c03.avi
HeadMassage/v_HeadMassage_g05_c04.avi
HeadMassage/v_HeadMassage_g05_c05.avi
HeadMassage/v_HeadMassage_g05_c06.avi
HeadMassage/v_HeadMassage_g06_c01.avi
HeadMassage/v_HeadMassage_g06_c02.avi
HeadMassage/v_HeadMassage_g06_c03.avi
HeadMassage/v_HeadMassage_g06_c04.avi
HeadMassage/v_HeadMassage_g06_c05.avi
HeadMassage/v_HeadMassage_g06_c06.avi
HeadMassage/v_HeadMassage_g06_c07.avi
HeadMassage/v_HeadMassage_g07_c01.avi
HeadMassage/v_HeadMassage_g07_c02.avi
HeadMassage/v_HeadMassage_g07_c03.avi
HeadMassage/v_HeadMassage_g07_c04.avi
HeadMassage/v_HeadMassage_g07_c05.avi
HighJump/v_HighJump_g01_c01.avi
HighJump/v_HighJump_g01_c02.avi
HighJump/v_HighJump_g01_c03.avi
HighJump/v_HighJump_g01_c04.avi
HighJump/v_HighJump_g01_c05.avi
HighJump/v_HighJump_g02_c01.avi
HighJump/v_HighJump_g02_c02.avi
HighJump/v_HighJump_g02_c03.avi
HighJump/v_HighJump_g02_c04.avi
HighJump/v_HighJump_g02_c05.avi
HighJump/v_HighJump_g02_c06.avi
HighJump/v_HighJump_g02_c07.avi
HighJump/v_HighJump_g03_c01.avi
HighJump/v_HighJump_g03_c02.avi
HighJump/v_HighJump_g03_c03.avi
HighJump/v_HighJump_g03_c04.avi
HighJump/v_HighJump_g04_c01.avi
HighJump/v_HighJump_g04_c02.avi
HighJump/v_HighJump_g04_c03.avi
HighJump/v_HighJump_g04_c04.avi
HighJump/v_HighJump_g04_c05.avi
HighJump/v_HighJump_g04_c06.avi
HighJump/v_HighJump_g05_c01.avi
HighJump/v_HighJump_g05_c02.avi
HighJump/v_HighJump_g05_c03.avi
HighJump/v_HighJump_g05_c04.avi
HighJump/v_HighJump_g05_c05.avi
HighJump/v_HighJump_g06_c01.avi
HighJump/v_HighJump_g06_c02.avi
HighJump/v_HighJump_g06_c03.avi
HighJump/v_HighJump_g06_c04.avi
HighJump/v_HighJump_g07_c01.avi
HighJump/v_HighJump_g07_c02.avi
HighJump/v_HighJump_g07_c03.avi
HighJump/v_HighJump_g07_c04.avi
HighJump/v_HighJump_g07_c05.avi
HighJump/v_HighJump_g07_c06.avi
HorseRace/v_HorseRace_g01_c01.avi
HorseRace/v_HorseRace_g01_c02.avi
HorseRace/v_HorseRace_g01_c03.avi
HorseRace/v_HorseRace_g01_c04.avi
HorseRace/v_HorseRace_g02_c01.avi
HorseRace/v_HorseRace_g02_c02.avi
HorseRace/v_HorseRace_g02_c03.avi
HorseRace/v_HorseRace_g02_c04.avi
HorseRace/v_HorseRace_g03_c01.avi
HorseRace/v_HorseRace_g03_c02.avi
HorseRace/v_HorseRace_g03_c03.avi
HorseRace/v_HorseRace_g03_c04.avi
HorseRace/v_HorseRace_g03_c05.avi
HorseRace/v_HorseRace_g04_c01.avi
HorseRace/v_HorseRace_g04_c02.avi
HorseRace/v_HorseRace_g04_c03.avi
HorseRace/v_HorseRace_g04_c04.avi
HorseRace/v_HorseRace_g04_c05.avi
HorseRace/v_HorseRace_g04_c06.avi
HorseRace/v_HorseRace_g05_c01.avi
HorseRace/v_HorseRace_g05_c02.avi
HorseRace/v_HorseRace_g05_c03.avi
HorseRace/v_HorseRace_g05_c04.avi
HorseRace/v_HorseRace_g06_c01.avi
HorseRace/v_HorseRace_g06_c02.avi
HorseRace/v_HorseRace_g06_c03.avi
HorseRace/v_HorseRace_g06_c04.avi
HorseRace/v_HorseRace_g06_c05.avi
HorseRace/v_HorseRace_g06_c06.avi
HorseRace/v_HorseRace_g07_c01.avi
HorseRace/v_HorseRace_g07_c02.avi
HorseRace/v_HorseRace_g07_c03.avi
HorseRace/v_HorseRace_g07_c04.avi
HorseRace/v_HorseRace_g07_c05.avi
HorseRace/v_HorseRace_g07_c06.avi
HorseRiding/v_HorseRiding_g01_c01.avi
HorseRiding/v_HorseRiding_g01_c02.avi
HorseRiding/v_HorseRiding_g01_c03.avi
HorseRiding/v_HorseRiding_g01_c04.avi
HorseRiding/v_HorseRiding_g01_c05.avi
HorseRiding/v_HorseRiding_g01_c06.avi
HorseRiding/v_HorseRiding_g01_c07.avi
HorseRiding/v_HorseRiding_g02_c01.avi
HorseRiding/v_HorseRiding_g02_c02.avi
HorseRiding/v_HorseRiding_g02_c03.avi
HorseRiding/v_HorseRiding_g02_c04.avi
HorseRiding/v_HorseRiding_g02_c05.avi
HorseRiding/v_HorseRiding_g02_c06.avi
HorseRiding/v_HorseRiding_g02_c07.avi
HorseRiding/v_HorseRiding_g03_c01.avi
HorseRiding/v_HorseRiding_g03_c02.avi
HorseRiding/v_HorseRiding_g03_c03.avi
HorseRiding/v_HorseRiding_g03_c04.avi
HorseRiding/v_HorseRiding_g03_c05.avi
HorseRiding/v_HorseRiding_g03_c06.avi
HorseRiding/v_HorseRiding_g03_c07.avi
HorseRiding/v_HorseRiding_g04_c01.avi
HorseRiding/v_HorseRiding_g04_c02.avi
HorseRiding/v_HorseRiding_g04_c03.avi
HorseRiding/v_HorseRiding_g04_c04.avi
HorseRiding/v_HorseRiding_g04_c05.avi
HorseRiding/v_HorseRiding_g04_c06.avi
HorseRiding/v_HorseRiding_g04_c07.avi
HorseRiding/v_HorseRiding_g05_c01.avi
HorseRiding/v_HorseRiding_g05_c02.avi
HorseRiding/v_HorseRiding_g05_c03.avi
HorseRiding/v_HorseRiding_g05_c04.avi
HorseRiding/v_HorseRiding_g05_c05.avi
HorseRiding/v_HorseRiding_g05_c06.avi
HorseRiding/v_HorseRiding_g05_c07.avi
HorseRiding/v_HorseRiding_g06_c01.avi
HorseRiding/v_HorseRiding_g06_c02.avi
HorseRiding/v_HorseRiding_g06_c03.avi
HorseRiding/v_HorseRiding_g06_c04.avi
HorseRiding/v_HorseRiding_g06_c05.avi
HorseRiding/v_HorseRiding_g06_c06.avi
HorseRiding/v_HorseRiding_g06_c07.avi
HorseRiding/v_HorseRiding_g07_c01.avi
HorseRiding/v_HorseRiding_g07_c02.avi
HorseRiding/v_HorseRiding_g07_c03.avi
HorseRiding/v_HorseRiding_g07_c04.avi
HorseRiding/v_HorseRiding_g07_c05.avi
HorseRiding/v_HorseRiding_g07_c06.avi
HorseRiding/v_HorseRiding_g07_c07.avi
HulaHoop/v_HulaHoop_g01_c01.avi
HulaHoop/v_HulaHoop_g01_c02.avi
HulaHoop/v_HulaHoop_g01_c03.avi
HulaHoop/v_HulaHoop_g01_c04.avi
HulaHoop/v_HulaHoop_g01_c05.avi
HulaHoop/v_HulaHoop_g01_c06.avi
HulaHoop/v_HulaHoop_g01_c07.avi
HulaHoop/v_HulaHoop_g02_c01.avi
HulaHoop/v_HulaHoop_g02_c02.avi
HulaHoop/v_HulaHoop_g02_c03.avi
HulaHoop/v_HulaHoop_g02_c04.avi
HulaHoop/v_HulaHoop_g03_c01.avi
HulaHoop/v_HulaHoop_g03_c02.avi
HulaHoop/v_HulaHoop_g03_c03.avi
HulaHoop/v_HulaHoop_g03_c04.avi
HulaHoop/v_HulaHoop_g03_c05.avi
HulaHoop/v_HulaHoop_g04_c01.avi
HulaHoop/v_HulaHoop_g04_c02.avi
HulaHoop/v_HulaHoop_g04_c03.avi
HulaHoop/v_HulaHoop_g04_c04.avi
HulaHoop/v_HulaHoop_g04_c05.avi
HulaHoop/v_HulaHoop_g05_c01.avi
HulaHoop/v_HulaHoop_g05_c02.avi
HulaHoop/v_HulaHoop_g05_c03.avi
HulaHoop/v_HulaHoop_g05_c04.avi
HulaHoop/v_HulaHoop_g06_c01.avi
HulaHoop/v_HulaHoop_g06_c02.avi
HulaHoop/v_HulaHoop_g06_c03.avi
HulaHoop/v_HulaHoop_g06_c04.avi
HulaHoop/v_HulaHoop_g07_c01.avi
HulaHoop/v_HulaHoop_g07_c02.avi
HulaHoop/v_HulaHoop_g07_c03.avi
HulaHoop/v_HulaHoop_g07_c04.avi
HulaHoop/v_HulaHoop_g07_c05.avi
IceDancing/v_IceDancing_g01_c01.avi
IceDancing/v_IceDancing_g01_c02.avi
IceDancing/v_IceDancing_g01_c03.avi
IceDancing/v_IceDancing_g01_c04.avi
IceDancing/v_IceDancing_g01_c05.avi
IceDancing/v_IceDancing_g01_c06.avi
IceDancing/v_IceDancing_g01_c07.avi
IceDancing/v_IceDancing_g02_c01.avi
IceDancing/v_IceDancing_g02_c02.avi
IceDancing/v_IceDancing_g02_c03.avi
IceDancing/v_IceDancing_g02_c04.avi
IceDancing/v_IceDancing_g02_c05.avi
IceDancing/v_IceDancing_g02_c06.avi
IceDancing/v_IceDancing_g02_c07.avi
IceDancing/v_IceDancing_g03_c01.avi
IceDancing/v_IceDancing_g03_c02.avi
IceDancing/v_IceDancing_g03_c03.avi
IceDancing/v_IceDancing_g03_c04.avi
IceDancing/v_IceDancing_g03_c05.avi
IceDancing/v_IceDancing_g03_c06.avi
IceDancing/v_IceDancing_g04_c01.avi
IceDancing/v_IceDancing_g04_c02.avi
IceDancing/v_IceDancing_g04_c03.avi
IceDancing/v_IceDancing_g04_c04.avi
IceDancing/v_IceDancing_g04_c05.avi
IceDancing/v_IceDancing_g04_c06.avi
IceDancing/v_IceDancing_g04_c07.avi
IceDancing/v_IceDancing_g05_c01.avi
IceDancing/v_IceDancing_g05_c02.avi
IceDancing/v_IceDancing_g05_c03.avi
IceDancing/v_IceDancing_g05_c04.avi
IceDancing/v_IceDancing_g05_c05.avi
IceDancing/v_IceDancing_g05_c06.avi
IceDancing/v_IceDancing_g06_c01.avi
IceDancing/v_IceDancing_g06_c02.avi
IceDancing/v_IceDancing_g06_c03.avi
IceDancing/v_IceDancing_g06_c04.avi
IceDancing/v_IceDancing_g06_c05.avi
IceDancing/v_IceDancing_g06_c06.avi
IceDancing/v_IceDancing_g07_c01.avi
IceDancing/v_IceDancing_g07_c02.avi
IceDancing/v_IceDancing_g07_c03.avi
IceDancing/v_IceDancing_g07_c04.avi
IceDancing/v_IceDancing_g07_c05.avi
IceDancing/v_IceDancing_g07_c06.avi
IceDancing/v_IceDancing_g07_c07.avi
JavelinThrow/v_JavelinThrow_g01_c01.avi
JavelinThrow/v_JavelinThrow_g01_c02.avi
JavelinThrow/v_JavelinThrow_g01_c03.avi
JavelinThrow/v_JavelinThrow_g01_c04.avi
JavelinThrow/v_JavelinThrow_g02_c01.avi
JavelinThrow/v_JavelinThrow_g02_c02.avi
JavelinThrow/v_JavelinThrow_g02_c03.avi
JavelinThrow/v_JavelinThrow_g02_c04.avi
JavelinThrow/v_JavelinThrow_g03_c01.avi
JavelinThrow/v_JavelinThrow_g03_c02.avi
JavelinThrow/v_JavelinThrow_g03_c03.avi
JavelinThrow/v_JavelinThrow_g03_c04.avi
JavelinThrow/v_JavelinThrow_g04_c01.avi
JavelinThrow/v_JavelinThrow_g04_c02.avi
JavelinThrow/v_JavelinThrow_g04_c03.avi
JavelinThrow/v_JavelinThrow_g04_c04.avi
JavelinThrow/v_JavelinThrow_g05_c01.avi
JavelinThrow/v_JavelinThrow_g05_c02.avi
JavelinThrow/v_JavelinThrow_g05_c03.avi
JavelinThrow/v_JavelinThrow_g05_c04.avi
JavelinThrow/v_JavelinThrow_g05_c05.avi
JavelinThrow/v_JavelinThrow_g05_c06.avi
JavelinThrow/v_JavelinThrow_g06_c01.avi
JavelinThrow/v_JavelinThrow_g06_c02.avi
JavelinThrow/v_JavelinThrow_g06_c03.avi
JavelinThrow/v_JavelinThrow_g06_c04.avi
JavelinThrow/v_JavelinThrow_g07_c01.avi
JavelinThrow/v_JavelinThrow_g07_c02.avi
JavelinThrow/v_JavelinThrow_g07_c03.avi
JavelinThrow/v_JavelinThrow_g07_c04.avi
JavelinThrow/v_JavelinThrow_g07_c05.avi
JugglingBalls/v_JugglingBalls_g01_c01.avi
JugglingBalls/v_JugglingBalls_g01_c02.avi
JugglingBalls/v_JugglingBalls_g01_c03.avi
JugglingBalls/v_JugglingBalls_g01_c04.avi
JugglingBalls/v_JugglingBalls_g02_c01.avi
JugglingBalls/v_JugglingBalls_g02_c02.avi
JugglingBalls/v_JugglingBalls_g02_c03.avi
JugglingBalls/v_JugglingBalls_g02_c04.avi
JugglingBalls/v_JugglingBalls_g02_c05.avi
JugglingBalls/v_JugglingBalls_g02_c06.avi
JugglingBalls/v_JugglingBalls_g03_c01.avi
JugglingBalls/v_JugglingBalls_g03_c02.avi
JugglingBalls/v_JugglingBalls_g03_c03.avi
JugglingBalls/v_JugglingBalls_g03_c04.avi
JugglingBalls/v_JugglingBalls_g03_c05.avi
JugglingBalls/v_JugglingBalls_g03_c06.avi
JugglingBalls/v_JugglingBalls_g03_c07.avi
JugglingBalls/v_JugglingBalls_g04_c01.avi
JugglingBalls/v_JugglingBalls_g04_c02.avi
JugglingBalls/v_JugglingBalls_g04_c03.avi
JugglingBalls/v_JugglingBalls_g04_c04.avi
JugglingBalls/v_JugglingBalls_g04_c05.avi
JugglingBalls/v_JugglingBalls_g05_c01.avi
JugglingBalls/v_JugglingBalls_g05_c02.avi
JugglingBalls/v_JugglingBalls_g05_c03.avi
JugglingBalls/v_JugglingBalls_g05_c04.avi
JugglingBalls/v_JugglingBalls_g05_c05.avi
JugglingBalls/v_JugglingBalls_g06_c01.avi
JugglingBalls/v_JugglingBalls_g06_c02.avi
JugglingBalls/v_JugglingBalls_g06_c03.avi
JugglingBalls/v_JugglingBalls_g06_c04.avi
JugglingBalls/v_JugglingBalls_g06_c05.avi
JugglingBalls/v_JugglingBalls_g06_c06.avi
JugglingBalls/v_JugglingBalls_g07_c01.avi
JugglingBalls/v_JugglingBalls_g07_c02.avi
JugglingBalls/v_JugglingBalls_g07_c03.avi
JugglingBalls/v_JugglingBalls_g07_c04.avi
JugglingBalls/v_JugglingBalls_g07_c05.avi
JugglingBalls/v_JugglingBalls_g07_c06.avi
JugglingBalls/v_JugglingBalls_g07_c07.avi
JumpingJack/v_JumpingJack_g01_c01.avi
JumpingJack/v_JumpingJack_g01_c02.avi
JumpingJack/v_JumpingJack_g01_c03.avi
JumpingJack/v_JumpingJack_g01_c04.avi
JumpingJack/v_JumpingJack_g01_c05.avi
JumpingJack/v_JumpingJack_g01_c06.avi
JumpingJack/v_JumpingJack_g01_c07.avi
JumpingJack/v_JumpingJack_g02_c01.avi
JumpingJack/v_JumpingJack_g02_c02.avi
JumpingJack/v_JumpingJack_g02_c03.avi
JumpingJack/v_JumpingJack_g02_c04.avi
JumpingJack/v_JumpingJack_g03_c01.avi
JumpingJack/v_JumpingJack_g03_c02.avi
JumpingJack/v_JumpingJack_g03_c03.avi
JumpingJack/v_JumpingJack_g03_c04.avi
JumpingJack/v_JumpingJack_g04_c01.avi
JumpingJack/v_JumpingJack_g04_c02.avi
JumpingJack/v_JumpingJack_g04_c03.avi
JumpingJack/v_JumpingJack_g04_c04.avi
JumpingJack/v_JumpingJack_g05_c01.avi
JumpingJack/v_JumpingJack_g05_c02.avi
JumpingJack/v_JumpingJack_g05_c03.avi
JumpingJack/v_JumpingJack_g05_c04.avi
JumpingJack/v_JumpingJack_g05_c05.avi
JumpingJack/v_JumpingJack_g05_c06.avi
JumpingJack/v_JumpingJack_g06_c01.avi
JumpingJack/v_JumpingJack_g06_c02.avi
JumpingJack/v_JumpingJack_g06_c03.avi
JumpingJack/v_JumpingJack_g06_c04.avi
JumpingJack/v_JumpingJack_g06_c05.avi
JumpingJack/v_JumpingJack_g06_c06.avi
JumpingJack/v_JumpingJack_g06_c07.avi
JumpingJack/v_JumpingJack_g07_c01.avi
JumpingJack/v_JumpingJack_g07_c02.avi
JumpingJack/v_JumpingJack_g07_c03.avi
JumpingJack/v_JumpingJack_g07_c04.avi
JumpingJack/v_JumpingJack_g07_c05.avi
JumpRope/v_JumpRope_g01_c01.avi
JumpRope/v_JumpRope_g01_c02.avi
JumpRope/v_JumpRope_g01_c03.avi
JumpRope/v_JumpRope_g01_c04.avi
JumpRope/v_JumpRope_g02_c01.avi
JumpRope/v_JumpRope_g02_c02.avi
JumpRope/v_JumpRope_g02_c03.avi
JumpRope/v_JumpRope_g02_c04.avi
JumpRope/v_JumpRope_g02_c05.avi
JumpRope/v_JumpRope_g02_c06.avi
JumpRope/v_JumpRope_g02_c07.avi
JumpRope/v_JumpRope_g03_c01.avi
JumpRope/v_JumpRope_g03_c02.avi
JumpRope/v_JumpRope_g03_c03.avi
JumpRope/v_JumpRope_g03_c04.avi
JumpRope/v_JumpRope_g04_c01.avi
JumpRope/v_JumpRope_g04_c02.avi
JumpRope/v_JumpRope_g04_c03.avi
JumpRope/v_JumpRope_g04_c04.avi
JumpRope/v_JumpRope_g04_c05.avi
JumpRope/v_JumpRope_g04_c06.avi
JumpRope/v_JumpRope_g04_c07.avi
JumpRope/v_JumpRope_g05_c01.avi
JumpRope/v_JumpRope_g05_c02.avi
JumpRope/v_JumpRope_g05_c03.avi
JumpRope/v_JumpRope_g05_c04.avi
JumpRope/v_JumpRope_g05_c05.avi
JumpRope/v_JumpRope_g06_c01.avi
JumpRope/v_JumpRope_g06_c02.avi
JumpRope/v_JumpRope_g06_c03.avi
JumpRope/v_JumpRope_g06_c04.avi
JumpRope/v_JumpRope_g06_c05.avi
JumpRope/v_JumpRope_g07_c01.avi
JumpRope/v_JumpRope_g07_c02.avi
JumpRope/v_JumpRope_g07_c03.avi
JumpRope/v_JumpRope_g07_c04.avi
JumpRope/v_JumpRope_g07_c05.avi
JumpRope/v_JumpRope_g07_c06.avi
Kayaking/v_Kayaking_g01_c01.avi
Kayaking/v_Kayaking_g01_c02.avi
Kayaking/v_Kayaking_g01_c03.avi
Kayaking/v_Kayaking_g01_c04.avi
Kayaking/v_Kayaking_g01_c05.avi
Kayaking/v_Kayaking_g01_c06.avi
Kayaking/v_Kayaking_g02_c01.avi
Kayaking/v_Kayaking_g02_c02.avi
Kayaking/v_Kayaking_g02_c03.avi
Kayaking/v_Kayaking_g02_c04.avi
Kayaking/v_Kayaking_g03_c01.avi
Kayaking/v_Kayaking_g03_c02.avi
Kayaking/v_Kayaking_g03_c03.avi
Kayaking/v_Kayaking_g03_c04.avi
Kayaking/v_Kayaking_g04_c01.avi
Kayaking/v_Kayaking_g04_c02.avi
Kayaking/v_Kayaking_g04_c03.avi
Kayaking/v_Kayaking_g04_c04.avi
Kayaking/v_Kayaking_g04_c05.avi
Kayaking/v_Kayaking_g04_c06.avi
Kayaking/v_Kayaking_g04_c07.avi
Kayaking/v_Kayaking_g05_c01.avi
Kayaking/v_Kayaking_g05_c02.avi
Kayaking/v_Kayaking_g05_c03.avi
Kayaking/v_Kayaking_g05_c04.avi
Kayaking/v_Kayaking_g06_c01.avi
Kayaking/v_Kayaking_g06_c02.avi
Kayaking/v_Kayaking_g06_c03.avi
Kayaking/v_Kayaking_g06_c04.avi
Kayaking/v_Kayaking_g06_c05.avi
Kayaking/v_Kayaking_g06_c06.avi
Kayaking/v_Kayaking_g06_c07.avi
Kayaking/v_Kayaking_g07_c01.avi
Kayaking/v_Kayaking_g07_c02.avi
Kayaking/v_Kayaking_g07_c03.avi
Kayaking/v_Kayaking_g07_c04.avi
Knitting/v_Knitting_g01_c01.avi
Knitting/v_Knitting_g01_c02.avi
Knitting/v_Knitting_g01_c03.avi
Knitting/v_Knitting_g01_c04.avi
Knitting/v_Knitting_g02_c01.avi
Knitting/v_Knitting_g02_c02.avi
Knitting/v_Knitting_g02_c03.avi
Knitting/v_Knitting_g02_c04.avi
Knitting/v_Knitting_g02_c05.avi
Knitting/v_Knitting_g03_c01.avi
Knitting/v_Knitting_g03_c02.avi
Knitting/v_Knitting_g03_c03.avi
Knitting/v_Knitting_g03_c04.avi
Knitting/v_Knitting_g03_c05.avi
Knitting/v_Knitting_g04_c01.avi
Knitting/v_Knitting_g04_c02.avi
Knitting/v_Knitting_g04_c03.avi
Knitting/v_Knitting_g04_c04.avi
Knitting/v_Knitting_g04_c05.avi
Knitting/v_Knitting_g04_c06.avi
Knitting/v_Knitting_g05_c01.avi
Knitting/v_Knitting_g05_c02.avi
Knitting/v_Knitting_g05_c03.avi
Knitting/v_Knitting_g05_c04.avi
Knitting/v_Knitting_g05_c05.avi
Knitting/v_Knitting_g06_c01.avi
Knitting/v_Knitting_g06_c02.avi
Knitting/v_Knitting_g06_c03.avi
Knitting/v_Knitting_g06_c04.avi
Knitting/v_Knitting_g07_c01.avi
Knitting/v_Knitting_g07_c02.avi
Knitting/v_Knitting_g07_c03.avi
Knitting/v_Knitting_g07_c04.avi
Knitting/v_Knitting_g07_c05.avi
LongJump/v_LongJump_g01_c01.avi
LongJump/v_LongJump_g01_c02.avi
LongJump/v_LongJump_g01_c03.avi
LongJump/v_LongJump_g01_c04.avi
LongJump/v_LongJump_g01_c05.avi
LongJump/v_LongJump_g01_c06.avi
LongJump/v_LongJump_g01_c07.avi
LongJump/v_LongJump_g02_c01.avi
LongJump/v_LongJump_g02_c02.avi
LongJump/v_LongJump_g02_c03.avi
LongJump/v_LongJump_g02_c04.avi
LongJump/v_LongJump_g02_c05.avi
LongJump/v_LongJump_g03_c01.avi
LongJump/v_LongJump_g03_c02.avi
LongJump/v_LongJump_g03_c03.avi
LongJump/v_LongJump_g03_c04.avi
LongJump/v_LongJump_g03_c05.avi
LongJump/v_LongJump_g03_c06.avi
LongJump/v_LongJump_g04_c01.avi
LongJump/v_LongJump_g04_c02.avi
LongJump/v_LongJump_g04_c03.avi
LongJump/v_LongJump_g04_c04.avi
LongJump/v_LongJump_g04_c05.avi
LongJump/v_LongJump_g04_c06.avi
LongJump/v_LongJump_g04_c07.avi
LongJump/v_LongJump_g05_c01.avi
LongJump/v_LongJump_g05_c02.avi
LongJump/v_LongJump_g05_c03.avi
LongJump/v_LongJump_g05_c04.avi
LongJump/v_LongJump_g05_c05.avi
LongJump/v_LongJump_g06_c01.avi
LongJump/v_LongJump_g06_c02.avi
LongJump/v_LongJump_g06_c03.avi
LongJump/v_LongJump_g06_c04.avi
LongJump/v_LongJump_g07_c01.avi
LongJump/v_LongJump_g07_c02.avi
LongJump/v_LongJump_g07_c03.avi
LongJump/v_LongJump_g07_c04.avi
LongJump/v_LongJump_g07_c05.avi
Lunges/v_Lunges_g01_c01.avi
Lunges/v_Lunges_g01_c02.avi
Lunges/v_Lunges_g01_c03.avi
Lunges/v_Lunges_g01_c04.avi
Lunges/v_Lunges_g01_c05.avi
Lunges/v_Lunges_g01_c06.avi
Lunges/v_Lunges_g01_c07.avi
Lunges/v_Lunges_g02_c01.avi
Lunges/v_Lunges_g02_c02.avi
Lunges/v_Lunges_g02_c03.avi
Lunges/v_Lunges_g02_c04.avi
Lunges/v_Lunges_g03_c01.avi
Lunges/v_Lunges_g03_c02.avi
Lunges/v_Lunges_g03_c03.avi
Lunges/v_Lunges_g03_c04.avi
Lunges/v_Lunges_g04_c01.avi
Lunges/v_Lunges_g04_c02.avi
Lunges/v_Lunges_g04_c03.avi
Lunges/v_Lunges_g04_c04.avi
Lunges/v_Lunges_g05_c01.avi
Lunges/v_Lunges_g05_c02.avi
Lunges/v_Lunges_g05_c03.avi
Lunges/v_Lunges_g05_c04.avi
Lunges/v_Lunges_g06_c01.avi
Lunges/v_Lunges_g06_c02.avi
Lunges/v_Lunges_g06_c03.avi
Lunges/v_Lunges_g06_c04.avi
Lunges/v_Lunges_g06_c05.avi
Lunges/v_Lunges_g06_c06.avi
Lunges/v_Lunges_g06_c07.avi
Lunges/v_Lunges_g07_c01.avi
Lunges/v_Lunges_g07_c02.avi
Lunges/v_Lunges_g07_c03.avi
Lunges/v_Lunges_g07_c04.avi
Lunges/v_Lunges_g07_c05.avi
Lunges/v_Lunges_g07_c06.avi
Lunges/v_Lunges_g07_c07.avi
MilitaryParade/v_MilitaryParade_g01_c01.avi
MilitaryParade/v_MilitaryParade_g01_c02.avi
MilitaryParade/v_MilitaryParade_g01_c03.avi
MilitaryParade/v_MilitaryParade_g01_c04.avi
MilitaryParade/v_MilitaryParade_g01_c05.avi
MilitaryParade/v_MilitaryParade_g01_c06.avi
MilitaryParade/v_MilitaryParade_g01_c07.avi
MilitaryParade/v_MilitaryParade_g02_c01.avi
MilitaryParade/v_MilitaryParade_g02_c02.avi
MilitaryParade/v_MilitaryParade_g02_c03.avi
MilitaryParade/v_MilitaryParade_g02_c04.avi
MilitaryParade/v_MilitaryParade_g03_c01.avi
MilitaryParade/v_MilitaryParade_g03_c02.avi
MilitaryParade/v_MilitaryParade_g03_c03.avi
MilitaryParade/v_MilitaryParade_g03_c04.avi
MilitaryParade/v_MilitaryParade_g04_c01.avi
MilitaryParade/v_MilitaryParade_g04_c02.avi
MilitaryParade/v_MilitaryParade_g04_c03.avi
MilitaryParade/v_MilitaryParade_g04_c04.avi
MilitaryParade/v_MilitaryParade_g05_c01.avi
MilitaryParade/v_MilitaryParade_g05_c02.avi
MilitaryParade/v_MilitaryParade_g05_c03.avi
MilitaryParade/v_MilitaryParade_g05_c04.avi
MilitaryParade/v_MilitaryParade_g06_c01.avi
MilitaryParade/v_MilitaryParade_g06_c02.avi
MilitaryParade/v_MilitaryParade_g06_c03.avi
MilitaryParade/v_MilitaryParade_g06_c04.avi
MilitaryParade/v_MilitaryParade_g07_c01.avi
MilitaryParade/v_MilitaryParade_g07_c02.avi
MilitaryParade/v_MilitaryParade_g07_c03.avi
MilitaryParade/v_MilitaryParade_g07_c04.avi
MilitaryParade/v_MilitaryParade_g07_c05.avi
MilitaryParade/v_MilitaryParade_g07_c06.avi
Mixing/v_Mixing_g01_c01.avi
Mixing/v_Mixing_g01_c02.avi
Mixing/v_Mixing_g01_c03.avi
Mixing/v_Mixing_g01_c04.avi
Mixing/v_Mixing_g01_c05.avi
Mixing/v_Mixing_g01_c06.avi
Mixing/v_Mixing_g01_c07.avi
Mixing/v_Mixing_g02_c01.avi
Mixing/v_Mixing_g02_c02.avi
Mixing/v_Mixing_g02_c03.avi
Mixing/v_Mixing_g02_c04.avi
Mixing/v_Mixing_g02_c05.avi
Mixing/v_Mixing_g02_c06.avi
Mixing/v_Mixing_g03_c01.avi
Mixing/v_Mixing_g03_c02.avi
Mixing/v_Mixing_g03_c03.avi
Mixing/v_Mixing_g03_c04.avi
Mixing/v_Mixing_g03_c05.avi
Mixing/v_Mixing_g03_c06.avi
Mixing/v_Mixing_g03_c07.avi
Mixing/v_Mixing_g04_c01.avi
Mixing/v_Mixing_g04_c02.avi
Mixing/v_Mixing_g04_c03.avi
Mixing/v_Mixing_g04_c04.avi
Mixing/v_Mixing_g04_c05.avi
Mixing/v_Mixing_g04_c06.avi
Mixing/v_Mixing_g04_c07.avi
Mixing/v_Mixing_g05_c01.avi
Mixing/v_Mixing_g05_c02.avi
Mixing/v_Mixing_g05_c03.avi
Mixing/v_Mixing_g05_c04.avi
Mixing/v_Mixing_g05_c05.avi
Mixing/v_Mixing_g05_c06.avi
Mixing/v_Mixing_g05_c07.avi
Mixing/v_Mixing_g06_c01.avi
Mixing/v_Mixing_g06_c02.avi
Mixing/v_Mixing_g06_c03.avi
Mixing/v_Mixing_g06_c04.avi
Mixing/v_Mixing_g06_c05.avi
Mixing/v_Mixing_g06_c06.avi
Mixing/v_Mixing_g07_c01.avi
Mixing/v_Mixing_g07_c02.avi
Mixing/v_Mixing_g07_c03.avi
Mixing/v_Mixing_g07_c04.avi
Mixing/v_Mixing_g07_c05.avi
MoppingFloor/v_MoppingFloor_g01_c01.avi
MoppingFloor/v_MoppingFloor_g01_c02.avi
MoppingFloor/v_MoppingFloor_g01_c03.avi
MoppingFloor/v_MoppingFloor_g01_c04.avi
MoppingFloor/v_MoppingFloor_g02_c01.avi
MoppingFloor/v_MoppingFloor_g02_c02.avi
MoppingFloor/v_MoppingFloor_g02_c03.avi
MoppingFloor/v_MoppingFloor_g02_c04.avi
MoppingFloor/v_MoppingFloor_g02_c05.avi
MoppingFloor/v_MoppingFloor_g02_c06.avi
MoppingFloor/v_MoppingFloor_g03_c01.avi
MoppingFloor/v_MoppingFloor_g03_c02.avi
MoppingFloor/v_MoppingFloor_g03_c03.avi
MoppingFloor/v_MoppingFloor_g03_c04.avi
MoppingFloor/v_MoppingFloor_g04_c01.avi
MoppingFloor/v_MoppingFloor_g04_c02.avi
MoppingFloor/v_MoppingFloor_g04_c03.avi
MoppingFloor/v_MoppingFloor_g04_c04.avi
MoppingFloor/v_MoppingFloor_g04_c05.avi
MoppingFloor/v_MoppingFloor_g04_c06.avi
MoppingFloor/v_MoppingFloor_g05_c01.avi
MoppingFloor/v_MoppingFloor_g05_c02.avi
MoppingFloor/v_MoppingFloor_g05_c03.avi
MoppingFloor/v_MoppingFloor_g05_c04.avi
MoppingFloor/v_MoppingFloor_g05_c05.avi
MoppingFloor/v_MoppingFloor_g06_c01.avi
MoppingFloor/v_MoppingFloor_g06_c02.avi
MoppingFloor/v_MoppingFloor_g06_c03.avi
MoppingFloor/v_MoppingFloor_g06_c04.avi
MoppingFloor/v_MoppingFloor_g07_c01.avi
MoppingFloor/v_MoppingFloor_g07_c02.avi
MoppingFloor/v_MoppingFloor_g07_c03.avi
MoppingFloor/v_MoppingFloor_g07_c04.avi
MoppingFloor/v_MoppingFloor_g07_c05.avi
Nunchucks/v_Nunchucks_g01_c01.avi
Nunchucks/v_Nunchucks_g01_c02.avi
Nunchucks/v_Nunchucks_g01_c03.avi
Nunchucks/v_Nunchucks_g01_c04.avi
Nunchucks/v_Nunchucks_g02_c01.avi
Nunchucks/v_Nunchucks_g02_c02.avi
Nunchucks/v_Nunchucks_g02_c03.avi
Nunchucks/v_Nunchucks_g02_c04.avi
Nunchucks/v_Nunchucks_g02_c05.avi
Nunchucks/v_Nunchucks_g02_c06.avi
Nunchucks/v_Nunchucks_g03_c01.avi
Nunchucks/v_Nunchucks_g03_c02.avi
Nunchucks/v_Nunchucks_g03_c03.avi
Nunchucks/v_Nunchucks_g03_c04.avi
Nunchucks/v_Nunchucks_g03_c05.avi
Nunchucks/v_Nunchucks_g03_c06.avi
Nunchucks/v_Nunchucks_g03_c07.avi
Nunchucks/v_Nunchucks_g04_c01.avi
Nunchucks/v_Nunchucks_g04_c02.avi
Nunchucks/v_Nunchucks_g04_c03.avi
Nunchucks/v_Nunchucks_g04_c04.avi
Nunchucks/v_Nunchucks_g04_c05.avi
Nunchucks/v_Nunchucks_g04_c06.avi
Nunchucks/v_Nunchucks_g05_c01.avi
Nunchucks/v_Nunchucks_g05_c02.avi
Nunchucks/v_Nunchucks_g05_c03.avi
Nunchucks/v_Nunchucks_g05_c04.avi
Nunchucks/v_Nunchucks_g06_c01.avi
Nunchucks/v_Nunchucks_g06_c02.avi
Nunchucks/v_Nunchucks_g06_c03.avi
Nunchucks/v_Nunchucks_g06_c04.avi
Nunchucks/v_Nunchucks_g07_c01.avi
Nunchucks/v_Nunchucks_g07_c02.avi
Nunchucks/v_Nunchucks_g07_c03.avi
Nunchucks/v_Nunchucks_g07_c04.avi
ParallelBars/v_ParallelBars_g01_c01.avi
ParallelBars/v_ParallelBars_g01_c02.avi
ParallelBars/v_ParallelBars_g01_c03.avi
ParallelBars/v_ParallelBars_g01_c04.avi
ParallelBars/v_ParallelBars_g02_c01.avi
ParallelBars/v_ParallelBars_g02_c02.avi
ParallelBars/v_ParallelBars_g02_c03.avi
ParallelBars/v_ParallelBars_g02_c04.avi
ParallelBars/v_ParallelBars_g03_c01.avi
ParallelBars/v_ParallelBars_g03_c02.avi
ParallelBars/v_ParallelBars_g03_c03.avi
ParallelBars/v_ParallelBars_g03_c04.avi
ParallelBars/v_ParallelBars_g04_c01.avi
ParallelBars/v_ParallelBars_g04_c02.avi
ParallelBars/v_ParallelBars_g04_c03.avi
ParallelBars/v_ParallelBars_g04_c04.avi
ParallelBars/v_ParallelBars_g04_c05.avi
ParallelBars/v_ParallelBars_g04_c06.avi
ParallelBars/v_ParallelBars_g04_c07.avi
ParallelBars/v_ParallelBars_g05_c01.avi
ParallelBars/v_ParallelBars_g05_c02.avi
ParallelBars/v_ParallelBars_g05_c03.avi
ParallelBars/v_ParallelBars_g05_c04.avi
ParallelBars/v_ParallelBars_g05_c05.avi
ParallelBars/v_ParallelBars_g06_c01.avi
ParallelBars/v_ParallelBars_g06_c02.avi
ParallelBars/v_ParallelBars_g06_c03.avi
ParallelBars/v_ParallelBars_g06_c04.avi
ParallelBars/v_ParallelBars_g06_c05.avi
ParallelBars/v_ParallelBars_g06_c06.avi
ParallelBars/v_ParallelBars_g06_c07.avi
ParallelBars/v_ParallelBars_g07_c01.avi
ParallelBars/v_ParallelBars_g07_c02.avi
ParallelBars/v_ParallelBars_g07_c03.avi
ParallelBars/v_ParallelBars_g07_c04.avi
ParallelBars/v_ParallelBars_g07_c05.avi
ParallelBars/v_ParallelBars_g07_c06.avi
PizzaTossing/v_PizzaTossing_g01_c01.avi
PizzaTossing/v_PizzaTossing_g01_c02.avi
PizzaTossing/v_PizzaTossing_g01_c03.avi
PizzaTossing/v_PizzaTossing_g01_c04.avi
PizzaTossing/v_PizzaTossing_g02_c01.avi
PizzaTossing/v_PizzaTossing_g02_c02.avi
PizzaTossing/v_PizzaTossing_g02_c03.avi
PizzaTossing/v_PizzaTossing_g02_c04.avi
PizzaTossing/v_PizzaTossing_g02_c05.avi
PizzaTossing/v_PizzaTossing_g03_c01.avi
PizzaTossing/v_PizzaTossing_g03_c02.avi
PizzaTossing/v_PizzaTossing_g03_c03.avi
PizzaTossing/v_PizzaTossing_g03_c04.avi
PizzaTossing/v_PizzaTossing_g04_c01.avi
PizzaTossing/v_PizzaTossing_g04_c02.avi
PizzaTossing/v_PizzaTossing_g04_c03.avi
PizzaTossing/v_PizzaTossing_g04_c04.avi
PizzaTossing/v_PizzaTossing_g04_c05.avi
PizzaTossing/v_PizzaTossing_g04_c06.avi
PizzaTossing/v_PizzaTossing_g04_c07.avi
PizzaTossing/v_PizzaTossing_g05_c01.avi
PizzaTossing/v_PizzaTossing_g05_c02.avi
PizzaTossing/v_PizzaTossing_g05_c03.avi
PizzaTossing/v_PizzaTossing_g05_c04.avi
PizzaTossing/v_PizzaTossing_g06_c01.avi
PizzaTossing/v_PizzaTossing_g06_c02.avi
PizzaTossing/v_PizzaTossing_g06_c03.avi
PizzaTossing/v_PizzaTossing_g06_c04.avi
PizzaTossing/v_PizzaTossing_g06_c05.avi
PizzaTossing/v_PizzaTossing_g07_c01.avi
PizzaTossing/v_PizzaTossing_g07_c02.avi
PizzaTossing/v_PizzaTossing_g07_c03.avi
PizzaTossing/v_PizzaTossing_g07_c04.avi
PlayingCello/v_PlayingCello_g01_c01.avi
PlayingCello/v_PlayingCello_g01_c02.avi
PlayingCello/v_PlayingCello_g01_c03.avi
PlayingCello/v_PlayingCello_g01_c04.avi
PlayingCello/v_PlayingCello_g01_c05.avi
PlayingCello/v_PlayingCello_g01_c06.avi
PlayingCello/v_PlayingCello_g01_c07.avi
PlayingCello/v_PlayingCello_g02_c01.avi
PlayingCello/v_PlayingCello_g02_c02.avi
PlayingCello/v_PlayingCello_g02_c03.avi
PlayingCello/v_PlayingCello_g02_c04.avi
PlayingCello/v_PlayingCello_g02_c05.avi
PlayingCello/v_PlayingCello_g02_c06.avi
PlayingCello/v_PlayingCello_g02_c07.avi
PlayingCello/v_PlayingCello_g03_c01.avi
PlayingCello/v_PlayingCello_g03_c02.avi
PlayingCello/v_PlayingCello_g03_c03.avi
PlayingCello/v_PlayingCello_g03_c04.avi
PlayingCello/v_PlayingCello_g04_c01.avi
PlayingCello/v_PlayingCello_g04_c02.avi
PlayingCello/v_PlayingCello_g04_c03.avi
PlayingCello/v_PlayingCello_g04_c04.avi
PlayingCello/v_PlayingCello_g04_c05.avi
PlayingCello/v_PlayingCello_g04_c06.avi
PlayingCello/v_PlayingCello_g04_c07.avi
PlayingCello/v_PlayingCello_g05_c01.avi
PlayingCello/v_PlayingCello_g05_c02.avi
PlayingCello/v_PlayingCello_g05_c03.avi
PlayingCello/v_PlayingCello_g05_c04.avi
PlayingCello/v_PlayingCello_g05_c05.avi
PlayingCello/v_PlayingCello_g05_c06.avi
PlayingCello/v_PlayingCello_g05_c07.avi
PlayingCello/v_PlayingCello_g06_c01.avi
PlayingCello/v_PlayingCello_g06_c02.avi
PlayingCello/v_PlayingCello_g06_c03.avi
PlayingCello/v_PlayingCello_g06_c04.avi
PlayingCello/v_PlayingCello_g06_c05.avi
PlayingCello/v_PlayingCello_g06_c06.avi
PlayingCello/v_PlayingCello_g06_c07.avi
PlayingCello/v_PlayingCello_g07_c01.avi
PlayingCello/v_PlayingCello_g07_c02.avi
PlayingCello/v_PlayingCello_g07_c03.avi
PlayingCello/v_PlayingCello_g07_c04.avi
PlayingCello/v_PlayingCello_g07_c05.avi
PlayingDaf/v_PlayingDaf_g01_c01.avi
PlayingDaf/v_PlayingDaf_g01_c02.avi
PlayingDaf/v_PlayingDaf_g01_c03.avi
PlayingDaf/v_PlayingDaf_g01_c04.avi
PlayingDaf/v_PlayingDaf_g02_c01.avi
PlayingDaf/v_PlayingDaf_g02_c02.avi
PlayingDaf/v_PlayingDaf_g02_c03.avi
PlayingDaf/v_PlayingDaf_g02_c04.avi
PlayingDaf/v_PlayingDaf_g02_c05.avi
PlayingDaf/v_PlayingDaf_g02_c06.avi
PlayingDaf/v_PlayingDaf_g02_c07.avi
PlayingDaf/v_PlayingDaf_g03_c01.avi
PlayingDaf/v_PlayingDaf_g03_c02.avi
PlayingDaf/v_PlayingDaf_g03_c03.avi
PlayingDaf/v_PlayingDaf_g03_c04.avi
PlayingDaf/v_PlayingDaf_g04_c01.avi
PlayingDaf/v_PlayingDaf_g04_c02.avi
PlayingDaf/v_PlayingDaf_g04_c03.avi
PlayingDaf/v_PlayingDaf_g04_c04.avi
PlayingDaf/v_PlayingDaf_g04_c05.avi
PlayingDaf/v_PlayingDaf_g04_c06.avi
PlayingDaf/v_PlayingDaf_g04_c07.avi
PlayingDaf/v_PlayingDaf_g05_c01.avi
PlayingDaf/v_PlayingDaf_g05_c02.avi
PlayingDaf/v_PlayingDaf_g05_c03.avi
PlayingDaf/v_PlayingDaf_g05_c04.avi
PlayingDaf/v_PlayingDaf_g05_c05.avi
PlayingDaf/v_PlayingDaf_g05_c06.avi
PlayingDaf/v_PlayingDaf_g05_c07.avi
PlayingDaf/v_PlayingDaf_g06_c01.avi
PlayingDaf/v_PlayingDaf_g06_c02.avi
PlayingDaf/v_PlayingDaf_g06_c03.avi
PlayingDaf/v_PlayingDaf_g06_c04.avi
PlayingDaf/v_PlayingDaf_g06_c05.avi
PlayingDaf/v_PlayingDaf_g06_c06.avi
PlayingDaf/v_PlayingDaf_g06_c07.avi
PlayingDaf/v_PlayingDaf_g07_c01.avi
PlayingDaf/v_PlayingDaf_g07_c02.avi
PlayingDaf/v_PlayingDaf_g07_c03.avi
PlayingDaf/v_PlayingDaf_g07_c04.avi
PlayingDaf/v_PlayingDaf_g07_c05.avi
PlayingDhol/v_PlayingDhol_g01_c01.avi
PlayingDhol/v_PlayingDhol_g01_c02.avi
PlayingDhol/v_PlayingDhol_g01_c03.avi
PlayingDhol/v_PlayingDhol_g01_c04.avi
PlayingDhol/v_PlayingDhol_g01_c05.avi
PlayingDhol/v_PlayingDhol_g01_c06.avi
PlayingDhol/v_PlayingDhol_g01_c07.avi
PlayingDhol/v_PlayingDhol_g02_c01.avi
PlayingDhol/v_PlayingDhol_g02_c02.avi
PlayingDhol/v_PlayingDhol_g02_c03.avi
PlayingDhol/v_PlayingDhol_g02_c04.avi
PlayingDhol/v_PlayingDhol_g02_c05.avi
PlayingDhol/v_PlayingDhol_g02_c06.avi
PlayingDhol/v_PlayingDhol_g02_c07.avi
PlayingDhol/v_PlayingDhol_g03_c01.avi
PlayingDhol/v_PlayingDhol_g03_c02.avi
PlayingDhol/v_PlayingDhol_g03_c03.avi
PlayingDhol/v_PlayingDhol_g03_c04.avi
PlayingDhol/v_PlayingDhol_g03_c05.avi
PlayingDhol/v_PlayingDhol_g03_c06.avi
PlayingDhol/v_PlayingDhol_g03_c07.avi
PlayingDhol/v_PlayingDhol_g04_c01.avi
PlayingDhol/v_PlayingDhol_g04_c02.avi
PlayingDhol/v_PlayingDhol_g04_c03.avi
PlayingDhol/v_PlayingDhol_g04_c04.avi
PlayingDhol/v_PlayingDhol_g04_c05.avi
PlayingDhol/v_PlayingDhol_g04_c06.avi
PlayingDhol/v_PlayingDhol_g04_c07.avi
PlayingDhol/v_PlayingDhol_g05_c01.avi
PlayingDhol/v_PlayingDhol_g05_c02.avi
PlayingDhol/v_PlayingDhol_g05_c03.avi
PlayingDhol/v_PlayingDhol_g05_c04.avi
PlayingDhol/v_PlayingDhol_g05_c05.avi
PlayingDhol/v_PlayingDhol_g05_c06.avi
PlayingDhol/v_PlayingDhol_g05_c07.avi
PlayingDhol/v_PlayingDhol_g06_c01.avi
PlayingDhol/v_PlayingDhol_g06_c02.avi
PlayingDhol/v_PlayingDhol_g06_c03.avi
PlayingDhol/v_PlayingDhol_g06_c04.avi
PlayingDhol/v_PlayingDhol_g06_c05.avi
PlayingDhol/v_PlayingDhol_g06_c06.avi
PlayingDhol/v_PlayingDhol_g06_c07.avi
PlayingDhol/v_PlayingDhol_g07_c01.avi
PlayingDhol/v_PlayingDhol_g07_c02.avi
PlayingDhol/v_PlayingDhol_g07_c03.avi
PlayingDhol/v_PlayingDhol_g07_c04.avi
PlayingDhol/v_PlayingDhol_g07_c05.avi
PlayingDhol/v_PlayingDhol_g07_c06.avi
PlayingDhol/v_PlayingDhol_g07_c07.avi
PlayingFlute/v_PlayingFlute_g01_c01.avi
PlayingFlute/v_PlayingFlute_g01_c02.avi
PlayingFlute/v_PlayingFlute_g01_c03.avi
PlayingFlute/v_PlayingFlute_g01_c04.avi
PlayingFlute/v_PlayingFlute_g01_c05.avi
PlayingFlute/v_PlayingFlute_g01_c06.avi
PlayingFlute/v_PlayingFlute_g01_c07.avi
PlayingFlute/v_PlayingFlute_g02_c01.avi
PlayingFlute/v_PlayingFlute_g02_c02.avi
PlayingFlute/v_PlayingFlute_g02_c03.avi
PlayingFlute/v_PlayingFlute_g02_c04.avi
PlayingFlute/v_PlayingFlute_g02_c05.avi
PlayingFlute/v_PlayingFlute_g02_c06.avi
PlayingFlute/v_PlayingFlute_g02_c07.avi
PlayingFlute/v_PlayingFlute_g03_c01.avi
PlayingFlute/v_PlayingFlute_g03_c02.avi
PlayingFlute/v_PlayingFlute_g03_c03.avi
PlayingFlute/v_PlayingFlute_g03_c04.avi
PlayingFlute/v_PlayingFlute_g03_c05.avi
PlayingFlute/v_PlayingFlute_g03_c06.avi
PlayingFlute/v_PlayingFlute_g03_c07.avi
PlayingFlute/v_PlayingFlute_g04_c01.avi
PlayingFlute/v_PlayingFlute_g04_c02.avi
PlayingFlute/v_PlayingFlute_g04_c03.avi
PlayingFlute/v_PlayingFlute_g04_c04.avi
PlayingFlute/v_PlayingFlute_g04_c05.avi
PlayingFlute/v_PlayingFlute_g04_c06.avi
PlayingFlute/v_PlayingFlute_g04_c07.avi
PlayingFlute/v_PlayingFlute_g05_c01.avi
PlayingFlute/v_PlayingFlute_g05_c02.avi
PlayingFlute/v_PlayingFlute_g05_c03.avi
PlayingFlute/v_PlayingFlute_g05_c04.avi
PlayingFlute/v_PlayingFlute_g05_c05.avi
PlayingFlute/v_PlayingFlute_g05_c06.avi
PlayingFlute/v_PlayingFlute_g05_c07.avi
PlayingFlute/v_PlayingFlute_g06_c01.avi
PlayingFlute/v_PlayingFlute_g06_c02.avi
PlayingFlute/v_PlayingFlute_g06_c03.avi
PlayingFlute/v_PlayingFlute_g06_c04.avi
PlayingFlute/v_PlayingFlute_g06_c05.avi
PlayingFlute/v_PlayingFlute_g06_c06.avi
PlayingFlute/v_PlayingFlute_g07_c01.avi
PlayingFlute/v_PlayingFlute_g07_c02.avi
PlayingFlute/v_PlayingFlute_g07_c03.avi
PlayingFlute/v_PlayingFlute_g07_c04.avi
PlayingFlute/v_PlayingFlute_g07_c05.avi
PlayingFlute/v_PlayingFlute_g07_c06.avi
PlayingFlute/v_PlayingFlute_g07_c07.avi
PlayingGuitar/v_PlayingGuitar_g01_c01.avi
PlayingGuitar/v_PlayingGuitar_g01_c02.avi
PlayingGuitar/v_PlayingGuitar_g01_c03.avi
PlayingGuitar/v_PlayingGuitar_g01_c04.avi
PlayingGuitar/v_PlayingGuitar_g01_c05.avi
PlayingGuitar/v_PlayingGuitar_g01_c06.avi
PlayingGuitar/v_PlayingGuitar_g02_c01.avi
PlayingGuitar/v_PlayingGuitar_g02_c02.avi
PlayingGuitar/v_PlayingGuitar_g02_c03.avi
PlayingGuitar/v_PlayingGuitar_g02_c04.avi
PlayingGuitar/v_PlayingGuitar_g03_c01.avi
PlayingGuitar/v_PlayingGuitar_g03_c02.avi
PlayingGuitar/v_PlayingGuitar_g03_c03.avi
PlayingGuitar/v_PlayingGuitar_g03_c04.avi
PlayingGuitar/v_PlayingGuitar_g03_c05.avi
PlayingGuitar/v_PlayingGuitar_g03_c06.avi
PlayingGuitar/v_PlayingGuitar_g03_c07.avi
PlayingGuitar/v_PlayingGuitar_g04_c01.avi
PlayingGuitar/v_PlayingGuitar_g04_c02.avi
PlayingGuitar/v_PlayingGuitar_g04_c03.avi
PlayingGuitar/v_PlayingGuitar_g04_c04.avi
PlayingGuitar/v_PlayingGuitar_g04_c05.avi
PlayingGuitar/v_PlayingGuitar_g04_c06.avi
PlayingGuitar/v_PlayingGuitar_g04_c07.avi
PlayingGuitar/v_PlayingGuitar_g05_c01.avi
PlayingGuitar/v_PlayingGuitar_g05_c02.avi
PlayingGuitar/v_PlayingGuitar_g05_c03.avi
PlayingGuitar/v_PlayingGuitar_g05_c04.avi
PlayingGuitar/v_PlayingGuitar_g05_c05.avi
PlayingGuitar/v_PlayingGuitar_g06_c01.avi
PlayingGuitar/v_PlayingGuitar_g06_c02.avi
PlayingGuitar/v_PlayingGuitar_g06_c03.avi
PlayingGuitar/v_PlayingGuitar_g06_c04.avi
PlayingGuitar/v_PlayingGuitar_g06_c05.avi
PlayingGuitar/v_PlayingGuitar_g06_c06.avi
PlayingGuitar/v_PlayingGuitar_g06_c07.avi
PlayingGuitar/v_PlayingGuitar_g07_c01.avi
PlayingGuitar/v_PlayingGuitar_g07_c02.avi
PlayingGuitar/v_PlayingGuitar_g07_c03.avi
PlayingGuitar/v_PlayingGuitar_g07_c04.avi
PlayingGuitar/v_PlayingGuitar_g07_c05.avi
PlayingGuitar/v_PlayingGuitar_g07_c06.avi
PlayingGuitar/v_PlayingGuitar_g07_c07.avi
PlayingPiano/v_PlayingPiano_g01_c01.avi
PlayingPiano/v_PlayingPiano_g01_c02.avi
PlayingPiano/v_PlayingPiano_g01_c03.avi
PlayingPiano/v_PlayingPiano_g01_c04.avi
PlayingPiano/v_PlayingPiano_g02_c01.avi
PlayingPiano/v_PlayingPiano_g02_c02.avi
PlayingPiano/v_PlayingPiano_g02_c03.avi
PlayingPiano/v_PlayingPiano_g02_c04.avi
PlayingPiano/v_PlayingPiano_g03_c01.avi
PlayingPiano/v_PlayingPiano_g03_c02.avi
PlayingPiano/v_PlayingPiano_g03_c03.avi
PlayingPiano/v_PlayingPiano_g03_c04.avi
PlayingPiano/v_PlayingPiano_g04_c01.avi
PlayingPiano/v_PlayingPiano_g04_c02.avi
PlayingPiano/v_PlayingPiano_g04_c03.avi
PlayingPiano/v_PlayingPiano_g04_c04.avi
PlayingPiano/v_PlayingPiano_g05_c01.avi
PlayingPiano/v_PlayingPiano_g05_c02.avi
PlayingPiano/v_PlayingPiano_g05_c03.avi
PlayingPiano/v_PlayingPiano_g05_c04.avi
PlayingPiano/v_PlayingPiano_g06_c01.avi
PlayingPiano/v_PlayingPiano_g06_c02.avi
PlayingPiano/v_PlayingPiano_g06_c03.avi
PlayingPiano/v_PlayingPiano_g06_c04.avi
PlayingPiano/v_PlayingPiano_g07_c01.avi
PlayingPiano/v_PlayingPiano_g07_c02.avi
PlayingPiano/v_PlayingPiano_g07_c03.avi
PlayingPiano/v_PlayingPiano_g07_c04.avi
PlayingSitar/v_PlayingSitar_g01_c01.avi
PlayingSitar/v_PlayingSitar_g01_c02.avi
PlayingSitar/v_PlayingSitar_g01_c03.avi
PlayingSitar/v_PlayingSitar_g01_c04.avi
PlayingSitar/v_PlayingSitar_g02_c01.avi
PlayingSitar/v_PlayingSitar_g02_c02.avi
PlayingSitar/v_PlayingSitar_g02_c03.avi
PlayingSitar/v_PlayingSitar_g02_c04.avi
PlayingSitar/v_PlayingSitar_g02_c05.avi
PlayingSitar/v_PlayingSitar_g02_c06.avi
PlayingSitar/v_PlayingSitar_g03_c01.avi
PlayingSitar/v_PlayingSitar_g03_c02.avi
PlayingSitar/v_PlayingSitar_g03_c03.avi
PlayingSitar/v_PlayingSitar_g03_c04.avi
PlayingSitar/v_PlayingSitar_g03_c05.avi
PlayingSitar/v_PlayingSitar_g03_c06.avi
PlayingSitar/v_PlayingSitar_g03_c07.avi
PlayingSitar/v_PlayingSitar_g04_c01.avi
PlayingSitar/v_PlayingSitar_g04_c02.avi
PlayingSitar/v_PlayingSitar_g04_c03.avi
PlayingSitar/v_PlayingSitar_g04_c04.avi
PlayingSitar/v_PlayingSitar_g04_c05.avi
PlayingSitar/v_PlayingSitar_g04_c06.avi
PlayingSitar/v_PlayingSitar_g04_c07.avi
PlayingSitar/v_PlayingSitar_g05_c01.avi
PlayingSitar/v_PlayingSitar_g05_c02.avi
PlayingSitar/v_PlayingSitar_g05_c03.avi
PlayingSitar/v_PlayingSitar_g05_c04.avi
PlayingSitar/v_PlayingSitar_g05_c05.avi
PlayingSitar/v_PlayingSitar_g05_c06.avi
PlayingSitar/v_PlayingSitar_g05_c07.avi
PlayingSitar/v_PlayingSitar_g06_c01.avi
PlayingSitar/v_PlayingSitar_g06_c02.avi
PlayingSitar/v_PlayingSitar_g06_c03.avi
PlayingSitar/v_PlayingSitar_g06_c04.avi
PlayingSitar/v_PlayingSitar_g06_c05.avi
PlayingSitar/v_PlayingSitar_g06_c06.avi
PlayingSitar/v_PlayingSitar_g07_c01.avi
PlayingSitar/v_PlayingSitar_g07_c02.avi
PlayingSitar/v_PlayingSitar_g07_c03.avi
PlayingSitar/v_PlayingSitar_g07_c04.avi
PlayingSitar/v_PlayingSitar_g07_c05.avi
PlayingSitar/v_PlayingSitar_g07_c06.avi
PlayingSitar/v_PlayingSitar_g07_c07.avi
PlayingTabla/v_PlayingTabla_g01_c01.avi
PlayingTabla/v_PlayingTabla_g01_c02.avi
PlayingTabla/v_PlayingTabla_g01_c03.avi
PlayingTabla/v_PlayingTabla_g01_c04.avi
PlayingTabla/v_PlayingTabla_g02_c01.avi
PlayingTabla/v_PlayingTabla_g02_c02.avi
PlayingTabla/v_PlayingTabla_g02_c03.avi
PlayingTabla/v_PlayingTabla_g02_c04.avi
PlayingTabla/v_PlayingTabla_g03_c01.avi
PlayingTabla/v_PlayingTabla_g03_c02.avi
PlayingTabla/v_PlayingTabla_g03_c03.avi
PlayingTabla/v_PlayingTabla_g03_c04.avi
PlayingTabla/v_PlayingTabla_g03_c05.avi
PlayingTabla/v_PlayingTabla_g04_c01.avi
PlayingTabla/v_PlayingTabla_g04_c02.avi
PlayingTabla/v_PlayingTabla_g04_c03.avi
PlayingTabla/v_PlayingTabla_g04_c04.avi
PlayingTabla/v_PlayingTabla_g04_c05.avi
PlayingTabla/v_PlayingTabla_g04_c06.avi
PlayingTabla/v_PlayingTabla_g05_c01.avi
PlayingTabla/v_PlayingTabla_g05_c02.avi
PlayingTabla/v_PlayingTabla_g05_c03.avi
PlayingTabla/v_PlayingTabla_g05_c04.avi
PlayingTabla/v_PlayingTabla_g06_c01.avi
PlayingTabla/v_PlayingTabla_g06_c02.avi
PlayingTabla/v_PlayingTabla_g06_c03.avi
PlayingTabla/v_PlayingTabla_g06_c04.avi
PlayingTabla/v_PlayingTabla_g07_c01.avi
PlayingTabla/v_PlayingTabla_g07_c02.avi
PlayingTabla/v_PlayingTabla_g07_c03.avi
PlayingTabla/v_PlayingTabla_g07_c04.avi
PlayingViolin/v_PlayingViolin_g01_c01.avi
PlayingViolin/v_PlayingViolin_g01_c02.avi
PlayingViolin/v_PlayingViolin_g01_c03.avi
PlayingViolin/v_PlayingViolin_g01_c04.avi
PlayingViolin/v_PlayingViolin_g02_c01.avi
PlayingViolin/v_PlayingViolin_g02_c02.avi
PlayingViolin/v_PlayingViolin_g02_c03.avi
PlayingViolin/v_PlayingViolin_g02_c04.avi
PlayingViolin/v_PlayingViolin_g03_c01.avi
PlayingViolin/v_PlayingViolin_g03_c02.avi
PlayingViolin/v_PlayingViolin_g03_c03.avi
PlayingViolin/v_PlayingViolin_g03_c04.avi
PlayingViolin/v_PlayingViolin_g04_c01.avi
PlayingViolin/v_PlayingViolin_g04_c02.avi
PlayingViolin/v_PlayingViolin_g04_c03.avi
PlayingViolin/v_PlayingViolin_g04_c04.avi
PlayingViolin/v_PlayingViolin_g05_c01.avi
PlayingViolin/v_PlayingViolin_g05_c02.avi
PlayingViolin/v_PlayingViolin_g05_c03.avi
PlayingViolin/v_PlayingViolin_g05_c04.avi
PlayingViolin/v_PlayingViolin_g06_c01.avi
PlayingViolin/v_PlayingViolin_g06_c02.avi
PlayingViolin/v_PlayingViolin_g06_c03.avi
PlayingViolin/v_PlayingViolin_g06_c04.avi
PlayingViolin/v_PlayingViolin_g07_c01.avi
PlayingViolin/v_PlayingViolin_g07_c02.avi
PlayingViolin/v_PlayingViolin_g07_c03.avi
PlayingViolin/v_PlayingViolin_g07_c04.avi
PoleVault/v_PoleVault_g01_c01.avi
PoleVault/v_PoleVault_g01_c02.avi
PoleVault/v_PoleVault_g01_c03.avi
PoleVault/v_PoleVault_g01_c04.avi
PoleVault/v_PoleVault_g01_c05.avi
PoleVault/v_PoleVault_g02_c01.avi
PoleVault/v_PoleVault_g02_c02.avi
PoleVault/v_PoleVault_g02_c03.avi
PoleVault/v_PoleVault_g02_c04.avi
PoleVault/v_PoleVault_g02_c05.avi
PoleVault/v_PoleVault_g02_c06.avi
PoleVault/v_PoleVault_g02_c07.avi
PoleVault/v_PoleVault_g03_c01.avi
PoleVault/v_PoleVault_g03_c02.avi
PoleVault/v_PoleVault_g03_c03.avi
PoleVault/v_PoleVault_g03_c04.avi
PoleVault/v_PoleVault_g03_c05.avi
PoleVault/v_PoleVault_g03_c06.avi
PoleVault/v_PoleVault_g03_c07.avi
PoleVault/v_PoleVault_g04_c01.avi
PoleVault/v_PoleVault_g04_c02.avi
PoleVault/v_PoleVault_g04_c03.avi
PoleVault/v_PoleVault_g04_c04.avi
PoleVault/v_PoleVault_g04_c05.avi
PoleVault/v_PoleVault_g04_c06.avi
PoleVault/v_PoleVault_g04_c07.avi
PoleVault/v_PoleVault_g05_c01.avi
PoleVault/v_PoleVault_g05_c02.avi
PoleVault/v_PoleVault_g05_c03.avi
PoleVault/v_PoleVault_g05_c04.avi
PoleVault/v_PoleVault_g05_c05.avi
PoleVault/v_PoleVault_g06_c01.avi
PoleVault/v_PoleVault_g06_c02.avi
PoleVault/v_PoleVault_g06_c03.avi
PoleVault/v_PoleVault_g06_c04.avi
PoleVault/v_PoleVault_g06_c05.avi
PoleVault/v_PoleVault_g07_c01.avi
PoleVault/v_PoleVault_g07_c02.avi
PoleVault/v_PoleVault_g07_c03.avi
PoleVault/v_PoleVault_g07_c04.avi
PommelHorse/v_PommelHorse_g01_c01.avi
PommelHorse/v_PommelHorse_g01_c02.avi
PommelHorse/v_PommelHorse_g01_c03.avi
PommelHorse/v_PommelHorse_g01_c04.avi
PommelHorse/v_PommelHorse_g01_c05.avi
PommelHorse/v_PommelHorse_g01_c06.avi
PommelHorse/v_PommelHorse_g01_c07.avi
PommelHorse/v_PommelHorse_g02_c01.avi
PommelHorse/v_PommelHorse_g02_c02.avi
PommelHorse/v_PommelHorse_g02_c03.avi
PommelHorse/v_PommelHorse_g02_c04.avi
PommelHorse/v_PommelHorse_g03_c01.avi
PommelHorse/v_PommelHorse_g03_c02.avi
PommelHorse/v_PommelHorse_g03_c03.avi
PommelHorse/v_PommelHorse_g03_c04.avi
PommelHorse/v_PommelHorse_g04_c01.avi
PommelHorse/v_PommelHorse_g04_c02.avi
PommelHorse/v_PommelHorse_g04_c03.avi
PommelHorse/v_PommelHorse_g04_c04.avi
PommelHorse/v_PommelHorse_g04_c05.avi
PommelHorse/v_PommelHorse_g05_c01.avi
PommelHorse/v_PommelHorse_g05_c02.avi
PommelHorse/v_PommelHorse_g05_c03.avi
PommelHorse/v_PommelHorse_g05_c04.avi
PommelHorse/v_PommelHorse_g06_c01.avi
PommelHorse/v_PommelHorse_g06_c02.avi
PommelHorse/v_PommelHorse_g06_c03.avi
PommelHorse/v_PommelHorse_g06_c04.avi
PommelHorse/v_PommelHorse_g07_c01.avi
PommelHorse/v_PommelHorse_g07_c02.avi
PommelHorse/v_PommelHorse_g07_c03.avi
PommelHorse/v_PommelHorse_g07_c04.avi
PommelHorse/v_PommelHorse_g07_c05.avi
PommelHorse/v_PommelHorse_g07_c06.avi
PommelHorse/v_PommelHorse_g07_c07.avi
PullUps/v_PullUps_g01_c01.avi
PullUps/v_PullUps_g01_c02.avi
PullUps/v_PullUps_g01_c03.avi
PullUps/v_PullUps_g01_c04.avi
PullUps/v_PullUps_g02_c01.avi
PullUps/v_PullUps_g02_c02.avi
PullUps/v_PullUps_g02_c03.avi
PullUps/v_PullUps_g02_c04.avi
PullUps/v_PullUps_g03_c01.avi
PullUps/v_PullUps_g03_c02.avi
PullUps/v_PullUps_g03_c03.avi
PullUps/v_PullUps_g03_c04.avi
PullUps/v_PullUps_g04_c01.avi
PullUps/v_PullUps_g04_c02.avi
PullUps/v_PullUps_g04_c03.avi
PullUps/v_PullUps_g04_c04.avi
PullUps/v_PullUps_g05_c01.avi
PullUps/v_PullUps_g05_c02.avi
PullUps/v_PullUps_g05_c03.avi
PullUps/v_PullUps_g05_c04.avi
PullUps/v_PullUps_g06_c01.avi
PullUps/v_PullUps_g06_c02.avi
PullUps/v_PullUps_g06_c03.avi
PullUps/v_PullUps_g06_c04.avi
PullUps/v_PullUps_g07_c01.avi
PullUps/v_PullUps_g07_c02.avi
PullUps/v_PullUps_g07_c03.avi
PullUps/v_PullUps_g07_c04.avi
Punch/v_Punch_g01_c01.avi
Punch/v_Punch_g01_c02.avi
Punch/v_Punch_g01_c03.avi
Punch/v_Punch_g01_c04.avi
Punch/v_Punch_g01_c05.avi
Punch/v_Punch_g02_c01.avi
Punch/v_Punch_g02_c02.avi
Punch/v_Punch_g02_c03.avi
Punch/v_Punch_g02_c04.avi
Punch/v_Punch_g03_c01.avi
Punch/v_Punch_g03_c02.avi
Punch/v_Punch_g03_c03.avi
Punch/v_Punch_g03_c04.avi
Punch/v_Punch_g04_c01.avi
Punch/v_Punch_g04_c02.avi
Punch/v_Punch_g04_c03.avi
Punch/v_Punch_g04_c04.avi
Punch/v_Punch_g04_c05.avi
Punch/v_Punch_g05_c01.avi
Punch/v_Punch_g05_c02.avi
Punch/v_Punch_g05_c03.avi
Punch/v_Punch_g05_c04.avi
Punch/v_Punch_g05_c05.avi
Punch/v_Punch_g05_c06.avi
Punch/v_Punch_g05_c07.avi
Punch/v_Punch_g06_c01.avi
Punch/v_Punch_g06_c02.avi
Punch/v_Punch_g06_c03.avi
Punch/v_Punch_g06_c04.avi
Punch/v_Punch_g06_c05.avi
Punch/v_Punch_g06_c06.avi
Punch/v_Punch_g06_c07.avi
Punch/v_Punch_g07_c01.avi
Punch/v_Punch_g07_c02.avi
Punch/v_Punch_g07_c03.avi
Punch/v_Punch_g07_c04.avi
Punch/v_Punch_g07_c05.avi
Punch/v_Punch_g07_c06.avi
Punch/v_Punch_g07_c07.avi
PushUps/v_PushUps_g01_c01.avi
PushUps/v_PushUps_g01_c02.avi
PushUps/v_PushUps_g01_c03.avi
PushUps/v_PushUps_g01_c04.avi
PushUps/v_PushUps_g01_c05.avi
PushUps/v_PushUps_g02_c01.avi
PushUps/v_PushUps_g02_c02.avi
PushUps/v_PushUps_g02_c03.avi
PushUps/v_PushUps_g02_c04.avi
PushUps/v_PushUps_g03_c01.avi
PushUps/v_PushUps_g03_c02.avi
PushUps/v_PushUps_g03_c03.avi
PushUps/v_PushUps_g03_c04.avi
PushUps/v_PushUps_g04_c01.avi
PushUps/v_PushUps_g04_c02.avi
PushUps/v_PushUps_g04_c03.avi
PushUps/v_PushUps_g04_c04.avi
PushUps/v_PushUps_g04_c05.avi
PushUps/v_PushUps_g05_c01.avi
PushUps/v_PushUps_g05_c02.avi
PushUps/v_PushUps_g05_c03.avi
PushUps/v_PushUps_g05_c04.avi
PushUps/v_PushUps_g06_c01.avi
PushUps/v_PushUps_g06_c02.avi
PushUps/v_PushUps_g06_c03.avi
PushUps/v_PushUps_g06_c04.avi
PushUps/v_PushUps_g07_c01.avi
PushUps/v_PushUps_g07_c02.avi
PushUps/v_PushUps_g07_c03.avi
PushUps/v_PushUps_g07_c04.avi
Rafting/v_Rafting_g01_c01.avi
Rafting/v_Rafting_g01_c02.avi
Rafting/v_Rafting_g01_c03.avi
Rafting/v_Rafting_g01_c04.avi
Rafting/v_Rafting_g02_c01.avi
Rafting/v_Rafting_g02_c02.avi
Rafting/v_Rafting_g02_c03.avi
Rafting/v_Rafting_g02_c04.avi
Rafting/v_Rafting_g03_c01.avi
Rafting/v_Rafting_g03_c02.avi
Rafting/v_Rafting_g03_c03.avi
Rafting/v_Rafting_g03_c04.avi
Rafting/v_Rafting_g04_c01.avi
Rafting/v_Rafting_g04_c02.avi
Rafting/v_Rafting_g04_c03.avi
Rafting/v_Rafting_g04_c04.avi
Rafting/v_Rafting_g05_c01.avi
Rafting/v_Rafting_g05_c02.avi
Rafting/v_Rafting_g05_c03.avi
Rafting/v_Rafting_g05_c04.avi
Rafting/v_Rafting_g06_c01.avi
Rafting/v_Rafting_g06_c02.avi
Rafting/v_Rafting_g06_c03.avi
Rafting/v_Rafting_g06_c04.avi
Rafting/v_Rafting_g07_c01.avi
Rafting/v_Rafting_g07_c02.avi
Rafting/v_Rafting_g07_c03.avi
Rafting/v_Rafting_g07_c04.avi
RockClimbingIndoor/v_RockClimbingIndoor_g01_c01.avi
RockClimbingIndoor/v_RockClimbingIndoor_g01_c02.avi
RockClimbingIndoor/v_RockClimbingIndoor_g01_c03.avi
RockClimbingIndoor/v_RockClimbingIndoor_g01_c04.avi
RockClimbingIndoor/v_RockClimbingIndoor_g01_c05.avi
RockClimbingIndoor/v_RockClimbingIndoor_g02_c01.avi
RockClimbingIndoor/v_RockClimbingIndoor_g02_c02.avi
RockClimbingIndoor/v_RockClimbingIndoor_g02_c03.avi
RockClimbingIndoor/v_RockClimbingIndoor_g02_c04.avi
RockClimbingIndoor/v_RockClimbingIndoor_g02_c05.avi
RockClimbingIndoor/v_RockClimbingIndoor_g03_c01.avi
RockClimbingIndoor/v_RockClimbingIndoor_g03_c02.avi
RockClimbingIndoor/v_RockClimbingIndoor_g03_c03.avi
RockClimbingIndoor/v_RockClimbingIndoor_g03_c04.avi
RockClimbingIndoor/v_RockClimbingIndoor_g03_c05.avi
RockClimbingIndoor/v_RockClimbingIndoor_g03_c06.avi
RockClimbingIndoor/v_RockClimbingIndoor_g03_c07.avi
RockClimbingIndoor/v_RockClimbingIndoor_g04_c01.avi
RockClimbingIndoor/v_RockClimbingIndoor_g04_c02.avi
RockClimbingIndoor/v_RockClimbingIndoor_g04_c03.avi
RockClimbingIndoor/v_RockClimbingIndoor_g04_c04.avi
RockClimbingIndoor/v_RockClimbingIndoor_g05_c01.avi
RockClimbingIndoor/v_RockClimbingIndoor_g05_c02.avi
RockClimbingIndoor/v_RockClimbingIndoor_g05_c03.avi
RockClimbingIndoor/v_RockClimbingIndoor_g05_c04.avi
RockClimbingIndoor/v_RockClimbingIndoor_g05_c05.avi
RockClimbingIndoor/v_RockClimbingIndoor_g05_c06.avi
RockClimbingIndoor/v_RockClimbingIndoor_g06_c01.avi
RockClimbingIndoor/v_RockClimbingIndoor_g06_c02.avi
RockClimbingIndoor/v_RockClimbingIndoor_g06_c03.avi
RockClimbingIndoor/v_RockClimbingIndoor_g06_c04.avi
RockClimbingIndoor/v_RockClimbingIndoor_g06_c05.avi
RockClimbingIndoor/v_RockClimbingIndoor_g06_c06.avi
RockClimbingIndoor/v_RockClimbingIndoor_g06_c07.avi
RockClimbingIndoor/v_RockClimbingIndoor_g07_c01.avi
RockClimbingIndoor/v_RockClimbingIndoor_g07_c02.avi
RockClimbingIndoor/v_RockClimbingIndoor_g07_c03.avi
RockClimbingIndoor/v_RockClimbingIndoor_g07_c04.avi
RockClimbingIndoor/v_RockClimbingIndoor_g07_c05.avi
RockClimbingIndoor/v_RockClimbingIndoor_g07_c06.avi
RockClimbingIndoor/v_RockClimbingIndoor_g07_c07.avi
RopeClimbing/v_RopeClimbing_g01_c01.avi
RopeClimbing/v_RopeClimbing_g01_c02.avi
RopeClimbing/v_RopeClimbing_g01_c03.avi
RopeClimbing/v_RopeClimbing_g01_c04.avi
RopeClimbing/v_RopeClimbing_g02_c01.avi
RopeClimbing/v_RopeClimbing_g02_c02.avi
RopeClimbing/v_RopeClimbing_g02_c03.avi
RopeClimbing/v_RopeClimbing_g02_c04.avi
RopeClimbing/v_RopeClimbing_g02_c05.avi
RopeClimbing/v_RopeClimbing_g02_c06.avi
RopeClimbing/v_RopeClimbing_g03_c01.avi
RopeClimbing/v_RopeClimbing_g03_c02.avi
RopeClimbing/v_RopeClimbing_g03_c03.avi
RopeClimbing/v_RopeClimbing_g03_c04.avi
RopeClimbing/v_RopeClimbing_g04_c01.avi
RopeClimbing/v_RopeClimbing_g04_c02.avi
RopeClimbing/v_RopeClimbing_g04_c03.avi
RopeClimbing/v_RopeClimbing_g04_c04.avi
RopeClimbing/v_RopeClimbing_g05_c01.avi
RopeClimbing/v_RopeClimbing_g05_c02.avi
RopeClimbing/v_RopeClimbing_g05_c03.avi
RopeClimbing/v_RopeClimbing_g05_c04.avi
RopeClimbing/v_RopeClimbing_g05_c05.avi
RopeClimbing/v_RopeClimbing_g05_c06.avi
RopeClimbing/v_RopeClimbing_g05_c07.avi
RopeClimbing/v_RopeClimbing_g06_c01.avi
RopeClimbing/v_RopeClimbing_g06_c02.avi
RopeClimbing/v_RopeClimbing_g06_c03.avi
RopeClimbing/v_RopeClimbing_g06_c04.avi
RopeClimbing/v_RopeClimbing_g07_c01.avi
RopeClimbing/v_RopeClimbing_g07_c02.avi
RopeClimbing/v_RopeClimbing_g07_c03.avi
RopeClimbing/v_RopeClimbing_g07_c04.avi
RopeClimbing/v_RopeClimbing_g07_c05.avi
Rowing/v_Rowing_g01_c01.avi
Rowing/v_Rowing_g01_c02.avi
Rowing/v_Rowing_g01_c03.avi
Rowing/v_Rowing_g01_c04.avi
Rowing/v_Rowing_g02_c01.avi
Rowing/v_Rowing_g02_c02.avi
Rowing/v_Rowing_g02_c03.avi
Rowing/v_Rowing_g02_c04.avi
Rowing/v_Rowing_g02_c05.avi
Rowing/v_Rowing_g02_c06.avi
Rowing/v_Rowing_g03_c01.avi
Rowing/v_Rowing_g03_c02.avi
Rowing/v_Rowing_g03_c03.avi
Rowing/v_Rowing_g03_c04.avi
Rowing/v_Rowing_g03_c05.avi
Rowing/v_Rowing_g03_c06.avi
Rowing/v_Rowing_g03_c07.avi
Rowing/v_Rowing_g04_c01.avi
Rowing/v_Rowing_g04_c02.avi
Rowing/v_Rowing_g04_c03.avi
Rowing/v_Rowing_g04_c04.avi
Rowing/v_Rowing_g04_c05.avi
Rowing/v_Rowing_g04_c06.avi
Rowing/v_Rowing_g05_c01.avi
Rowing/v_Rowing_g05_c02.avi
Rowing/v_Rowing_g05_c03.avi
Rowing/v_Rowing_g05_c04.avi
Rowing/v_Rowing_g06_c01.avi
Rowing/v_Rowing_g06_c02.avi
Rowing/v_Rowing_g06_c03.avi
Rowing/v_Rowing_g06_c04.avi
Rowing/v_Rowing_g07_c01.avi
Rowing/v_Rowing_g07_c02.avi
Rowing/v_Rowing_g07_c03.avi
Rowing/v_Rowing_g07_c04.avi
Rowing/v_Rowing_g07_c05.avi
SalsaSpin/v_SalsaSpin_g01_c01.avi
SalsaSpin/v_SalsaSpin_g01_c02.avi
SalsaSpin/v_SalsaSpin_g01_c03.avi
SalsaSpin/v_SalsaSpin_g01_c04.avi
SalsaSpin/v_SalsaSpin_g01_c05.avi
SalsaSpin/v_SalsaSpin_g01_c06.avi
SalsaSpin/v_SalsaSpin_g01_c07.avi
SalsaSpin/v_SalsaSpin_g02_c01.avi
SalsaSpin/v_SalsaSpin_g02_c02.avi
SalsaSpin/v_SalsaSpin_g02_c03.avi
SalsaSpin/v_SalsaSpin_g02_c04.avi
SalsaSpin/v_SalsaSpin_g02_c05.avi
SalsaSpin/v_SalsaSpin_g02_c06.avi
SalsaSpin/v_SalsaSpin_g02_c07.avi
SalsaSpin/v_SalsaSpin_g03_c01.avi
SalsaSpin/v_SalsaSpin_g03_c02.avi
SalsaSpin/v_SalsaSpin_g03_c03.avi
SalsaSpin/v_SalsaSpin_g03_c04.avi
SalsaSpin/v_SalsaSpin_g03_c05.avi
SalsaSpin/v_SalsaSpin_g03_c06.avi
SalsaSpin/v_SalsaSpin_g04_c01.avi
SalsaSpin/v_SalsaSpin_g04_c02.avi
SalsaSpin/v_SalsaSpin_g04_c03.avi
SalsaSpin/v_SalsaSpin_g04_c04.avi
SalsaSpin/v_SalsaSpin_g04_c05.avi
SalsaSpin/v_SalsaSpin_g04_c06.avi
SalsaSpin/v_SalsaSpin_g05_c01.avi
SalsaSpin/v_SalsaSpin_g05_c02.avi
SalsaSpin/v_SalsaSpin_g05_c03.avi
SalsaSpin/v_SalsaSpin_g05_c04.avi
SalsaSpin/v_SalsaSpin_g05_c05.avi
SalsaSpin/v_SalsaSpin_g05_c06.avi
SalsaSpin/v_SalsaSpin_g06_c01.avi
SalsaSpin/v_SalsaSpin_g06_c02.avi
SalsaSpin/v_SalsaSpin_g06_c03.avi
SalsaSpin/v_SalsaSpin_g06_c04.avi
SalsaSpin/v_SalsaSpin_g06_c05.avi
SalsaSpin/v_SalsaSpin_g07_c01.avi
SalsaSpin/v_SalsaSpin_g07_c02.avi
SalsaSpin/v_SalsaSpin_g07_c03.avi
SalsaSpin/v_SalsaSpin_g07_c04.avi
SalsaSpin/v_SalsaSpin_g07_c05.avi
SalsaSpin/v_SalsaSpin_g07_c06.avi
ShavingBeard/v_ShavingBeard_g01_c01.avi
ShavingBeard/v_ShavingBeard_g01_c02.avi
ShavingBeard/v_ShavingBeard_g01_c03.avi
ShavingBeard/v_ShavingBeard_g01_c04.avi
ShavingBeard/v_ShavingBeard_g02_c01.avi
ShavingBeard/v_ShavingBeard_g02_c02.avi
ShavingBeard/v_ShavingBeard_g02_c03.avi
ShavingBeard/v_ShavingBeard_g02_c04.avi
ShavingBeard/v_ShavingBeard_g02_c05.avi
ShavingBeard/v_ShavingBeard_g02_c06.avi
ShavingBeard/v_ShavingBeard_g02_c07.avi
ShavingBeard/v_ShavingBeard_g03_c01.avi
ShavingBeard/v_ShavingBeard_g03_c02.avi
ShavingBeard/v_ShavingBeard_g03_c03.avi
ShavingBeard/v_ShavingBeard_g03_c04.avi
ShavingBeard/v_ShavingBeard_g03_c05.avi
ShavingBeard/v_ShavingBeard_g03_c06.avi
ShavingBeard/v_ShavingBeard_g03_c07.avi
ShavingBeard/v_ShavingBeard_g04_c01.avi
ShavingBeard/v_ShavingBeard_g04_c02.avi
ShavingBeard/v_ShavingBeard_g04_c03.avi
ShavingBeard/v_ShavingBeard_g04_c04.avi
ShavingBeard/v_ShavingBeard_g05_c01.avi
ShavingBeard/v_ShavingBeard_g05_c02.avi
ShavingBeard/v_ShavingBeard_g05_c03.avi
ShavingBeard/v_ShavingBeard_g05_c04.avi
ShavingBeard/v_ShavingBeard_g05_c05.avi
ShavingBeard/v_ShavingBeard_g05_c06.avi
ShavingBeard/v_ShavingBeard_g05_c07.avi
ShavingBeard/v_ShavingBeard_g06_c01.avi
ShavingBeard/v_ShavingBeard_g06_c02.avi
ShavingBeard/v_ShavingBeard_g06_c03.avi
ShavingBeard/v_ShavingBeard_g06_c04.avi
ShavingBeard/v_ShavingBeard_g06_c05.avi
ShavingBeard/v_ShavingBeard_g06_c06.avi
ShavingBeard/v_ShavingBeard_g06_c07.avi
ShavingBeard/v_ShavingBeard_g07_c01.avi
ShavingBeard/v_ShavingBeard_g07_c02.avi
ShavingBeard/v_ShavingBeard_g07_c03.avi
ShavingBeard/v_ShavingBeard_g07_c04.avi
ShavingBeard/v_ShavingBeard_g07_c05.avi
ShavingBeard/v_ShavingBeard_g07_c06.avi
ShavingBeard/v_ShavingBeard_g07_c07.avi
Shotput/v_Shotput_g01_c01.avi
Shotput/v_Shotput_g01_c02.avi
Shotput/v_Shotput_g01_c03.avi
Shotput/v_Shotput_g01_c04.avi
Shotput/v_Shotput_g01_c05.avi
Shotput/v_Shotput_g01_c06.avi
Shotput/v_Shotput_g01_c07.avi
Shotput/v_Shotput_g02_c01.avi
Shotput/v_Shotput_g02_c02.avi
Shotput/v_Shotput_g02_c03.avi
Shotput/v_Shotput_g02_c04.avi
Shotput/v_Shotput_g02_c05.avi
Shotput/v_Shotput_g02_c06.avi
Shotput/v_Shotput_g02_c07.avi
Shotput/v_Shotput_g03_c01.avi
Shotput/v_Shotput_g03_c02.avi
Shotput/v_Shotput_g03_c03.avi
Shotput/v_Shotput_g03_c04.avi
Shotput/v_Shotput_g03_c05.avi
Shotput/v_Shotput_g03_c06.avi
Shotput/v_Shotput_g04_c01.avi
Shotput/v_Shotput_g04_c02.avi
Shotput/v_Shotput_g04_c03.avi
Shotput/v_Shotput_g04_c04.avi
Shotput/v_Shotput_g04_c05.avi
Shotput/v_Shotput_g05_c01.avi
Shotput/v_Shotput_g05_c02.avi
Shotput/v_Shotput_g05_c03.avi
Shotput/v_Shotput_g05_c04.avi
Shotput/v_Shotput_g05_c05.avi
Shotput/v_Shotput_g05_c06.avi
Shotput/v_Shotput_g05_c07.avi
Shotput/v_Shotput_g06_c01.avi
Shotput/v_Shotput_g06_c02.avi
Shotput/v_Shotput_g06_c03.avi
Shotput/v_Shotput_g06_c04.avi
Shotput/v_Shotput_g06_c05.avi
Shotput/v_Shotput_g06_c06.avi
Shotput/v_Shotput_g06_c07.avi
Shotput/v_Shotput_g07_c01.avi
Shotput/v_Shotput_g07_c02.avi
Shotput/v_Shotput_g07_c03.avi
Shotput/v_Shotput_g07_c04.avi
Shotput/v_Shotput_g07_c05.avi
Shotput/v_Shotput_g07_c06.avi
Shotput/v_Shotput_g07_c07.avi
SkateBoarding/v_SkateBoarding_g01_c01.avi
SkateBoarding/v_SkateBoarding_g01_c02.avi
SkateBoarding/v_SkateBoarding_g01_c03.avi
SkateBoarding/v_SkateBoarding_g01_c04.avi
SkateBoarding/v_SkateBoarding_g02_c01.avi
SkateBoarding/v_SkateBoarding_g02_c02.avi
SkateBoarding/v_SkateBoarding_g02_c03.avi
SkateBoarding/v_SkateBoarding_g02_c04.avi
SkateBoarding/v_SkateBoarding_g02_c05.avi
SkateBoarding/v_SkateBoarding_g02_c06.avi
SkateBoarding/v_SkateBoarding_g03_c01.avi
SkateBoarding/v_SkateBoarding_g03_c02.avi
SkateBoarding/v_SkateBoarding_g03_c03.avi
SkateBoarding/v_SkateBoarding_g03_c04.avi
SkateBoarding/v_SkateBoarding_g04_c01.avi
SkateBoarding/v_SkateBoarding_g04_c02.avi
SkateBoarding/v_SkateBoarding_g04_c03.avi
SkateBoarding/v_SkateBoarding_g04_c04.avi
SkateBoarding/v_SkateBoarding_g04_c05.avi
SkateBoarding/v_SkateBoarding_g05_c01.avi
SkateBoarding/v_SkateBoarding_g05_c02.avi
SkateBoarding/v_SkateBoarding_g05_c03.avi
SkateBoarding/v_SkateBoarding_g05_c04.avi
SkateBoarding/v_SkateBoarding_g06_c01.avi
SkateBoarding/v_SkateBoarding_g06_c02.avi
SkateBoarding/v_SkateBoarding_g06_c03.avi
SkateBoarding/v_SkateBoarding_g06_c04.avi
SkateBoarding/v_SkateBoarding_g07_c01.avi
SkateBoarding/v_SkateBoarding_g07_c02.avi
SkateBoarding/v_SkateBoarding_g07_c03.avi
SkateBoarding/v_SkateBoarding_g07_c04.avi
SkateBoarding/v_SkateBoarding_g07_c05.avi
Skiing/v_Skiing_g01_c01.avi
Skiing/v_Skiing_g01_c02.avi
Skiing/v_Skiing_g01_c03.avi
Skiing/v_Skiing_g01_c04.avi
Skiing/v_Skiing_g01_c05.avi
Skiing/v_Skiing_g01_c06.avi
Skiing/v_Skiing_g02_c01.avi
Skiing/v_Skiing_g02_c02.avi
Skiing/v_Skiing_g02_c03.avi
Skiing/v_Skiing_g02_c04.avi
Skiing/v_Skiing_g02_c05.avi
Skiing/v_Skiing_g03_c01.avi
Skiing/v_Skiing_g03_c02.avi
Skiing/v_Skiing_g03_c03.avi
Skiing/v_Skiing_g03_c04.avi
Skiing/v_Skiing_g03_c05.avi
Skiing/v_Skiing_g03_c06.avi
Skiing/v_Skiing_g03_c07.avi
Skiing/v_Skiing_g04_c01.avi
Skiing/v_Skiing_g04_c02.avi
Skiing/v_Skiing_g04_c03.avi
Skiing/v_Skiing_g04_c04.avi
Skiing/v_Skiing_g04_c05.avi
Skiing/v_Skiing_g04_c06.avi
Skiing/v_Skiing_g04_c07.avi
Skiing/v_Skiing_g05_c01.avi
Skiing/v_Skiing_g05_c02.avi
Skiing/v_Skiing_g05_c03.avi
Skiing/v_Skiing_g05_c04.avi
Skiing/v_Skiing_g06_c01.avi
Skiing/v_Skiing_g06_c02.avi
Skiing/v_Skiing_g06_c03.avi
Skiing/v_Skiing_g06_c04.avi
Skiing/v_Skiing_g06_c05.avi
Skiing/v_Skiing_g06_c06.avi
Skiing/v_Skiing_g06_c07.avi
Skiing/v_Skiing_g07_c01.avi
Skiing/v_Skiing_g07_c02.avi
Skiing/v_Skiing_g07_c03.avi
Skiing/v_Skiing_g07_c04.avi
Skijet/v_Skijet_g01_c01.avi
Skijet/v_Skijet_g01_c02.avi
Skijet/v_Skijet_g01_c03.avi
Skijet/v_Skijet_g01_c04.avi
Skijet/v_Skijet_g02_c01.avi
Skijet/v_Skijet_g02_c02.avi
Skijet/v_Skijet_g02_c03.avi
Skijet/v_Skijet_g02_c04.avi
Skijet/v_Skijet_g03_c01.avi
Skijet/v_Skijet_g03_c02.avi
Skijet/v_Skijet_g03_c03.avi
Skijet/v_Skijet_g03_c04.avi
Skijet/v_Skijet_g04_c01.avi
Skijet/v_Skijet_g04_c02.avi
Skijet/v_Skijet_g04_c03.avi
Skijet/v_Skijet_g04_c04.avi
Skijet/v_Skijet_g05_c01.avi
Skijet/v_Skijet_g05_c02.avi
Skijet/v_Skijet_g05_c03.avi
Skijet/v_Skijet_g05_c04.avi
Skijet/v_Skijet_g06_c01.avi
Skijet/v_Skijet_g06_c02.avi
Skijet/v_Skijet_g06_c03.avi
Skijet/v_Skijet_g06_c04.avi
Skijet/v_Skijet_g07_c01.avi
Skijet/v_Skijet_g07_c02.avi
Skijet/v_Skijet_g07_c03.avi
Skijet/v_Skijet_g07_c04.avi
SkyDiving/v_SkyDiving_g01_c01.avi
SkyDiving/v_SkyDiving_g01_c02.avi
SkyDiving/v_SkyDiving_g01_c03.avi
SkyDiving/v_SkyDiving_g01_c04.avi
SkyDiving/v_SkyDiving_g02_c01.avi
SkyDiving/v_SkyDiving_g02_c02.avi
SkyDiving/v_SkyDiving_g02_c03.avi
SkyDiving/v_SkyDiving_g02_c04.avi
SkyDiving/v_SkyDiving_g03_c01.avi
SkyDiving/v_SkyDiving_g03_c02.avi
SkyDiving/v_SkyDiving_g03_c03.avi
SkyDiving/v_SkyDiving_g03_c04.avi
SkyDiving/v_SkyDiving_g03_c05.avi
SkyDiving/v_SkyDiving_g04_c01.avi
SkyDiving/v_SkyDiving_g04_c02.avi
SkyDiving/v_SkyDiving_g04_c03.avi
SkyDiving/v_SkyDiving_g04_c04.avi
SkyDiving/v_SkyDiving_g05_c01.avi
SkyDiving/v_SkyDiving_g05_c02.avi
SkyDiving/v_SkyDiving_g05_c03.avi
SkyDiving/v_SkyDiving_g05_c04.avi
SkyDiving/v_SkyDiving_g05_c05.avi
SkyDiving/v_SkyDiving_g06_c01.avi
SkyDiving/v_SkyDiving_g06_c02.avi
SkyDiving/v_SkyDiving_g06_c03.avi
SkyDiving/v_SkyDiving_g06_c04.avi
SkyDiving/v_SkyDiving_g07_c01.avi
SkyDiving/v_SkyDiving_g07_c02.avi
SkyDiving/v_SkyDiving_g07_c03.avi
SkyDiving/v_SkyDiving_g07_c04.avi
SkyDiving/v_SkyDiving_g07_c05.avi
SoccerJuggling/v_SoccerJuggling_g01_c01.avi
SoccerJuggling/v_SoccerJuggling_g01_c02.avi
SoccerJuggling/v_SoccerJuggling_g01_c03.avi
SoccerJuggling/v_SoccerJuggling_g01_c04.avi
SoccerJuggling/v_SoccerJuggling_g01_c05.avi
SoccerJuggling/v_SoccerJuggling_g02_c01.avi
SoccerJuggling/v_SoccerJuggling_g02_c02.avi
SoccerJuggling/v_SoccerJuggling_g02_c03.avi
SoccerJuggling/v_SoccerJuggling_g02_c04.avi
SoccerJuggling/v_SoccerJuggling_g02_c05.avi
SoccerJuggling/v_SoccerJuggling_g02_c06.avi
SoccerJuggling/v_SoccerJuggling_g03_c01.avi
SoccerJuggling/v_SoccerJuggling_g03_c02.avi
SoccerJuggling/v_SoccerJuggling_g03_c03.avi
SoccerJuggling/v_SoccerJuggling_g03_c04.avi
SoccerJuggling/v_SoccerJuggling_g04_c01.avi
SoccerJuggling/v_SoccerJuggling_g04_c02.avi
SoccerJuggling/v_SoccerJuggling_g04_c03.avi
SoccerJuggling/v_SoccerJuggling_g04_c04.avi
SoccerJuggling/v_SoccerJuggling_g04_c05.avi
SoccerJuggling/v_SoccerJuggling_g04_c06.avi
SoccerJuggling/v_SoccerJuggling_g05_c01.avi
SoccerJuggling/v_SoccerJuggling_g05_c02.avi
SoccerJuggling/v_SoccerJuggling_g05_c03.avi
SoccerJuggling/v_SoccerJuggling_g05_c04.avi
SoccerJuggling/v_SoccerJuggling_g05_c05.avi
SoccerJuggling/v_SoccerJuggling_g05_c06.avi
SoccerJuggling/v_SoccerJuggling_g06_c01.avi
SoccerJuggling/v_SoccerJuggling_g06_c02.avi
SoccerJuggling/v_SoccerJuggling_g06_c03.avi
SoccerJuggling/v_SoccerJuggling_g06_c04.avi
SoccerJuggling/v_SoccerJuggling_g06_c05.avi
SoccerJuggling/v_SoccerJuggling_g07_c01.avi
SoccerJuggling/v_SoccerJuggling_g07_c02.avi
SoccerJuggling/v_SoccerJuggling_g07_c03.avi
SoccerJuggling/v_SoccerJuggling_g07_c04.avi
SoccerJuggling/v_SoccerJuggling_g07_c05.avi
SoccerJuggling/v_SoccerJuggling_g07_c06.avi
SoccerJuggling/v_SoccerJuggling_g07_c07.avi
SoccerPenalty/v_SoccerPenalty_g01_c01.avi
SoccerPenalty/v_SoccerPenalty_g01_c02.avi
SoccerPenalty/v_SoccerPenalty_g01_c03.avi
SoccerPenalty/v_SoccerPenalty_g01_c04.avi
SoccerPenalty/v_SoccerPenalty_g01_c05.avi
SoccerPenalty/v_SoccerPenalty_g01_c06.avi
SoccerPenalty/v_SoccerPenalty_g02_c01.avi
SoccerPenalty/v_SoccerPenalty_g02_c02.avi
SoccerPenalty/v_SoccerPenalty_g02_c03.avi
SoccerPenalty/v_SoccerPenalty_g02_c04.avi
SoccerPenalty/v_SoccerPenalty_g02_c05.avi
SoccerPenalty/v_SoccerPenalty_g03_c01.avi
SoccerPenalty/v_SoccerPenalty_g03_c02.avi
SoccerPenalty/v_SoccerPenalty_g03_c03.avi
SoccerPenalty/v_SoccerPenalty_g03_c04.avi
SoccerPenalty/v_SoccerPenalty_g03_c05.avi
SoccerPenalty/v_SoccerPenalty_g04_c01.avi
SoccerPenalty/v_SoccerPenalty_g04_c02.avi
SoccerPenalty/v_SoccerPenalty_g04_c03.avi
SoccerPenalty/v_SoccerPenalty_g04_c04.avi
SoccerPenalty/v_SoccerPenalty_g04_c05.avi
SoccerPenalty/v_SoccerPenalty_g05_c01.avi
SoccerPenalty/v_SoccerPenalty_g05_c02.avi
SoccerPenalty/v_SoccerPenalty_g05_c03.avi
SoccerPenalty/v_SoccerPenalty_g05_c04.avi
SoccerPenalty/v_SoccerPenalty_g05_c05.avi
SoccerPenalty/v_SoccerPenalty_g05_c06.avi
SoccerPenalty/v_SoccerPenalty_g05_c07.avi
SoccerPenalty/v_SoccerPenalty_g06_c01.avi
SoccerPenalty/v_SoccerPenalty_g06_c02.avi
SoccerPenalty/v_SoccerPenalty_g06_c03.avi
SoccerPenalty/v_SoccerPenalty_g06_c04.avi
SoccerPenalty/v_SoccerPenalty_g06_c05.avi
SoccerPenalty/v_SoccerPenalty_g06_c06.avi
SoccerPenalty/v_SoccerPenalty_g06_c07.avi
SoccerPenalty/v_SoccerPenalty_g07_c01.avi
SoccerPenalty/v_SoccerPenalty_g07_c02.avi
SoccerPenalty/v_SoccerPenalty_g07_c03.avi
SoccerPenalty/v_SoccerPenalty_g07_c04.avi
SoccerPenalty/v_SoccerPenalty_g07_c05.avi
SoccerPenalty/v_SoccerPenalty_g07_c06.avi
StillRings/v_StillRings_g01_c01.avi
StillRings/v_StillRings_g01_c02.avi
StillRings/v_StillRings_g01_c03.avi
StillRings/v_StillRings_g01_c04.avi
StillRings/v_StillRings_g01_c05.avi
StillRings/v_StillRings_g02_c01.avi
StillRings/v_StillRings_g02_c02.avi
StillRings/v_StillRings_g02_c03.avi
StillRings/v_StillRings_g02_c04.avi
StillRings/v_StillRings_g03_c01.avi
StillRings/v_StillRings_g03_c02.avi
StillRings/v_StillRings_g03_c03.avi
StillRings/v_StillRings_g03_c04.avi
StillRings/v_StillRings_g03_c05.avi
StillRings/v_StillRings_g03_c06.avi
StillRings/v_StillRings_g03_c07.avi
StillRings/v_StillRings_g04_c01.avi
StillRings/v_StillRings_g04_c02.avi
StillRings/v_StillRings_g04_c03.avi
StillRings/v_StillRings_g04_c04.avi
StillRings/v_StillRings_g05_c01.avi
StillRings/v_StillRings_g05_c02.avi
StillRings/v_StillRings_g05_c03.avi
StillRings/v_StillRings_g05_c04.avi
StillRings/v_StillRings_g06_c01.avi
StillRings/v_StillRings_g06_c02.avi
StillRings/v_StillRings_g06_c03.avi
StillRings/v_StillRings_g06_c04.avi
StillRings/v_StillRings_g07_c01.avi
StillRings/v_StillRings_g07_c02.avi
StillRings/v_StillRings_g07_c03.avi
StillRings/v_StillRings_g07_c04.avi
SumoWrestling/v_SumoWrestling_g01_c01.avi
SumoWrestling/v_SumoWrestling_g01_c02.avi
SumoWrestling/v_SumoWrestling_g01_c03.avi
SumoWrestling/v_SumoWrestling_g01_c04.avi
SumoWrestling/v_SumoWrestling_g02_c01.avi
SumoWrestling/v_SumoWrestling_g02_c02.avi
SumoWrestling/v_SumoWrestling_g02_c03.avi
SumoWrestling/v_SumoWrestling_g02_c04.avi
SumoWrestling/v_SumoWrestling_g03_c01.avi
SumoWrestling/v_SumoWrestling_g03_c02.avi
SumoWrestling/v_SumoWrestling_g03_c03.avi
SumoWrestling/v_SumoWrestling_g03_c04.avi
SumoWrestling/v_SumoWrestling_g04_c01.avi
SumoWrestling/v_SumoWrestling_g04_c02.avi
SumoWrestling/v_SumoWrestling_g04_c03.avi
SumoWrestling/v_SumoWrestling_g04_c04.avi
SumoWrestling/v_SumoWrestling_g05_c01.avi
SumoWrestling/v_SumoWrestling_g05_c02.avi
SumoWrestling/v_SumoWrestling_g05_c03.avi
SumoWrestling/v_SumoWrestling_g05_c04.avi
SumoWrestling/v_SumoWrestling_g06_c01.avi
SumoWrestling/v_SumoWrestling_g06_c02.avi
SumoWrestling/v_SumoWrestling_g06_c03.avi
SumoWrestling/v_SumoWrestling_g06_c04.avi
SumoWrestling/v_SumoWrestling_g06_c05.avi
SumoWrestling/v_SumoWrestling_g06_c06.avi
SumoWrestling/v_SumoWrestling_g06_c07.avi
SumoWrestling/v_SumoWrestling_g07_c01.avi
SumoWrestling/v_SumoWrestling_g07_c02.avi
SumoWrestling/v_SumoWrestling_g07_c03.avi
SumoWrestling/v_SumoWrestling_g07_c04.avi
SumoWrestling/v_SumoWrestling_g07_c05.avi
SumoWrestling/v_SumoWrestling_g07_c06.avi
SumoWrestling/v_SumoWrestling_g07_c07.avi
Surfing/v_Surfing_g01_c01.avi
Surfing/v_Surfing_g01_c02.avi
Surfing/v_Surfing_g01_c03.avi
Surfing/v_Surfing_g01_c04.avi
Surfing/v_Surfing_g01_c05.avi
Surfing/v_Surfing_g01_c06.avi
Surfing/v_Surfing_g01_c07.avi
Surfing/v_Surfing_g02_c01.avi
Surfing/v_Surfing_g02_c02.avi
Surfing/v_Surfing_g02_c03.avi
Surfing/v_Surfing_g02_c04.avi
Surfing/v_Surfing_g02_c05.avi
Surfing/v_Surfing_g02_c06.avi
Surfing/v_Surfing_g03_c01.avi
Surfing/v_Surfing_g03_c02.avi
Surfing/v_Surfing_g03_c03.avi
Surfing/v_Surfing_g03_c04.avi
Surfing/v_Surfing_g04_c01.avi
Surfing/v_Surfing_g04_c02.avi
Surfing/v_Surfing_g04_c03.avi
Surfing/v_Surfing_g04_c04.avi
Surfing/v_Surfing_g05_c01.avi
Surfing/v_Surfing_g05_c02.avi
Surfing/v_Surfing_g05_c03.avi
Surfing/v_Surfing_g05_c04.avi
Surfing/v_Surfing_g06_c01.avi
Surfing/v_Surfing_g06_c02.avi
Surfing/v_Surfing_g06_c03.avi
Surfing/v_Surfing_g06_c04.avi
Surfing/v_Surfing_g07_c01.avi
Surfing/v_Surfing_g07_c02.avi
Surfing/v_Surfing_g07_c03.avi
Surfing/v_Surfing_g07_c04.avi
Swing/v_Swing_g01_c01.avi
Swing/v_Swing_g01_c02.avi
Swing/v_Swing_g01_c03.avi
Swing/v_Swing_g01_c04.avi
Swing/v_Swing_g01_c05.avi
Swing/v_Swing_g02_c01.avi
Swing/v_Swing_g02_c02.avi
Swing/v_Swing_g02_c03.avi
Swing/v_Swing_g02_c04.avi
Swing/v_Swing_g02_c05.avi
Swing/v_Swing_g03_c01.avi
Swing/v_Swing_g03_c02.avi
Swing/v_Swing_g03_c03.avi
Swing/v_Swing_g03_c04.avi
Swing/v_Swing_g04_c01.avi
Swing/v_Swing_g04_c02.avi
Swing/v_Swing_g04_c03.avi
Swing/v_Swing_g04_c04.avi
Swing/v_Swing_g04_c05.avi
Swing/v_Swing_g04_c06.avi
Swing/v_Swing_g04_c07.avi
Swing/v_Swing_g05_c01.avi
Swing/v_Swing_g05_c02.avi
Swing/v_Swing_g05_c03.avi
Swing/v_Swing_g05_c04.avi
Swing/v_Swing_g05_c05.avi
Swing/v_Swing_g05_c06.avi
Swing/v_Swing_g05_c07.avi
Swing/v_Swing_g06_c01.avi
Swing/v_Swing_g06_c02.avi
Swing/v_Swing_g06_c03.avi
Swing/v_Swing_g06_c04.avi
Swing/v_Swing_g06_c05.avi
Swing/v_Swing_g06_c06.avi
Swing/v_Swing_g06_c07.avi
Swing/v_Swing_g07_c01.avi
Swing/v_Swing_g07_c02.avi
Swing/v_Swing_g07_c03.avi
Swing/v_Swing_g07_c04.avi
Swing/v_Swing_g07_c05.avi
Swing/v_Swing_g07_c06.avi
Swing/v_Swing_g07_c07.avi
TableTennisShot/v_TableTennisShot_g01_c01.avi
TableTennisShot/v_TableTennisShot_g01_c02.avi
TableTennisShot/v_TableTennisShot_g01_c03.avi
TableTennisShot/v_TableTennisShot_g01_c04.avi
TableTennisShot/v_TableTennisShot_g01_c05.avi
TableTennisShot/v_TableTennisShot_g01_c06.avi
TableTennisShot/v_TableTennisShot_g02_c01.avi
TableTennisShot/v_TableTennisShot_g02_c02.avi
TableTennisShot/v_TableTennisShot_g02_c03.avi
TableTennisShot/v_TableTennisShot_g02_c04.avi
TableTennisShot/v_TableTennisShot_g03_c01.avi
TableTennisShot/v_TableTennisShot_g03_c02.avi
TableTennisShot/v_TableTennisShot_g03_c03.avi
TableTennisShot/v_TableTennisShot_g03_c04.avi
TableTennisShot/v_TableTennisShot_g03_c05.avi
TableTennisShot/v_TableTennisShot_g04_c01.avi
TableTennisShot/v_TableTennisShot_g04_c02.avi
TableTennisShot/v_TableTennisShot_g04_c03.avi
TableTennisShot/v_TableTennisShot_g04_c04.avi
TableTennisShot/v_TableTennisShot_g04_c05.avi
TableTennisShot/v_TableTennisShot_g04_c06.avi
TableTennisShot/v_TableTennisShot_g04_c07.avi
TableTennisShot/v_TableTennisShot_g05_c01.avi
TableTennisShot/v_TableTennisShot_g05_c02.avi
TableTennisShot/v_TableTennisShot_g05_c03.avi
TableTennisShot/v_TableTennisShot_g05_c04.avi
TableTennisShot/v_TableTennisShot_g05_c05.avi
TableTennisShot/v_TableTennisShot_g05_c06.avi
TableTennisShot/v_TableTennisShot_g05_c07.avi
TableTennisShot/v_TableTennisShot_g06_c01.avi
TableTennisShot/v_TableTennisShot_g06_c02.avi
TableTennisShot/v_TableTennisShot_g06_c03.avi
TableTennisShot/v_TableTennisShot_g06_c04.avi
TableTennisShot/v_TableTennisShot_g06_c05.avi
TableTennisShot/v_TableTennisShot_g06_c06.avi
TableTennisShot/v_TableTennisShot_g07_c01.avi
TableTennisShot/v_TableTennisShot_g07_c02.avi
TableTennisShot/v_TableTennisShot_g07_c03.avi
TableTennisShot/v_TableTennisShot_g07_c04.avi
TaiChi/v_TaiChi_g01_c01.avi
TaiChi/v_TaiChi_g01_c02.avi
TaiChi/v_TaiChi_g01_c03.avi
TaiChi/v_TaiChi_g01_c04.avi
TaiChi/v_TaiChi_g02_c01.avi
TaiChi/v_TaiChi_g02_c02.avi
TaiChi/v_TaiChi_g02_c03.avi
TaiChi/v_TaiChi_g02_c04.avi
TaiChi/v_TaiChi_g03_c01.avi
TaiChi/v_TaiChi_g03_c02.avi
TaiChi/v_TaiChi_g03_c03.avi
TaiChi/v_TaiChi_g03_c04.avi
TaiChi/v_TaiChi_g04_c01.avi
TaiChi/v_TaiChi_g04_c02.avi
TaiChi/v_TaiChi_g04_c03.avi
TaiChi/v_TaiChi_g04_c04.avi
TaiChi/v_TaiChi_g05_c01.avi
TaiChi/v_TaiChi_g05_c02.avi
TaiChi/v_TaiChi_g05_c03.avi
TaiChi/v_TaiChi_g05_c04.avi
TaiChi/v_TaiChi_g06_c01.avi
TaiChi/v_TaiChi_g06_c02.avi
TaiChi/v_TaiChi_g06_c03.avi
TaiChi/v_TaiChi_g06_c04.avi
TaiChi/v_TaiChi_g07_c01.avi
TaiChi/v_TaiChi_g07_c02.avi
TaiChi/v_TaiChi_g07_c03.avi
TaiChi/v_TaiChi_g07_c04.avi
TennisSwing/v_TennisSwing_g01_c01.avi
TennisSwing/v_TennisSwing_g01_c02.avi
TennisSwing/v_TennisSwing_g01_c03.avi
TennisSwing/v_TennisSwing_g01_c04.avi
TennisSwing/v_TennisSwing_g01_c05.avi
TennisSwing/v_TennisSwing_g01_c06.avi
TennisSwing/v_TennisSwing_g01_c07.avi
TennisSwing/v_TennisSwing_g02_c01.avi
TennisSwing/v_TennisSwing_g02_c02.avi
TennisSwing/v_TennisSwing_g02_c03.avi
TennisSwing/v_TennisSwing_g02_c04.avi
TennisSwing/v_TennisSwing_g02_c05.avi
TennisSwing/v_TennisSwing_g02_c06.avi
TennisSwing/v_TennisSwing_g02_c07.avi
TennisSwing/v_TennisSwing_g03_c01.avi
TennisSwing/v_TennisSwing_g03_c02.avi
TennisSwing/v_TennisSwing_g03_c03.avi
TennisSwing/v_TennisSwing_g03_c04.avi
TennisSwing/v_TennisSwing_g03_c05.avi
TennisSwing/v_TennisSwing_g03_c06.avi
TennisSwing/v_TennisSwing_g03_c07.avi
TennisSwing/v_TennisSwing_g04_c01.avi
TennisSwing/v_TennisSwing_g04_c02.avi
TennisSwing/v_TennisSwing_g04_c03.avi
TennisSwing/v_TennisSwing_g04_c04.avi
TennisSwing/v_TennisSwing_g04_c05.avi
TennisSwing/v_TennisSwing_g04_c06.avi
TennisSwing/v_TennisSwing_g04_c07.avi
TennisSwing/v_TennisSwing_g05_c01.avi
TennisSwing/v_TennisSwing_g05_c02.avi
TennisSwing/v_TennisSwing_g05_c03.avi
TennisSwing/v_TennisSwing_g05_c04.avi
TennisSwing/v_TennisSwing_g05_c05.avi
TennisSwing/v_TennisSwing_g05_c06.avi
TennisSwing/v_TennisSwing_g05_c07.avi
TennisSwing/v_TennisSwing_g06_c01.avi
TennisSwing/v_TennisSwing_g06_c02.avi
TennisSwing/v_TennisSwing_g06_c03.avi
TennisSwing/v_TennisSwing_g06_c04.avi
TennisSwing/v_TennisSwing_g06_c05.avi
TennisSwing/v_TennisSwing_g06_c06.avi
TennisSwing/v_TennisSwing_g06_c07.avi
TennisSwing/v_TennisSwing_g07_c01.avi
TennisSwing/v_TennisSwing_g07_c02.avi
TennisSwing/v_TennisSwing_g07_c03.avi
TennisSwing/v_TennisSwing_g07_c04.avi
TennisSwing/v_TennisSwing_g07_c05.avi
TennisSwing/v_TennisSwing_g07_c06.avi
TennisSwing/v_TennisSwing_g07_c07.avi
ThrowDiscus/v_ThrowDiscus_g01_c01.avi
ThrowDiscus/v_ThrowDiscus_g01_c02.avi
ThrowDiscus/v_ThrowDiscus_g01_c03.avi
ThrowDiscus/v_ThrowDiscus_g01_c04.avi
ThrowDiscus/v_ThrowDiscus_g02_c01.avi
ThrowDiscus/v_ThrowDiscus_g02_c02.avi
ThrowDiscus/v_ThrowDiscus_g02_c03.avi
ThrowDiscus/v_ThrowDiscus_g02_c04.avi
ThrowDiscus/v_ThrowDiscus_g02_c05.avi
ThrowDiscus/v_ThrowDiscus_g02_c06.avi
ThrowDiscus/v_ThrowDiscus_g02_c07.avi
ThrowDiscus/v_ThrowDiscus_g03_c01.avi
ThrowDiscus/v_ThrowDiscus_g03_c02.avi
ThrowDiscus/v_ThrowDiscus_g03_c03.avi
ThrowDiscus/v_ThrowDiscus_g03_c04.avi
ThrowDiscus/v_ThrowDiscus_g04_c01.avi
ThrowDiscus/v_ThrowDiscus_g04_c02.avi
ThrowDiscus/v_ThrowDiscus_g04_c03.avi
ThrowDiscus/v_ThrowDiscus_g04_c04.avi
ThrowDiscus/v_ThrowDiscus_g05_c01.avi
ThrowDiscus/v_ThrowDiscus_g05_c02.avi
ThrowDiscus/v_ThrowDiscus_g05_c03.avi
ThrowDiscus/v_ThrowDiscus_g05_c04.avi
ThrowDiscus/v_ThrowDiscus_g05_c05.avi
ThrowDiscus/v_ThrowDiscus_g06_c01.avi
ThrowDiscus/v_ThrowDiscus_g06_c02.avi
ThrowDiscus/v_ThrowDiscus_g06_c03.avi
ThrowDiscus/v_ThrowDiscus_g06_c04.avi
ThrowDiscus/v_ThrowDiscus_g06_c05.avi
ThrowDiscus/v_ThrowDiscus_g06_c06.avi
ThrowDiscus/v_ThrowDiscus_g06_c07.avi
ThrowDiscus/v_ThrowDiscus_g07_c01.avi
ThrowDiscus/v_ThrowDiscus_g07_c02.avi
ThrowDiscus/v_ThrowDiscus_g07_c03.avi
ThrowDiscus/v_ThrowDiscus_g07_c04.avi
ThrowDiscus/v_ThrowDiscus_g07_c05.avi
ThrowDiscus/v_ThrowDiscus_g07_c06.avi
ThrowDiscus/v_ThrowDiscus_g07_c07.avi
TrampolineJumping/v_TrampolineJumping_g01_c01.avi
TrampolineJumping/v_TrampolineJumping_g01_c02.avi
TrampolineJumping/v_TrampolineJumping_g01_c03.avi
TrampolineJumping/v_TrampolineJumping_g01_c04.avi
TrampolineJumping/v_TrampolineJumping_g02_c01.avi
TrampolineJumping/v_TrampolineJumping_g02_c02.avi
TrampolineJumping/v_TrampolineJumping_g02_c03.avi
TrampolineJumping/v_TrampolineJumping_g02_c04.avi
TrampolineJumping/v_TrampolineJumping_g02_c05.avi
TrampolineJumping/v_TrampolineJumping_g02_c06.avi
TrampolineJumping/v_TrampolineJumping_g03_c01.avi
TrampolineJumping/v_TrampolineJumping_g03_c02.avi
TrampolineJumping/v_TrampolineJumping_g03_c03.avi
TrampolineJumping/v_TrampolineJumping_g03_c04.avi
TrampolineJumping/v_TrampolineJumping_g04_c01.avi
TrampolineJumping/v_TrampolineJumping_g04_c02.avi
TrampolineJumping/v_TrampolineJumping_g04_c03.avi
TrampolineJumping/v_TrampolineJumping_g04_c04.avi
TrampolineJumping/v_TrampolineJumping_g04_c05.avi
TrampolineJumping/v_TrampolineJumping_g05_c01.avi
TrampolineJumping/v_TrampolineJumping_g05_c02.avi
TrampolineJumping/v_TrampolineJumping_g05_c03.avi
TrampolineJumping/v_TrampolineJumping_g05_c04.avi
TrampolineJumping/v_TrampolineJumping_g06_c01.avi
TrampolineJumping/v_TrampolineJumping_g06_c02.avi
TrampolineJumping/v_TrampolineJumping_g06_c03.avi
TrampolineJumping/v_TrampolineJumping_g06_c04.avi
TrampolineJumping/v_TrampolineJumping_g07_c01.avi
TrampolineJumping/v_TrampolineJumping_g07_c02.avi
TrampolineJumping/v_TrampolineJumping_g07_c03.avi
TrampolineJumping/v_TrampolineJumping_g07_c04.avi
TrampolineJumping/v_TrampolineJumping_g07_c05.avi
Typing/v_Typing_g01_c01.avi
Typing/v_Typing_g01_c02.avi
Typing/v_Typing_g01_c03.avi
Typing/v_Typing_g01_c04.avi
Typing/v_Typing_g01_c05.avi
Typing/v_Typing_g01_c06.avi
Typing/v_Typing_g01_c07.avi
Typing/v_Typing_g02_c01.avi
Typing/v_Typing_g02_c02.avi
Typing/v_Typing_g02_c03.avi
Typing/v_Typing_g02_c04.avi
Typing/v_Typing_g02_c05.avi
Typing/v_Typing_g02_c06.avi
Typing/v_Typing_g03_c01.avi
Typing/v_Typing_g03_c02.avi
Typing/v_Typing_g03_c03.avi
Typing/v_Typing_g03_c04.avi
Typing/v_Typing_g03_c05.avi
Typing/v_Typing_g03_c06.avi
Typing/v_Typing_g03_c07.avi
Typing/v_Typing_g04_c01.avi
Typing/v_Typing_g04_c02.avi
Typing/v_Typing_g04_c03.avi
Typing/v_Typing_g04_c04.avi
Typing/v_Typing_g05_c01.avi
Typing/v_Typing_g05_c02.avi
Typing/v_Typing_g05_c03.avi
Typing/v_Typing_g05_c04.avi
Typing/v_Typing_g05_c05.avi
Typing/v_Typing_g05_c06.avi
Typing/v_Typing_g06_c01.avi
Typing/v_Typing_g06_c02.avi
Typing/v_Typing_g06_c03.avi
Typing/v_Typing_g06_c04.avi
Typing/v_Typing_g06_c05.avi
Typing/v_Typing_g06_c06.avi
Typing/v_Typing_g06_c07.avi
Typing/v_Typing_g07_c01.avi
Typing/v_Typing_g07_c02.avi
Typing/v_Typing_g07_c03.avi
Typing/v_Typing_g07_c04.avi
Typing/v_Typing_g07_c05.avi
Typing/v_Typing_g07_c06.avi
UnevenBars/v_UnevenBars_g01_c01.avi
UnevenBars/v_UnevenBars_g01_c02.avi
UnevenBars/v_UnevenBars_g01_c03.avi
UnevenBars/v_UnevenBars_g01_c04.avi
UnevenBars/v_UnevenBars_g02_c01.avi
UnevenBars/v_UnevenBars_g02_c02.avi
UnevenBars/v_UnevenBars_g02_c03.avi
UnevenBars/v_UnevenBars_g02_c04.avi
UnevenBars/v_UnevenBars_g03_c01.avi
UnevenBars/v_UnevenBars_g03_c02.avi
UnevenBars/v_UnevenBars_g03_c03.avi
UnevenBars/v_UnevenBars_g03_c04.avi
UnevenBars/v_UnevenBars_g04_c01.avi
UnevenBars/v_UnevenBars_g04_c02.avi
UnevenBars/v_UnevenBars_g04_c03.avi
UnevenBars/v_UnevenBars_g04_c04.avi
UnevenBars/v_UnevenBars_g05_c01.avi
UnevenBars/v_UnevenBars_g05_c02.avi
UnevenBars/v_UnevenBars_g05_c03.avi
UnevenBars/v_UnevenBars_g05_c04.avi
UnevenBars/v_UnevenBars_g06_c01.avi
UnevenBars/v_UnevenBars_g06_c02.avi
UnevenBars/v_UnevenBars_g06_c03.avi
UnevenBars/v_UnevenBars_g06_c04.avi
UnevenBars/v_UnevenBars_g07_c01.avi
UnevenBars/v_UnevenBars_g07_c02.avi
UnevenBars/v_UnevenBars_g07_c03.avi
UnevenBars/v_UnevenBars_g07_c04.avi
VolleyballSpiking/v_VolleyballSpiking_g01_c01.avi
VolleyballSpiking/v_VolleyballSpiking_g01_c02.avi
VolleyballSpiking/v_VolleyballSpiking_g01_c03.avi
VolleyballSpiking/v_VolleyballSpiking_g01_c04.avi
VolleyballSpiking/v_VolleyballSpiking_g02_c01.avi
VolleyballSpiking/v_VolleyballSpiking_g02_c02.avi
VolleyballSpiking/v_VolleyballSpiking_g02_c03.avi
VolleyballSpiking/v_VolleyballSpiking_g02_c04.avi
VolleyballSpiking/v_VolleyballSpiking_g03_c01.avi
VolleyballSpiking/v_VolleyballSpiking_g03_c02.avi
VolleyballSpiking/v_VolleyballSpiking_g03_c03.avi
VolleyballSpiking/v_VolleyballSpiking_g03_c04.avi
VolleyballSpiking/v_VolleyballSpiking_g04_c01.avi
VolleyballSpiking/v_VolleyballSpiking_g04_c02.avi
VolleyballSpiking/v_VolleyballSpiking_g04_c03.avi
VolleyballSpiking/v_VolleyballSpiking_g04_c04.avi
VolleyballSpiking/v_VolleyballSpiking_g04_c05.avi
VolleyballSpiking/v_VolleyballSpiking_g04_c06.avi
VolleyballSpiking/v_VolleyballSpiking_g04_c07.avi
VolleyballSpiking/v_VolleyballSpiking_g05_c01.avi
VolleyballSpiking/v_VolleyballSpiking_g05_c02.avi
VolleyballSpiking/v_VolleyballSpiking_g05_c03.avi
VolleyballSpiking/v_VolleyballSpiking_g05_c04.avi
VolleyballSpiking/v_VolleyballSpiking_g05_c05.avi
VolleyballSpiking/v_VolleyballSpiking_g06_c01.avi
VolleyballSpiking/v_VolleyballSpiking_g06_c02.avi
VolleyballSpiking/v_VolleyballSpiking_g06_c03.avi
VolleyballSpiking/v_VolleyballSpiking_g06_c04.avi
VolleyballSpiking/v_VolleyballSpiking_g07_c01.avi
VolleyballSpiking/v_VolleyballSpiking_g07_c02.avi
VolleyballSpiking/v_VolleyballSpiking_g07_c03.avi
VolleyballSpiking/v_VolleyballSpiking_g07_c04.avi
VolleyballSpiking/v_VolleyballSpiking_g07_c05.avi
VolleyballSpiking/v_VolleyballSpiking_g07_c06.avi
VolleyballSpiking/v_VolleyballSpiking_g07_c07.avi
WalkingWithDog/v_WalkingWithDog_g01_c01.avi
WalkingWithDog/v_WalkingWithDog_g01_c02.avi
WalkingWithDog/v_WalkingWithDog_g01_c03.avi
WalkingWithDog/v_WalkingWithDog_g01_c04.avi
WalkingWithDog/v_WalkingWithDog_g02_c01.avi
WalkingWithDog/v_WalkingWithDog_g02_c02.avi
WalkingWithDog/v_WalkingWithDog_g02_c03.avi
WalkingWithDog/v_WalkingWithDog_g02_c04.avi
WalkingWithDog/v_WalkingWithDog_g02_c05.avi
WalkingWithDog/v_WalkingWithDog_g02_c06.avi
WalkingWithDog/v_WalkingWithDog_g03_c01.avi
WalkingWithDog/v_WalkingWithDog_g03_c02.avi
WalkingWithDog/v_WalkingWithDog_g03_c03.avi
WalkingWithDog/v_WalkingWithDog_g03_c04.avi
WalkingWithDog/v_WalkingWithDog_g03_c05.avi
WalkingWithDog/v_WalkingWithDog_g04_c01.avi
WalkingWithDog/v_WalkingWithDog_g04_c02.avi
WalkingWithDog/v_WalkingWithDog_g04_c03.avi
WalkingWithDog/v_WalkingWithDog_g04_c04.avi
WalkingWithDog/v_WalkingWithDog_g04_c05.avi
WalkingWithDog/v_WalkingWithDog_g05_c01.avi
WalkingWithDog/v_WalkingWithDog_g05_c02.avi
WalkingWithDog/v_WalkingWithDog_g05_c03.avi
WalkingWithDog/v_WalkingWithDog_g05_c04.avi
WalkingWithDog/v_WalkingWithDog_g05_c05.avi
WalkingWithDog/v_WalkingWithDog_g06_c01.avi
WalkingWithDog/v_WalkingWithDog_g06_c02.avi
WalkingWithDog/v_WalkingWithDog_g06_c03.avi
WalkingWithDog/v_WalkingWithDog_g06_c04.avi
WalkingWithDog/v_WalkingWithDog_g06_c05.avi
WalkingWithDog/v_WalkingWithDog_g07_c01.avi
WalkingWithDog/v_WalkingWithDog_g07_c02.avi
WalkingWithDog/v_WalkingWithDog_g07_c03.avi
WalkingWithDog/v_WalkingWithDog_g07_c04.avi
WalkingWithDog/v_WalkingWithDog_g07_c05.avi
WalkingWithDog/v_WalkingWithDog_g07_c06.avi
WallPushups/v_WallPushups_g01_c01.avi
WallPushups/v_WallPushups_g01_c02.avi
WallPushups/v_WallPushups_g01_c03.avi
WallPushups/v_WallPushups_g01_c04.avi
WallPushups/v_WallPushups_g02_c01.avi
WallPushups/v_WallPushups_g02_c02.avi
WallPushups/v_WallPushups_g02_c03.avi
WallPushups/v_WallPushups_g02_c04.avi
WallPushups/v_WallPushups_g03_c01.avi
WallPushups/v_WallPushups_g03_c02.avi
WallPushups/v_WallPushups_g03_c03.avi
WallPushups/v_WallPushups_g03_c04.avi
WallPushups/v_WallPushups_g03_c05.avi
WallPushups/v_WallPushups_g04_c01.avi
WallPushups/v_WallPushups_g04_c02.avi
WallPushups/v_WallPushups_g04_c03.avi
WallPushups/v_WallPushups_g04_c04.avi
WallPushups/v_WallPushups_g05_c01.avi
WallPushups/v_WallPushups_g05_c02.avi
WallPushups/v_WallPushups_g05_c03.avi
WallPushups/v_WallPushups_g05_c04.avi
WallPushups/v_WallPushups_g05_c05.avi
WallPushups/v_WallPushups_g06_c01.avi
WallPushups/v_WallPushups_g06_c02.avi
WallPushups/v_WallPushups_g06_c03.avi
WallPushups/v_WallPushups_g06_c04.avi
WallPushups/v_WallPushups_g06_c05.avi
WallPushups/v_WallPushups_g06_c06.avi
WallPushups/v_WallPushups_g06_c07.avi
WallPushups/v_WallPushups_g07_c01.avi
WallPushups/v_WallPushups_g07_c02.avi
WallPushups/v_WallPushups_g07_c03.avi
WallPushups/v_WallPushups_g07_c04.avi
WallPushups/v_WallPushups_g07_c05.avi
WallPushups/v_WallPushups_g07_c06.avi
WritingOnBoard/v_WritingOnBoard_g01_c01.avi
WritingOnBoard/v_WritingOnBoard_g01_c02.avi
WritingOnBoard/v_WritingOnBoard_g01_c03.avi
WritingOnBoard/v_WritingOnBoard_g01_c04.avi
WritingOnBoard/v_WritingOnBoard_g01_c05.avi
WritingOnBoard/v_WritingOnBoard_g01_c06.avi
WritingOnBoard/v_WritingOnBoard_g01_c07.avi
WritingOnBoard/v_WritingOnBoard_g02_c01.avi
WritingOnBoard/v_WritingOnBoard_g02_c02.avi
WritingOnBoard/v_WritingOnBoard_g02_c03.avi
WritingOnBoard/v_WritingOnBoard_g02_c04.avi
WritingOnBoard/v_WritingOnBoard_g02_c05.avi
WritingOnBoard/v_WritingOnBoard_g02_c06.avi
WritingOnBoard/v_WritingOnBoard_g02_c07.avi
WritingOnBoard/v_WritingOnBoard_g03_c01.avi
WritingOnBoard/v_WritingOnBoard_g03_c02.avi
WritingOnBoard/v_WritingOnBoard_g03_c03.avi
WritingOnBoard/v_WritingOnBoard_g03_c04.avi
WritingOnBoard/v_WritingOnBoard_g03_c05.avi
WritingOnBoard/v_WritingOnBoard_g03_c06.avi
WritingOnBoard/v_WritingOnBoard_g03_c07.avi
WritingOnBoard/v_WritingOnBoard_g04_c01.avi
WritingOnBoard/v_WritingOnBoard_g04_c02.avi
WritingOnBoard/v_WritingOnBoard_g04_c03.avi
WritingOnBoard/v_WritingOnBoard_g04_c04.avi
WritingOnBoard/v_WritingOnBoard_g05_c01.avi
WritingOnBoard/v_WritingOnBoard_g05_c02.avi
WritingOnBoard/v_WritingOnBoard_g05_c03.avi
WritingOnBoard/v_WritingOnBoard_g05_c04.avi
WritingOnBoard/v_WritingOnBoard_g05_c05.avi
WritingOnBoard/v_WritingOnBoard_g05_c06.avi
WritingOnBoard/v_WritingOnBoard_g06_c01.avi
WritingOnBoard/v_WritingOnBoard_g06_c02.avi
WritingOnBoard/v_WritingOnBoard_g06_c03.avi
WritingOnBoard/v_WritingOnBoard_g06_c04.avi
WritingOnBoard/v_WritingOnBoard_g06_c05.avi
WritingOnBoard/v_WritingOnBoard_g06_c06.avi
WritingOnBoard/v_WritingOnBoard_g06_c07.avi
WritingOnBoard/v_WritingOnBoard_g07_c01.avi
WritingOnBoard/v_WritingOnBoard_g07_c02.avi
WritingOnBoard/v_WritingOnBoard_g07_c03.avi
WritingOnBoard/v_WritingOnBoard_g07_c04.avi
WritingOnBoard/v_WritingOnBoard_g07_c05.avi
WritingOnBoard/v_WritingOnBoard_g07_c06.avi
WritingOnBoard/v_WritingOnBoard_g07_c07.avi
YoYo/v_YoYo_g01_c01.avi
YoYo/v_YoYo_g01_c02.avi
YoYo/v_YoYo_g01_c03.avi
YoYo/v_YoYo_g01_c04.avi
YoYo/v_YoYo_g01_c05.avi
YoYo/v_YoYo_g01_c06.avi
YoYo/v_YoYo_g01_c07.avi
YoYo/v_YoYo_g02_c01.avi
YoYo/v_YoYo_g02_c02.avi
YoYo/v_YoYo_g02_c03.avi
YoYo/v_YoYo_g02_c04.avi
YoYo/v_YoYo_g02_c05.avi
YoYo/v_YoYo_g03_c01.avi
YoYo/v_YoYo_g03_c02.avi
YoYo/v_YoYo_g03_c03.avi
YoYo/v_YoYo_g03_c04.avi
YoYo/v_YoYo_g03_c05.avi
YoYo/v_YoYo_g03_c06.avi
YoYo/v_YoYo_g04_c01.avi
YoYo/v_YoYo_g04_c02.avi
YoYo/v_YoYo_g04_c03.avi
YoYo/v_YoYo_g04_c04.avi
YoYo/v_YoYo_g04_c05.avi
YoYo/v_YoYo_g05_c01.avi
YoYo/v_YoYo_g05_c02.avi
YoYo/v_YoYo_g05_c03.avi
YoYo/v_YoYo_g05_c04.avi
YoYo/v_YoYo_g05_c05.avi
YoYo/v_YoYo_g06_c01.avi
YoYo/v_YoYo_g06_c02.avi
YoYo/v_YoYo_g06_c03.avi
YoYo/v_YoYo_g06_c04.avi
YoYo/v_YoYo_g07_c01.avi
YoYo/v_YoYo_g07_c02.avi
YoYo/v_YoYo_g07_c03.avi
YoYo/v_YoYo_g07_c04.avi
================================================
FILE: braincog/datasets/scripts/ucf101_dvs_preprocessing.py
================================================
# encoding: utf-8
# Author : Floyed
# Datetime : 2022/12/20 20:16
# User : Floyed
# Product : PyCharm
# Project : BrainCog
# File : ucf101_dvs_preprocessing.py
# explain :
import os
import shutil
ROOT_DIR = '/data/datasets/UCF101_DVS/UCF101_DVS'
train_path = os.path.join(ROOT_DIR, 'train')
val_path = os.path.join(ROOT_DIR, 'val')
val_fname = 'testlist01.txt'
cls_path = os.listdir(train_path)
if not os.path.exists(val_path):
os.mkdir(val_path)
for cls_name in cls_path:
os.mkdir(os.path.join(val_path, cls_name))
f = open(val_fname, 'r')
for fname in f.readlines():
fname = fname[:-4] + 'mat'
fname.replace('Billards', 'Billiards')
src = os.path.join(train_path, fname)
dst = os.path.join(val_path, fname)
try:
shutil.move(src, dst)
except:
print('[Warning] Cannot find {}.'.format(src))
print('[Moving] {} -> {}.'.format(src, dst))
================================================
FILE: braincog/datasets/ucf101_dvs/__init__.py
================================================
# encoding: utf-8
# Author : Floyed
# Datetime : 2023/1/30 21:04
# User : yu
# Product : PyCharm
# Project : BrainCog
# File : __init__.py.py
# explain :
from .ucf101_dvs import UCF101DVS
__all__ = [
'UCF101DVS'
]
================================================
FILE: braincog/datasets/ucf101_dvs/ucf101_dvs.py
================================================
# encoding: utf-8
# Author : Floyed
# Datetime : 2023/1/30 21:05
# User : yu
# Product : PyCharm
# Project : BrainCog
# File : ucf51_dvs.py
# explain :
# encoding: utf-8
# Author : Floyed
# Datetime : 2022/12/20 20:47
# User : Floyed
# Product : PyCharm
# Project : tonic
# File : ucf101dvs.py
# explain :
import os
import numpy as np
from numpy.lib import recfunctions
import scipy.io as scio
from typing import Tuple, Any, Optional
from tonic.dataset import Dataset
from tonic.download_utils import extract_archive
class UCF101DVS(Dataset):
"""ASL-DVS dataset . Events have (txyp) ordering.
::
@inproceedings{bi2019graph,
title={Graph-based Object Classification for Neuromorphic Vision Sensing},
author={Bi, Y and Chadha, A and Abbas, A and and Bourtsoulatze, E and Andreopoulos, Y},
booktitle={2019 IEEE International Conference on Computer Vision (ICCV)},
year={2019},
organization={IEEE}
}
Parameters:
save_to (string): Location to save files to on disk.
transform (callable, optional): A callable of transforms to apply to the data.
target_transform (callable, optional): A callable of transforms to apply to the targets/labels.
"""
sensor_size = (240, 180, 2)
dtype = np.dtype([("t", int), ("x", int), ("y", int), ("p", int)])
ordering = dtype.names
folder_name = 'UCF101DVS'
def __init__(self, save_to, train=False, transform=None, target_transform=None):
super(UCF101DVS, self).__init__(
save_to, transform=transform, target_transform=target_transform
)
if not self._check_exists():
raise NotImplementedError(
'Please manually download the dataset from'
' https://www.dropbox.com/sh/ie75dn246cacf6n/AACoU-_zkGOAwj51lSCM0JhGa?dl=0 '
'and extract it to {}'.format(self.location_on_system))
if train:
self.location_on_system = os.path.join(self.location_on_system, 'train')
else:
self.location_on_system = os.path.join(self.location_on_system, 'val')
classes = os.listdir(self.location_on_system)
self.int_classes = dict(zip(classes, range(len(classes))))
for path, dirs, files in os.walk(self.location_on_system):
dirs.sort()
files.sort()
for file in files:
if file.endswith("mat"):
fsize = os.path.getsize(path + '/' + file) / float(1024)
if fsize < 1:
# print('{} size {} K'.format(file, fsize))
continue
self.data.append(path + "/" + file)
self.targets.append(self.int_classes[path.split('/')[-1]])
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Returns:
(events, target) where target is index of the target class.
"""
events, target = scio.loadmat(self.data[index]), self.targets[index]
events = np.column_stack(
[
events["ts"],
events["x"],
self.sensor_size[1] - 1 - events["y"],
events["pol"],
]
)
events = np.lib.recfunctions.unstructured_to_structured(events, self.dtype)
if self.transform is not None:
events = self.transform(events)
if self.target_transform is not None:
target = self.target_transform(target)
return events, target
def __len__(self):
return len(self.data)
def _check_exists(self):
print(self.folder_name)
return self._folder_contains_at_least_n_files_of_type(
13523, ".mat"
)
================================================
FILE: braincog/datasets/utils.py
================================================
import torch
from einops import repeat
from braincog.datasets.gen_input_signal import lambda_max
def rescale(x, factor=None):
"""
数据放缩函数
:param x: 输入的tensor
:param factor: 缩放因子
:return: 缩放后的数据
"""
if factor:
x *= factor
else:
x *= lambda_max
return x
def dvs_channel_check_expend(x):
"""
检查是否存在DVS数据缺失, N-Car中有的数据会缺少一个通道
:param x: 输入的tensor
:return: 补全之后的数据
"""
if x.shape[1] == 1:
return repeat(x, 'b c w h -> b (r c) w h', r=2)
else:
return x
================================================
FILE: braincog/model_zoo/NeuEvo/__init__.py
================================================
# encoding: utf-8
# Author : Floyed
# Datetime : 2022/9/1 16:43
# User : Floyed
# Product : PyCharm
# Project : BrainCog
# File : __init__.py.py
# explain :
import os
import numpy as np
from .genotypes import PRIMITIVES, Genotype
forward_edge_num = sum(1 for i in range(3) for n in range(2 + i))
backward_edge_num = sum(1 for i in range(3) for n in range(i))
num_ops = len(PRIMITIVES)
type_num = len(PRIMITIVES) // 2
# edge_num = [2, 3, 4]
# node_id: (forward) 2, 3, 4
# node_id: (backward) 3, 2
edge_num = [2, 3, 4, 1, 2]
def parse(weights, operation_set,
op_threshold, parse_method,
steps, reduction=False,
back_connection=False):
global k_best
gene = []
if parse_method == 'darts':
n = 2
start = 0
for i in range(steps): # step = 4
end = start + n
W = weights[start:end].copy()
edges = sorted(range(i + 2), key=lambda x: -
max(W[x][k] for k in range(len(W[x]))))[:2]
for j in edges:
for k in range(len(W[j])):
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
# geno item : (operation, node idx)
gene.append((operation_set[k_best], j))
start = end
n += 1
elif parse_method == 'bio_darts':
weights_backward = weights[forward_edge_num:]
weights_forward = weights[:forward_edge_num]
# forward
n = 2
start = 0
# idx = np.argsort(weights_forward[:, 0]).tolist()
# if reduction:
# idx.remove(0)
# idx.remove(1)
# weights_forward[:, 0] = 0.
# weights_forward[idx[-2:], 0] = 1.
for i in range(steps): # step = 4
end = start + n
W = weights_forward[start:end].copy()
edges = sorted(range(i + 2), key=lambda x: -
max(W[x][k] for k in range(len(W[x]))))[:2]
k_best = None
idx = np.argsort(W[edges[0]])
gene.append((operation_set[idx[-1]], edges[0]))
idx = np.argsort(W[edges[1]])
gene.append((operation_set[idx[-1]], edges[1]))
#
# op_name = operation_set[idx[-1]]
# idx = np.argsort(W[edges[1]])
# if 'skip' in op_name:
# gene.append((operation_set[idx[-1]], edges[1]))
# elif '_n' in op_name:
# for k in reversed(idx):
# if '_n' not in operation_set[k]:
# gene.append((operation_set[k], edges[1]))
# break
# else:
# for k in reversed(idx):
# if '_n' in operation_set[k]:
# gene.append((operation_set[k], edges[1]))
# break
start = end
n += 1
if back_connection:
# backward
n = 1
start = 0
for i in range(1, steps):
end = start + n
W = weights_backward[start:end].copy()
edges = sorted(range(i), key=lambda x: -
max(W[x][k] for k in range(len(W[x]))))[0]
idx = np.argsort(W[edges])
gene.append((operation_set[idx[-1]] + '_back', edges + 2))
start = end
n += 1
elif 'threshold' in parse_method:
n = 2
start = 0
for i in range(steps): # step = 4
end = start + n
W = weights[start:end].copy()
if 'edge' in parse_method:
edges = list(range(i + 2))
else: # select edges using darts methods
edges = sorted(range(i + 2), key=lambda x: -
max(W[x][k] for k in range(len(W[x]))))[:2]
for j in edges:
if 'edge' in parse_method: # OP_{prob > T} AND |Edge| <= 2
topM = sorted(enumerate(W[j]), key=lambda x: x[1])[-2:]
for k, v in topM: # Get top M = 2 operations for one edge
if W[j][k] >= op_threshold:
gene.append((operation_set[k], i + 2, j))
# max( OP_{prob > T} ) and |Edge| <= 2
elif 'sparse' in parse_method:
k_best = None
for k in range(len(W[j])):
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
if W[j][k_best] >= op_threshold:
gene.append((operation_set[k_best], i + 2, j))
else:
raise NotImplementedError(
"Not support parse method: {}".format(parse_method))
start = end
n += 1
return gene
def parse_genotype(alphas, steps, multiplier, path=None,
parse_method='threshold_sparse', op_threshold=0.85):
alphas_normal, alphas_reduce = alphas
gene_normal = parse(alphas_normal, PRIMITIVES,
op_threshold, parse_method, steps)
gene_reduce = parse(alphas_reduce, PRIMITIVES,
op_threshold, parse_method, steps)
concat = range(2 + steps - multiplier, steps + 2)
genotype = Genotype(
normal=gene_normal, normal_concat=concat,
reduce=gene_reduce, reduce_concat=concat
)
if path is not None:
if not os.path.exists(path):
os.makedirs(path)
print('Architecture parsing....\n', genotype)
save_path = os.path.join(
path, parse_method + '_' + str(op_threshold) + '.txt')
with open(save_path, "w+") as f:
f.write(str(genotype))
print('Save in :', save_path)
================================================
FILE: braincog/model_zoo/NeuEvo/architect.py
================================================
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
from numpy.linalg import eigvals
from braincog.model_zoo.NeuEvo.model_search import calc_weight, calc_loss
def normalize(x):
mu = np.average(x)
sigma = np.std(x)
return (x - mu) / sigma
def _concat(xs):
return torch.cat([x.view(-1) for x in xs])
class Architect(object):
def __init__(self, model, args):
self.network_momentum = args.momentum
self.network_weight_decay = args.weight_decay
self.model = model
self.optimizer = torch.optim.AdamW(self.model.arch_parameters(),
lr=args.arch_learning_rate,
betas=(args.arch_lr_gamma, 0.999),
weight_decay=args.arch_weight_decay)
# self.optimizer = torch.optim.SGD(self.model.arch_parameters(), lr=args.arch_learning_rate)
self.hessian = None
self.grads = None
def step(self, input_valid, target_valid):
self.optimizer.zero_grad()
aux_input = torch.cat([calc_loss(self.model.alphas_normal)], dim=0)
loss, loss1, loss2 = self.model._loss(
input_valid, target_valid, aux_input)
# loss = self.model._loss(input_valid, target_valid)
loss.backward()
self.optimizer.step()
return loss1, loss2
def compute_Hw(self, input_valid, target_valid):
self.zero_grads(self.model.parameters())
self.zero_grads(self.model.arch_parameters())
aux_input = torch.cat(
[F.softmax(self.model.alphas_normal, dim=-1)], dim=0)
loss = self.model._loss(input_valid, target_valid, aux_input)
self.hessian = self._hessian(loss, self.model.arch_parameters())
return self.hessian
def zero_grads(self, parameters):
for p in parameters:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def compute_eigenvalues(self):
self.compute_Hw()
return eigvals(self.hessian.cpu().data.numpy())
def _hessian(self, outputs, inputs, out=None, allow_unused=False):
if torch.is_tensor(inputs):
inputs = [inputs]
else:
inputs = list(inputs)
n = sum(p.numel() for p in inputs)
if out is None:
out = torch.tensor(torch.zeros(n, n)).type_as(outputs)
ai = 0
for i, inp in enumerate(inputs):
[grad] = torch.autograd.grad(outputs, inp, create_graph=True,
allow_unused=allow_unused)
grad = grad.contiguous().view(-1) + self.weight_decay * inp.view(-1)
for j in range(inp.numel()):
if grad[j].requires_grad:
row = self.gradient(
grad[j], inputs[i:], retain_graph=True)[j:]
else:
n = sum(x.numel() for x in inputs[i:]) - j
row = Variable(torch.zeros(n)).type_as(grad[j])
out.data[ai, ai:].add_(row.clone().type_as(out).data)
if ai + 1 < n:
out.data[ai + 1:,
ai].add_(row.clone().type_as(out).data[1:])
del row
ai += 1
del grad
return out
================================================
FILE: braincog/model_zoo/NeuEvo/genotypes.py
================================================
from collections import namedtuple
import torch
Genotype = namedtuple('Genotype', 'normal normal_concat')
"""
Operation sets
"""
PRIMITIVES = [
'conv_3x3_p',
# 'max_pool_3x3',
# 'avg_pool_3x3',
# 'def_conv_3x3',
# 'def_conv_5x5',
# 'sep_conv_3x3',
# 'sep_conv_5x5',
# 'dil_conv_3x3',
# 'dil_conv_5x5',
# 'max_pool_3x3_p',
# 'avg_pool_3x3_p',
'conv_3x3_p',
'conv_5x5_p',
# 'conv_3x3_p_p',
# 'sep_conv_3x3_p',
# 'sep_conv_5x5_p',
# 'dil_conv_3x3_p',
# 'dil_conv_5x5_p',
# 'def_conv_3x3_p',
# 'def_conv_5x5_p',n
# 'max_pool_3x3_n',
# 'avg_pool_3x3_n',
'conv_3x3_n',
'conv_5x5_n',
# 'conv_3x3_p_n',
# 'sep_conv_3x3_n',
# 'sep_conv_5x5_n',
# 'dil_conv_3x3_n',
# 'dil_conv_5x5_n',
# 'def_conv_3x3_n',
# 'def_conv_5x5_n',
# 'transformer',
]
"""====== SnnMlp Archirtecture By Other Methods"""
mlp1 = Genotype(
normal=[
('mlp', 0), ('conv_3x3_p', 1), # 2
('mlp', 1), ('mlp', 0), # 3
('conv_3x3_p', 2), ('mlp', 3), # 4
('mlp_back', 2),
('conv_3x3_p_back', 2)
],
normal_concat=range(2, 5)
)
mlp2 = Genotype(
normal=[
('mlp', 0), ('conv_3x3_p', 1),
('conv_3x3_p', 2), ('mlp_p', 1),
# ('mlp_n', 1), ('conv_3x3_p', 2),
('mlp_back', 2)
],
normal_concat=range(2, 4)
)
"""====== SNN Archirtecture By Other Methods"""
dvsc10_new_skip22 = Genotype(
normal=[
('conv_3x3_p', 1), ('conv_3x3_p', 0), # 2
('conv_5x5_p', 1), ('conv_3x3_p', 2), # 3
('conv_3x3_p', 0), ('conv_3x3_p', 3), # 4
('conv_3x3_n_back', 2), ('conv_3x3_p_back', 3) # 3, 4
],
normal_concat=range(2, 5)
)
dvsc10_new_skip22 = Genotype(
normal=[
('conv_3x3_p', 1), ('conv_3x3_p', 0),
('conv_5x5_n', 1), ('conv_3x3_p', 2),
('conv_5x5_n', 0), ('conv_3x3_p', 3),
('conv_3x3_n_back', 0), ('conv_3x3_p_back', 1)
],
normal_concat=range(2, 5)
)
dvsc10_new_skip21 = Genotype(
normal=[
('conv_3x3_n', 0), ('conv_5x5_p', 1), # 2
('conv_3x3_p', 1), ('conv_5x5_p', 2), # 3
('conv_5x5_n', 2), ('conv_3x3_p', 1), # 4
# ('conv_3x3_p_back', 2), ('conv_5x5_p_back', 2)
],
normal_concat=range(2, 5)
)
dvsc10_new_skip20 = Genotype(
normal=[
('conv_5x5_p', 0), ('conv_5x5_n', 1),
('conv_3x3_n', 2), ('conv_5x5_p', 0),
('conv_3x3_p', 2), ('conv_3x3_n', 3),
('conv_3x3_p_back', 2),
('conv_5x5_p_back', 3)
],
normal_concat=range(2, 5)
)
dvsc10_new_skip19 = Genotype(
normal=[
('conv_5x5_n', 0), ('conv_3x3_p', 1),
('conv_5x5_n', 2), ('conv_5x5_n', 0),
('conv_3x3_p', 2), ('conv_5x5_p', 3),
('conv_3x3_p_back', 2),
('conv_5x5_p_back', 2)
],
normal_concat=range(2, 5)
)
dvsc10_new_skip18 = Genotype(
normal=[
('conv_5x5_p', 0), ('conv_3x3_p', 1),
('conv_5x5_p', 2), ('conv_5x5_n', 0),
('conv_3x3_p', 2), ('conv_5x5_p', 3),
('conv_5x5_n_back', 2),
('conv_3x3_p_back', 2)],
normal_concat=range(2, 5)
)
dvsc10_new_skip17 = Genotype(
normal=[
('conv_3x3_p', 1), ('conv_5x5_n', 0),
('conv_5x5_n', 2), ('conv_5x5_p', 1),
('conv_3x3_p', 2), ('avg_pool_3x3_p', 3),
('avg_pool_3x3_p_back', 2), ('conv_3x3_p_back', 2)
],
normal_concat=range(2, 5)
)
dvsc10_new_skip16 = Genotype(
normal=[
('conv_5x5_p', 0), ('conv_5x5_n', 1),
('conv_3x3_n', 2), ('avg_pool_3x3_p', 0),
('conv_3x3_p', 2), ('avg_pool_3x3_n', 3),
('conv_3x3_p_back', 2),
('conv_3x3_p_back', 3)
],
normal_concat=range(2, 5)
)
dvsc10_new_skip15 = Genotype(
normal=[
('conv_5x5_n', 0), ('conv_5x5_p', 1),
('conv_5x5_p', 2), ('conv_5x5_n', 1),
('conv_3x3_p', 2), ('conv_5x5_p', 3),
('conv_5x5_p_back', 2),
('conv_3x3_p_back', 3)
],
normal_concat=range(2, 5)
)
dvsc10_new_skip14 = Genotype(
normal=[
('conv_5x5_n', 1), ('conv_3x3_p', 0),
('conv_5x5_p', 1), ('conv_3x3_p', 2),
('conv_3x3_p', 2), ('conv_5x5_n', 1),
('conv_3x3_n_back', 2),
('conv_3x3_p_back', 3)],
normal_concat=range(2, 5)
)
dvsc10_new_skip13 = Genotype(
normal=[
('conv_5x5_n', 1), ('conv_3x3_p', 0),
('conv_5x5_n', 1), ('conv_3x3_p', 2),
('conv_3x3_p', 2), ('conv_5x5_n', 1),
('conv_3x3_n_back', 2),
('conv_3x3_p_back', 3)],
normal_concat=range(2, 5)
)
dvsc10_new_skip12 = Genotype(
normal=[
('conv_5x5_n', 0), ('conv_3x3_p', 1),
('conv_5x5_n', 2), ('conv_5x5_n', 0),
('conv_3x3_p', 2), ('conv_5x5_p', 3),
('conv_3x3_n_back', 2),
('conv_3x3_p_back', 2)
],
normal_concat=range(2, 5)
)
# dvsc10_new_skip12 = Genotype(
# normal=[
# ('conv_3x3_p', 0), ('conv_3x3_n', 1),
# ('conv_3x3_p', 1), ('conv_5x5_n', 2),
# ('conv_3x3_n', 3), ('conv_3x3_p', 0),
# ('conv_5x5_p_back', 2), ('conv_3x3_p_back', 3)
# ],
# normal_concat=range(2, 5)
# )
dvsc10_new_skip11 = Genotype(normal=[
('conv_3x3_n', 0), ('conv_5x5_n', 1),
('conv_5x5_p', 0), ('conv_3x3_n', 2),
('conv_3x3_p', 2), ('conv_5x5_n', 0),
('conv_3x3_n_back', 2),
('conv_3x3_p_back', 3)],
normal_concat=range(2, 5)
)
dvsc10_new_skip10 = Genotype(
normal=[
('conv_5x5_n', 1), ('conv_3x3_p', 0),
('conv_5x5_p', 2), ('conv_5x5_p', 1),
('conv_3x3_p', 2), ('conv_5x5_n', 1),
('conv_3x3_n_back', 2),
('conv_3x3_n_back', 2)],
normal_concat=range(2, 5)
)
dvsc10_new_skip9 = Genotype(
normal=[
('conv_5x5_p', 1), ('conv_5x5_n', 0),
('conv_5x5_p', 2), ('conv_5x5_n', 0),
('conv_3x3_p', 2), ('conv_5x5_p', 3),
('conv_3x3_p_back', 2),
('conv_5x5_n_back', 3)
],
normal_concat=range(2, 5)
)
dvsc10_new_skip8 = Genotype(
normal=[
('conv_5x5_n', 0), ('conv_5x5_n', 1),
('conv_3x3_n', 2), ('conv_5x5_p', 0),
('conv_3x3_p', 2), ('conv_5x5_n', 1),
('conv_5x5_n_back', 2),
('conv_3x3_p_back', 3)],
normal_concat=range(2, 5)
)
dvsc10_new_skip7 = Genotype(
normal=[
('conv_5x5_p', 0), ('conv_5x5_p', 1),
('conv_3x3_n', 2), ('conv_5x5_n', 0),
('conv_3x3_p', 2), ('conv_5x5_p', 3),
('conv_5x5_n_back', 2),
('conv_3x3_p_back', 3)],
normal_concat=range(2, 5)
)
dvsc10_new_skip6 = Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_n', 1),
('conv_5x5_p', 2), ('conv_5x5_p', 1),
('conv_3x3_p', 2), ('conv_5x5_n', 0),
('conv_3x3_n_back', 2), ('conv_3x3_n_back', 2)],
normal_concat=range(2, 5)
)
dvsc10_new_skip5 = Genotype(
normal=[
('conv_3x3_p', 0), ('conv_3x3_p', 1),
('conv_3x3_n', 2), ('conv_3x3_n', 0),
('conv_3x3_p', 2), ('conv_3x3_p', 3),
('conv_5x5_n_back', 2),
('conv_3x3_p_back', 2)],
normal_concat=range(2, 5)
)
dvsc10_new_skip4 = Genotype(
normal=[
('conv_5x5_n', 1), ('conv_5x5_p', 0),
('conv_3x3_p', 2), ('conv_5x5_p', 1),
('conv_3x3_p', 2), ('conv_5x5_n', 0),
('conv_3x3_p_back', 2),
('conv_3x3_p_back', 3)],
normal_concat=range(2, 5)
)
dvsc10_new_skip3 = Genotype(
normal=[
('conv_5x5_p', 0), ('conv_3x3_p', 1),
('conv_3x3_n', 2), ('conv_3x3_n', 0),
('conv_3x3_p', 2), ('conv_5x5_p', 3),
('conv_5x5_n_back', 2),
('conv_3x3_p_back', 3)
],
normal_concat=range(2, 5)
)
dvsc10_new_skip2 = Genotype(
normal=[
('avg_pool_3x3_p', 0), ('avg_pool_3x3_p', 1),
('avg_pool_3x3_p', 0), ('avg_pool_3x3_p', 1),
('conv_3x3_p', 2), ('avg_pool_3x3_n', 0),
('avg_pool_3x3_n_back', 2),
('conv_3x3_p_back', 2)],
normal_concat=range(2, 5)
)
dvsc10_new_skip1 = Genotype(
normal=[
('conv_5x5_p', 0), ('conv_3x3_n', 1),
('conv_3x3_n', 2), ('conv_3x3_p', 1),
('conv_5x5_p', 1), ('conv_3x3_p', 2),
('conv_3x3_p_back', 2),
('conv_3x3_p_back', 2)],
normal_concat=range(2, 5)
)
dvsc10_new_skip = Genotype(
normal=[
('conv_3x3_n', 1), ('conv_3x3_p', 0),
('conv_3x3_p', 0), ('avg_pool_3x3_p', 1),
('conv_3x3_p', 2), ('conv_3x3_n', 0),
('conv_3x3_p_back', 2),
('conv_3x3_p_back', 2)],
normal_concat=range(2, 5)
)
dvsc10_new_base0 = Genotype(
normal=[
('avg_pool_3x3_p', 1), ('avg_pool_3x3_p', 0),
('avg_pool_3x3_n', 2), ('avg_pool_3x3_p', 1),
('avg_pool_3x3_n', 2), ('avg_pool_3x3_n', 3),
('avg_pool_3x3_n_back', 2),
('avg_pool_3x3_n_back', 3)],
normal_concat=range(2, 5)
)
dvsc10_new_base1 = Genotype(
normal=[
('conv_3x3_p', 1), ('conv_5x5_n', 0),
('conv_5x5_p', 1), ('conv_3x3_p', 0),
('conv_5x5_n', 1), ('conv_3x3_p', 0),
('avg_pool_3x3_p_back', 2),
('conv_3x3_p_back', 3)
],
normal_concat=range(2, 5)
)
dvsc10_new_base2 = Genotype(
normal=[
('conv_5x5_p', 0), ('conv_3x3_p', 1),
('conv_5x5_n', 1), ('avg_pool_3x3_p', 0),
('avg_pool_3x3_n', 3), ('conv_5x5_n', 1),
('avg_pool_3x3_n_back', 2),
('avg_pool_3x3_n_back', 2)],
normal_concat=range(2, 5)
)
dvsc10_new_base3 = Genotype(
normal=[
('avg_pool_3x3_p', 0), ('conv_5x5_p', 1),
('conv_3x3_p', 1), ('conv_3x3_n', 0),
('conv_5x5_p', 1), ('conv_3x3_n', 0),
('conv_3x3_p_back', 2),
('avg_pool_3x3_n_back', 3)],
normal_concat=range(2, 5)
)
dvsc10_grad2 = Genotype(
normal=[
('avg_pool_3x3_n', 1), ('conv_5x5_p', 0),
('conv_5x5_n', 1), ('conv_5x5_n', 0),
('conv_3x3_p', 3), ('conv_5x5_n', 1),
('conv_5x5_p_back', 2),
('conv_3x3_p_back', 2)],
normal_concat=range(2, 5)
)
dvsc10_grad1 = Genotype(
normal=[
('avg_pool_3x3_p', 1), ('conv_5x5_p', 0),
('avg_pool_3x3_n', 2), ('avg_pool_3x3_n', 1),
('avg_pool_3x3_p', 2), ('conv_5x5_n', 1),
('conv_5x5_p_back', 2),
('conv_3x3_p_back', 3)],
normal_concat=range(2, 5))
dvsg_new2 = Genotype(
normal=[
('avg_pool_3x3_p', 1), ('conv_5x5_p', 0),
('conv_3x3_p', 1), ('conv_3x3_p', 0),
('conv_3x3_p', 1), ('avg_pool_3x3_p', 0),
('avg_pool_3x3_n_back', 2),
('avg_pool_3x3_n_back', 3)],
normal_concat=range(2, 5))
dvsg_new1 = Genotype(
normal=[
('avg_pool_3x3_p', 1), ('conv_5x5_p', 0),
('conv_3x3_p', 1), ('conv_3x3_p', 0),
('conv_3x3_p', 1), ('avg_pool_3x3_p', 0),
('avg_pool_3x3_n_back', 2),
('conv_5x5_n_back', 3)],
normal_concat=range(2, 5))
dvscal_new1 = Genotype(
normal=[
('conv_5x5_n', 0), ('conv_5x5_n', 1),
('conv_5x5_n', 1), ('conv_5x5_p', 0),
('avg_pool_3x3_p', 1), ('conv_5x5_p', 0),
('avg_pool_3x3_n_back', 2),
('avg_pool_3x3_n_back', 2)],
normal_concat=range(2, 5)
)
dvsc10_new8 = Genotype(
normal=[('conv_5x5_p', 0), ('conv_5x5_p', 1),
('conv_3x3_p', 0), ('conv_5x5_n', 1),
('conv_5x5_p', 0), ('conv_5x5_n', 1),
('avg_pool_3x3_n_back', 2),
('avg_pool_3x3_n_back', 3)],
normal_concat=range(2, 5)
)
dvsc10_new7 = Genotype(
normal=[
('conv_5x5_p', 0), ('conv_5x5_p', 1),
('conv_3x3_p', 0), ('conv_5x5_n', 1),
('conv_5x5_p', 0), ('conv_5x5_n', 1),
('conv_3x3_n_back', 2),
('avg_pool_3x3_n_back', 2)],
normal_concat=range(2, 5))
dvsc10_new6 = Genotype(
normal=[
('conv_3x3_p', 1), ('conv_3x3_p', 0),
('conv_3x3_p', 0), ('conv_3x3_p', 1),
('conv_3x3_p', 0), ('avg_pool_3x3_p', 1),
('avg_pool_3x3_n_back', 2),
('avg_pool_3x3_n_back', 2)],
normal_concat=range(2, 5))
dvsc10_new5 = Genotype(
normal=[
('conv_5x5_p', 1), ('conv_3x3_p', 0),
('conv_3x3_p', 0), ('conv_5x5_p', 1),
('conv_3x3_p', 0), ('avg_pool_3x3_p', 1),
('avg_pool_3x3_n_back', 2),
('avg_pool_3x3_n_back', 2)],
normal_concat=range(2, 5))
dvsc10_new4 = Genotype(
normal=[
('conv_3x3_n', 1), ('conv_3x3_p', 0),
('conv_5x5_p', 1), ('conv_5x5_p', 0),
('conv_5x5_p', 1), ('conv_5x5_p', 0),
('avg_pool_3x3_p_back', 2),
('avg_pool_3x3_n_back', 2)],
normal_concat=range(2, 5),
)
dvsc10_new3 = Genotype(
normal=[
('avg_pool_3x3_p', 0), ('conv_3x3_n', 1),
('conv_3x3_n', 1), ('conv_3x3_n', 0),
('avg_pool_3x3_p', 2), ('conv_3x3_n', 1),
('avg_pool_3x3_n_back', 2),
('avg_pool_3x3_p', 2)],
normal_concat=range(2, 5),
)
dvsc10_new2 = Genotype(normal=[
('conv_3x3_p', 0), ('conv_3x3_n', 1),
('conv_3x3_n', 1), ('avg_pool_3x3_p', 0),
('avg_pool_3x3_p', 2), ('conv_3x3_n', 1),
('avg_pool_3x3_n_back', 2),
('conv_3x3_p_back', 2)],
normal_concat=range(2, 5),
)
dvsc10_new1 = Genotype(
normal=[
('conv_3x3_p', 1), ('avg_pool_3x3_p', 0),
('avg_pool_3x3_p', 0), ('conv_3x3_n', 1),
('conv_3x3_p', 0), ('conv_3x3_p', 1),
('conv_3x3_p_back', 2),
('conv_3x3_n_back', 2)],
normal_concat=range(2, 5)
)
dvsc10_new0 = Genotype(
normal=[
('conv_3x3_p', 1), ('avg_pool_3x3_p', 0),
('avg_pool_3x3_p', 2), ('conv_3x3_n', 1),
('conv_3x3_p', 0), ('conv_3x3_p', 3),
('conv_3x3_p_back', 2),
('conv_3x3_n_back', 3)],
normal_concat=range(2, 5)
)
cifar_new_skip1 = Genotype(
normal=[
('conv_5x5_n', 0), ('conv_5x5_p', 1),
('avg_pool_3x3_p', 0), ('avg_pool_3x3_n', 2),
('avg_pool_3x3_p', 2), ('conv_5x5_p', 0),
('avg_pool_3x3_n_back', 2),
('avg_pool_3x3_p_back', 3)
],
normal_concat=range(2, 5))
cifar_new1 = Genotype(
normal=[
('avg_pool_3x3_p', 1), ('avg_pool_3x3_p', 0),
('conv_3x3_n', 0), ('avg_pool_3x3_p', 1),
('avg_pool_3x3_p', 2), ('conv_3x3_p', 0),
('avg_pool_3x3_n_back', 2),
('conv_3x3_p_back', 2)],
normal_concat=range(2, 5)
)
cifar_new2 = Genotype(
normal=[
('conv_3x3_n', 0), ('avg_pool_3x3_p', 1),
('conv_3x3_p', 0), ('avg_pool_3x3_p', 1),
('conv_3x3_p', 2), ('conv_3x3_n', 0),
('conv_3x3_n_back', 2),
('conv_3x3_p_back', 2)],
normal_concat=range(2, 5),
)
cifar_new0 = Genotype(
normal=[
('avg_pool_3x3_p', 1), ('avg_pool_3x3_n', 0), # 2, 3
('conv_3x3_n', 0), ('avg_pool_3x3_p', 1), # 4, 5
('conv_3x3_p', 2), ('conv_3x3_n', 3), # 6 , 7
('avg_pool_3x3_n_back', 2),
('conv_3x3_p_back', 1)],
normal_concat=range(2, 5)
)
================================================
FILE: braincog/model_zoo/NeuEvo/model.py
================================================
from functools import partial
from typing import List, Type
from braincog.model_zoo.NeuEvo.operations import *
from braincog.model_zoo.NeuEvo.genotypes import Genotype
from braincog.base.utils import drop_path
from timm.models import register_model
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.model_zoo.base_module import BaseModule
class MlpCell(BaseModule):
def __init__(
self,
genotype: Genotype,
C: int,
input_dim: int,
output_dim: int,
encode_type: str = 'direct',
activation_fn: Type[nn.Module] = LIFNode,
squash_output: bool = False,
back_connection: bool = True,
step: int = 10,
**kwargs
):
super(MlpCell, self).__init__(
step=step,
encode_type=encode_type,
layer_by_layer=True
)
# print(activation_fn, step)
self.act_fun = partial(activation_fn, step=step, layer_by_layer=self.layer_by_layer, **kwargs)
self.back_connection = back_connection
op_names, indices = zip(*genotype.normal)
concat = genotype.normal_concat
self._compile(C, op_names, indices, concat)
self.feature = nn.Sequential(
nn.Linear(input_dim, C),
self.act_fun(),
)
if output_dim > 0:
self.output_fn = nn.Linear(self.multiplier * C, output_dim)
elif squash_output:
self.output_fn = nn.Tanh()
else:
self.output_fn = nn.Identity()
def _compile(self, C, op_names, indices, concat):
assert len(op_names) == len(indices)
# self._steps = len(op_names) // 2
self._concat = concat
self.multiplier = len(concat)
self._ops = nn.ModuleList()
self._ops_back = nn.ModuleList()
back_begin_index = 0
for i, (name, index) in enumerate(zip(op_names, indices)):
# print(name, index)
if '_back' in name:
back_begin_index = i
break
op = OPS_Mlp[name](C, act_fun=self.act_fun)
self._ops += [op]
if self.back_connection:
for name, index in zip(op_names[back_begin_index:], indices[back_begin_index:]):
op = OPS_Mlp[name.replace('_back', '')](
C, act_fun=self.act_fun)
self._ops_back += [op]
if self.back_connection:
self._indices_forward = indices[:back_begin_index]
self._indices_backward = indices[back_begin_index:]
else:
self._indices_backward = []
self._indices_forward = indices
self._steps = len(self._indices_forward) // 2
def _forward_once(self, s0, s1, drop_prob):
states = [s0, s1]
for i in range(self._steps):
h1 = states[self._indices_forward[2 * i]]
h2 = states[self._indices_forward[2 * i + 1]]
op1 = self._ops[2 * i]
op2 = self._ops[2 * i + 1]
h1 = op1(h1)
h2 = op2(h2)
if self.training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
s = h1 + h2
if self.back_connection:
if i != 0:
s_back = self._ops_back[i - 1](s)
states[self._indices_backward[i - 1]
] = states[self._indices_backward[i - 1]] + s_back
states += [s]
outputs = []
for i in self._concat:
outputs.append(rearrange(states[i], '(t b) c -> t b c', t=self.step))
outputs = torch.cat(outputs, dim=2) # T, B, C
return outputs
def forward(self, inputs):
inputs = self.encoder(inputs)
self.reset()
if self.layer_by_layer:
x = self.feature(inputs)
x = self._forward_once(x, x, 0.)
x = self.output_fn(x)
x = x.mean(0)
else:
raise NotImplementedError
return x
class Cell(nn.Module):
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, act_fun, back_connection):
# print(C_prev_prev, C_prev, C, reduction)
super(Cell, self).__init__()
self.act_fun = act_fun
self.back_connection = back_connection
self.reduction = reduction
if reduction:
self.fun = FactorizedReduce(
C_prev, C * 3, act_fun=act_fun
)
self.multiplier = 3
else:
if reduction_prev:
self.preprocess0 = FactorizedReduce(
C_prev_prev, C, act_fun=act_fun)
else:
self.preprocess0 = ReLUConvBN(
C_prev_prev, C, 1, 1, 0, act_fun=act_fun)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, act_fun=act_fun)
op_names, indices = zip(*genotype.normal)
concat = genotype.normal_concat
self._compile(C, op_names, indices, concat, reduction)
def _compile(self, C, op_names, indices, concat, reduction):
assert len(op_names) == len(indices)
# self._steps = len(op_names) // 2
self._concat = concat
self.multiplier = len(concat)
self._ops = nn.ModuleList()
self._ops_back = nn.ModuleList()
back_begin_index = 0
for i, (name, index) in enumerate(zip(op_names, indices)):
# print(name, index)
if '_back' in name:
back_begin_index = i
break
stride = 2 if reduction and index < 2 else 1
op = OPS[name](C, stride, True, act_fun=self.act_fun)
self._ops += [op]
if self.back_connection:
for name, index in zip(op_names[back_begin_index:], indices[back_begin_index:]):
op = OPS[name.replace('_back', '')](
C, 1, True, act_fun=self.act_fun)
self._ops_back += [op]
if self.back_connection:
self._indices_forward = indices[:back_begin_index]
self._indices_backward = indices[back_begin_index:]
else:
self._indices_backward = []
self._indices_forward = indices
self._steps = len(self._indices_forward) // 2
def forward(self, s0, s1, drop_prob):
if self.reduction:
return self.fun(s1)
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
h1 = states[self._indices_forward[2 * i]]
h2 = states[self._indices_forward[2 * i + 1]]
op1 = self._ops[2 * i]
op2 = self._ops[2 * i + 1]
h1 = op1(h1)
h2 = op2(h2)
if self.training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
s = h1 + h2
if self.back_connection:
if i != 0:
s_back = self._ops_back[i - 1](s)
states[self._indices_backward[i - 1]
] = states[self._indices_backward[i - 1]] + s_back
states += [s]
outputs = torch.cat([states[i]
for i in self._concat], dim=1) # N,C,H, W
return outputs
# return self.node(outputs)
class DCOCell(nn.Module):
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, act_fun):
super(DCOCell, self).__init__()
self.act_fun = act_fun
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
if reduction:
op_names, tos, froms = zip(*genotype.reduce)
else:
op_names, tos, froms = zip(*genotype.normal)
self._compile(C, op_names, tos, froms, reduction)
def _compile(self, C, op_names, tos, froms, reduction):
self._ops = nn.ModuleDict()
for name_i, to_i, from_i in zip(op_names, tos, froms):
stride = 2 if reduction and from_i < 2 else 1
op = OPS[name_i](C, stride, True, act_fun=self.act_fun)
if str(to_i) in self._ops.keys():
if str(from_i) in self._ops[str(to_i)]:
self._ops[str(to_i)][str(from_i)] += [op]
else:
self._ops[str(to_i)][str(from_i)] = nn.ModuleList()
self._ops[str(to_i)][str(from_i)] += [op]
else:
self._ops[str(to_i)] = nn.ModuleDict()
self._ops[str(to_i)][str(from_i)] = nn.ModuleList()
self._ops[str(to_i)][str(from_i)] += [op]
# TODO: Some intermediate node maybe no selected during search.
self.multiplier = len(self._ops)
def forward(self, s0, s1, drop_prob):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = {}
states['0'] = s0
states['1'] = s1
# get all the operations in current intermediate node
for to_i, ops in self._ops.items():
h = []
for from_i, op_i in ops.items():
# each edge may no more than one operation
if from_i not in states:
# print('Exist the isolate node, which id is {}, we need ignore it!'.format(from_i))
continue
h += [sum([op(states[from_i])
for op in op_i if from_i in states])]
out = sum(h)
if self.training and drop_prob > 0:
out = drop_path(out, drop_prob)
states[to_i] = out
outputs = torch.cat([v for v in states.values()][2:], dim=1)
# return outputs
return outputs
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes, act_fun):
"""assuming inputs size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.act_fun = act_fun
self.features = nn.Sequential(
# nn.ReLU(inplace=True),
self.act_fun(),
# image size = 2 x 2
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False),
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
# nn.ReLU(inplace=True),
self.act_fun(),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
# nn.ReLU(inplace=True)
self.act_fun()
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0), -1))
return x
class AuxiliaryHeadImageNet(nn.Module):
def __init__(self, C, num_classes):
"""assuming inputs size 14x14"""
super(AuxiliaryHeadImageNet, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
# NOTE: This batchnorm was omitted in my earlier implementation due to a typo.
# Commenting it out for consistency with the experiments in the paper.
# nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0), -1))
return x
@register_model
class NetworkCIFAR(BaseModule):
def __init__(self,
C,
num_classes,
layers,
auxiliary,
genotype,
parse_method='darts',
step=1,
node_type='ReLUNode',
**kwargs):
super(NetworkCIFAR, self).__init__(
step=step,
num_classes=num_classes,
**kwargs
)
if isinstance(node_type, str):
self.act_fun = eval(node_type)
else:
self.act_fun = node_type
self.act_fun = partial(self.act_fun, **kwargs)
if 'back_connection' in kwargs.keys():
self.back_connection = kwargs['back_connection']
else:
self.back_connection = False
self.spike_output = kwargs['spike_output'] if 'spike_output' in kwargs else True
self.dataset = kwargs['dataset']
if self.layer_by_layer:
self.flatten = nn.Flatten(start_dim=1)
else:
self.flatten = nn.Flatten()
self._layers = layers
self._auxiliary = auxiliary
self.drop_path_prob = 0
stem_multiplier = 3
C_curr = stem_multiplier * C
if self.dataset == 'dvsg' or self.dataset == 'dvsc10' or self.dataset == 'NCALTECH101':
self.stem = nn.Sequential(
nn.Conv2d(2 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr),
)
# self.reduce_idx = [
# layers // 4,
# layers // 2,
# 3 * layers // 4
# ]
self.reduce_idx = [1, 3, 5, 7]
else:
self.stem = nn.Sequential(
nn.Conv2d(3 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr),
)
self.reduce_idx = [layers // 4,
layers // 2,
3 * layers // 4]
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
for i in range(layers):
if i in self.reduce_idx:
C_curr *= 2
reduction = True
else:
reduction = False
if parse_method == 'darts':
cell = Cell(genotype, C_prev_prev, C_prev, C_curr,
reduction, reduction_prev,
act_fun=self.act_fun, back_connection=self.back_connection)
else:
cell = DCOCell(genotype, C_prev_prev, C_prev, C_curr,
reduction, reduction_prev, act_fun=self.act_fun)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
if i == 2 * layers // 3:
C_to_auxiliary = C_prev
if auxiliary:
self.auxiliary_head = AuxiliaryHeadCIFAR(
C_to_auxiliary, num_classes, act_fun=self.act_fun)
self.global_pooling = nn.Sequential(
self.act_fun(), nn.AdaptiveAvgPool2d(1))
if self.spike_output:
self.classifier = nn.Sequential(
nn.Linear(C_prev, 10 * num_classes),
self.act_fun())
self.vote = VotingLayer(10)
else:
self.classifier = nn.Linear(C_prev, num_classes)
self.vote = nn.Identity()
# self.classifier = nn.Linear(C_prev, num_classes)
# self.vote = nn.Identity()
def forward(self, inputs):
logits_aux = None
inputs = self.encoder(inputs)
if not self.layer_by_layer:
outputs = []
output_aux = []
self.reset()
for t in range(self.step):
x = inputs[t]
s0 = s1 = self.stem(x)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
# print(s0.shape, s1.shape)
# if i == 2 * self._layers // 3:
# if self._auxiliary and self.training:
# logits_aux = self.auxiliary_head(s1)
out = self.global_pooling(s1)
out = self.classifier(self.flatten(out))
logits = self.vote(out)
outputs.append(logits)
output_aux.append(logits_aux)
return sum(outputs) / len(outputs)
# logits_aux if logits_aux is None else (sum(output_aux) / len(output_aux))
else:
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
if i == 2 * self._layers // 3:
if self._auxiliary and self.training:
logits_aux = self.auxiliary_head(s1)
out = self.global_pooling(s1)
out = self.classifier(self.flatten(out))
out = rearrange(out, '(t b) c -> t b c', t=self.step).mean(0)
logits = self.vote(out)
return logits
@register_model
class NetworkImageNet(BaseModule):
def __init__(self,
C,
num_classes,
layers,
auxiliary,
genotype,
step=1,
node_type='ReLUNode',
**kwargs):
super(NetworkImageNet, self).__init__(
step=step,
num_classes=num_classes,
**kwargs)
if isinstance(node_type, str):
self.act_fun = eval(node_type)
else:
self.act_fun = node_type
self.act_fun = partial(self.act_fun, **kwargs)
if 'back_connection' in kwargs.keys():
self.back_connection = kwargs['back_connection']
else:
self.back_connection = False
self.spike_output = kwargs['spike_output'] if 'spike_output' in kwargs else True
if self.layer_by_layer:
self.flatten = nn.Flatten(start_dim=1)
else:
self.flatten = nn.Flatten()
self._layers = layers
self._auxiliary = auxiliary
self.drop_path_prob = 0
self.stem0 = nn.Sequential(
nn.Conv2d(3, C // 2, kernel_size=3,
stride=2, padding=1, bias=False),
nn.BatchNorm2d(C // 2),
# nn.ReLU(inplace=True),
self.act_fun(),
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
self.stem1 = nn.Sequential(
# nn.ReLU(inplace=True),
self.act_fun(),
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
C_prev_prev, C_prev, C_curr = C, C, C
self.cells = nn.ModuleList()
reduction_prev = True
for i in range(layers):
if i in [layers // 3, 2 * layers // 3]:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(genotype, C_prev_prev, C_prev,
C_curr, reduction, reduction_prev,
act_fun=self.act_fun, back_connection=self.back_connection)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
self.global_pooling = nn.AvgPool2d(7)
self.classifier = nn.Linear(C_prev, num_classes)
def forward(self, inputs):
outputs = []
self.reset()
for t in range(self.step):
s0 = self.stem0(inputs)
s1 = self.stem1(s0)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
out = self.global_pooling(s1)
logits = self.classifier(self.flatten(out))
outputs.append(logits)
return sum(outputs) / len(outputs)
if __name__ == '__main__':
from braincog.model_zoo.NeuEvo.genotypes import mlp2
cell = MlpCell(mlp2, C=128, input_dim=17, output_dim=-1)
x = torch.rand(4, 17)
out = cell(x)
print(out)
print(out.shape)
================================================
FILE: braincog/model_zoo/NeuEvo/model_search.py
================================================
from functools import partial
from braincog.model_zoo.NeuEvo.operations import *
from torch.autograd import Variable
from braincog.model_zoo.NeuEvo.genotypes import PRIMITIVES
from braincog.model_zoo.NeuEvo.genotypes import Genotype
from . import parse
from braincog.base.connection.layer import VotingLayer
from braincog.base.node.node import *
from braincog.model_zoo.base_module import BaseModule
from . import forward_edge_num
from . import edge_num
def calc_weight(x):
tmp0 = torch.split(x[0], edge_num, dim=0)
tmp1 = torch.split(x[1], edge_num, dim=0)
res = []
for i in range(len(edge_num)):
res.append(
torch.softmax(tmp0[i].view(-1), dim=-1).view(tmp0[i].shape)
+ torch.softmax(tmp1[i].view(-1), dim=-1).view(tmp1[i].shape)
)
return torch.cat(res, dim=0)
def calc_loss(x):
tmp0 = torch.split(x[0], edge_num, dim=0)
tmp1 = torch.split(x[1], edge_num, dim=0)
res = []
for i in range(len(edge_num)):
res.append(
torch.softmax(tmp0[i].view(-1), dim=-1).view(tmp0[i].shape)
- torch.softmax(tmp1[i].view(-1), dim=-1).view(tmp1[i].shape)
)
return torch.cat(res, dim=0)
class darts_fun(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, weights): # feature map / arch weight
output = inputs * weights
ctx.save_for_backward(inputs, weights)
return output
@staticmethod
def backward(ctx, grad_output): # error signal
grad_inputs, grad_weights = None, None
inputs, weights = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_inputs = grad_output * weights
if ctx.needs_input_grad[1]:
if torch.min(inputs) < -1e-12 and torch.max(inputs) > 1e-12:
inputs = torch.abs(inputs) / 2.
else:
inputs = torch.abs(inputs)
grad_weights = -inputs.mean()
return grad_inputs, grad_weights
class MixedOp(nn.Module):
def __init__(self, C, stride, act_fun):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
for primitive in PRIMITIVES:
op = OPS[primitive](C, stride, False, act_fun)
if 'pool' in primitive:
op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
self._ops.append(op)
self.multiply = darts_fun.apply
def forward(self, x, weights):
feature_map = []
for i, op in enumerate(self._ops):
res = op(x)
feature_map.append(res)
return sum(self.multiply(mp, w) for w, mp in zip(weights, feature_map))
class Cell(nn.Module):
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, act_fun, back_connection):
super(Cell, self).__init__()
self.reduction = reduction
self.back_connection = back_connection
if reduction:
self.fun = FactorizedReduce(
C_prev, C * multiplier, affine=True, act_fun=act_fun, positive=1
)
else:
if reduction_prev:
self.preprocess0 = FactorizedReduce(
C_prev_prev, C, affine=False, act_fun=act_fun, positive=1)
else:
self.preprocess0 = ReLUConvBN(
C_prev_prev, C, 1, 1, 0, affine=False, act_fun=act_fun, positive=1)
self.preprocess1 = ReLUConvBN(
C_prev, C, 1, 1, 0, affine=False, act_fun=act_fun, positive=1)
self._steps = steps
self._multiplier = multiplier
self._ops = nn.ModuleList()
for i in range(self._steps):
for j in range(2 + i):
stride = 2 if reduction and j < 2 else 1
op = MixedOp(C, stride, act_fun)
self._ops.append(op)
if self.back_connection:
self._ops_back = nn.ModuleList()
for i in range(self._steps):
for j in range(i):
op = MixedOp(C, 1, act_fun)
self._ops_back.append(op)
def forward(self, s0, s1, weights):
if self.reduction:
return self.fun(s1)
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
offset = 0
offset_back = 0
weights_forward = weights[:forward_edge_num]
weights_backward = weights[forward_edge_num:]
for i in range(self._steps):
s = sum(self._ops[offset + j](h, weights_forward[offset + j])
for j, h in enumerate(states))
offset += len(states)
if self.back_connection:
for j in range(2, len(states)):
# print(j, len(states), offset_back, len(self._ops_back))
states[j] = states[j] + \
self._ops_back[offset_back](
s, weights_backward[offset_back])
offset_back += 1
states.append(s)
outputs = torch.cat(states[-self._multiplier:], dim=1)
return outputs
class Network(BaseModule):
def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3,
parse_method='bio_darts', op_threshold=None, step=1, node_type='ReLUNode', **kwargs):
super().__init__(
step=step,
encode_type='direct',
**kwargs
)
self.act_fun = eval(node_type)
self.act_fun = partial(self.act_fun, **kwargs)
self._C = C
self._num_classes = num_classes
self._layers = layers
self._criterion = criterion
self._steps = steps
self._multiplier = multiplier
self.parse_method = parse_method
self.op_threshold = op_threshold
self.fire_rate_per_step = [0.] * self.step
self.forward_step = 0
self.record_fire_rate = False
if 'back_connection' in kwargs.keys():
self.back_connection = kwargs['back_connection']
else:
self.back_connection = False
self.dataset = kwargs['dataset']
self.spike_output = kwargs['spike_output'] if 'spike_output' in kwargs else True
C_curr = stem_multiplier * C
if self.dataset == 'dvsg' or self.dataset == 'dvsc10' or self.dataset == 'NCALTECH101':
self.stem = nn.Sequential(
nn.Conv2d(2 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr),
)
self.reduce_idx = [layers // 3,
2 * layers // 3]
else:
self.stem = nn.Sequential(
nn.Conv2d(3 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr),
)
self.reduce_idx = [1, 3, 5]
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
for i in range(layers):
if i in self.reduce_idx:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, self.act_fun,
self.back_connection)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, multiplier * C_curr
self.global_pooling = nn.Sequential(
self.act_fun(), nn.AdaptiveAvgPool2d(1))
if self.spike_output:
self.classifier = nn.Sequential(
nn.Linear(C_prev, 10 * num_classes),
self.act_fun())
self.vote = VotingLayer(10)
else:
self.classifier = nn.Linear(C_prev, num_classes)
self.vote = nn.Identity()
self._initialize_alphas()
def new(self):
model_new = Network(self._C, self._num_classes,
self._layers, self._criterion).cuda()
for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
x.data.copy_(y.data)
return model_new
def forward(self, inputs):
inputs = self.encoder(inputs)
self.reset()
if not self.training:
self.fire_rate.clear()
outputs = []
for t in range(self.step):
x = inputs[t]
s0 = s1 = self.stem(x)
for i, cell in enumerate(self.cells):
if not cell.reduction:
weights = calc_weight(self.alphas_normal)
s0, s1 = s1, cell(s0, s1, weights)
else:
s0, s1 = s1, cell(s0, s1, None)
out = self.global_pooling(s1)
out = self.classifier(out.view(out.size(0), -1))
logits = self.vote(out)
outputs.append(logits)
# print(self.get_fire_rate_avg(), self.fire_rate_per_step, len(self.fire_rate_per_step))
if self.record_fire_rate:
self.forward_step += 1
return sum(outputs) / len(outputs)
def reset_fire_rate_record(self):
self.fire_rate_per_step = [0.] * self.step
self.forward_step = 0
def get_fire_per_step(self):
return [x / self.forward_step for x in self.fire_rate_per_step]
def _loss(self, input1, target1, input2):
logits = self(input1)
return self._criterion(logits, target1, input2)
# def _loss(self, input1, target1):
# logits = self(input1)
# return self._criterion(logits, target1)
def _initialize_alphas(self):
# k = 2 + 3 + 4 + 5 = 14
k = sum(1 for i in range(self._steps) for n in range(2 + i))
if self.back_connection:
k += sum(1 for i in range(self._steps) for n in range(i))
num_ops = len(PRIMITIVES)
self.alphas_normal = Variable(
0.5 * torch.randn(2, k, num_ops).cuda(), requires_grad=True)
# init the history
self.alphas_normal_history = {}
mm = 0
last_id = 1
node_id = 0
for i in range(k):
for j in range(num_ops):
self.alphas_normal_history['edge: {}, op: {}'.format(
(node_id, mm), PRIMITIVES[j])] = []
if mm == last_id:
mm = 0
last_id += 1
node_id += 1
else:
mm += 1
def arch_parameters(self):
return [self.alphas_normal]
def genotype(self):
# alphas_normal
gene_normal = parse(calc_weight(self.alphas_normal).data.cpu().numpy(),
PRIMITIVES, self.op_threshold, self.parse_method,
self._steps, reduction=False, back_connection=self.back_connection)
concat = range(2 + self._steps - self._multiplier, self._steps + 2)
genotype = Genotype(
normal=gene_normal, normal_concat=concat,
)
return genotype
def states(self):
return {
'alphas_normal': self.alphas_normal,
'alphas_normal_history': self.alphas_normal_history,
'criterion': self._criterion
}
def restore(self, states):
self.alphas_normal = states['alphas_normal']
self.alphas_normal_history = states['alphas_normal_history']
def update_history(self):
mm = 0
last_id = 1
node_id = 0
weights1 = calc_weight(self.alphas_normal).data.cpu().numpy()
k, num_ops = weights1.shape
for i in range(k):
for j in range(num_ops):
self.alphas_normal_history['edge: {}, op: {}'.format((node_id, mm), PRIMITIVES[j])].append(
float(weights1[i][j]))
if mm == last_id:
mm = 0
last_id += 1
node_id += 1
else:
mm += 1
================================================
FILE: braincog/model_zoo/NeuEvo/operations.py
================================================
import numpy as np
import torch
import torch.nn as nn
from torch.nn import *
import torch.nn.functional as F
from torch import einsum
from einops import rearrange
from braincog.model_zoo.base_module import DeformConvPack
from braincog.model_zoo.base_module import BaseLinearModule
# from mmcv.ops import ModulatedDeformConv2dPack
def si_relu(x, positive):
if positive == 1:
return torch.where(x > 0., x, torch.zeros_like(x))
elif positive == 0:
return x
elif positive == -1:
return torch.where(x < 0., x, torch.zeros_like(x))
else:
raise ValueError
class SiReLU(nn.Module):
def __init__(self, positive=0):
super().__init__()
self.positive = positive
def forward(self, x):
return si_relu(x, self.positive)
def weight_init(m):
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_normal(m.weight.data, gain=0.1)
torch.nn.init.constant(m.bias.data, 0.)
OPS_Mlp = {
'mlp': lambda C, act_fun:
SiMLP(C, C, act_fun=act_fun, positive=0),
'mlp_p': lambda C, act_fun:
SiMLP(C, C, act_fun=act_fun, positive=1),
'mlp_n': lambda C, act_fun:
SiMLP(C, C, act_fun=act_fun, positive=-1),
'skip_connect': lambda C, act_fun:
Identity(positive=0),
'skip_connect_p': lambda C, act_fun:
Identity(positive=1),
'skip_connect_n': lambda C, act_fun:
Identity(positive=-1),
}
OPS = {
'avg_pool_3x3': lambda C, stride, affine, act_fun: nn.AvgPool2d(3, stride=stride, padding=1,
count_include_pad=False),
'conv_3x3': lambda C, stride, affine, act_fun:
ReLUConvBN(C_in=C, C_out=C, kernel_size=3, padding=1, stride=stride, affine=affine, act_fun=act_fun, positive=0),
'conv_5x5': lambda C, stride, affine, act_fun:
ReLUConvBN(C_in=C, C_out=C, kernel_size=5, padding=2, stride=stride, affine=affine, act_fun=act_fun, positive=0),
'max_pool_3x3': lambda C, stride, affine, act_fun: nn.MaxPool2d(3, stride=stride, padding=1),
'skip_connect': lambda C, stride, affine, act_fun:
Identity(positive=0) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun),
'sep_conv_3x3': lambda C, stride, affine, act_fun:
SepConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=0),
'sep_conv_5x5': lambda C, stride, affine, act_fun:
SepConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=0),
'sep_conv_7x7': lambda C, stride, affine, act_fun:
SepConv(C, C, 7, stride, 3, affine=affine, act_fun=act_fun, positive=0),
'dil_conv_3x3': lambda C, stride, affine, act_fun:
DilConv(C, C, 3, stride, 2, 2, affine=affine, act_fun=act_fun, positive=0),
'dil_conv_5x5': lambda C, stride, affine, act_fun:
DilConv(C, C, 5, stride, 4, 2, affine=affine, act_fun=act_fun, positive=0),
'def_conv_3x3': lambda C, stride, affine, act_fun:
DeformConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=0),
'def_conv_5x5': lambda C, stride, affine, act_fun:
DeformConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=0),
'avg_pool_3x3_p': lambda C, stride, affine, act_fun: nn.Sequential(
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
SiReLU(positive=1)
),
'max_pool_3x3_p': lambda C, stride, affine, act_fun: nn.Sequential(
nn.MaxPool2d(3, stride=stride, padding=1),
SiReLU(positive=1)
),
'conv_3x3_p': lambda C, stride, affine, act_fun:
ReLUConvBN(C_in=C, C_out=C, kernel_size=3, padding=1, stride=stride, affine=affine, act_fun=act_fun, positive=1),
'conv_5x5_p': lambda C, stride, affine, act_fun:
ReLUConvBN(C_in=C, C_out=C, kernel_size=5, padding=2, stride=stride, affine=affine, act_fun=act_fun, positive=1),
'skip_connect_p': lambda C, stride, affine, act_fun:
Identity(positive=1) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun, positive=1),
'sep_conv_3x3_p': lambda C, stride, affine, act_fun:
SepConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=1),
'sep_conv_5x5_p': lambda C, stride, affine, act_fun:
SepConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=1),
'sep_conv_7x7_p': lambda C, stride, affine, act_fun:
SepConv(C, C, 7, stride, 3, affine=affine, act_fun=act_fun, positive=1),
'dil_conv_3x3_p': lambda C, stride, affine, act_fun:
DilConv(C, C, 3, stride, 2, 2, affine=affine, act_fun=act_fun, positive=1),
'dil_conv_5x5_p': lambda C, stride, affine, act_fun:
DilConv(C, C, 5, stride, 4, 2, affine=affine, act_fun=act_fun, positive=1),
'def_conv_3x3_p': lambda C, stride, affine, act_fun:
DeformConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=1),
'def_conv_5x5_p': lambda C, stride, affine, act_fun:
DeformConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=1),
'avg_pool_3x3_n': lambda C, stride, affine, act_fun: nn.Sequential(
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
SiReLU(positive=-1)
),
'max_pool_3x3_n': lambda C, stride, affine, act_fun: nn.Sequential(
nn.MaxPool2d(3, stride=stride, padding=1),
SiReLU(positive=-1)
),
'conv_3x3_n': lambda C, stride, affine, act_fun:
ReLUConvBN(C_in=C, C_out=C, kernel_size=3, padding=1, stride=stride, affine=affine, act_fun=act_fun, positive=-1),
'conv_5x5_n': lambda C, stride, affine, act_fun:
ReLUConvBN(C_in=C, C_out=C, kernel_size=5, padding=2, stride=stride, affine=affine, act_fun=act_fun, positive=-1),
'skip_connect_n': lambda C, stride, affine, act_fun:
Identity(positive=-1) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun, positive=-1),
'sep_conv_3x3_n': lambda C, stride, affine, act_fun:
SepConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=-1),
'sep_conv_5x5_n': lambda C, stride, affine, act_fun:
SepConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=-1),
'sep_conv_7x7_n': lambda C, stride, affine, act_fun:
SepConv(C, C, 7, stride, 3, affine=affine, act_fun=act_fun, positive=-1),
'dil_conv_3x3_n': lambda C, stride, affine, act_fun:
DilConv(C, C, 3, stride, 2, 2, affine=affine, act_fun=act_fun, positive=-1),
'dil_conv_5x5_n': lambda C, stride, affine, act_fun:
DilConv(C, C, 5, stride, 4, 2, affine=affine, act_fun=act_fun, positive=-1),
'def_conv_3x3_n': lambda C, stride, affine, act_fun:
DeformConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=-1),
'def_conv_5x5_n': lambda C, stride, affine, act_fun:
DeformConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=-1),
'conv_7x1_1x7': lambda C, stride, affine, act_fun: nn.Sequential(
# nn.ReLU(inplace=False),
act_fun(),
nn.Conv2d(C, C, (1, 7), stride=(1, stride),
padding=(0, 3), bias=False),
nn.Conv2d(C, C, (7, 1), stride=(stride, 1),
padding=(3, 0), bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'transformer': lambda C, stride, affine, act_fun:
FactorizedReduce(
C, C, affine=affine, act_fun=act_fun) if stride != 1 else TransformerEncoderLayer(C),
}
class SiMLP(nn.Module):
def __init__(self, c_in, c_out, act_fun=nn.ReLU, positive=0, *args, **kwargs):
super(SiMLP, self).__init__()
self.op = nn.Sequential(
nn.Linear(c_in, c_out, bias=True),
act_fun()
)
self.positive = positive
def forward(self, x):
out = self.op(si_relu(x, self.positive))
return out
class ReLUConvBN(nn.Module):
"""
ReLu -> Conv2d -> BatchNorm2d
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, act_fun=nn.ReLU, positive=0):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
# nn.ReLU(inplace=False),
# act_fun(),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride,
padding=padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
self.positive = positive
# if positive == -1:
# weight_init(self.op)
def forward(self, x):
out = self.op(x)
return si_relu(out, self.positive)
class DilConv(nn.Module):
"""
Dilation Convolution : ReLU -> DilConv -> Conv2d -> BatchNorm2d
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True, act_fun=nn.ReLU, positive=0):
super(DilConv, self).__init__()
self.op = nn.Sequential(
# nn.ReLU(inplace=False),
act_fun(),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
self.positive = positive
# if positive == -1:
# weight_init(self.op)
def forward(self, x):
out = self.op(x)
return si_relu(out, self.positive)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, act_fun=nn.ReLU, positive=0):
super(SepConv, self).__init__()
self.op = nn.Sequential(
# nn.ReLU(inplace=False),
act_fun(),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size,
stride=stride, padding=padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size,
stride=1, padding=padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
self.positive = positive
# if positive == -1:
# weight_init(self.op)
def forward(self, x):
out = self.op(x)
return si_relu(out, self.positive)
class Identity(nn.Module):
def __init__(self, positive=0):
super(Identity, self).__init__()
self.positive = positive
def forward(self, x):
return si_relu(x, self.positive)
class Zero(nn.Module):
def __init__(self, stride):
super(Zero, self).__init__()
self.stride = stride
def forward(self, x):
if self.stride == 1:
return x.mul(0.)
return x[:, :, ::self.stride, ::self.stride].mul(0.) # N * C * W * H
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=True, act_fun=nn.ReLU, positive=0):
super(FactorizedReduce, self).__init__()
assert C_out % 2 == 0
# self.relu = nn.ReLU(inplace=False)
self.activation = act_fun()
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 3,
stride=2, padding=1, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 3,
stride=2, padding=1, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
self.positive = positive
# if positive == -1:
# weight_init(self.op)
def forward(self, x):
# x = self.relu(x)
x = self.activation(x)
out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
out = si_relu(out, self.positive)
return out
class DeformConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, act_fun=nn.ReLU, positive=0):
super(DeformConv, self).__init__()
self.op = nn.Sequential(
# nn.ReLU(inplace=False),
act_fun(),
DeformConvPack(C_in, C_out, kernel_size=kernel_size,
stride=stride, padding=padding, bias=True),
nn.BatchNorm2d(C_out, affine=affine)
)
self.positive = positive
# if positive == -1:
# weight_init(self.op)
def forward(self, x):
out = self.op(x)
return si_relu(out, self.positive)
class Attention(Module):
"""
Obtained from: github.com:rwightman/pytorch-image-models
"""
def __init__(self, dim, num_heads=4, attention_dropout=0.1, projection_dropout=0.1):
super().__init__()
self.num_heads = num_heads
head_dim = dim // self.num_heads
self.scale = head_dim ** -0.5
self.qkv = Linear(dim, dim * 3, bias=False)
self.attn_drop = Dropout(attention_dropout)
self.proj = Linear(dim, dim)
self.proj_drop = Dropout(projection_dropout)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TransformerEncoderLayer(Module):
"""
Inspired by torch.nn.TransformerEncoderLayer and
rwightman's timm package.
"""
def __init__(self, d_model, nhead=4, dim_feedforward=256, dropout=0.1,
attention_dropout=0.1, drop_path_rate=0.1):
super(TransformerEncoderLayer, self).__init__()
self.pre_norm = LayerNorm(d_model)
self.self_attn = Attention(dim=d_model, num_heads=nhead,
attention_dropout=attention_dropout, projection_dropout=dropout)
dim_feedforward = d_model
self.linear1 = Linear(d_model, dim_feedforward)
self.dropout1 = Dropout(dropout)
self.norm1 = LayerNorm(d_model)
self.linear2 = Linear(dim_feedforward, d_model)
self.dropout2 = Dropout(dropout)
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0 else Identity()
self.activation = F.gelu
def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
# print(src.shape)
c = src.shape[-1]
src = rearrange(src, 'b d r c -> b (r c) d')
# print(src.shape)
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
src = self.norm1(src)
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
src = src + self.drop_path(self.dropout2(src2))
src = rearrange(src, 'b (r c) d -> b d r c', c=c)
return src
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""
Obtained from: github.com:rwightman/pytorch-image-models
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(Module):
"""
Obtained from: github.com:rwightman/pytorch-image-models
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
================================================
FILE: braincog/model_zoo/NeuEvo/others.py
================================================
# encoding: utf-8
# Author : Floyed
# Datetime : 2023/5/22 13:32
# User : yu
# Product : PyCharm
# Project : BrainCog
# File : others.py
# explain :
from functools import partial
import torch
import torch.nn as nn
from copy import deepcopy
from timm.models import register_model
from braincog.base.node.node import *
from braincog.base.connection.layer import WSConv2d
from braincog.datasets import is_dvs_data
from braincog.model_zoo.base_module import BaseModule, BaseConvModule
@register_model
class CIFARNet_Wu(BaseModule):
def __init__(
self, num_classes=10,
node_type=LIFNode,
step=4,
encode_type='direct',
*args,
**kwargs,
):
super().__init__(step, encode_type, *args, **kwargs)
self.dataset = kwargs['dataset']
self.node = node_type
if issubclass(self.node, BaseNode):
self.node = partial(self.node, **kwargs, step=step)
channels = 32
if not is_dvs_data(self.dataset):
init_channel = 3
out_size = 2 ** 2
else:
init_channel = 2
out_size = 3 ** 2
self.feature = nn.Sequential(
BaseConvModule(init_channel, channels, node=self.node),
BaseConvModule(channels, channels * 2, node=self.node),
nn.AvgPool2d(2, 2),
self.node(),
BaseConvModule(channels * 2, channels * 4, node=self.node),
nn.AvgPool2d(2, 2),
# self.node(),
BaseConvModule(channels * 4, channels * 8, node=self.node),
BaseConvModule(channels * 8, channels * 4, node=self.node),
nn.Flatten(),
)
self.fc = nn.Sequential(
nn.Linear(channels * 4 * 8 * 8, channels * 8, bias=False),
self.node(),
nn.Linear(channels * 8, channels * 4, bias=False),
self.node(),
nn.Linear(channels * 4, num_classes, bias=False)
)
def forward(self, inputs):
inputs = self.encoder(inputs).contiguous()
self.reset()
outputs = []
for t in range(self.step):
x = inputs[t]
x = self.feature(x)
x = self.fc(x)
outputs.append(x)
return sum(outputs) / len(outputs)
@register_model
class CIFARNet_Fang(BaseModule):
def __init__(
self, num_classes=10,
node_type=LIFNode,
step=4,
encode_type='direct',
*args,
**kwargs,
):
super().__init__(step, encode_type, *args, **kwargs)
self.dataset = kwargs['dataset']
self.node = node_type
if issubclass(self.node, BaseNode):
self.node = partial(self.node, **kwargs, step=step)
channels = 32
if not is_dvs_data(self.dataset):
init_channel = 3
else:
init_channel = 2
self.feature = nn.Sequential(
BaseConvModule(init_channel, channels, node=self.node),
BaseConvModule(channels, channels, node=self.node),
BaseConvModule(channels, channels, node=self.node),
nn.MaxPool2d(2, 2),
BaseConvModule(channels, channels, node=self.node),
BaseConvModule(channels, channels, node=self.node),
BaseConvModule(channels, channels, node=self.node),
nn.MaxPool2d(2, 2),
nn.Flatten(),
)
self.fc = nn.Sequential(
nn.Linear(channels * 8 * 8, channels * 8, bias=False),
self.node(),
nn.Linear(channels * 8, channels, bias=False),
)
def forward(self, inputs):
inputs = self.encoder(inputs).contiguous()
self.reset()
outputs = []
for t in range(self.step):
x = inputs[t]
x = self.feature(x)
x = self.fc(x)
outputs.append(x)
return sum(outputs) / len(outputs)
@register_model
class DVS_CIFARNet_Fang(BaseModule):
def __init__(
self, num_classes=10,
node_type=LIFNode,
step=10,
encode_type='direct',
*args,
**kwargs,
):
super().__init__(step, encode_type, *args, **kwargs)
self.dataset = kwargs['dataset']
self.node = node_type
if issubclass(self.node, BaseNode):
self.node = partial(self.node, **kwargs, step=step)
channels = 128
if not is_dvs_data(self.dataset):
init_channel = 3
else:
init_channel = 2
self.feature = nn.Sequential(
BaseConvModule(init_channel, channels, node=self.node),
nn.MaxPool2d(2, 2),
BaseConvModule(channels, channels, node=self.node),
nn.MaxPool2d(2, 2),
BaseConvModule(channels, channels, node=self.node),
nn.MaxPool2d(2, 2),
BaseConvModule(channels, channels, node=self.node),
nn.MaxPool2d(2, 2),
nn.Flatten(),
)
self.fc = nn.Sequential(
nn.Linear(channels * 8 * 8, channels * 4, bias=False),
self.node(),
nn.Linear(channels * 4, channels, bias=False),
)
def forward(self, inputs):
inputs = self.encoder(inputs).contiguous()
self.reset()
outputs = []
for t in range(self.step):
x = inputs[t]
x = self.feature(x)
x = self.fc(x)
outputs.append(x)
return sum(outputs) / len(outputs)
================================================
FILE: braincog/model_zoo/__init__.py
================================================
__all__ = ['convnet', 'resnet', 'base_module', 'glsnn', 'qsnn', 'resnet19_snn']
from . import (
convnet,
resnet,
base_module,
glsnn,
qsnn,
resnet19_snn
)
================================================
FILE: braincog/model_zoo/backeinet.py
================================================
import numpy as np
from timm.models import register_model
from braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule
from braincog.base.node.node import *
from braincog.base.connection.layer import *
class MNISTNet(BaseModule):
def __init__(self, step=20, encode_type='rate', if_back=True, if_ei=True, data='mnist', *args, **kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.if_back = if_back
self.if_ei = if_ei
if data == 'mnist':
self.cfg_conv = ((1, 15, 5, 1, 0), (15, 40, 5, 1, 0))
self.cfg_fc = (300, 10)
self.cfg_kernel = (24, 8, 4)
cfg_backei = 2
if data == 'fashion':
self.cfg_conv = ((1, 32, 5, 1, 2), (32, 64, 5, 1, 2))
self.cfg_fc = (1024, 10)
self.cfg_kernel = (28, 14, 7)
cfg_backei = 1
self.feature = nn.Sequential(
nn.Conv2d(self.cfg_conv[0][0], self.cfg_conv[0][1], self.cfg_conv[0][2], self.cfg_conv[0][3],
self.cfg_conv[0][4]),
BackEINode(channel=self.cfg_conv[0][1], if_back=self.if_back, if_ei=self.if_ei, cfg_backei=cfg_backei),
nn.AvgPool2d(2),
nn.Conv2d(self.cfg_conv[1][0], self.cfg_conv[1][1], self.cfg_conv[1][2], self.cfg_conv[1][3],
self.cfg_conv[1][4]),
BackEINode(channel=self.cfg_conv[1][1], if_back=self.if_back, if_ei=self.if_ei, cfg_backei=cfg_backei),
nn.AvgPool2d(2),
nn.Flatten(),
nn.Linear(self.cfg_kernel[2] * self.cfg_kernel[2] * self.cfg_conv[1][1], self.cfg_fc[0]),
BackEINode(if_back=False, if_ei=False),
nn.Linear(self.cfg_fc[0], self.cfg_fc[1]),
BackEINode(if_back=False, if_ei=False)
)
def forward(self, inputs):
inputs = self.encoder(inputs)
self.reset()
if not self.training:
self.fire_rate.clear()
outputs = []
step = self.step
for t in range(step):
x = inputs[t]
x = self.feature(x)
outputs.append(x)
return sum(outputs) / len(outputs)
class CIFARNet(BaseModule):
def __init__(self, step=20, encode_type='rate', if_back=True, if_ei=True, *args, **kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.if_back = if_back
self.if_ei = if_ei
self.feature = nn.Sequential(
nn.Conv2d(3, 128, 3, 1, 1),
BackEINode(channel=128, if_back=self.if_back, if_ei=self.if_ei, cfg_backei=1),
nn.Dropout(0.5),
nn.AvgPool2d(2),
nn.Conv2d(128, 256, 3, 1, 1),
BackEINode(channel=256, if_back=self.if_back, if_ei=self.if_ei, cfg_backei=1),
nn.Dropout(0.5),
nn.AvgPool2d(2),
nn.Conv2d(256, 512, 3, 1, 1),
BackEINode(channel=512, if_back=self.if_back, if_ei=self.if_ei, cfg_backei=1),
nn.Dropout(0.5),
nn.AvgPool2d(2),
nn.Flatten(),
nn.Linear(4 * 4 * 512, 1024),
BackEINode(if_back=False, if_ei=False),
nn.Dropout(0.5),
nn.Linear(1024, 10),
BackEINode(if_back=False, if_ei=False)
)
def forward(self, inputs):
inputs = self.encoder(inputs)
self.reset()
if not self.training:
self.fire_rate.clear()
outputs = []
step = self.step
for t in range(step):
x = inputs[t]
x = self.feature(x)
outputs.append(x)
return sum(outputs) / len(outputs)
================================================
FILE: braincog/model_zoo/base_module.py
================================================
from functools import partial
from torchvision.ops import DeformConv2d
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.encoder.encoder import *
class BaseLinearModule(nn.Module):
"""
线性模块
:param in_features: 输入尺寸
:param out_features: 输出尺寸
:param bias: 是否有Bias, 默认 ``False``
:param node: 神经元类型, 默认 ``LIFNode``
:param args:
:param kwargs:
"""
def __init__(self,
in_features: int,
out_features: int,
bias=True,
node=LIFNode,
*args,
**kwargs):
super().__init__()
if node is None:
raise TypeError
self.groups = kwargs['groups'] if 'groups' in kwargs else 1
if self.groups == 1:
self.fc = nn.Linear(in_features=in_features,
out_features=out_features, bias=bias)
else:
self.fc = nn.ModuleList()
for i in range(self.groups):
self.fc.append(nn.Linear(
in_features=in_features,
out_features=out_features,
bias=bias
))
self.node = partial(node, **kwargs)()
def forward(self, x):
if self.groups == 1: # (t b) c
outputs = self.fc(x)
else: # b (c t)
x = rearrange(x, 'b (c t) -> t b c', t=self.groups)
outputs = []
for i in range(self.groups):
outputs.append(self.fc[i](x[i]))
outputs = torch.stack(outputs) # t b c
outputs = rearrange(outputs, 't b c -> b (c t)')
return self.node(outputs)
class BaseConvModule(nn.Module):
"""
SNN卷积模块
:param in_channels: 输入通道数
:param out_channels: 输出通道数
:param kernel_size: kernel size
:param stride: stride
:param padding: padding
:param bias: Bias
:param node: 神经元类型
:param kwargs:
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias=False,
node=PLIFNode,
**kwargs):
super().__init__()
if node is None:
raise TypeError
self.groups = kwargs['groups'] if 'groups' in kwargs else 1
self.conv = nn.Conv2d(in_channels=in_channels * self.groups,
out_channels=out_channels * self.groups,
kernel_size=kernel_size,
padding=padding,
stride=stride,
bias=bias)
self.bn = nn.BatchNorm2d(out_channels * self.groups)
self.node = partial(node, **kwargs)()
self.activation = nn.Identity()
def forward(self, x):
# origin_shape = x.shape
# if len(origin_shape) > 4:
# x = x.reshape(np.prod(origin_shape[0:-3]), *origin_shape[-3:])
x = self.conv(x)
x = self.bn(x)
# if len(origin_shape) > 4:
# x = x.reshape(*origin_shape[0:-3], *x.shape[-3:])
x = self.node(x)
return x
class BaseModule(nn.Module, abc.ABC):
"""
SNN抽象类, 所有的SNN都要继承这个类, 以实现一些基础方法
:param step: 仿真步长
:param encode_type: 数据编码类型
:param layer_by_layer: 是否layer wise地进行前向推理
:param temporal_flatten: 是否将时间维度和channel合并
:param args:
:param kwargs:
"""
def __init__(self,
step,
encode_type,
layer_by_layer=False,
temporal_flatten=False,
*args,
**kwargs):
super(BaseModule, self).__init__()
self.step = step
# print(kwargs['layer_by_layer'])
self.layer_by_layer = layer_by_layer
self.temporal_flatten = temporal_flatten
encode_step = self.step
if temporal_flatten is True:
self.init_channel_mul = self.step
self.step = 1
else: # origin
self.init_channel_mul = 1
self.encoder = Encoder(encode_step, encode_type, temporal_flatten=self.temporal_flatten, layer_by_layer=self.layer_by_layer, **kwargs)
self.kwargs = kwargs
self.warm_up = False
self.fire_rate = []
def reset(self):
"""
重置所有神经元的膜电位
:return:
"""
for mod in self.modules():
if hasattr(mod, 'n_reset'):
mod.n_reset()
def set_attr(self, attr, val):
"""
设置神经元的属性
:param attr: 属性名称
:param val: 设置的属性值
:return:
"""
for mod in self.modules():
if isinstance(mod, BaseNode):
if hasattr(mod, attr):
setattr(mod, attr, val)
else:
ValueError('{} do not has {}'.format(self, attr))
def get_threshold(self):
"""
获取所有神经元的阈值
:return:
"""
outputs = []
for mod in self.modules():
if isinstance(mod, BaseNode):
thresh = (mod.get_thres())
outputs.append(thresh)
return outputs
def get_fp(self, temporal_info=False):
"""
获取所有神经元的状态
:param temporal_info: 是否要读取神经元的时间维度状态, False会把时间维度拍平
:return: 所有神经元的状态, List
"""
outputs = []
for mod in self.modules():
if isinstance(mod, BaseNode):
if temporal_info:
outputs.append(mod.feature_map)#[l,[t,[b,w,h]]]
else:
outputs.append(sum(mod.feature_map) / len(mod.feature_map))
return outputs
def get_mem(self, temporal_info=False):
"""
获取所有神经元的模电势
:param temporal_info: 是否要读取神经元的时间维度状态, False会把时间维度拍平
:return: 所有神经元的状态, List
"""
outputs = []
for mod in self.modules():
if isinstance(mod, BaseNode):
if temporal_info:
outputs.append(mod.mem_collect)#[l,[t,[b,w,h]]]
else:
outputs.append(sum(mod.mem_collect) / len(mod.mem_collect))
return outputs
def get_fire_rate(self, requires_grad=False):
"""
获取神经元的fire-rate
:param requires_grad: 是否需要梯度信息, 默认为 ``False`` 会截断梯度
:return: 所有神经元的fire-rate
"""
outputs = []
fp = self.get_attr('feature_map')
for f in fp:
if requires_grad is False:
if len(f) == 0:
return torch.tensor([0.])
outputs.append(((sum(f) / len(f)).detach() > 0.).float().mean())
else:
outputs.append(((sum(f) / len(f)) > 0.).float().mean())
if len(outputs) == 0:
return torch.tensor([0.])
return torch.stack(outputs)
def get_tot_spike(self):
"""
获取神经元总的脉冲数量
:return:
"""
tot_spike = 0
batch_size = 1
fp = self.get_attr('feature_map')
for f in fp:
if len(f) == 0:
break
tot_spike += sum(f).sum()
batch_size = f[0].shape[0]
return tot_spike / batch_size
def get_spike_info(self):
"""
获取神经元的脉冲信息, 主要用于绘图
:return:
"""
spike_feature_list = self.get_fp(temporal_info=True)
avg, var, spike = [], [], []
avg_per_step = []
for spike_feature in spike_feature_list:
avg_list = []
for spike_t in spike_feature:
avg_list.append(float(spike_t.mean()))
avg_per_step.append(avg_list)
spike_feature = sum(spike_feature)
num = np.prod(spike_feature.shape)
avg.append(float(spike_feature.sum()))
var.append(float(spike_feature.std()))
lst = []
for t in range(self.step + 1):
lst.append(float((spike_feature == t).sum() / num))
spike.append(lst)
return avg, var, spike, avg_per_step
def set_requires_fp(self, flag):
for mod in self.modules():
if hasattr(mod, 'requires_fp'):
mod.requires_fp = flag
def set_requires_mem(self, flag):
for mod in self.modules():
if hasattr(mod, 'requires_mem'):
mod.requires_mem = flag
def get_attr(self, attr):
"""
获取神经元的某一属性值
:param attr: 属性名称
:return: 对应属性的值, List
"""
outputs = []
for mod in self.modules():
if hasattr(mod, attr):
outputs.append(getattr(mod, attr))
return outputs
@staticmethod
def forward(self, inputs):
pass
class DeformConvPack(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
padding,
stride,
bias,
*args,
**kwargs):
super(DeformConvPack, self).__init__()
self.in_channels = in_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
if isinstance(self.kernel_size, tuple) or isinstance(self.kernel_size, list):
self.receptive_field = self.kernel_size[0]
else:
self.receptive_field = self.kernel_size
self.kernel_size = (self.kernel_size, self.kernel_size)
self.receptive_field = 4 * (self.receptive_field // 2)
self.conv_offset = nn.Conv2d(
self.in_channels,
3 * self.kernel_size[0] * self.kernel_size[1],
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
bias=True)
self.deform_conv = DeformConv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
stride=stride,
bias=bias)
self.init_weights()
def init_weights(self):
if hasattr(self, 'conv_offset'):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
offset = self.receptive_field * (torch.sigmoid(offset) - 0.5)
mask = torch.sigmoid(mask)
return self.deform_conv(x, offset, mask)
================================================
FILE: braincog/model_zoo/bdmsnn.py
================================================
import torch
from torch import nn
from braincog.base.node.node import IFNode, SimHHNode
from braincog.base.learningrule.STDP import STDP, MutliInputSTDP
from braincog.base.connection.CustomLinear import CustomLinear
from braincog.base.brainarea.basalganglia import basalganglia
import pygame
from pygame.locals import *
from collections import deque
from random import randint
#os.environ["SDL_VIDEODRIVER"] = "dummy"
class BDMSNN(nn.Module):
def __init__(self, num_state, num_action, weight_exc, weight_inh, node_type):
"""
定义BDM-SNN网络
:param num_state: 状态个数
:param num_action: 动作个数
:param weight_exc: 兴奋性连接权重
:param weight_inh: 抑制性连接权重
"""
super().__init__()
# parameters
BG = basalganglia(num_state, num_action, weight_exc, weight_inh, node_type)
dm_connection = BG.getweight()
dm_mask = BG.getmask()
# input-dlpfc
con_matrix9 = torch.eye((num_state), dtype=torch.float)
dm_connection.append(CustomLinear(weight_exc * con_matrix9, con_matrix9))
dm_mask.append(con_matrix9)
# gpi-th
con_matrix10 = torch.eye((num_action), dtype=torch.float)
dm_mask.append(con_matrix10)
dm_connection.append(CustomLinear(weight_inh * con_matrix10, con_matrix10))
# th-pm
dm_mask.append(con_matrix10)
dm_connection.append(CustomLinear(weight_exc * con_matrix10, con_matrix10))
# dlpfc-th
con_matrix11 = torch.ones((num_state, num_action), dtype=torch.float)
dm_mask.append(con_matrix11)
dm_connection.append(CustomLinear(0.2 * weight_exc * con_matrix11, con_matrix11))
# pm-pm
con_matrix3 = torch.ones((num_action, num_action), dtype=torch.float)
con_matrix4 = torch.eye((num_action), dtype=torch.float)
con_matrix5 = con_matrix3 - con_matrix4
con_matrix5 = con_matrix5
dm_mask.append(con_matrix5)
dm_connection.append(CustomLinear(5 * weight_inh * con_matrix5, con_matrix5))
# dlpfc thalamus pm +bg
self.weight_exc = weight_exc
self.num_subDM = 8
self.connection = dm_connection
self.mask = dm_mask
self.node = BG.node
self.node_type = node_type
if self.node_type == "hh":
self.node.extend([SimHHNode() for i in range(self.num_subDM - BG.num_subBG)])
self.node[6].g_Na = torch.tensor(12)
self.node[6].g_K = torch.tensor(3.6)
self.node[6].g_L = torch.tensor(0.03)
if self.node_type == "lif":
self.node.extend([IFNode() for i in range(self.num_subDM - BG.num_subBG)])
self.learning_rule = BG.learning_rule
self.learning_rule.append(MutliInputSTDP(self.node[5], [self.connection[10], self.connection[12]])) # gpi-丘脑
self.learning_rule.append(MutliInputSTDP(self.node[6], [self.connection[11], self.connection[13]])) # pm
self.learning_rule.append(STDP(self.node[7], self.connection[9]))
out_shape=[self.connection[0].weight.shape[1],self.connection[1].weight.shape[1],self.connection[2].weight.shape[1],self.connection[4].weight.shape[1],self.connection[3].weight.shape[1],self.connection[10].weight.shape[1],self.connection[11].weight.shape[1],self.connection[9].weight.shape[1]]
self.out = []
self.dw = []
for i in range(self.num_subDM):
self.out.append(torch.zeros((out_shape[i]), dtype=torch.float))
self.dw.append(torch.zeros((out_shape[i]), dtype=torch.float))
def forward(self, input):
"""
根据输入得到网络的输出
:param input: 输入
:return: 网络的输出
"""
self.out[7] = self.node[7](self.connection[9](input))
self.out[0], self.dw[0] = self.learning_rule[0](self.out[7])
self.out[1], self.dw[1] = self.learning_rule[1](self.out[7])
self.out[2], self.dw[2] = self.learning_rule[2](self.out[7], self.out[3])
self.out[3], self.dw[3] = self.learning_rule[3](self.out[1], self.out[2])
self.out[4], self.dw[4] = self.learning_rule[4](self.out[0], self.out[3], self.out[2])
self.out[5], self.dw[5] = self.learning_rule[5](self.out[4], self.out[7])
self.out[6], self.dw[6] = self.learning_rule[6](self.out[5], self.out[6])
br = ["StrD1", "StrD2", "STN", "Gpe", "Gpi", "thalamus", "PM", "DLPFC"]
for i in range(self.num_subDM):
if torch.max(self.out[i]) > 0 and self.node_type == "hh":
self.node[i].n_reset()
print("every areas:", br[i], self.out[i])
return self.out[6], self.dw
def UpdateWeight(self, i, s, num_action, dw):
"""
更新网络中第i组连接的权重
:param i:要更新的连接组索引
:param s:传入状态
:param dw:更新权重的量
:return:
"""
if self.node_type == "hh":
self.connection[i].update(0.2 * self.weight_exc * dw)
self.connection[i].weight.data[s, [s * num_action, s * num_action + 1]] /= (self.connection[i].weight.data[s, [s * num_action, s * num_action + 1]].float().max() + 1e-12)
self.connection[i].weight.data[s, :] = self.connection[i].weight.data[s, :] * self.weight_exc
if self.node_type == "lif":
dw_mean = dw[s, [s * num_action, s * num_action + 1]].mean()
dw_std = dw[s, [s * num_action, s * num_action + 1]].std()
dw[s, [s * num_action, s * num_action + 1]] = (dw[s, [s * num_action,s * num_action + 1]] - dw_mean) / dw_std
dw[s, :] = dw[s, :] * self.mask[i][s, :]
self.connection[i].update(dw)
self.connection[i].weight.data[s, [s * num_action, s * num_action + 1]] /= (self.connection[i].weight.data[s, [s * num_action, s * num_action + 1]].float().max() + 1e-12)
if i in [0, 1, 2, 6, 7, 11, 12]:
self.connection[i].weight.data = torch.clamp(self.connection[i].weight.data, 0, None)
if i in [3, 4, 5, 8, 10]:
self.connection[i].weight.data = torch.clamp(self.connection[i].weight.data, None, 0)
def reset(self):
"""
reset神经元或学习法则的中间量
:return: None
"""
for i in range(self.num_subDM):
self.node[i].n_reset()
for i in range(len(self.learning_rule)):
self.learning_rule[i].reset()
def getweight(self):
"""
获取网络的连接(包括权值等)
:return: 网络的连接
"""
return self.connection
================================================
FILE: braincog/model_zoo/convnet.py
================================================
import abc
from functools import partial
from torch.nn import functional as F
import torchvision
from timm.models import register_model
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.encoder.encoder import *
from braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule
class BaseConvNet(BaseModule, abc.ABC):
def __init__(self,
step,
input_channels,
num_classes,
encode_type,
spike_output: bool,
out_channels: list,
block_depth: list,
node_list: list,
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.num_cls = num_classes
self.spike_output = spike_output
self.groups = kwargs['n_groups'] if 'n_groups' in kwargs else 1
if not spike_output:
node_list.append(nn.Identity)
out_channels.append(self.num_cls)
self.vote = nn.Identity()
# self.vote = nn.Sequential(
# nn.Linear(self.step, 32),
# nn.ReLU(),
# nn.Linear(32, 1)
# )
else:
out_channels.append(10 * self.num_cls)
self.vote = VotingLayer(10)
# check list length
if len(node_list) != len(out_channels):
raise ValueError
self.input_channels = input_channels
self.out_channels = out_channels
self.block_depth = block_depth
self.node_list = node_list
self.feature = self._create_feature()
self.fc = self._create_fc()
if self.layer_by_layer:
self.flatten = nn.Flatten(start_dim=1)
else:
self.flatten = nn.Flatten()
@staticmethod
def _create_feature(self):
raise NotImplementedError
@staticmethod
def _create_fc(self):
raise NotImplementedError
def forward(self, inputs):
inputs = self.encoder(inputs)
self.reset()
if not self.training:
self.fire_rate.clear()
if not self.layer_by_layer:
outputs = []
if self.warm_up:
step = 1
else:
step = self.step
for t in range(step):
x = inputs[t]
x = self.feature(x)
x = self.flatten(x)
x = self.fc(x)
x = self.vote(x)
outputs.append(x)
return sum(outputs) / len(outputs)
# outputs = torch.stack(outputs)
# outputs = rearrange(outputs, 't b c -> b c t')
# outputs = self.vote(outputs).squeeze()
# return outputs
else:
x = self.feature(inputs)
x = self.flatten(x)
x = self.fc(x)
if self.groups == 1:
x = rearrange(x, '(t b) c -> t b c', t=self.step).mean(0)
else:
x = rearrange(x, 'b (c t) -> t b c', t=self.step).mean(0)
x = self.vote(x)
return x
class MNISTConvNet(BaseConvNet):
def __init__(self,
step,
input_channels,
num_classes,
encode_type,
block_depth,
spike_output: bool,
out_channels: list,
node_list: list,
*args,
**kwargs):
self.feature_size = 28
super().__init__(step,
input_channels,
num_classes,
encode_type,
spike_output,
out_channels,
block_depth,
node_list,
*args,
**kwargs)
def _create_feature(self):
feature_depth = len(self.node_list) - 2
feature = [BaseConvModule(
self.input_channels, self.out_channels[0], node=self.node_list[0])]
if self.block_depth[0] != 1:
feature.extend(
[BaseConvModule(self.out_channels[0], self.out_channels[0], node=self.node_list[0])] * (
self.block_depth[0] - 1),
)
feature.append(nn.AvgPool2d(2))
self.feature_size = self.feature_size // 2
for i in range(1, feature_depth):
feature.append(BaseConvModule(
self.out_channels[i - 1], self.out_channels[i], node=self.node_list[i]))
if self.block_depth[i] != 1:
feature.extend(
[BaseConvModule(self.out_channels[i], self.out_channels[i], node=self.node_list[i])] * (
self.block_depth[0] - 1),
)
feature.append(nn.AvgPool2d(2))
feature.append(self.node_list[0]())
self.feature_size = self.feature_size // 2
return nn.Sequential(*feature)
def _create_fc(self):
fc = nn.Sequential(
NDropout(.5),
BaseLinearModule(self.out_channels[-3] * self.feature_size * self.feature_size, self.out_channels[-2],
node=self.node_list[-2]),
NDropout(.5),
BaseLinearModule(
self.out_channels[-2], self.out_channels[-1], node=self.node_list[-1])
)
return fc
class CifarConvNet(BaseConvNet):
def __init__(self,
step,
input_channels,
num_classes,
encode_type,
spike_output: bool,
out_channels: list,
node_list: list,
block_depth: list,
*args,
**kwargs):
super().__init__(step,
input_channels,
num_classes,
encode_type,
spike_output,
out_channels,
block_depth,
node_list,
*args,
**kwargs)
def _create_feature(self):
feature_depth = len(self.node_list) - 1
feature = [BaseConvModule(
self.input_channels * self.init_channel_mul, self.out_channels[0], node=self.node_list[0], groups=self.groups)]
if self.block_depth[0] != 1:
feature.extend(
[BaseConvModule(self.out_channels[0], self.out_channels[0], node=self.node_list[0], groups=self.groups)] * (
self.block_depth[0] - 1),
)
feature.append(nn.AvgPool2d(2))
for i in range(1, feature_depth - 1):
feature.append(BaseConvModule(
self.out_channels[i - 1], self.out_channels[i], node=self.node_list[i], groups=self.groups))
if self.block_depth[i] != 1:
feature.extend(
[BaseConvModule(self.out_channels[i], self.out_channels[i], node=self.node_list[i], groups=self.groups)] * (
self.block_depth[i] - 1),
)
feature.append(nn.AvgPool2d(2))
feature.append(BaseConvModule(
self.out_channels[-3], self.out_channels[-2], node=self.node_list[-2], groups=self.groups))
if self.block_depth[feature_depth - 1] != 1:
feature.extend(
[BaseConvModule(self.out_channels[-2], self.out_channels[-2], node=self.node_list[-2], groups=self.groups)] * (
self.block_depth[feature_depth - 1] - 1),
)
feature.append(nn.AdaptiveAvgPool2d((1, 1)))
return nn.Sequential(*feature)
def _create_fc(self):
fc = nn.Sequential(
# NDropout(.5),
BaseLinearModule(
self.out_channels[-2], self.out_channels[-1], node=self.node_list[-1], groups=self.groups)
)
return fc
@register_model
def mnist_convnet(step,
encode_type,
spike_output: bool,
node_type,
*args,
**kwargs):
out_channels = [128, 128, 2048]
block_depth = [1, 1]
node_cls = partial(node_type, step=step, **kwargs)
if spike_output:
node_list = [node_cls] * (len(out_channels) + 1)
else:
node_list = [node_cls] * (len(out_channels))
return MNISTConvNet(step=step,
input_channels=1,
encode_type=encode_type,
block_depth=block_depth,
node_list=node_list,
out_channels=out_channels,
spike_output=spike_output,
**kwargs)
@register_model
def cifar_convnet(step,
encode_type,
spike_output: bool,
node_type,
*args,
**kwargs):
out_channels = [256, 256, 512, 1024]
# out_channels = [64, 128, 128, 256]
block_depth = [2, 2, 2, 2]
# print(kwargs)
node_cls = partial(node_type, step=step, **kwargs)
# print(node_cls)
if spike_output:
node_list = [node_cls] * (len(out_channels) + 1)
else:
node_list = [node_cls] * (len(out_channels))
return CifarConvNet(step=step,
input_channels=3,
encode_type=encode_type,
node_list=node_list,
block_depth=block_depth,
out_channels=out_channels,
spike_output=spike_output,
**kwargs)
@register_model
def dvs_convnet(step,
encode_type,
spike_output: bool,
node_type,
num_classes,
*args,
**kwargs):
out_channels = [128, 256, 256, 512, 512]
block_depth = [2, 1, 2, 1, 2]
# out_channels = [40, 80, 80, 160, 160]
# out_channels = [256, 512, 512, 1024, 1024]
# out_channels = [64, 128, 128, 256, 256]
# block_depth = [4, 2, 4, 2, 4]
# out_channels = [128, 256, 512, 512]
# block_depth = [2, 2, 2, 2]
node_cls = partial(node_type, step=step, **kwargs)
if spike_output:
node_list = [node_cls] * (len(out_channels) + 1)
# node_list[-2] = partial(DoubleSidePLIFNode, step=step, **kwargs)
else:
node_list = [node_cls] * (len(out_channels))
# node_list[-1] = partial(DoubleSidePLIFNode, step=step, **kwargs)
return CifarConvNet(step=step,
input_channels=2,
num_classes=num_classes,
encode_type=encode_type,
node_list=node_list,
block_depth=block_depth,
out_channels=out_channels,
spike_output=spike_output,
**kwargs)
================================================
FILE: braincog/model_zoo/fc_snn.py
================================================
from functools import partial
from torch.nn import functional as F
import torchvision
from timm.models import register_model
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.encoder.encoder import *
from braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule
from braincog.datasets import is_dvs_data
class STSC_Attention(nn.Module):
def __init__(self, n_channel: int, dimension: int = 2, time_rf: int = 4, reduction: int = 2):
super().__init__()
assert dimension == 4 or dimension == 2, 'dimension must be 4 or 2'
self.dimension = dimension
if self.dimension == 4:
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.time_padding = (time_rf - 1) // 2
self.n_channels = n_channel
r_channel = n_channel // reduction
self.recv_T = nn.Conv1d(n_channel, r_channel, kernel_size=time_rf, padding=self.time_padding, groups=1,
bias=True)
self.recv_C = nn.Sequential(
nn.ReLU(),
nn.Linear(r_channel, n_channel, bias=False),
)
self.sigmoid = nn.Sigmoid()
def forward(self, x_seq: torch.Tensor):
assert x_seq.dim() == 3 or x_seq.dim() == 5, ValueError(
f'expected 3D or 5D input with shape [T, B, N] or [T, B, C, H, W], but got input with shape {x_seq.shape}')
x_seq_C = x_seq.transpose(0, 1) # x_seq_C.shape = [B, T, N] or [B, T, C, H, W]
x_seq_T = x_seq_C.transpose(1, 2) # x_seq_T.shape = [B, C, N] or [B, C, T, H, W]
if self.dimension == 2:
recv_h_T = self.recv_T(x_seq_T)
recv_h_C = self.recv_C(recv_h_T.transpose(1, 2))
D_ = 1 - self.sigmoid(recv_h_C)
D = D_.transpose(0, 1)
elif self.dimension == 4:
avgout_C = self.avg_pool(x_seq_C).view(
[x_seq_C.shape[0], x_seq_C.shape[1], x_seq_C.shape[2]]) # avgout_C.shape = [N, T, C]
avgout_T = avgout_C.transpose(1, 2)
recv_h_T = self.recv_T(avgout_T)
recv_h_C = self.recv_C(recv_h_T.transpose(1, 2))
D_ = 1 - self.sigmoid(recv_h_C)
D = D_.transpose(0, 1)
return D
class STSC_Temporal_Conv(nn.Module):
def __init__(self, channels: int, dimension: int = 2, time_rf: int = 2):
super().__init__()
assert dimension == 4 or dimension == 2, 'dimension must be 4 or 2'
self.dimension = dimension
time_padding = (time_rf - 1) // 2
self.time_padding = time_padding
if dimension == 4:
kernel_size = (time_rf, 1, 1)
padding = (time_padding, 0, 0)
self.conv = nn.Conv3d(channels, channels, kernel_size=kernel_size, padding=padding, groups=channels,
bias=False)
else:
kernel_size = time_rf
self.conv = nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=time_padding, groups=channels,
bias=False)
def forward(self, x_seq: torch.Tensor):
assert x_seq.dim() == 3 or x_seq.dim() == 5, ValueError(
f'expected 3D or 5D input with shape [T, B, N] or [T, B, C, H, W], but got input with shape {x_seq.shape}')
# x_seq.shape = [T, B, N] or [T, B, C, H, W]
x_seq = x_seq.transpose(0, 1) # x_seq.shape = [B, T, N] or [B, T, C, H, W]
x_seq = x_seq.transpose(1, 2) # x_seq.shape = [B, N, T] or [B, C, T, H, W]
x_seq = self.conv(x_seq)
x_seq = x_seq.transpose(1, 2) # x_seq.shape = [B, T, N] or [B, T, C, H, W]
x_seq = x_seq.transpose(0, 1) # x_seq.shape = [T, B, N] or [T, B, C, H, W]
return x_seq
class STSC(nn.Module):
def __init__(self, in_channel: int, dimension: int = 2, time_rf_conv: int = 5, time_rf_at: int = 3, use_gate=True,
use_filter=True, reduction: int = 1):
super().__init__()
assert dimension == 4 or dimension == 2, 'dimension must be 4 or 2'
self.dimension = dimension
self.time_rf_conv = time_rf_conv
self.time_rf_at = time_rf_at
if use_filter:
self.temporal_conv = STSC_Temporal_Conv(in_channel, time_rf=time_rf_conv, dimension=dimension)
if use_gate:
self.spatio_temporal_attention = STSC_Attention(in_channel, time_rf=time_rf_at, reduction=reduction,
dimension=dimension)
self.use_gate = use_gate
self.use_filter = use_filter
def forward(self, x_seq: torch.Tensor):
assert x_seq.dim() == 3 or x_seq.dim() == 5, ValueError(
f'expected 3D or 5D input with shape [T, B, N] or [T, B, C, H, W], but got input with shape {x_seq.shape}')
if self.use_filter:
# Filitering
x_seq_conv = self.temporal_conv(x_seq)
else:
# without filtering
x_seq_conv = x_seq
if self.dimension == 2:
if self.use_gate:
# Gating
x_seq_D = self.spatio_temporal_attention(x_seq)
y_seq = x_seq_conv * x_seq_D
else:
# without gating
y_seq = x_seq_conv
else:
if self.use_gate:
# Gating
x_seq_D = self.spatio_temporal_attention(x_seq)
y_seq = x_seq_conv * x_seq_D[:, :, :, None, None] # broadcast
else:
# without gating
y_seq = x_seq_conv
return y_seq
@register_model
class SHD_SNN(BaseModule):
"""
在SHD数据集上的SNN基准网络:Input-128FC-128FC-100FC-Voting-20.
STSC是增强时序信息的模块, 参考https://www.frontiersin.org/articles/10.3389/fnins.2022.1079357.
不加STSC模块的acc在78%左右
"""
def __init__(self,
num_classes=20,
step=15,
node_type=LIFNode,
encode_type='direct',
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.n_preact = kwargs['n_preact'] if 'n_preact' in kwargs else False
self.num_classes = num_classes
self.tet_loss = kwargs['tet_loss'] if 'tet_loss' in kwargs else False
self.node = node_type
if issubclass(self.node, BaseNode):
self.node = partial(self.node, **kwargs, step=step)
self.dataset = kwargs['dataset']
self.ts_conv = STSC(700, dimension=2, time_rf_conv=5, time_rf_at=3, use_gate=True, use_filter=True)
self.fc = nn.Sequential(
nn.Linear(700, 128),
partial(self.node, **kwargs)(),
nn.Linear(128, 128),
partial(self.node, **kwargs)(),
nn.Linear(128, 100),
partial(self.node, **kwargs)(),
VotingLayer(5)
)
def forward(self, inputs):
inputs = self.encoder(inputs)
self.reset()
if self.layer_by_layer:
inputs = rearrange(inputs, '(t b) c -> t b c', t=self.step)
inputs = self.ts_conv(inputs)
x = rearrange(inputs, 't b c -> (t b) c', t=self.step)
x = self.fc(x)
x = rearrange(x, '(t b) c -> t b c', t=self.step).mean(0)
return x
else:
outputs = []
inputs = self.ts_conv(inputs)
for t in range(self.step):
x = inputs[t]
x = self.fc(x)
outputs.append(x)
return sum(outputs) / len(outputs)
================================================
FILE: braincog/model_zoo/glsnn.py
================================================
import abc
from functools import partial
from timm.models import register_model
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.encoder.encoder import *
from braincog.model_zoo.base_module import BaseModule, BaseLinearModule, BaseConvModule
from braincog.utils import rand_ortho, mse
from torch import autograd
class BaseGLSNN(BaseModule):
"""
The fully connected model of the GLSNN
:param input_size: the shape of the input
:param hidden_sizes: list, the number of neurons of each layer in the hidden layers
:param ouput_size: the number of the output layers
"""
def __init__(self, input_size=784, hidden_sizes=[800] * 3, output_size=10, opt=None):
super().__init__(step=opt.step, encode_type=opt.encode_type)
network_sizes = [input_size] + hidden_sizes + [output_size]
feedforward = []
for ind in range(len(network_sizes) - 1):
feedforward.append(
BaseLinearModule(in_features=network_sizes[ind], out_features=network_sizes[ind + 1], node=LIFNode))
self.ff = nn.ModuleList(feedforward)
feedback = []
for ind in range(1, len(network_sizes) - 2):
feedback.append(nn.Linear(network_sizes[-1], network_sizes[ind]))
self.fb = nn.ModuleList(feedback)
for m in self.modules():
if isinstance(m, nn.Linear):
out_, in_ = m.weight.shape
m.weight.data = torch.Tensor(rand_ortho((out_, in_), np.sqrt(6. / (out_ + in_))))
m.bias.data.zero_()
self.step = opt.step
self.lr_target = opt.lr_target
def forward(self, x):
"""
process the information in the forward manner
:param x: the input
"""
self.reset()
x = x.view(x.shape[0], 784)
sumspikes = [0] * (len(self.ff) + 1)
sumspikes[0] = x
for ind, mod in enumerate(self.ff):
for t in range(self.step):
spike = mod(sumspikes[ind])
sumspikes[ind + 1] += spike
sumspikes[ind + 1] = sumspikes[ind + 1] / self.step
return sumspikes
def feedback(self, ff_value, y_label):
"""
process information in the feedback manner and get target
:param ff_value: the feedforward value of each layer
:param y_label: the label of the corresponding input
"""
fb_value = []
cost = mse(ff_value[-1], y_label)
P = ff_value[-1]
h_ = ff_value[-2] - self.lr_target * torch.autograd.grad(cost, ff_value[-2], retain_graph=True)[0]
fb_value.append(h_)
for i in range(len(self.fb) - 1, -1, -1):
h = ff_value[i + 1]
h_ = h - self.fb[i](P - y_label)
fb_value.append(h_)
return fb_value, cost
def set_gradient(self, x, y):
"""
get the corresponding update of each layer
"""
ff_value = self.forward(x)
fb_value, cost = self.feedback(ff_value, y)
ff_value = ff_value[1:]
len_ff = len(self.ff)
for idx, layer in enumerate(self.ff):
if idx == len_ff - 1:
layer.fc.weight.grad, layer.fc.bias.grad = autograd.grad(cost, layer.fc.parameters())
else:
in1 = ff_value[idx]
in2 = fb_value[len(fb_value) - 1 - idx]
loss_local = mse(in1, in2.detach())
layer.fc.weight.grad, layer.fc.bias.grad = autograd.grad(loss_local, layer.fc.parameters())
return ff_value, cost
def forward_parameters(self):
res = []
for layer in self.ff:
res += layer.parameters()
return res
def feedback_parameters(self):
res = []
for layer in self.fb:
res += layer.parameters()
return res
if __name__ == '__main__':
net = BaseGLSNN()
print(net)
================================================
FILE: braincog/model_zoo/linearNet.py
================================================
import torch.nn.functional as F
from braincog.base.strategy.surrogate import *
from braincog.base.node.node import IFNode
from braincog.base.learningrule.STDP import STDP, MutliInputSTDP
class droDMTrainNet(nn.Module):
"""
Drosophila Training network: compound eye-KC-MBON
"""
def __init__(self, connection):
"""
根据传入的连接 构建训练网络
:param connection: 训练网络的连接
"""
super().__init__()
trace_stdp = 0.99
self.num_subMB = 3
self.node = [IFNode() for i in range(self.num_subMB)]
self.connection = connection
self.learning_rule = []
self.learning_rule.append(STDP(self.node[0], self.connection[0], trace_stdp))
self.learning_rule.append(STDP(self.node[1], self.connection[1], trace_stdp))
self.learning_rule.append(MutliInputSTDP(self.node[2], [self.connection[2], self.connection[3]], trace_stdp))
self.out_vis = torch.zeros((self.connection[0].weight.shape[1]), dtype=torch.float)
self.out_KC = torch.zeros((self.connection[1].weight.shape[1]), dtype=torch.float)
self.out_MBON = torch.zeros((self.connection[2].weight.shape[1]), dtype=torch.float)
def forward(self, input):
"""
根据输入得到输出
:param input: 输入电流
:return: 网络的输出,以及网络运行产生的STDP可塑性
"""
self.out_vis = self.node[0](self.connection[0](input))
self.out_KC, dw_kc = self.learning_rule[1](self.out_vis)
self.out_MBON, dw_mbon = self.learning_rule[2](self.out_KC, self.out_MBON)
return self.out_MBON, dw_kc[0], dw_mbon[0]
def UpdateWeight(self, i, dw):
"""
更新网络中第i组连接的权重
:param i: 要更新的连接的索引
:param dw: 更新的量
:return: None
"""
self.connection[i].update(dw)
self.connection[i].weight.data = F.normalize(self.connection[i].weight.data.float(), p=1, dim=1)
def reset(self):
"""
reset神经元或学习法则的中间量
:return: None
"""
for i in range(self.num_subMB):
self.node[i].n_reset()
for i in range(len(self.learning_rule)):
self.learning_rule[i].reset()
def getweight(self):
"""
获取网络的连接(包括权值等)
:return: 网络的连接
"""
return self.connection
================================================
FILE: braincog/model_zoo/nonlinearNet.py
================================================
import torch.nn.functional as F
from braincog.base.strategy.surrogate import *
from braincog.base.node.node import IFNode
from braincog.base.learningrule.STDP import STDP, MutliInputSTDP
class droDMTestNet(nn.Module):
"""
Drosophila Testing Network: compound eye-KC-MBON DA-GABA-MB
"""
def __init__(self, connection):
"""
根据传入的连接 构建测试网络
:param connection: 测试网络的连接
"""
super().__init__()
trace_stdp = 0.99
self.num_subMB = 5
self.node = [IFNode() for i in range(self.num_subMB)]
self.connection = connection
self.learning_rule = []
self.learning_rule.append(STDP(self.node[0], self.connection[0], trace_stdp))
self.learning_rule.append(MutliInputSTDP(self.node[1], [self.connection[1], self.connection[5]], trace_stdp))
self.learning_rule.append(MutliInputSTDP(self.node[2], [self.connection[2], self.connection[3], self.connection[9]], trace_stdp))
self.learning_rule.append(MutliInputSTDP(self.node[3], [self.connection[4], self.connection[6]], trace_stdp))
self.learning_rule.append(MutliInputSTDP(self.node[4], [self.connection[7], self.connection[8]], trace_stdp))
self.out_vis = torch.zeros((self.connection[0].weight.shape[1]), dtype=torch.float)
self.out_KC = torch.zeros((self.connection[1].weight.shape[1]), dtype=torch.float)
self.out_MBON = torch.zeros((self.connection[2].weight.shape[1]), dtype=torch.float)
self.out_APL = torch.zeros((self.connection[4].weight.shape[1]), dtype=torch.float)
self.out_DA = torch.zeros((self.connection[7].weight.shape[1]), dtype=torch.float)
def forward(self, input, input_da):
"""
根据输入得到输出
:param input: 输入电流
:return: 网络的输出,以及网络运行产生的STDP可塑性
"""
self.out_vis = self.node[0](self.connection[0](input))
self.out_KC, dw_kc = self.learning_rule[1](self.out_vis, self.out_APL)
self.out_MBON, dw_mbon = self.learning_rule[2](self.out_KC, self.out_MBON, self.out_DA)
self.out_APL, dw_apl = self.learning_rule[3](self.out_KC, self.out_DA)
self.out_DA, dw_da = self.learning_rule[4](self.out_APL, input_da)
return self.out_MBON, dw_kc[1], dw_apl[0]
def UpdateWeight(self, i, dw):
"""
更新网络中第i组连接的权重
:param i: 要更新的连接的索引
:param dw: 更新的量
:return: None
"""
self.connection[i].update(dw)
self.connection[i].weight.data = F.normalize(self.connection[i].weight.data.float(), p=1, dim=0)
def reset(self):
"""
reset神经元或学习法则的中间量
:return: None
"""
for i in range(self.num_subMB):
self.node[i].n_reset()
for i in range(len(self.learning_rule)):
self.learning_rule[i].reset()
def getweight(self):
"""
获取网络的连接(包括权值等)
:return: 网络的连接
"""
return self.connection
================================================
FILE: braincog/model_zoo/qsnn.py
================================================
import numpy as np
from scipy.linalg import orth
from scipy.special import expit
from scipy.signal import fftconvolve
import torch
from torch.nn import Parameter
import torch.nn as nn
from braincog.datasets.gen_input_signal import lambda_max, dt
from braincog.base.encoder import QSEncoder
gamma = 0.1
beta = 1.0
theta = 3.0
# kernel parameters
tau_s = 4.0 # synaptic time constant
tau_L = 10.0 # leak time constant
# conductance parameters
g_B = 0.6 # basal conductance
g_A = 0.05 # apical conductance
g_L = 1.0 / tau_L # leak conductance
g_D = g_B # dendritic conductance in output layer
k_D = g_D / (g_L + g_D)
STEPS = int(50 / dt)
SLEN = 20 # spike time length
# --- sigmoid function --- #
def sigma(x):
return torch.sigmoid(x)
# def sigma(x):
# return gamma * np.log(1+np.exp(beta*(x-theta)))
def deriv_sigma(x):
return sigma(x) * (1.0 - sigma(x))
# kernel parameters
tau_s = 4.0 # synaptic time constant
tau_L = 10.0 # leak time constant
# --- kernel function --- #
mem = STEPS
def kappa(x):
return np.exp(-x / tau_s)
def get_kappas(n):
return np.array([kappa(i + 1) for i in range(n)])
kappas = get_kappas(mem // 2) # initialize kappas array
kernel = np.zeros(mem)
kernel[:mem // 2] = kappas[:]
kernel[mem // 2:] = -np.flipud(kappas)[:]
W_MIN = -1.0
W_MAX = 1.0
class Net(nn.Module):
"""
两房室脉冲神经网络
"""
def __init__(self, net_size):
super().__init__()
self.input_size = net_size[0]
self.hidden_layers = nn.ModuleList([Hidden_layer(net_size[i], net_size[i + 1], net_size[-1]) for i in range(len(net_size) - 2)])
self.out_layer = Output_layer(net_size[-2], net_size[-1])
self.kernel = torch.from_numpy(kernel[:, np.newaxis]).cuda()
self.qs_code = QSEncoder
def update_state(self, input_, label, test):
if len(self.hidden_layers) > 1:
self.hidden_layers[0].update_state(input_, self.out_layer.spike_rate, test=test)
for i in range(len(self.hidden_layers) - 2):
self.hidden_layers[i + 1].update_state(self.hidden_layers[i].spike_rate, self.out_layer.spike_rate, test=test)
self.hidden_layers[-1].update_state(self.hidden_layers[-2].spike_rate, self.out_layer.spike_rate, test=test)
else:
self.hidden_layers[0].update_state(input_, self.out_layer.spike_rate, test=test)
self.out_layer.update_state(self.hidden_layers[-1].spike_rate, label, test=test)
def routine(self,
input_,
input_delta,
image_ori,
image_ori_delta,
shift,
label,
test=False,
noise=False,
noise_rate=None):
"""
网络信息处理过程
:param input_: 输入图片
:param input_delta: 输入扰动图片,用于计算相位
:param image_ori: 原始图片
:param image_ori_delta: 原始扰动图片
:param shift: 是否反转背景
:param label: 输入数据分类标签
:param test: 是否是测试阶段
:param noise: 是否增加噪声
:param noise_rate: 噪声比例
"""
encoder = self.qs_code(lambda_max, STEPS, SLEN, shift, noise, noise_rate)
input_ = encoder(input_, input_delta, image_ori, image_ori_delta)
input_ = torch.from_numpy(input_).to(self.kernel.device)
psp = torch.mm(input_, self.kernel).abs().float()
for i in range(STEPS):
self.update_state(psp, label, test=test)
def update_weight(self, lr, t, beta, eps):
self.out_layer.update_weight(lr, t, beta, eps)
if len(self.hidden_layers) > 1:
self.hidden_layers[-1].update_weight(self.out_layer.delta, lr, t, beta, eps)
for i in range(len(self.hidden_layers) - 1):
self.hidden_layers[-(i + 2)].update_weight(self.hidden_layers[-(i + 1)].delta, lr, t, beta, eps)
else:
self.hidden_layers[0].update_weight(self.out_layer.delta, lr, t, beta, eps)
def predict(self,
input_,
input_delta,
image_ori,
image_ori_delta,
shift,
noise,
noise_rate=0):
self.routine(input_,
input_delta,
image_ori=image_ori,
image_ori_delta=image_ori_delta,
shift=shift,
label=None,
test=True,
noise=noise,
noise_rate=noise_rate)
pred = torch.argmax(self.out_layer.spike_rate.flatten())
return pred
class Hidden_layer(nn.Module):
"""
隐藏层两房室网络
"""
def __init__(self, input_size, neu_num, fb_neus):
super().__init__()
self.basal_linear = nn.Linear(input_size, neu_num)
nn.init.uniform_(self.basal_linear.weight, -0.1, 0.1)
nn.init.uniform_(self.basal_linear.bias, -0.1, 0.1)
self.soma_V = 0.0
self.basal_V = 0.0
# for adam
self.m = 0.0
self.v = 0.0
self.m_hat = 0.0
self.v_hat = 0.0
self.m_b = 0.0
self.v_b = 0.0
self.m_b_hat = 0.0
self.v_b_hat = 0.0
# backprop
self.delta = 0.0
def update_state(self, basal_input, apical_input, test):
self.basal_input = basal_input.T # [1, 781]
self.basal_V = self.basal_linear(basal_input.T)
self.soma_V = self.soma_V + 1 / tau_L * (-self.soma_V + g_B / g_L * (self.basal_V - self.soma_V)) * dt
self.spike_rate = lambda_max * sigma(self.soma_V)
def update_weight(self, delta_, lr, t, beta, eps):
weight_dot = lambda_max * k_D * delta_ * deriv_sigma(k_D * self.basal_V) # [1, 500]
self.delta = torch.mm(weight_dot, self.basal_linear.weight.data) # [500, 784] x [1, 500]
weight_delta = weight_dot[:, :, None] * self.basal_input[:, None, :]
bias_delta = weight_dot
self.m = beta[0] * self.m + (1 - beta[0]) * weight_delta
self.v = beta[1] * self.v + (1 - beta[1]) * torch.square(weight_delta)
self.m_hat = self.m / (1 - beta[0] ** t)
self.v_hat = self.v / (1 - beta[1] ** t)
self.m_b = beta[0] * self.m_b + (1 - beta[0]) * bias_delta
self.v_b = beta[1] * self.v_b + (1 - beta[1]) * torch.square(bias_delta)
self.m_b_hat = self.m_b / (1 - beta[0] ** t)
self.v_b_hat = self.v_b / (1 - beta[1] ** t)
# update weight
weight_delta = lr * self.m_hat / (torch.sqrt(self.v_hat) + eps)
bias_delta = lr * self.m_b_hat / (torch.sqrt(self.v_b_hat) + eps)
self.basal_linear.weight.data.sub_(weight_delta.mean(0))
self.basal_linear.bias.data.sub_(bias_delta.mean(0))
class Output_layer(nn.Module):
"""
输出层两房室网络
"""
def __init__(self, input_size, neu_num):
super().__init__()
self.basal_linear = nn.Linear(input_size, neu_num)
nn.init.uniform_(self.basal_linear.weight, -0.1, 0.1)
nn.init.uniform_(self.basal_linear.bias, -0.1, 0.1)
self.soma_V = 0.0
self.basal_V = 0.0
self.spike_rate = 0.0
# adam
self.m = 0.0
self.v = 0.0
self.m_hat = 0.0
self.v_hat = 0.0
self.m_b = 0.0
self.v_b = 0.0
self.m_b_hat = 0.0
self.v_b_hat = 0.0
# backprop
self.delta = 0.0
def update_state(self, basal_input, I, test):
self.basal_input = basal_input
self.basal_V = self.basal_linear(basal_input)
if test:
self.soma_V = self.soma_V + 1 / tau_L * (-self.soma_V + g_B / g_L * (self.basal_V - self.soma_V)) * dt
else:
self.soma_V = self.soma_V + 1 / tau_L * (-self.soma_V + g_B / g_L * (self.basal_V - self.soma_V) +
I - self.soma_V) * dt
self.spike_rate = lambda_max * sigma(self.soma_V)
def update_weight(self, lr, t, beta, eps):
weight_dot = lambda_max * k_D * (sigma(k_D * self.basal_V) - sigma(self.soma_V)) * deriv_sigma(k_D * self.basal_V)
self.delta = torch.mm(weight_dot, self.basal_linear.weight.data) # [1, 500]
bias_delta = weight_dot # [1, 10]
weight_delta = weight_dot[:, :, None] * self.basal_input[:, None, :] # [1, 10, 500]
self.m = beta[0] * self.m + (1 - beta[0]) * weight_delta
self.v = beta[1] * self.v + (1 - beta[1]) * torch.square(weight_delta)
self.m_hat = self.m / (1 - beta[0] ** t)
self.v_hat = self.v / (1 - beta[1] ** t)
self.m_b = beta[0] * self.m_b + (1 - beta[0]) * bias_delta
self.v_b = beta[1] * self.v_b + (1 - beta[1]) * torch.square(bias_delta)
self.m_b_hat = self.m_b / (1 - beta[0] ** t)
self.v_b_hat = self.v_b / (1 - beta[1] ** t)
# update weight
weight_delta = lr * self.m_hat / (torch.sqrt(self.v_hat) + eps)
bias_delta = lr * self.m_b_hat / (torch.sqrt(self.v_b_hat) + eps)
self.basal_linear.weight.data.sub_(weight_delta.mean(0))
self.basal_linear.bias.data.sub_(bias_delta.mean(0))
================================================
FILE: braincog/model_zoo/resnet.py
================================================
'''
Deep Residual Learning for Image Recognition
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
'''
import os
import sys
from functools import partial
from timm.models import register_model
from timm.models.layers import trunc_normal_, DropPath
from braincog.model_zoo.base_module import *
from braincog.base.node.node import *
__all__ = [
'ResNet',
'resnet18',
'resnet34_half',
'resnet34',
'resnet50_half',
'resnet50',
'resnet101',
'resnet152',
'resnext50_32x4d',
'resnext101_32x8d',
'wide_resnet50_2',
'wide_resnet101_2',
]
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
'''3x3 convolution with padding'''
return nn.Conv2d(in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
'''1x1 convolution'''
return nn.Conv2d(in_planes,
out_planes,
kernel_size=1,
stride=stride,
bias=False)
class BasicBlock(nn.Module):
"""
ResNet的基础模块, 采用identity-connection的方式.
:param inplanes: 输出通道数
:param planes: 内部通道数量
:param stride: stride
:param downsample: 是否降采样
:param groups: 分组卷积
:param base_width: 基础通道数量
:param dilation: 空洞卷积
:param norm_layer: Norm的方式
:param node: 神经元类型, 默认为 ``LIFNode``
"""
expansion = 1
__constants__ = ['downsample']
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None,
node=LIFNode):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError(
'BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError(
'Dilation > 1 not supported in BasicBlock')
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.bn1 = norm_layer(inplanes)
self.node1 = node()
self.conv1 = conv3x3(inplanes, planes, stride)
# self.relu = nn.ReLU(inplace=False)
self.node2 = node()
self.bn2 = norm_layer(planes)
self.conv2 = conv3x3(planes, planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.bn1(x)
out = self.node1(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.node2(out)
out = self.conv2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
class Bottleneck(nn.Module):
"""
ResNet的Botteneck模块, 采用identity-connection的方式.
:param inplanes: 输出通道数
:param planes: 内部通道数量
:param stride: stride
:param downsample: 是否降采样
:param groups: 分组卷积
:param base_width: 基础通道数量
:param dilation: 空洞卷积
:param norm_layer: Norm的方式
:param node: 神经元类型, 默认为 ``LIFNode``
"""
expansion = 4
__constants__ = ['downsample']
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None,
node=torch.nn.Identity):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.bn1 = norm_layer(inplanes)
self.conv1 = conv1x1(inplanes, width)
self.bn2 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn3 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
# self.relu = nn.ReLU(inplace=False)
self.downsample = downsample
self.stride = stride
self.node1 = node()
self.node2 = node()
self.node3 = node()
def forward(self, x):
identity = x
out = self.bn1(x)
out = self.node1(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.node2(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.node3(out)
out = self.conv3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
class ResNet(BaseModule):
"""
ResNet-SNN
:param block: Block类型
:param layers: block 层数
:param inplanes: 输入通道数量
:param num_classes: 输出类别数
:param zero_init_residual: 是否使用零初始化
:param groups: 卷积分组
:param width_per_group: 每一组的宽度
:param replace_stride_with_dilation: 是否使用stride替换dilation
:param norm_layer: Norm 方式, 默认为 ``BatchNorm``
:param step: 仿真步长, 默认为 ``8``
:param encode_type: 编码方式, 默认为 ``direct``
:param spike_output: 是否使用脉冲输出, 默认为 ``False``
:param args:
:param kwargs:
"""
def __init__(self,
block,
layers,
inplanes=64,
num_classes=10,
zero_init_residual=False,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
norm_layer=None,
step=8,
encode_type='direct',
spike_output=False,
*args,
**kwargs):
super().__init__(
step,
encode_type,
*args,
**kwargs
)
self.spike_output = spike_output
self.num_classes = num_classes
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
# print('inplanes %d' % inplanes)
self.inplanes = inplanes
self.interplanes = [
self.inplanes, self.inplanes * 2, self.inplanes * 4,
self.inplanes * 8
]
self.dilation = 1
self.node = kwargs['node_type']
if issubclass(self.node, BaseNode):
self.node = partial(self.node, **kwargs)
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError('replace_stride_with_dilation should be None '
'or a 3-element tuple, got {}'.format(
replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.static_data = False
self.dataset = kwargs['dataset']
if self.dataset == 'dvsg' or self.dataset == 'dvsc10' or self.dataset == 'NCALTECH101' or self.dataset == 'NCARS' or self.dataset == 'DVSG':
self.conv1 = nn.Conv2d(2 * self.init_channel_mul,
self.inplanes,
kernel_size=3,
padding=1,
bias=False)
elif self.dataset == 'imnet':
self.conv1 = nn.Conv2d(3 * self.init_channel_mul,
self.inplanes,
kernel_size=7,
stride=2,
padding=3,
bias=False)
self.static_data = True
elif self.dataset == 'esimnet':
reconstruct = kwargs["reconstruct"] if "reconstruct" in kwargs else False
print(reconstruct)
if reconstruct:
self.conv1 = nn.Conv2d(1 * self.init_channel_mul,
self.inplanes,
kernel_size=7,
stride=2,
padding=3,
bias=False)
self.static_data = True
else:
self.conv1 = nn.Conv2d(2 * self.init_channel_mul,
self.inplanes,
kernel_size=7,
stride=2,
padding=3,
bias=False)
self.static_data = True
elif self.dataset == 'cifar10' or self.dataset == 'cifar100':
self.conv1 = nn.Conv2d(3 * self.init_channel_mul,
self.inplanes,
kernel_size=3,
padding=1,
bias=False)
self.static_data = True
# self.relu = nn.ReLU(inplace=False)
self.layer1 = self._make_layer(
block, self.interplanes[0], layers[0], node=self.node)
self.layer2 = self._make_layer(block,
self.interplanes[1],
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0], node=self.node)
self.layer3 = self._make_layer(block,
self.interplanes[2],
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1], node=self.node)
self.layer4 = self._make_layer(block,
self.interplanes[3],
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2], node=self.node)
self.bn1 = norm_layer(self.inplanes)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
if self.spike_output:
self.fc = nn.Linear(
self.interplanes[3] * block.expansion, num_classes * 10)
self.node2 = self.node()
self.vote = VotingLayer(10)
else:
self.fc = nn.Linear(
self.interplanes[3] * block.expansion, num_classes
)
self.node2 = nn.Identity()
self.vote = nn.Identity()
self.warm_up = False
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight,
mode='fan_out',
nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False, node=torch.nn.Identity):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
if block == BasicBlock:
downsample = nn.Sequential(
norm_layer(self.inplanes),
self.node(),
conv1x1(self.inplanes, planes * block.expansion, stride),
)
elif block == Bottleneck:
downsample = nn.Sequential(
norm_layer(self.inplanes),
self.node(),
conv1x1(self.inplanes, planes * block.expansion, stride),
)
else:
raise NotImplementedError
layers = [block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer, node=node)]
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer, node=node))
return nn.Sequential(*layers)
def forward(self, inputs):
inputs = self.encoder(inputs)
self.reset()
if self.layer_by_layer:
x = self.conv1(inputs)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.bn1(x)
# x = self.node1(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
# print(x.shape)
x = self.fc(x)
x = rearrange(x, '(t b) c -> t b c', t=self.step).mean(0)
x = self.node2(x)
x = self.vote(x)
return x
else:
outputs = []
if self.warm_up:
step = 1
else:
step = self.step
for t in range(step):
x = inputs[t]
x = self.conv1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.bn1(x)
# x = self.node1(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
x = self.node2(x)
x = self.vote(x)
outputs.append(x)
return sum(outputs) / len(outputs)
def _resnet(arch, block, layers, pretrained=False, **kwargs):
model = ResNet(block, layers, **kwargs)
# only load state_dict()
if pretrained:
raise NotImplementedError
return model
@register_model
def resnet9(pretrained=False, **kwargs):
return _resnet('resnet9', BasicBlock, [1, 1, 1, 1], pretrained, **kwargs)
@register_model
def resnet18(pretrained=False, **kwargs):
# kwargs['inplanes'] = 96
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, **kwargs)
@register_model
def resnet34_half(pretrained=False, **kwargs):
kwargs['inplanes'] = 32
return _resnet('resnet34_half', BasicBlock, [3, 4, 6, 3], pretrained,
**kwargs)
@register_model
def resnet34(pretrained=False, **kwargs):
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, **kwargs)
@register_model
def resnet50_half(pretrained=False, **kwargs):
kwargs['inplanes'] = 32
return _resnet('resnet50_half', Bottleneck, [3, 4, 6, 3], pretrained,
**kwargs)
@register_model
def resnet50(pretrained=False, **kwargs):
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, **kwargs)
@register_model
def resnet101(pretrained=False, **kwargs):
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained,
**kwargs)
@register_model
def resnet152(pretrained=False, **kwargs):
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained,
**kwargs)
@register_model
def resnext50_32x4d(pretrained=False, **kwargs):
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained,
**kwargs)
@register_model
def resnext101_32x8d(pretrained=False, **kwargs):
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained,
**kwargs)
@register_model
def wide_resnet50_2(pretrained=False, **kwargs):
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained,
**kwargs)
@register_model
def wide_resnet101_2(pretrained=False, **kwargs):
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained,
**kwargs)
if __name__ == '__main__':
net = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=1000)
image_h, image_w = 224, 224
from thop import profile
from thop import clever_format
flops, params = profile(net,
inputs=(torch.randn(1, 3, image_h, image_w),),
verbose=False)
flops, params = clever_format([flops, params], '%.3f')
out = net(torch.autograd.Variable(torch.randn(3, 3, image_h, image_w)))
print(f'1111, flops: {flops}, params: {params},out_shape: {out.shape}')
================================================
FILE: braincog/model_zoo/resnet19_snn.py
================================================
# encoding: utf-8
# Author : Floyed
# Datetime : 2022/7/26 19:33
# User : Floyed
# Product : PyCharm
# Project : braincog
# File : resnet19_snn.py
# explain :
import os
import sys
from functools import partial
import numpy as np
from timm.models import register_model
from timm.models.layers import trunc_normal_, DropPath
from braincog.model_zoo.base_module import *
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.datasets import is_dvs_data
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
node=LIFNode, base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = ThresholdDependentBatchNorm2d
# if groups != 1 or base_width != 64:
# raise ValueError('BasicBlock only supports groups=1 and base_width=64')
# if dilation > 1:
# raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(num_features=planes, alpha=1.)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(num_features=planes, alpha=np.sqrt(.5))
self.downsample = downsample
self.stride = stride
self.node1 = node()
self.node2 = node()
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.node1(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.node2(out)
return out
class ResNet(BaseModule):
def __init__(self, block, layers, num_classes=10, zero_init_residual=False, groups=1, width_per_group=128,
replace_stride_with_dilation=None, norm_layer=None, step=4, encode_type='direct', node_type=LIFNode,
*args, **kwargs):
super().__init__(
step,
encode_type,
*args,
**kwargs
)
super().__init__(step, encode_type, *args, **kwargs)
if not self.layer_by_layer:
raise ValueError('ResNet-SNN only support for layer-wise mode, because of tdBN')
self.node = node_type
if issubclass(self.node, BaseNode):
self.node = partial(self.node, **kwargs, step=step)
self.dataset = kwargs['dataset']
if is_dvs_data(self.dataset):
data_channel = 2
else:
data_channel = 3
if norm_layer is None:
norm_layer = ThresholdDependentBatchNorm2d
self._norm_layer = partial(norm_layer, step=step)
self.sum_output=kwargs["sum_output"] if "sum_output"in kwargs else True
self.inplanes = 128
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(data_channel, self.inplanes, kernel_size=3, stride=1, padding=1,
bias=False)
self.bn1 = self._norm_layer(num_features=self.inplanes, alpha=np.sqrt(.5))
self.layer1 = self._make_layer(block, 128, layers[0])
self.layer2 = self._make_layer(block, 256, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 512, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc1 = nn.Linear(512 * block.expansion, 256)
self.fc2 = nn.Linear(256, num_classes)
self.node1 = self.node()
self.node2 = self.node()
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
# downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(num_features=planes * block.expansion, alpha=np.sqrt(.5)),
)
else:
downsample = nn.Sequential(
norm_layer(num_features=planes * block.expansion, alpha=np.sqrt(.5)),
)
layers = []
layers.append(block(self.inplanes, planes, stride=stride, downsample=downsample, groups=self.groups,
base_width=self.base_width, norm_layer=norm_layer, node=self.node))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer, node=self.node))
return nn.Sequential(*layers)
def _forward_impl(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.node1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = self.node2(x)
x = self.fc2(x)
if self.sum_output:x= rearrange(x, '(t b) c -> b c t', t=self.step).mean(-1)
else :x= rearrange(x, '(t b) c -> t b c ', t=self.step)
return x
def forward(self, inputs):
inputs = self.encoder(inputs)
self.reset()
return self._forward_impl(inputs)
def _resnet(arch, block, layers, pretrained, progress, norm=ThresholdDependentBatchNorm2d, **kwargs):
tdBN = partial(norm, layer_by_layer=kwargs['layer_by_layer'], threshold=kwargs['threshold'])
model = ResNet(block, layers, norm_layer=tdBN, **kwargs)
if pretrained:
raise NotImplementedError
return model
@register_model
def resnet19(pretrained=False, progress=True, norm=ThresholdDependentBatchNorm2d, **kwargs):
return _resnet('resnet19', BasicBlock, [3, 3, 2], pretrained, progress, norm=norm, **kwargs)
if __name__ == '__main__':
net = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=1000)
image_h, image_w = 224, 224
from thop import profile
from thop import clever_format
flops, params = profile(net,
inputs=(torch.randn(1, 3, image_h, image_w),),
verbose=False)
flops, params = clever_format([flops, params], '%.3f')
out = net(torch.autograd.Variable(torch.randn(3, 3, image_h, image_w)))
print(f'1111, flops: {flops}, params: {params},out_shape: {out.shape}')
================================================
FILE: braincog/model_zoo/rsnn.py
================================================
import torch
from torch import nn
from braincog.base.node.node import IFNode
from braincog.base.learningrule.STDP import STDP,MutliInputSTDP
from braincog.base.connection.CustomLinear import CustomLinear
from collections import deque
from random import randint
class RSNN(nn.Module):
def __init__(self,num_state,num_action):
super().__init__()
# parameters
rsnn_mask=[]
rsnn_con=[]
con_matrix1 = torch.ones((num_state,num_action), dtype=torch.float)
rsnn_mask.append(con_matrix1)
rsnn_con.append(CustomLinear(torch.randn(num_state,num_action), con_matrix1))
self.num_subR=2
self.connection = rsnn_con
self.mask=rsnn_mask
self.node = [IFNode() for i in range(self.num_subR)]
self.learning_rule = []
self.learning_rule.append(MutliInputSTDP(self.node[1], [self.connection[0]]))
self.weight_trace = torch.zeros(con_matrix1.shape, dtype=torch.float)
self.out_in = torch.zeros((num_state), dtype=torch.float)
self.out = torch.zeros((self.connection[0].weight.size()[1]), dtype=torch.float)
self.dw = torch.zeros((self.connection[0].weight.size()), dtype=torch.float)
def forward(self, input):
input=torch.tensor(input, dtype=torch.float)
self.out_in=self.node[0](input)
self.out,self.dw = self.learning_rule[0](self.out_in)
return self.out,self.dw
def UpdateWeight(self,reward):
self.weight_trace[self.weight_trace>0]=self.weight_trace[self.weight_trace>0]*reward
self.weight_trace[self.weight_trace < 0] = -1*self.weight_trace[self.weight_trace < 0] * reward
self.connection[0].update(self.weight_trace)
for i in range(self.connection[0].weight.size()[1]):
self.connection[0].weight.data[:, i] = (self.connection[0].weight.data[:, i] - torch.min(self.connection[0].weight.data[:, i])) / (torch.max(self.connection[0].weight.data[:, i]) - torch.min(self.connection[0].weight.data[:, i]))
self.connection[0].weight.data= self.connection[0].weight.data * 0.5
def reset(self):
for i in range(self.num_subR):
self.node[i].n_reset()
for i in range(len(self.learning_rule)):
self.learning_rule[i].reset()
def getweight(self):
return self.connection
================================================
FILE: braincog/model_zoo/sew_resnet.py
================================================
import torch
import torch.nn as nn
from copy import deepcopy
try:
from torchvision.models.utils import load_state_dict_from_url
except ImportError:
from torchvision._internally_replaced_utils import load_state_dict_from_url
from braincog.base.node import *
from braincog.model_zoo.base_module import *
from braincog.datasets import is_dvs_data
from timm.models import register_model
__all__ = ['SEWResNet', 'sew_resnet18', 'sew_resnet34', 'sew_resnet50', 'sew_resnet101',
'sew_resnet152', 'sew_resnext50_32x4d', 'sew_resnext101_32x8d',
'sew_wide_resnet50_2', 'sew_wide_resnet101_2']
model_urls = {
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
"wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
"wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
}
# modified by https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
def sew_function(x: torch.Tensor, y: torch.Tensor, cnf:str):
if cnf == 'ADD':
return x + y
elif cnf == 'AND':
return x * y
elif cnf == 'IAND':
return x * (1. - y)
else:
raise NotImplementedError
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None, cnf: str = None, node: callable = None, **kwargs):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.node1 = node()
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.node2 = node()
self.downsample = downsample
if downsample is not None:
self.downsample_sn = node()
self.stride = stride
self.cnf = cnf
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.node1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.node2(out)
if self.downsample is not None:
identity = self.downsample_sn(self.downsample(x))
out = sew_function(identity, out, self.cnf)
return out
def extra_repr(self) -> str:
return super().extra_repr() + f'cnf={self.cnf}'
class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None, cnf: str = None, node: callable = None, **kwargs):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.node1 = node()
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.node2 = node()
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.node3 = node()
self.downsample = downsample
if downsample is not None:
self.downsample_sn = node()
self.stride = stride
self.cnf = cnf
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.node1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.node2(out)
out = self.conv3(out)
out = self.bn3(out)
out = self.node3(out)
if self.downsample is not None:
identity = self.downsample_sn(self.downsample(x))
out = sew_function(out, identity, self.cnf)
return out
def extra_repr(self) -> str:
return super().extra_repr() + f'cnf={self.cnf}'
class SEWResNet(BaseModule):
def __init__(self, block, layers, num_classes=1000, step=8,encode_type="direct",zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None, cnf: str = None, *args,**kwargs):
super().__init__(
step,
encode_type,
*args,
**kwargs
)
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.num_classes = num_classes
self.node = kwargs['node_type']
if issubclass(self.node, BaseNode):
self.node = partial(self.node, **kwargs, step=step)
self.once=kwargs["once"] if "once"in kwargs else False
self.sum_output=kwargs["sum_output"] if "sum_output"in kwargs else True
self.dataset = kwargs['dataset']
if not is_dvs_data(self.dataset):
init_channel = 3
else:
init_channel = 2
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(init_channel, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.node1 = self.node()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], cnf=cnf, node=self.node, **kwargs)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0], cnf=cnf, node=self.node, **kwargs)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1], cnf=cnf, node=self.node, **kwargs)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2], cnf=cnf, node=self.node, **kwargs)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False, cnf: str=None, node: callable = None, **kwargs):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer, cnf, node, **kwargs))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer, cnf=cnf, node=node, **kwargs))
return nn.Sequential(*layers)
def _forward_impl(self, inputs):
# See note [TorchScript super()]
inputs = self.encoder(inputs)
self.reset()
if self.layer_by_layer:
x = self.conv1(inputs)
x = self.bn1(x)
x = self.node1(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
x = rearrange(x, '(t b) c -> t b c', t=self.step)
#print(x)
if self.sum_output:x=x.mean(0)
return x
else:
outputs=[]
for t in range(self.step):
x = inputs[t]
x = self.conv1(x)
x = self.bn1(x)
x = self.node1(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
outputs.append(x)
if not self.sum_output:return outputs
return sum(outputs) / len(outputs)
def _forward_once(self,x):
# inputs = self.encoder(inputs)
# x = inputs[t]
x = self.conv1(x)
x = self.bn1(x)
x = self.node1(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
if self.once:return self._forward_once(x)
return self._forward_impl(x)
class SEWResNet19(BaseModule):
def __init__(self, block, layers, num_classes=1000, step=8,encode_type="direct",zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None, cnf: str = None, *args,**kwargs):
super().__init__(
step,
encode_type,
*args,
**kwargs
)
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.num_classes = num_classes
self.node = kwargs['node_type']
if issubclass(self.node, BaseNode):
self.node = partial(self.node, **kwargs, step=step)
self.once=kwargs["once"] if "once"in kwargs else False
self.sum_output=kwargs["sum_output"] if "sum_output"in kwargs else True
self.dataset = kwargs['dataset']
if not is_dvs_data(self.dataset):
init_channel = 3
else:
init_channel = 2
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(init_channel, self.inplanes, kernel_size=3, stride=1, padding=1,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.node1 = self.node()
self.layer1 = self._make_layer(block, 128, layers[0], cnf=cnf, node=self.node, **kwargs)
self.layer2 = self._make_layer(block, 256, layers[1], stride=2,
dilate=replace_stride_with_dilation[0], cnf=cnf, node=self.node, **kwargs)
self.layer3 = self._make_layer(block, 512, layers[2], stride=2,
dilate=replace_stride_with_dilation[1], cnf=cnf, node=self.node, **kwargs)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc1 = nn.Linear(512 * block.expansion, 256)
self.fc2 = nn.Linear(256, num_classes)
self.node2 = self.node()
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False, cnf: str=None, node: callable = None, **kwargs):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer, cnf, node, **kwargs))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer, cnf=cnf, node=node, **kwargs))
return nn.Sequential(*layers)
def _forward_impl(self, inputs):
# See note [TorchScript super()]
inputs = self.encoder(inputs)
self.reset()
if self.layer_by_layer:
x = self.conv1(inputs)
x = self.bn1(x)
x = self.node1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = self.node2(x)
x = self.fc2(x)
x = rearrange(x, '(t b) c -> t b c', t=self.step)
#print(x)
if self.sum_output:x=x.mean(0)
return x
else:
outputs=[]
for t in range(self.step):
x = inputs[t]
x = self.conv1(x)
x = self.bn1(x)
x = self.node1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = self.node2(x)
x = self.fc2(x)
outputs.append(x)
if not self.sum_output:return outputs
return sum(outputs) / len(outputs)
def _forward_once(self,x):
# inputs = self.encoder(inputs)
# x = inputs[t]
x = self.conv1(x)
x = self.bn1(x)
x = self.node1(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
if self.once:return self._forward_once(x)
return self._forward_impl(x)
class SEWResNetCifar(BaseModule):
def __init__(self, block, layers, num_classes=1000, step=8,encode_type="direct",zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None, cnf: str = None, *args,**kwargs):
super().__init__(
step,
encode_type,
*args,
**kwargs
)
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.num_classes = num_classes
self.node = kwargs['node_type']
if issubclass(self.node, BaseNode):
self.node = partial(self.node, **kwargs, step=step)
self.once=kwargs["once"] if "once"in kwargs else False
self.sum_output=kwargs["sum_output"] if "sum_output"in kwargs else True
self.dataset = kwargs['dataset']
if not is_dvs_data(self.dataset):
init_channel = 3
else:
init_channel = 2
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(init_channel, self.inplanes, kernel_size=3, stride=1, padding=1,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.node1 = self.node()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 128, layers[0], cnf=cnf, node=self.node, **kwargs)
self.layer2 = self._make_layer(block, 256, layers[1], stride=2,
dilate=replace_stride_with_dilation[0], cnf=cnf, node=self.node, **kwargs)
self.layer3 = self._make_layer(block, 512, layers[2], stride=2,
dilate=replace_stride_with_dilation[1], cnf=cnf, node=self.node, **kwargs)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False, cnf: str=None, node: callable = None, **kwargs):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer, cnf, node, **kwargs))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer, cnf=cnf, node=node, **kwargs))
return nn.Sequential(*layers)
def _forward_impl(self, inputs):
# See note [TorchScript super()]
inputs = self.encoder(inputs)
self.reset()
if self.layer_by_layer:
x = self.conv1(inputs)
x = self.bn1(x)
x = self.node1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
x = rearrange(x, '(t b) c -> t b c', t=self.step)
#print(x)
if self.sum_output:x=x.mean(0)
return x
else:
outputs=[]
for t in range(self.step):
x = inputs[t]
x = self.conv1(x)
x = self.bn1(x)
x = self.node1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
outputs.append(x)
if not self.sum_output:return outputs
return sum(outputs) / len(outputs)
def _forward_once(self,x):
# inputs = self.encoder(inputs)
# x = inputs[t]
x = self.conv1(x)
x = self.bn1(x)
x = self.node1(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
if self.once:return self._forward_once(x)
return self._forward_impl(x)
def _sew_resnet(arch, block, layers, pretrained, progress, cnf, **kwargs):
model = SEWResNet(block, layers, cnf=cnf, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
@register_model
def sew_resnet19(pretrained=False, progress=True, cnf: str = None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param cnf: the name of spike-element-wise function
:type cnf: str
:param node: a spiking neuron layer
:type node: callable
:param kwargs: kwargs for `node`
:type kwargs: dict
:return: Spiking ResNet-18
:rtype: torch.nn.Module
The spike-element-wise ResNet-18 `"Deep Residual Learning in Spiking Neural Networks" `_ modified by the ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_
"""
return SEWResNet19( BasicBlock, [3,3, 2], cnf=cnf, **kwargs)
@register_model
def sew_resnet18(pretrained=False, progress=True, cnf: str = None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param cnf: the name of spike-element-wise function
:type cnf: str
:param node: a spiking neuron layer
:type node: callable
:param kwargs: kwargs for `node`
:type kwargs: dict
:return: Spiking ResNet-18
:rtype: torch.nn.Module
The spike-element-wise ResNet-18 `"Deep Residual Learning in Spiking Neural Networks" `_ modified by the ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_
"""
return _sew_resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, cnf, **kwargs)
@register_model
def sew_resnet20(pretrained=False, progress=True, cnf: str = None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param cnf: the name of spike-element-wise function
:type cnf: str
:param node: a spiking neuron layer
:type node: callable
:param kwargs: kwargs for `node`
:type kwargs: dict
:return: Spiking ResNet-34
:rtype: torch.nn.Module
The spike-element-wise ResNet-34 `"Deep Residual Learning in Spiking Neural Networks" `_
modified by the ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_
"""
return SEWResNetCifar( BasicBlock, [3,3,3], cnf=cnf, **kwargs)
@register_model
def sew_resnet32(pretrained=False, progress=True, cnf: str = None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param cnf: the name of spike-element-wise function
:type cnf: str
:param node: a spiking neuron layer
:type node: callable
:param kwargs: kwargs for `node`
:type kwargs: dict
:return: Spiking ResNet-34
:rtype: torch.nn.Module
The spike-element-wise ResNet-34 `"Deep Residual Learning in Spiking Neural Networks" `_
modified by the ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_
"""
return SEWResNetCifar( BasicBlock, [5,5,5], cnf=cnf, **kwargs)
@register_model
def sew_resnet44(pretrained=False, progress=True, cnf: str = None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param cnf: the name of spike-element-wise function
:type cnf: str
:param node: a spiking neuron layer
:type node: callable
:param kwargs: kwargs for `node`
:type kwargs: dict
:return: Spiking ResNet-34
:rtype: torch.nn.Module
The spike-element-wise ResNet-34 `"Deep Residual Learning in Spiking Neural Networks" `_
modified by the ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_
"""
return SEWResNetCifar( BasicBlock, [7,7,7], cnf=cnf, **kwargs)
@register_model
def sew_resnet56(pretrained=False, progress=True, cnf: str = None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param cnf: the name of spike-element-wise function
:type cnf: str
:param node: a spiking neuron layer
:type node: callable
:param kwargs: kwargs for `node`
:type kwargs: dict
:return: Spiking ResNet-34
:rtype: torch.nn.Module
The spike-element-wise ResNet-34 `"Deep Residual Learning in Spiking Neural Networks" `_
modified by the ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_
"""
return SEWResNetCifar( BasicBlock, [9,9,9], cnf=cnf, **kwargs)
@register_model
def sew_resnet34(pretrained=False, progress=True, cnf: str = None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param cnf: the name of spike-element-wise function
:type cnf: str
:param node: a spiking neuron layer
:type node: callable
:param kwargs: kwargs for `node`
:type kwargs: dict
:return: Spiking ResNet-34
:rtype: torch.nn.Module
The spike-element-wise ResNet-34 `"Deep Residual Learning in Spiking Neural Networks" `_
modified by the ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_
"""
return _sew_resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, cnf, **kwargs)
@register_model
def sew_resnet50(pretrained=False, progress=True, cnf: str = None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param cnf: the name of spike-element-wise function
:type cnf: str
:param node: a spiking neuron layer
:type node: callable
:param kwargs: kwargs for `node`
:type kwargs: dict
:return: Spiking ResNet-50
:rtype: torch.nn.Module
The spike-element-wise ResNet-50 `"Deep Residual Learning in Spiking Neural Networks" `_
modified by the ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_
"""
return _sew_resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, cnf, **kwargs)
@register_model
def sew_resnet101(pretrained=False, progress=True, cnf: str = None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param cnf: the name of spike-element-wise function
:type cnf: str
:param node: a spiking neuron layer
:type node: callable
:param kwargs: kwargs for `node`
:type kwargs: dict
:return: Spiking ResNet-101
:rtype: torch.nn.Module
The spike-element-wise ResNet-101 `"Deep Residual Learning in Spiking Neural Networks" `_
modified by the ResNet-101 model from `"Deep Residual Learning for Image Recognition" `_
"""
return _sew_resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, cnf, **kwargs)
@register_model
def sew_resnet152(pretrained=False, progress=True, cnf: str = None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param cnf: the name of spike-element-wise function
:type cnf: str
:param node: a single step neuron
:type node: callable
:param kwargs: kwargs for `node`
:type kwargs: dict
:return: Spiking ResNet-152
:rtype: torch.nn.Module
The spike-element-wise ResNet-152 `"Deep Residual Learning in Spiking Neural Networks" `_
modified by the ResNet-152 model from `"Deep Residual Learning for Image Recognition" `_
"""
return _sew_resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, cnf, **kwargs)
@register_model
def sew_resnext50_32x4d(pretrained=False, progress=True, cnf: str = None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param cnf: the name of spike-element-wise function
:type cnf: str
:param node: a single step neuron
:type node: callable
:param kwargs: kwargs for `node`
:type kwargs: dict
:return: Spiking ResNeXt-50 32x4d
:rtype: torch.nn.Module
The spike-element-wise ResNeXt-50 32x4d `"Deep Residual Learning in Spiking Neural Networks" `_
modified by the ResNeXt-50 32x4d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _sew_resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained, progress, cnf, **kwargs)
@register_model
def sew_resnext34_32x4d(pretrained=False, progress=True, cnf: str = None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param cnf: the name of spike-element-wise function
:type cnf: str
:param node: a single step neuron
:type node: callable
:param kwargs: kwargs for `node`
:type kwargs: dict
:return: Spiking ResNeXt-101 32x8d
:rtype: torch.nn.Module
The spike-element-wise ResNeXt-101 32x8d `"Deep Residual Learning in Spiking Neural Networks" `_ modified by the ResNeXt-101 32x8d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _sew_resnet('resnext34_32x4d', BasicBlock, [3, 4, 6, 3], pretrained, progress, cnf, **kwargs)
@register_model
def sew_resnext101_32x8d(pretrained=False, progress=True, cnf: str = None, node: callable=None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param cnf: the name of spike-element-wise function
:type cnf: str
:param node: a single step neuron
:type node: callable
:param kwargs: kwargs for `node`
:type kwargs: dict
:return: Spiking ResNeXt-101 32x8d
:rtype: torch.nn.Module
The spike-element-wise ResNeXt-101 32x8d `"Deep Residual Learning in Spiking Neural Networks" `_ modified by the ResNeXt-101 32x8d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _sew_resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained, progress, cnf, node, **kwargs)
@register_model
def sew_wide_resnet50_2(pretrained=False, progress=True, cnf: str = None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param cnf: the name of spike-element-wise function
:type cnf: str
:param node: a single step neuron
:type node: callable
:param kwargs: kwargs for `node`
:type kwargs: dict
:return: Spiking Wide ResNet-50-2
:rtype: torch.nn.Module
The spike-element-wise Wide ResNet-50-2 `"Deep Residual Learning in Spiking Neural Networks" `_
modified by the Wide ResNet-50-2 model from `"Wide Residual Networks" `_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
"""
kwargs['width_per_group'] = 64 * 2
return _sew_resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained, progress, cnf, **kwargs)
@register_model
def sew_wide_resnet101_2(pretrained=False, progress=True, cnf: str = None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param cnf: the name of spike-element-wise function
:type cnf: str
:param node: a single step neuron
:type node: callable
:param kwargs: kwargs for `node`
:type kwargs: dict
:return: Spiking Wide ResNet-101-2
:rtype: torch.nn.Module
The spike-element-wise Wide ResNet-101-2 `"Deep Residual Learning in Spiking Neural Networks" `_
modified by the Wide ResNet-101-2 model from `"Wide Residual Networks" `_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
"""
kwargs['width_per_group'] = 64 * 2
return _sew_resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained, progress, cnf, **kwargs)
================================================
FILE: braincog/model_zoo/vgg_snn.py
================================================
# encoding: utf-8
# Author : Floyed
# Datetime : 2022/7/26 18:56
# User : Floyed
# Product : PyCharm
# Project : BrainCog
# File : vgg_snn.py
# explain :
from functools import partial
from torch.nn import functional as F
import torchvision
from timm.models import register_model
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.encoder.encoder import *
from braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule
from braincog.datasets import is_dvs_data
@register_model
class SNN7_tiny(BaseModule):
def __init__(self,
num_classes=10,
step=8,
node_type=LIFNode,
encode_type='direct',
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.num_classes = num_classes
self.node = node_type
if issubclass(self.node, BaseNode):
self.node = partial(self.node, **kwargs, step=step)
self.dataset = kwargs['dataset']
assert not is_dvs_data(self.dataset), 'SNN7_tiny only support static datasets now'
self.feature = nn.Sequential(
BaseConvModule(3, 16, kernel_size=(3, 3), padding=(1, 1), node=self.node),
BaseConvModule(16, 64, kernel_size=(3, 3), padding=(1, 1), node=self.node),
nn.MaxPool2d(2),
BaseConvModule(64, 128, kernel_size=(3, 3), padding=(1, 1), node=self.node),
BaseConvModule(128, 128, kernel_size=(3, 3), padding=(1, 1), node=self.node),
nn.MaxPool2d(2),
BaseConvModule(128, 256, kernel_size=(3, 3), padding=(1, 1), node=self.node),
BaseConvModule(256, 256, kernel_size=(3, 3), padding=(1, 1), node=self.node),
nn.MaxPool2d(2),
BaseConvModule(256, 512, kernel_size=(3, 3), padding=(1, 1), node=self.node),
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(512 * 4 * 4, self.num_classes),
)
def forward(self, inputs):
inputs = self.encoder(inputs)
self.reset()
if self.layer_by_layer:
x = self.feature(inputs)
x = self.fc(x)
x = rearrange(x, '(t b) c -> t b c', t=self.step).mean(0)
return x
else:
outputs = []
for t in range(self.step):
x = inputs[t]
x = self.feature(x)
x = self.fc(x)
outputs.append(x)
return sum(outputs) / len(outputs)
@register_model
class SNN5(BaseModule):
def __init__(self,
num_classes=10,
step=8,
node_type=LIFNode,
encode_type='direct',
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.n_preact = kwargs['n_preact'] if 'n_preact' in kwargs else False
self.num_classes = num_classes
self.node = node_type
if issubclass(self.node, BaseNode):
self.node = partial(self.node, **kwargs, step=step)
self.dataset = kwargs['dataset']
if not is_dvs_data(self.dataset):
init_channel = 3
else:
init_channel = 2
self.feature = nn.Sequential(
BaseConvModule(init_channel, 16, kernel_size=(3, 3), padding=(1, 1), node=self.node, n_preact=self.n_preact),
BaseConvModule(16, 64, kernel_size=(5, 5), padding=(2, 2), node=self.node, n_preact=self.n_preact),
nn.AvgPool2d(2),
BaseConvModule(64, 128, kernel_size=(5, 5), padding=(2, 2), node=self.node, n_preact=self.n_preact),
nn.AvgPool2d(2),
BaseConvModule(128, 256, kernel_size=(3, 3), padding=(1, 1), node=self.node, n_preact=self.n_preact),
nn.AvgPool2d(2),
BaseConvModule(256, 512, kernel_size=(3, 3), padding=(1, 1), node=self.node, n_preact=self.n_preact),
nn.AvgPool2d(2),
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(512 * 3 * 3, self.num_classes),
)
def forward(self, inputs):
inputs = self.encoder(inputs)
self.reset()
if self.layer_by_layer:
x = self.feature(inputs)
x = self.fc(x)
x = rearrange(x, '(t b) c -> t b c', t=self.step).mean(0)
return x
else:
outputs = []
for t in range(self.step):
x = inputs[t]
x = self.feature(x)
x = self.fc(x)
outputs.append(x)
return sum(outputs) / len(outputs)
@register_model
class VGG_SNN(BaseModule):
def __init__(self,
num_classes=10,
step=8,
node_type=LIFNode,
encode_type='direct',
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.n_preact = kwargs['n_preact'] if 'n_preact' in kwargs else False
self.num_classes = num_classes
self.node = node_type
if issubclass(self.node, BaseNode):
self.node = partial(self.node, **kwargs, step=step)
self.dataset = kwargs['dataset']
if not is_dvs_data(self.dataset):
raise NotImplementedError('VGG-SNN model is only for DVS data, but current datasets is {}'.format(self.dataset))
self.feature = nn.Sequential(
BaseConvModule(2, 64, kernel_size=(3, 3), padding=(1, 1), node=self.node),
BaseConvModule(64, 128, kernel_size=(3, 3), padding=(1, 1), node=self.node),
nn.AvgPool2d(2),
BaseConvModule(128, 256, kernel_size=(3, 3), padding=(1, 1), node=self.node),
BaseConvModule(256, 256, kernel_size=(3, 3), padding=(1, 1), node=self.node),
nn.AvgPool2d(2),
BaseConvModule(256, 512, kernel_size=(3, 3), padding=(1, 1), node=self.node),
BaseConvModule(512, 512, kernel_size=(3, 3), padding=(1, 1), node=self.node),
nn.AvgPool2d(2),
BaseConvModule(512, 512, kernel_size=(3, 3), padding=(1, 1), node=self.node),
BaseConvModule(512, 512, kernel_size=(3, 3), padding=(1, 1), node=self.node),
nn.AvgPool2d(2),
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(512 * 3 * 3, self.num_classes),
)
def forward(self, inputs):
inputs = self.encoder(inputs)
self.reset()
if self.layer_by_layer:
x = self.feature(inputs)
x = self.fc(x)
x = rearrange(x, '(t b) c -> t b c', t=self.step).mean(0)
return x
else:
outputs = []
for t in range(self.step):
x = inputs[t]
x = self.feature(x)
x = self.fc(x)
outputs.append(x)
return sum(outputs) / len(outputs)
================================================
FILE: braincog/utils.py
================================================
import os
import random
import math
import csv
import numpy as np
import torch
from torch import nn
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
def setup_seed(seed):
"""
为CPU,GPU,所有GPU,numpy,python设置随机数种子,并禁止hash随机化
:param seed: seed value
:return:
"""
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
os.environ['PYTHONHASHSEED'] = str(seed)
def random_gradient(model: nn.Module, sigma: float):
"""
为梯度添加噪声
:param model: 模型
:param sigma: 噪声方差
:return:
"""
for param in model.parameters():
if param.grad is None:
continue
noise = torch.randn_like(param) * sigma
param.grad = param.grad + noise
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.avg = 0
self.sum = 0
self.cnt = 0
def update(self, val, n=1):
self.sum += val * n
self.cnt += n
self.avg = self.sum / self.cnt
class TensorGather(object):
def __init__(self):
self.reset()
def reset(self):
self.gather=None
def update(self, val):
if self.gather is not None:self.gather=torch.cat([self.gather,val],dim=0)
else:self.gather=val
def accuracy(output, target, topk=(1,)):
"""Compute the top1 and top5 accuracy
"""
maxk = max(topk)
batch_size = target.size(0)
# Return the k largest elements of the given input tensor
# along a given dimension -> N * k
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def mse(x, y):
out = (x - y).pow(2).sum(-1, keepdim=True).mean()
return out
def rand_ortho(shape, irange):
A = - irange + 2 * irange * np.random.rand(*shape)
U, s, V = np.linalg.svd(A, full_matrices=True)
return np.dot(U, np.dot(np.eye(U.shape[1], V.shape[0]), V))
def adjust_surrogate_coeff(epoch, tot_epochs):
T_min, T_max = 1e-3, 1e1
Kmin, Kmax = math.log(T_min) / math.log(10), math.log(T_max) / math.log(10)
t = torch.tensor([math.pow(10, Kmin + (Kmax - Kmin) / tot_epochs * epoch)]).float().cuda()
k = torch.tensor([1]).float().cuda()
if k < 1:
k = 1 / t
return t, k
def save_feature_map(x, dir=''):
for idx, layer in enumerate(x):
layer = layer.cpu()
for batch in range(layer.shape[0]):
for channel in range(layer.shape[1]):
fname = '{}_{}_{}_{}.jpg'.format(
idx, batch, channel, layer.shape[-1])
fp = layer[batch, channel]
plt.tight_layout()
plt.axis('off')
plt.imshow(fp, cmap='inferno')
plt.savefig(os.path.join(dir, fname),
bbox_inches='tight', pad_inches=0)
def save_spike_info(fname, epoch, batch_idx, step, avg, var, spike, avg_per_step):
"""
对spike-info格式进行调整, 便于保存
:param fname: 输出文件名
:param epoch: epoch
:param batch_idx: batch index
:param step: 仿真步长
:param avg: 平均脉冲发放率
:param var: 脉冲发放率的方差
:param spike:
:param avg_per_step:
:return:
"""
if not os.path.exists(fname):
f = open(fname, mode='w', encoding='utf8', newline='')
writer = csv.writer(f)
head = ['epoch', 'batch', 'layer', 'avg', 'var']
head.extend(['st_{}'.format(i) for i in range(step + 1)]) # spike times
head.extend(['as_{}'.format(i) for i in range(step)]) # avg spike per time
writer.writerow(head)
else:
f = open(fname, mode='a', encoding='utf8', newline='')
writer = csv.writer(f)
for layer in range(len(avg)):
lst = [epoch, batch_idx, layer, avg[layer], var[layer]]
lst.extend(spike[layer])
lst.extend(avg_per_step[layer])
lst = [str(x) for x in lst]
writer.writerow(lst)
def calc_aurc(confidences, labels):
predictions = torch.argmax(confidences, dim=1)
max_confs = torch.max(confidences, dim=1)[0]
n = len(labels)
indices = torch.argsort(max_confs)
labels, predictions, confidences = labels[indices].flip(dims=[0]), predictions[indices].flip(dims=[0]), confidences[indices].flip(dims=[0])
risk_cov = torch.divide(torch.cumsum(labels != predictions,dim=0).float(), torch.arange(1, n+1).cuda())
nrisk = torch.sum(labels != predictions)
aurc = torch.mean(risk_cov)
opt_aurc = (1./n) * torch.sum(torch.divide(torch.arange(1, nrisk + 1).cuda().float(), n - nrisk + torch.arange(1, nrisk + 1).cuda()))
eaurc = aurc - opt_aurc
return aurc, eaurc
================================================
FILE: docs/Makefile
================================================
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
================================================
FILE: docs/make.bat
================================================
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd
================================================
FILE: docs/source/conf.py
================================================
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
import warnings
warnings.filterwarnings("ignore")
sys.path.insert(0, os.path.abspath('../../braincog'))
# -- Project information -----------------------------------------------------
project = 'braincog'
copyright = '2022, Brain-Inspired-Cognitive-Intelligence-Engine(BrainCog)'
author = 'Brain-Inspired-Cognitive-Intelligence-Engine'
# The full version, including alpha/beta/rc tags
release = '0.2.7.11'
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.napoleon',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'recommonmark',
# 'sphinx_markdown_tables',
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = 'zh_CN'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
================================================
FILE: docs/source/examples/Brain_Cognitive_Function_Simulation/drosophila.md
================================================
# Drosophila-inspired decision-making SNN
## Run
"drosophila.py" includes the training phase and testing phase.
```shell
python drosophila.py
```
* Training Phase
green-upright T is safe and blue-inverted T is dangerous
* Testing Phase
For linear pathway and nonlinear pathway, choose between blue-upright T and green-inverted T, and count the PI values under different color intensity
## Results
The following picture shows the linear (a) and nonlinear (b) pathways, the training and testing phases (c), and the PI values on different color intensities (d).

================================================
FILE: docs/source/examples/Brain_Cognitive_Function_Simulation/index.rst
================================================
Brain_Cognitive_Function_Simulation
======================================
.. toctree::
:maxdepth: 2
drosophila
================================================
FILE: docs/source/examples/Decision_Making/BDM_SNN.md
================================================
# Brain-inspired Decision-Making Spiking Neural Network
## Run
"BDM-SNN.py" includes the multi-brain regions coordinated decision-making spiking neural network with LIF neurons.
"BDM-SNN-hh.py" includes the BDM-SNN with simplified HH neurons.
"BDM-SNN-UAV.py" includes the BDM-SNN applied to the UAV (DJI Tello talent), users need to define the reinforcement learning task.
```shell
python BDM-SNN.py
python BDM-SNN-hh.py
python BDM-SNN-UAV.py
```
## Results
"BDM-SNN.py" and "BDM-SNN-hh.py" have been verified on Flappy Bird game. BDM-SNN could stably pass the pipeline on the first try.

================================================
FILE: docs/source/examples/Decision_Making/RL.md
================================================
# PL-SDQN
This repository contains code from our paper [**Solving the Spike Feature Information Vanishing Problem in Spiking Deep Q Network with Potential Based Normalization**]. If you use our code or refer to this project, please cite this paper.
## Requirments
* numpy
* scipy
* pytorch >= 1.7.0
* torchvision
* gym
* atari-py
* opencv-python
* tianshou
## Train
```python
python ./sdqn/main.py
```
================================================
FILE: docs/source/examples/Decision_Making/index.rst
================================================
Decision Making
======================================
.. toctree::
:maxdepth: 2
RL
BDM_SNN
================================================
FILE: docs/source/examples/Knowledge_Representation_and_Reasoning/CKRGSNN.md
================================================
# Commonsense Knowledge Representation SNN
(https://arxiv.org/abs/2207.05561)
This repository contains code from our paper [**Brain-inspired Graph Spiking Neural Networks for Commonsense Knowledge Representation and Reasoning**] preprint in: https://arxiv.org/abs/2207.05561 . If you use our code or refer to this project, please cite this paper.
## Requirments
* python=3.8
* numpy
* scipy
* turicreate
* pytorch >= 1.7.0
* torchvision
## Dataset
ConceptNet: https://github.com/commonsense/conceptnet5
## Run
```shell
python main.py
```
This module selects core knowledge in ConceptNet to form the sub_Concept.csv file as the input of the model. The input current, spike trains during the learning process and the network weight distribution after the learning are shown in the Results folder.
================================================
FILE: docs/source/examples/Knowledge_Representation_and_Reasoning/CRSNN.md
================================================
# Causal Reasoning SNN
(https://10.1109/IJCNN52387.2021.9534102)
This repository contains code from our paper [**A Brain-Inspired Causal Reasoning Model Based on Spiking Neural Networks
**] published in 2021 International Joint Conference on Neural Networks (IJCNN). https://ieeexplore.ieee.org/abstract/document/9534102. If you use our code or refer to this project, please cite this paper.
## Requirments
* numpy
* scipy
* pytorch >= 1.7.0
* torchvision
## Run
```shell
python main.py
```
This module builds an example of a brain-like causal inference spiking neural network model. The input causal graph is shown in figure causal_graph.png. The input current, spike trains during the learning process and the network weight distribution after the learning are shown in the Results folder.
================================================
FILE: docs/source/examples/Knowledge_Representation_and_Reasoning/SPSNN.md
================================================
# Sequence Production SNN
[]
This repository contains code from our paper [**Brain Inspired Sequences Production by Spiking Neural Networks With Reward-Modulated STDP**] published in Frontiers in Computational Neuroscience. https://www.frontiersin.org/articles/10.3389/fncom.2021.612041/full. If you use our code or refer to this project, please cite this paper.
## Requirments
* numpy
* scipy
* pytorch >= 1.7.0
* torchvision
## Run
```shell
python main.py file
```
This module builds a sequence Production spiking neural network model, realizing the memory and reconstruction functions for arbitrary symbol sequences. The input causal graph is shown in figure causal_graph.png. The input current, spike trains during the learning process and the network weight distribution after the learning are shown in the Results folder.
================================================
FILE: docs/source/examples/Knowledge_Representation_and_Reasoning/index.rst
================================================
Knowledge Representation and Reasoning
=========================================
.. toctree::
:maxdepth: 2
musicMemory
SPSNN
CRSNN
CKRGSNN
================================================
FILE: docs/source/examples/Knowledge_Representation_and_Reasoning/musicMemory.md
================================================
# Music Memory
数据集:http://www.piano-midi.de/
自行下载数据集,数据使用方法见task下的示例。
================================================
FILE: docs/source/examples/Multi-scale_Brain_Structure_Simulation/Corticothalamic_minicolumn.md
================================================
# Corticothalamic minicolumn
## Description
The anatomical data is saved in the "tool" package. The **main.py** create the network of minicolumn deppending on the anatomical data.
A file named **"fire.csv"** will be generated to record the firing result of neurons in each time step.
## Requirments
* numpy
* scipy
* pytorch >= 1.7.0
```shell
python main.py
```
================================================
FILE: docs/source/examples/Multi-scale_Brain_Structure_Simulation/HumanBrain.md
================================================
# Human Brain Simulation
## Description
Human Brain Simulation is a large scale brain modeling framework depending on braincog framework.
## Requirements:
* numpy >= 1.21.2
* scipy >= 1.8.0
* h5py >= 3.6.0
* torch >= 1.10
* torchvision >= 0.12.0
* torchaudio >= 0.11.0
* timm >= 0.5.4
* matplotlib >= 3.5.1
* einops >= 0.4.1
* thop >= 0.0.31
* pyyaml >= 6.0
* loris >= 0.5.3
* pandas >= 1.4.2
* tonic (special)
* pandas >= 1.4.2
## Example:
```shell
cd ~/examples/Multi-scale Brain Structure Simulation/HumanBrain/
python brainSimHum.py
```
## Parameters:
To simulate the models (both human and macaque brain), the parameters of the neuron number in each region and the connectome power between regions can be set flexibly in the main function (nsz and asz) of the .py files.
================================================
FILE: docs/source/examples/Multi-scale_Brain_Structure_Simulation/Human_PFC.md
================================================
# Human PFC
## Input:
* 程序输入六层皮质柱的电生理数据,数据文件名中的数字表示神经元数量。程序默认有背景电流输入。其中data+数字命名的是随机输入刺激的文件,输入图片刺激的文件分别有人类参数的和小鼠参数的。分别都支持四种形状的图片文件,圆形、正方形、三角形和星形。
链接:https://drive.google.com/drive/folders/1AVc2aNTxkcsGAPlq1SuWtatGzyQRPCmp?usp=sharing
## output
* 程序生成的各个神经元放电时间点记录的数据文件
## application:
* 程序中可以修改每次PFC模型放电环境的情况
================================================
FILE: docs/source/examples/Multi-scale_Brain_Structure_Simulation/MacaqueBrain.md
================================================
# Macaque Brain Simulation
## Description
Macaque Brain Simulation is a large scale brain modeling framework depending on braincog framework.
## Requirements:
* numpy >= 1.21.2
* scipy >= 1.8.0
* h5py >= 3.6.0
* torch >= 1.10
* torchvision >= 0.12.0
* torchaudio >= 0.11.0
* timm >= 0.5.4
* matplotlib >= 3.5.1
* einops >= 0.4.1
* thop >= 0.0.31
* pyyaml >= 6.0
* loris >= 0.5.3
* pandas >= 1.4.2
* tonic (special)
* pandas >= 1.4.2
## Example:
```shell
cd ~/examples/Multi-scale Brain Structure Simulation/MacaqueBrain/
python brainSimMaq.py
```
## Parameters:
To simulate the models (both human and macaque brain), the parameters of the neuron number in each region and the connectome power between regions can be set flexibly in the main function (nsz and asz) of the .py files.
================================================
FILE: docs/source/examples/Multi-scale_Brain_Structure_Simulation/index.rst
================================================
Multi-scale Brain Structure Simulation
=========================================
.. toctree::
:maxdepth: 2
MacaqueBrain
HumanBrain
mouse_brain
Human_PFC
Corticothalamic_minicolumn
================================================
FILE: docs/source/examples/Multi-scale_Brain_Structure_Simulation/mouse_brain.md
================================================
# Mouse Brain
## Input:
* 程序输入213个脑区之间的连接权重的表格。放在谷歌网盘上面,名称是'W_213.xlsx'。
链接:[https://drive.google.com/drive/folders/1snPbpiVBpVuRgRYcl4AG4v49NgKKwBtA?usp=sharing](https://drive.google.com/drive/folders/1snPbpiVBpVuRgRYcl4AG4v49NgKKwBtA?usp=sharing)
## output
* 程序生成的各个神经元放电时间点记录的数据mat文件,数据点的数量大需要用画图软件显示结果。
## application:
* 程序中可以修改神经元的数量模拟时间的情况
================================================
FILE: docs/source/examples/Perception_and_Learning/Conversion.md
================================================
# Conversion Method
Training deep spiking neural network with ann-snn conversion
replace ReLU and MaxPooling in pytorch model to make origin ANN to be converted SNN to finish complex tasks
## Results
```shell
python CIFAR10_VGG16.py
python converted_CIFAR10.py
```
You should first run the `CIFAR10_VGG16.py` to get a well-trained ANN.
Then `converted_CIFAR10.py` can be used to run the snn inference process.
================================================
FILE: docs/source/examples/Perception_and_Learning/MultisensoryIntegration.md
================================================
# Multisensory Integration DEMO
In `MultisensoryIntegrationDEMO_AM.py` and `MultisensoryIntegrationDEMO_AM.py`, we implement the SNNs based multisensory integration framework. To load the dataset, preprocess it and get the weights with the function `get_concept_datase_dic_and_initial_weights_lst()`. We use `IMNet` or `AMNet` to describe the structure of the IM/AM model. For presynaptic neuron, we use the function `convert_vec_into_spike_trains()` to generate the spike trains.
While for postsynaptic neuron, we use the function `reducing_tol_binarycode()` to get the multisensory integrated output for each concept. And *tol* is the only parameter.
In `measure_and_visualization.py`, we will measure and visualize the results.
## Multisensory Dataset
When implement the model in braincog, we use the famous multisensory dataset--BBSR.
Some examples are as follows:
| Concept | Visual | Somatic | Audiation | Taste | Smell |
| --------- | ----------- | --------- | ----------- | -------- | -------- |
| advantage | 0.213333333 | 0.032 | 0 | 0 | 0 |
| arm | 2.5111112 | 2.2733334 | 0.133333286 | 0.233333 | 0.4 |
| ball | 1.9580246 | 2.3111112 | 0.523809429 | 0.185185 | 0.111111 |
| baseball | 2.2714286 | 2.6071428 | 0.352040714 | 0.071429 | 0.392857 |
| bee | 2.795698933 | 2.4129034 | 2.096774286 | 0.290323 | 0.419355 |
| beer | 1.4866666 | 2.2533334 | 0.190476286 | 5.8 | 4.6 |
| bird | 2.7632184 | 2.027586 | 3.064039286 | 1.068966 | 0.517241 |
| car | 2.521839133 | 2.9517244 | 2.216748857 | 0 | 2.206897 |
| foot | 2.664444533 | 2.58 | 0.380952429 | 0.433333 | 3 |
| honey | 1.757142867 | 2.3214286 | 0.015306143 | 5.642857 | 4.535714 |
## How to Run
To get the multisensory integrated vectors:
```
cd examples/MultisensoryIntegration/code
python MultisensoryIntegrationDEMO_AM.py
python MultisensoryIntegrationDEMO_IM.py
```
To measure and analysis the vectors:
```
cd examples/MultisensoryIntegration/code
python measure_and_visualization.py
```
================================================
FILE: docs/source/examples/Perception_and_Learning/QSNN.md
================================================
# Quantum superposition inspired spiking neural network
This repository contains code from our paper [**Quantum superposition inspired spiking neural network**] published in iScience. https://doi.org/10.1016/j.isci.2021.102880. If you use our code or refer to this project, please cite this paper.
## Requirments
* numpy
* scipy
* pytorch >= 1.7.0
* torchvision
## Data preparation
First modify the ```DATA_DIR='path/to/datasets``` in ```examples/QSNN/main.py``` to the root directory of your MNIST datasets.
## Train
```shell
python ./main.py
```
================================================
FILE: docs/source/examples/Perception_and_Learning/UnsupervisedSTDP.md
================================================
# Unsupervised STDP
This is an example of training Unsupervised STDP-based spiking neural network. We used a STB-STDP algrithom to train SNN, and mutiply adaptive mechanisms.
## How to run
python codef.py
## Result
We train the model on Mnist and FashionMNIST, and the best accuracy for MNIST is 97.9%, for FashionMNIST is 87.0%.
================================================
FILE: docs/source/examples/Perception_and_Learning/img_cls/bp.md
================================================
# Script for training high-performance SNNs based on back propagation
This is an example of training high-performance SNNs using the braincog.
It is able to train high performance SNNs on CIFAR10, DVS-CIFAR10, ImageNet and other datasets, and reach the advanced level.
## Install braincog
```shell
git clone https://github.com/xxx/Brain-Cog.git
cd braincog
python setup install --user
```
## Examples of training
```shell
cd examples/img_cls/bp
python main.py --model dvs_convnet --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGate --device 0
```
## Benchmark
We provide a benchmark of SNNs trained with braincog and the corresponding scripts.
This provides an open, fair platform for comparison of subsequent SNNs on classification tasks.
**Note**: The results may vary due to random seeding and software version issues.
### CIFAR10
| ID | Dataset | Node-type | Config | Model | Batch Size | Accuracy | Script |
|:----|:-------:|:----------:|:------:|:-----------:|:----------:|:------------:|:-------------------------------------------------------------------------------------------------------------------------------------------|
| 1 | CIFAR10 | IF+Atan | - | convnet | 128 | 95.54 | ```python main.py --model cifar_convnet --node-type IFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR10 | LIF+Atan | - | convnet | 128 | 91.92 | ```python main.py --model cifar_convnet --node-type LIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR10 | PLIF+Atan | - | convert | 128 | 93.32 | ```python main.py --model cifar_convnet --node-type PLIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR10 | IF+Atan | - | resnet18 | 128 | 89.76/89.80 | ```python main.py --model resnet18 --node-type IFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR10 | LIF+Atan | - | resnet18 | 128 | 89.93/89.88 | ```python main.py --model resnet18 --node-type LIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR10 | PLIF+Atan | - | resnet18 | 128 | 92.64/ 90.65 | ```python main.py --model resnet18 --node-type PLIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR10 | IF+QGate | - | dvs_convnet | 128 | 95.73 | ```python main.py --model cifar_convnet --node-type IFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | CIFAR10 | LIF+QGate | - | dvs_convnet | 128 | 96.04 | ```python main.py --model cifar_convnet --node-type LIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | CIFAR10 | PLIF+QGate | - | dvs_convnet | 128 | 96.04/95.84 | ```python main.py --model cifar_convnet --node-type PLIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | CIFAR10 | IF+QGate | - | resnet18 | 128 | 89.19 | ```python main.py --model resnet18 --node-type IFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | CIFAR10 | LIF+QGate | - | resnet18 | 128 | 90.95/90.68 | ```python main.py --model resnet18 --node-type LIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | CIFAR10 | PLIF+QGate | - | resnet18 | 128 | 90.97/91.02 | ```python main.py --model resnet18 --node-type PLIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGrad --device 0``` |
### CIFAR100
| ID | Dataset | Node-type | Config | Model | Batch Size | Accuracy | Script |
|:----|:--------:|:----------:|:------:|:-----------:|:----------:|:--------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 1 | CIFAR100 | IF+Atan | - | dvs_convnet | 128 | 76.52 | ```python main.py --num-classes 100 --model cifar_convnet --node-type IFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR100 | LIF+Atan | - | dvs_convnet | 128 | 71.89 | ```python main.py --num-classes 100 --model cifar_convnet --node-type LIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR100 | PLIF+Atan | - | dvs_convnet | 128 | 72.82 | ```python main.py --num-classes 100 --model cifar_convnet --node-type PLIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR100 | IF+Atan | - | resnet18 | 128 | 62.47 | ```python main.py --num-classes 100 --model resnet18 --node-type IFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR100 | LIF+Atan | - | resnet18 | 128 | 62.63 | ```python main.py --num-classes 100 --model resnet18 --node-type LIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR100 | PLIF+Atan | - | resnet18 | 128 | 62.71 | ```python main.py --num-classes 100 --model resnet18 --node-type PLIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR100 | IF+QGate | - | dvs_convnet | 128 | 76.44 | ```python main.py --num-classes 100 --model cifar_convnet --node-type IFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | CIFAR100 | LIF+QGate | - | dvs_convnet | 128 | 77.73 | ```python main.py --num-classes 100 --model cifar_convnet --node-type LIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | CIFAR100 | PLIF+QGate | - | dvs_convnet | 128 | 77.25 | ```python main.py --num-classes 100 --model cifar_convnet --node-type PLIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | CIFAR100 | IF+QGate | - | resnet18 | 128 | 60.01 | ```python main.py --num-classes 100 --model resnet18 --node-type IFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | CIFAR100 | LIF+QGate | - | resnet18 | 128 | 61.33 | ```python main.py --num-classes 100 --model resnet18 --node-type LIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | CIFAR100 | PLIF+QGate | - | resnet18 | 128 | 62.32 | ```python main.py --num-classes 100 --model resnet18 --node-type PLIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGrad --device 0``` |
### DVS-CIFAR10
| ID | Dataset | Node-type | Config | Model | Batch Size | FLOPS | Accuracy | Script |
|:----|:-----------:|:----------:|:------:|:-----------:|:----------:|:-----:|:-----------:|:-----------------------------------------------------------------------------------------------------------------------------------------|
| 1 | DVS-CIFAR10 | IF+Atan | - | dvs_convnet | 128 | 7503 | 65.90 | ```python main.py --model dvs_convnet --node-type IFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-CIFAR10 | LIF+Atan | - | dvs_convnet | 128 | 7503 | 82.10 | ```python main.py --model dvs_convnet --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-CIFAR10 | PLIF+Atan | - | dvs_convnet | 128 | 7503 | 81.90 | ```python main.py --model dvs_convnet --node-type PLIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-CIFAR10 | IF+Atan | - | resnet18 | 128 | 3149 | 69.10 | ```python main.py --model resnet18 --node-type IFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-CIFAR10 | LIF+Atan | - | resnet18 | 128 | 3149 | 78.50 | ```python main.py --model resnet18 --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-CIFAR10 | PLIF+Atan | - | resnet18 | 128 | 3149 | 77.70 | ```python main.py --model resnet18 --node-type PLIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-CIFAR10 | IF+QGate | - | dvs_convnet | 128 | 7503 | 68.30 | ```python main.py --model dvs_convnet --node-type IFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | DVS-CIFAR10 | LIF+QGate | - | dvs_convnet | 128 | 7503 | 82.60/82.90 | ```python main.py --model dvs_convnet --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | DVS-CIFAR10 | PLIF+QGate | - | dvs_convnet | 128 | 7503 | 83.20 | ```python main.py --model dvs_convnet --node-type PLIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | DVS-CIFAR10 | IF+QGate | - | resnet18 | 128 | 3149 | 65.70/66.80 | ```python main.py --model resnet18 --node-type IFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | DVS-CIFAR10 | LIF+QGate | - | resnet18 | 128 | 3149 | 79.00/79.40 | ```python main.py --model resnet18 --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | DVS-CIFAR10 | PLIF+QGate | - | resnet18 | 128 | 3149 | 78.10/78.20 | ```python main.py --model resnet18 --node-type PLIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |
### DVS-Gesture
| ID | Dataset | Node-type | Config | Model | Batch Size | Accuracy | Script |
|:----|:-------:|:----------:|:------:|:-----------:|:----------:|:-----------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------|
| 1 | DVS-G | IF+Atan | - | dvs_convnet | 128 | 64.77 | ```python main.py --num-classes 11 --model dvs_convnet --node-type IFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-G | LIF+Atan | - | dvs_convnet | 128 | 91.28 | ```python main.py --num-classes 11 --model dvs_convnet --node-type LIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-G | PLIF+Atan | - | dvs_convnet | 128 | 91.67 | ```python main.py --num-classes 11 --model dvs_convnet --node-type PLIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-G | IF+Atan | - | resnet18 | 128 | 63.25 | ```python main.py --num-classes 11 --model resnet18 --node-type IFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-G | LIF+Atan | - | resnet18 | 128 | 91.29 | ```python main.py --num-classes 11 --model resnet18 --node-type LIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-G | PLIF+Atan | - | resnet18 | 128 | 90.15 | ```python main.py --num-classes 11 --model resnet18 --node-type PLIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-G | IF+QGate | - | dvs_convnet | 128 | 48.48 | ```python main.py --num-classes 11 --model dvs_convnet --node-type IFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | DVS-G | LIF+QGate | - | dvs_convnet | 128 | 92.05/92.42 | ```python main.py --num-classes 11 --model dvs_convnet --node-type LIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | DVS-G | PLIF+QGate | - | dvs_convnet | 128 | 91.28 | ```python main.py --num-classes 11 --model dvs_convnet --node-type PLIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | DVS-G | IF+QGate | - | resnet18 | 128 | 57.95 | ```python main.py --num-classes 11 --model resnet18 --node-type IFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | DVS-G | LIF+QGate | - | resnet18 | 128 | 90.91 | ```python main.py --num-classes 11 --model resnet18 --node-type LIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | DVS-G | PLIF+QGate | - | resnet18 | 128 | 92.42 | ```python main.py --num-classes 11 --model resnet18 --node-type PLIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |
### NCALTECH101
| ID | Dataset | Node-type | Config | Model | Batch Size | Accuracy | Script |
|:----|:-----------:|:----------:|:------:|:-----------:|:----------:|:-----------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 1 | NCALTECH101 | IF+QGate | - | dvs_convnet | 128 | 23.09/51.15 | ```python main.py --num-classes 100 --model dvs_convnet --node-type IFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | NCALTECH101 | LIF+QGate | - | dvs_convnet | 128 | 72.78/75.09 | ```python main.py --num-classes 100 --model dvs_convnet --node-type LIFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | NCALTECH101 | PLIF+QGate | - | dvs_convnet | 128 | 74.61/76.79 | ```python main.py --num-classes 100 --model dvs_convnet --node-type PLIFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | NCALTECH101 | IF+QGate | -/mix | resnet18 | 128 | 61.24/60.87 | ```python main.py --num-classes 100 --model resnet18 --node-type IFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | NCALTECH101 | LIF+QGate | -/mix | resnet18 | 128 | 66.22/70.84 | ```python main.py --num-classes 100 --model resnet18 --node-type LIFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |
| 1 | NCALTECH101 | PLIF+QGate | -/mix | resnet18 | 128 | 69.62/69.87 | ```python main.py --num-classes 100 --model resnet18 --node-type PLIFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |
Note:
1. resnet18 is used here by adding a maximum pooling after the initial convolution layer.
However, in the final version of braincog, we remove this pooling layer.
2. mix refers to the use of EventMix as a data augmentation method.
3. We will continue to add other results.
### Citation
If you find this package helpful, please consider citing it:
```BibTex
@software{name,
author = {xxx},
title = {braincog: xxx},
month = jul,
year = 2022,
note = {{Documentation available under
https://xxx.readthedocs.io}},
publisher = {Zenodo},
version = {xxx},
doi = {xxx},
url = {xxx}
}
```
================================================
FILE: docs/source/examples/Perception_and_Learning/img_cls/glsnn.md
================================================
# SNN with global feedback connections
Training deep spiking neural network with the global
feedback connections and the local optimization learning rules. And is a little different from our original paper.
## Results
```shell
python cls_glsnn.py
```
We train the model for 100 epochs, and the best accuracy for MNIST is 98.23\%, for FashionMNIST is 89.68\%.

================================================
FILE: docs/source/examples/Perception_and_Learning/img_cls/index.rst
================================================
Examples for Image Classification
=================================
.. toctree::
:maxdepth: 2
bp
glsnn
================================================
FILE: docs/source/examples/Perception_and_Learning/index.rst
================================================
Perception and Learning
=================================
.. toctree::
:maxdepth: 2
img_cls/index
Conversion
UnsupervisedSTDP
QSNN
MultisensoryIntegration
================================================
FILE: docs/source/examples/Social_Cognition/Mirror_Test.md
================================================
# Mirror Test
The mirror_test.py implements the core code of the Multi-Robots Mirror Self-Recognition Test in "Toward Robot Self-Consciousness (II): Brain-Inspired Robot Bodily Self Model for Self-Recognition".
The experiment is: three robots with identical appearance move their arms randomly in front of the mirror at the same time.
In the training stage, according to the spiking time difference of neurons in IPLM and IPLV, the robot learns the correlations between self-generated actions and visual feedbacks in motion by learning with spike timing dependent plasticity (STDP) mechanism.
In the test stage, the robot can predicts the visual feedback generated by its arm movement according to the training results. With the InsulaNet, the robot can identify which mirror image belongs to it.
In the result, Motion Detection shows the results of visual detection, and Motion Prediction shows the visual feedback generated by itself. The red line in the figure indicates that the robot determines that the corresponding mirror belongs to itself.
Differences from the original article:
Since there is no motion error under the simulation conditions, the theta_threshold is set to zero.
### Citation
If you find this package helpful, please consider citing it:
```BibTex
@article{zeng2018toward,
title={Toward robot self-consciousness (ii): brain-inspired robot bodily self model for self-recognition},
author={Zeng, Yi and Zhao, Yuxuan and Bai, Jun and Xu, Bo},
journal={Cognitive Computation},
volume={10},
number={2},
pages={307--320},
year={2018},
publisher={Springer}
}
```
================================================
FILE: docs/source/examples/Social_Cognition/ToM.md
================================================
# ToM
## Requirments
* numpy
* scipy
* pytorch >= 1.7.0
* torchvision
* pygame
## Run
### Train
* the file to be run: main_both.py
* args:
* the path to save net_NPC: --save_net_N
* the path to save net_a: --save_net_a
* time steps: --T
```bash
python main_both.py --save_net_N=net_NPC.pth --save_net_a=net_agent.pth --episodes=45 --trajectories=30 --T=50 --mode=train --task=both
```
### Test
```bash
python main_ToM.py --save_net_N=net_NPC.pth --save_net_a=net_agent.pth --episodes=45 --trajectories=30 --T=50 --mode=train --task=both
```
================================================
FILE: docs/source/examples/Social_Cognition/index.rst
================================================
Social Cognition
======================================
.. toctree::
:maxdepth: 2
ToM
Mirror_Test
================================================
FILE: docs/source/examples/index.rst
================================================
Examples
=================================
.. toctree::
:maxdepth: 2
Perception_and_Learning/index
Brain_Cognitive_Function_Simulation/index
Decision_Making/index
Knowledge_Representation_and_Reasoning/index
Social_Cognition/index
Multi-scale_Brain_Structure_Simulation/index
================================================
FILE: docs/source/index.rst
================================================
.. braincog documentation master file, created by
sphinx-quickstart on Sun Apr 10 21:02:06 2022.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to braincog's documentation!
====================================
.. toctree::
:maxdepth: 2
tutorial/index
braincog
examples/index
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
================================================
FILE: docs/source/modules.rst
================================================
braincog
========
.. toctree::
:maxdepth: 4
braincog
================================================
FILE: docs/source/setup.rst
================================================
setup module
============
.. automodule:: setup
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs.md
================================================
# Sphinx 文档教程
## 安装
131 braincog 环境已经装好了
```shell
pip install sphinx sphinx-rtd-theme recommonmark
```
## 配置
已经配置好了, 直接用就行了
```shell
sphinx-quickstart
```
## 编译
### braincog 之中的, 编译在Brain Docs之中
1. 重新从 repo 中抓取 ```rst``` 文本
```shell
cd braincog/docs
rm -rf ./source/braincog*rst
sphinx-apidoc -o ./source/ ../braincog -f
```
2. 编译 html
```shell
make clean
make html
```
### Examples 的编译
1. 在 ```braincog/docs/source/index.rst``` 中, ```img_cls/Tutorial``` 后面一行添加 ``xxx/Tutorial``.
2. 然后在 ``Brain/docs/source`` 下面添加 ``xxx.md`` 文件, 要和上面的 ``xxx`` 同名.
3. 用 [Markdown](https://markdown.com.cn/basic-syntax/) 语法, 编写教程, 怎么用, 效果是啥.
4. 编译html
```shell
make clean
make html
```
## 查看
编译好的文件可以在 ```braincog/docs/build/html``` 中查看.
## 上传
在130服务器上面:
```shell
sudo cp braincog/docs/build/html/* /var/www/html
```
就可以更新文档了, 并在 [172.18.116.130](http://172.18.116.130/index.html) 中看到.
================================================
FILE: documents/Data_engine.md
================================================
# BrainCog Data Engine
In addition to the static datasets, BrainCog supports the commonly used neuromorphic
datasets, such as DVSGesture, DVSCIFAR10, NCALTECH101, ES-ImageNet.
Also, the neuromorphic dataset N-Omniglot for few-shot learning is also integrated into
BrainCog.
**[DVSGesture](https://openaccess.thecvf.com/content_cvpr_2017/papers/Amir_A_Low_Power_CVPR_2017_paper.pdf)**
This dataset contains 11 hand gestures from 29 subjects under 3 illumination conditions recorded using a DVS128.
**[DVSCIFAR10](https://www.frontiersin.org/articles/10.3389/fnins.2017.00309/full)**
This dataset converts 10,000 frame-based images in the CIFAR10 dataset into 10,000 event streams using a dynamic vision sensor.
**[NCALTECH101](https://www.frontiersin.org/articles/10.3389/fnins.2015.00437/full)**
The NCaltech101 dataset is captured by mounting the ATIS sensor on a motorized pan-tilt unit and having the sensor move while it views Caltech101 examples on an LCD monitor.
The "Faces" class has been removed from N-Caltech101, leaving 100 object classes plus a background class
**[ES-ImageNet](https://arxiv.org/abs/2110.12211)**
The dataset is converted with Omnidirectional Discrete Gradient (ODG) from 1,300,000 frame-based images in the ImageNet dataset into event-stream samples, which has 1000 categories.
**[N-Omniglot](https://www.nature.com/articles/s41597-022-01851-z)**
This dataset contains 1,623 categories of handwritten characters, with only 20 samples per class.
The dataset is acquired with the DVS acquisition platform to shoot videos (generated from the original Omniglot dataset) played on the monitor, and use the Robotic Process Automation (RPA) software to collect the data automatically.
You can easily use them in the braincog/datasets folder, taking DVSCIFAR10 as an example
```python
loader_train, loader_eval,_,_ = get_dvsc10_data(batch_size=128,step=10)
```
================================================
FILE: documents/Lectures.md
================================================
# Lectures
- [BrainCog Talk] Beginning BrainCog Lecture 32. Structural Modeling and Neural Activity Simulation of Mammalian Corticothalamic Functional Minicolumn on BrainCog [[English Version](https://www.youtube.com/watch?v=xwcu3yHe4FQ&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=30), [Chinese Version](https://www.bilibili.com/video/BV1aN411v7BG/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]
- [BrainCog Talk] Beginning BrainCog Lecture 31. A Brain-inspired Theory of Mind Spiking Neural Network Improves Multi-agent Cooperation [[English Version](https://www.youtube.com/watch?v=xwcu3yHe4FQ&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=30), [Chinese Version](https://www.bilibili.com/video/BV11m4y1N7Bh/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]
- [BrainCog Talk] Beginning BrainCog Lecture 30. FPGA-based Spiking Neural Networks Acceleration Using BrainCog [[English Version](https://www.youtube.com/watch?v=xwcu3yHe4FQ&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=30), [Chinese Version](https://www.bilibili.com/video/BV1mF411Q7Rr/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
- [BrainCog Talk] Beginning BrainCog Lecture 29. Using BrainCog to Realize Robot's Intention Prediction for Users [[English Version](https://www.youtube.com/watch?v=vlwRpq-HwDo&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=29), [Chinese Version](https://www.bilibili.com/video/BV1oz4y1H7QL/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
- [BrainCog Talk] Beginning BrainCog Lecture 28. Reward-modulated brain-inspired spiking neural network to empower group for nature-inspired self-organized obstacle avoidance implemented by Braincog [[English Version](https://www.youtube.com/watch?v=2X2KkG0KXo4&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=28), [Chinese Version](https://www.bilibili.com/video/BV16h411N7tG/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
- [BrainCog Talk] Beginning BrainCog Lecture 27. Drosophila-inspired Linear and Nonlinear Decision-Making Spiking Neural Network implemented by Braincog [[English Version](https://www.youtube.com/watch?v=j49q33dMbYk&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=27), [Chinese Version](https://www.bilibili.com/video/BV1Ch4y147i6/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
- [BrainCog Talk] Beginning BrainCog Lecture 26. Spiking CapsNet based on BrainCog [[English Version](https://www.youtube.com/watch?v=Gn1PMMFnD6M&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=26), [Chinese Version](https://www.bilibili.com/video/BV1M24y1N7Re/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
- [BrainCog Talk] Beginning BrainCog Lecture 25. SNNs with adaptive self-feedback and balanced excitatory–inhibitory neurons based on BrainCog [[English Version](https://www.youtube.com/watch?v=pqeagXtYk8w&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=25), [Chinese Version](https://www.bilibili.com/video/BV18g4y1j7WA/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
- [BrainCog Talk] Beginning BrainCog Lecture 24. DVS Augmentation Strategies based on BrainCog [[English Version](https://www.youtube.com/watch?v=aok8B34DNVs&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=24), [Chinese Version](https://www.bilibili.com/video/BV1324y1L7QP/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
- [BrainCog Talk] Beginning BrainCog Lecture 23. The Implement of Object Detection and Semantic Segmentation Based on SNNs with Braincog [[English Version](https://www.youtube.com/watch?v=2kh7rNwQkkY&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=23), [Chinese Version](https://www.bilibili.com/video/BV1XT411r7ak/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]
- [BrainCog Talk] Beginning BrainCog Lecture 22. BrainCog Data Engine: Spatio-temporal Sequence Data N-Omniglot and Its Applications [[English Version](https://www.youtube.com/watch?v=J5qBIX9f8Ns&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=22), [Chinese Version](https://www.bilibili.com/video/BV1uT411a733/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]
- [BrainCog Talk] Beginning BrainCog Lecture 21. Dynamic structural development for SNNs incorporating constraints, pruning and regeneration based on BrainCog [[English Version](https://www.youtube.com/watch?v=OccvvdyiOgY&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=21), [Chinese Version](https://www.bilibili.com/video/BV1Zv4y1Y72h/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]
- [BrainCog Talk] Beginning BrainCog Lecture 20. Developmental Plasticity-inspired Adaptive Pruning for SNNs based on BrainCog [[English Version](https://www.youtube.com/watch?v=ckrSW_2vqeE&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=20), [Chinese Version](https://www.bilibili.com/video/BV1iD4y1P7XE/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]
- [BrainCog Talk] Beginning BrainCog Lecture 19. Multi-brain areas Coordinated Brain-inspired Affective Empathy Spiking Neural Network Based on Braincog [[English Version](https://www.youtube.com/watch?v=dB-hQ-5RJ6U&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=19), [Chinese Version](https://www.bilibili.com/video/BV16Y411q7b2/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]
- [BrainCog Talk] Beginning BrainCog Lecture 18. Application of the Prefrontal Cortex Column Model in Working Memory Task with BrainCog [[English Version](https://www.youtube.com/watch?v=6xzURpSFkxk&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=18), [Chinese Version](https://www.bilibili.com/video/BV1As4y1t7wF/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]
- [BrainCog Talk] Beginning BrainCog Lecture 17. A Brain-inspired Theory of Mind Model Based on BrainCog for Reducing Other Agents’ Safety Risks [[English Version](https://www.youtube.com/watch?v=16Csw03bTjY&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=17), [Chinese Version](https://www.bilibili.com/video/BV12A411f74Q/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
- [BrainCog Talk] Beginning BrainCog Lecture 16. Brain-inspired Bodily Self-perception Model Based on BrainCog [[English Version](https://www.youtube.com/watch?v=mV_iMsDEQsg&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=16), [Chinese Version](https://www.bilibili.com/video/BV1pK411i7Jk/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
- [BrainCog Talk] Beginning BrainCog Lecture 15. SNN-based Music Memory and Generation Based on BrainCog [[English Version](https://www.youtube.com/watch?v=c0Tcs1B5xho&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=15), [Chinese Version](https://www.bilibili.com/video/BV1M3411X7C2/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
- [BrainCog Talk] Beginning BrainCog Lecture 14. The Implement of Multisensory Concept Learning Framework Based on SNNs with Braincog [[English Version](https://www.youtube.com/watch?v=c9UfOCGzFPQ&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=14), [Chinese Version](https://www.bilibili.com/video/BV1Ae411P7tY/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
- [BrainCog Talk] Beginning BrainCog Lecture 13. Symbolic Representation and Reasoning SNN Based on Braincog [[English Version](https://www.youtube.com/watch?v=1iosjPkOBRo&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=13), [Chinese Version](https://www.bilibili.com/video/BV1DW4y1p7kg/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
- [BrainCog Talk] Beginning BrainCog Lecture 12. Unsupervised STDP-based Spiking Neural Networks Based on BrainCog [[English Version](https://www.youtube.com/watch?v=pzPJ1XOEB9U&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=12), [Chinese Version](https://www.bilibili.com/video/BV1MR4y1Z7Be/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
- [BrainCog Talk] Beginning BrainCog Lecture 11. Backpropagation with Spatiotemporal Adjustment for Training Deep Spiking Neural Networks through BrainCog [[English Version](https://www.youtube.com/watch?v=Fm85BjQszng&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=11), [Chinese Version](https://www.bilibili.com/video/BV1EK411Z71F/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
- [BrainCog Talk] Beginning BrainCog Lecture 10. Multi-brain Areas Coordinated Brain-inspired Decision-Making Spiking Neural Network Based on Braincog [[English Version](https://www.youtube.com/watch?v=uVCcZHzzN3U&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=10), [Chinese Version](https://www.bilibili.com/video/BV1jD4y1x7PU/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
- [BrainCog Talk] Beginning BrainCog Lecture 9. Spiking Neural Networks with Global Feedback Connections Based on BrainCog [[English Version](https://www.youtube.com/watch?v=g_qelwoQsD8&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=9), [Chinese Version](https://www.bilibili.com/video/BV1qv4y1D74Y/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
- [BrainCog Talk] Beginning BrainCog Lecture 8. Converting Artificial Neural Network to Spiking Neural Network through BrainCog [[English Version](https://www.youtube.com/watch?v=cxiKyQ7F8UE&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=8), [Chinese Version](https://www.bilibili.com/video/BV17e4y147H6/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
- [BrainCog Talk] Beginning BrainCog Lecture 7. Implementing Quantum Superposition Inspired Spatio-temporal Spike Encoding through BrainCog [[English Version](https://www.youtube.com/watch?v=5T-2Yyr9a0s&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=7), [Chinese Version](https://www.bilibili.com/video/BV1BG41177Z6/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
- [BrainCog Talk] Beginning BrainCog Lecture 6. Implementing spiking deep Q network through Braincog [[English Version](https://www.youtube.com/watch?v=jCsOBtiN-q0&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=6), [Chinese Version](https://www.bilibili.com/video/BV1XN4y1c7Kn/?spm_id_from=333.337.search-card.all.click&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]
- [BrainCog Talk] Beginning BrainCog Lecture 5. Advanced BrainCog System Functions [[English Version](https://www.youtube.com/watch?v=VJBORFl6dTQ&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=5), [Chinese Version](https://www.bilibili.com/video/BV1tT411N7c9/?spm_id_from=333.337.search-card.all.click&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]
- [BrainCog Talk] Beginning BrainCog Lecture 4. Creating Cognitive SNNs for Brain Areas [[English Version](https://www.youtube.com/watch?v=gYxG1D1b4Zo&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=4), [Chinese Version](https://www.bilibili.com/video/BV19d4y1679Y/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]
- [BrainCog Talk] Beginning BrainCog Lecture 3. Creating SNNs Easily and Quickly [[English Version](https://www.youtube.com/watch?v=k3byUIp4O24&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=3), [Chinese Version](https://www.bilibili.com/video/BV1Be4y1874W/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]
- [BrainCog Talk] Beginning BrainCog Lecture 2. Computational Modeling of Spiking Neurons [[English Version](https://www.youtube.com/watch?v=5jPsPkyFTY8&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=2), [Chinese Version](https://www.bilibili.com/video/BV16K411f7vQ/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]
- [BrainCog Talk] Beginning BrainCog Lecture 1. Installing and Deploying BrainCog platform [[English Version](https://www.youtube.com/watch?v=XkHq-MbKo20&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=1), [Chinese Version](https://www.bilibili.com/video/BV1AW4y1b7v1/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]
================================================
FILE: documents/Pub_brain_inspired_AI.md
================================================
# Publications Using BrainCog
## Brain Inspired AI
### Perception and Leanring
| Papers | Codes | Publisher |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------- |
| [BackEISNN: A deep spiking neural network with adaptive self-feedback and balanced excitatory–inhibitory neurons](https://www.sciencedirect.com/science/article/pii/S0893608022002520) | [https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/img_cls/bp/main_backei.py](https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/img_cls/bp/main_backei.py) | Neural Networks |
| [Efficient and Accurate Conversion of Spiking Neural Network with Burst Spikes](https://www.ijcai.org/proceedings/2022/345) | [https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/Conversion/converted_CIFAR10.py](https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/Conversion/converted_CIFAR10.py) | IJCAI |
| [Spike Calibration: Fast and Accurate Conversion of Spiking Neural Network for Object Detection and Segmentation](https://arxiv.org/abs/2207.02702) | [https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/Conversion/converted_CIFAR10.py](https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/Conversion/converted_CIFAR10.py) | Arxiv |
| [An unsupervised STDP-based spiking neural network inspired by biologically plausible learning rules and connections](https://www.sciencedirect.com/science/article/pii/S0893608023003301) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/UnsupervisedSTDP](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/UnsupervisedSTDP) | Neural Networks |
| [Multisensory Concept Learning Framework Based on Spiking Neural Networks](https://www.frontiersin.org/articles/10.3389/fnsys.2022.845177/full) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/MultisensoryIntegration](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/MultisensoryIntegration) | Frontiers in Systems Neuroscience |
| [Backpropagation with biologically plausible spatiotemporal adjustment for training deep spiking neural networks](https://www.sciencedirect.com/science/article/pii/S2666389922001192) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/bp](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/bp) | Patterns |
| [Quantum superposition inspired spiking neural network](https://www.cell.com/iscience/fulltext/S2589-0042(21)00848-8) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/QSNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/QSNN) | iScience |
| [GLSNN: A Multi-Layer Spiking Neural Network Based on Global Feedback Alignment and Local STDP Plasticity](https://www.frontiersin.org/articles/10.3389/fncom.2020.576841/full) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/glsnn](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/glsnn) | Frontiers in Neuroscience |
| [Spiking CapsNet: A spiking neural network with a biologically plausible routing rule between capsules](https://www.sciencedirect.com/science/article/pii/S002002552200843X) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/spiking_capsnet](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/spiking_capsnet) | Information Sciences |
| [N-Omniglot, a large-scale neuromorphic dataset for spatio-temporal sparse few-shot learning](https://www.nature.com/articles/s41597-022-01851-z) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/datasets/NOmniglot](https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/datasets/NOmniglot) | Scientific Data |
| [EventMix: An efficient data augmentation strategy for event-based learning](https://www.sciencedirect.com/science/article/pii/S0020025523007557) | [https://github.com/BrainCog-X/Brain-Cog/blob/main/braincog/datasets/cut_mix.py](https://github.com/BrainCog-X/Brain-Cog/blob/main/braincog/datasets/cut_mix.py) | Information Sciences |
| [Challenging deep learning models with image distortion based on the abutting grating illusion](https://www.sciencedirect.com/science/article/pii/S2666389923000260) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/IllusionPerception/AbuttingGratingIllusion](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/IllusionPerception/AbuttingGratingIllusion) | Patterns |
| [MSAT: Biologically Inspired Multi-Stage Adaptive Threshold for Conversion of Spiking Neural Networks](https://arxiv.org/abs/2303.13080) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/Conversion/msat_conversion](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/Conversion/msat_conversion) | Arxiv |
| [Improving the Performance of Spiking Neural Networks on Event-based Datasets with Knowledge Transfer](https://arxiv.org/abs/2303.13077) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/transfer_for_dvs](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/transfer_for_dvs) | Arxiv |
| [Improving Stability and Performance of Spiking Neural Networks Through Enhancing Temporal Consistency](https://www.sciencedirect.com/science/article/abs/pii/S0031320324008458) | [https://github.com/Brain-Cog-Lab/ETC](https://github.com/Brain-Cog-Lab/ETC) | Pattern Recognition |
| [Spiking Generative Adversarial Network with Attention Scoring Decoding](https://www.sciencedirect.com/science/article/abs/pii/S0893608024003472) | [https://github.com/Brain-Cog-Lab/sgad](https://github.com/Brain-Cog-Lab/sgad) | Neural Networks |
| [Temporal Knowledge Sharing Enables Spiking Neural Network Learning from Past and Future](https://ieeexplore.ieee.org/document/10462632) | [https://github.com/Brain-Cog-Lab/TKS](https://github.com/Brain-Cog-Lab/TKS) | IEEE Transactions on Artificial Intelligence |
| [TIM: An Efficient Temporal Interaction Module for Spiking Transformer](https://www.ijcai.org/proceedings/2024/0347.pdf) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/TIM](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/TIM) | IJCAI2024 |
| [Exploiting Nonlinear Dendritic Adaptive Computation in Training Deep Spiking Neural Networks](https://www.sciencedirect.com/science/article/pii/S0893608023006202) | [https://github.com/BrainCog-X/Brain-Cog/blob/main/braincog/base/node/node.py#L1412](https://github.com/BrainCog-X/Brain-Cog/blob/main/braincog/base/node/node.py#L1412) | Neural Networks |
| [Are Conventional SNNs Really Efficient? A Perspective from Network Quantization](https://openaccess.thecvf.com/content/CVPR2024/papers/Shen_Are_Conventional_SNNs_Really_Efficient_A_Perspective_from_Network_Quantization_CVPR_2024_paper.pdf) | | 2024 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) |
| [CACE-Net: Co-guidance Attention and Contrastive Enhancement for Effective Audio-Visual Event Localization](https://dl.acm.org/doi/10.1145/3664647.3681503) | [https://github.com/Brain-Cog-Lab/CACE-Net](https://github.com/Brain-Cog-Lab/CACE-Net) | Proceedings of the 32nd ACM International Conference on Multimedia, 2024: 985-993 |
| [Parallel Spiking Unit for Efficient Training of Spiking Neural Networks](https://arxiv.org/abs/2402.00449) | [https://github.com/Brain-Cog-Lab/PSU](https://github.com/Brain-Cog-Lab/PSU) | IJCNN 2024 |
| [Spiking Neural Networks with Consistent Mapping Relations Allow High-Accuracy Inference](https://www.sciencedirect.com/science/article/abs/pii/S0020025524007369) | [https://github.com/Brain-Cog-Lab/casc](https://github.com/Brain-Cog-Lab/casc) | Information Sciences |
| [Directly training temporal Spiking Neural Network with sparse surrogate gradient](https://www.sciencedirect.com/science/article/abs/pii/S0893608024004234) | [https://github.com/Brain-Cog-Lab/msg](https://github.com/Brain-Cog-Lab/msg) | Neural Networks |
### Knowledge Representation and Reasoning
| Papers | Codes | Publisher |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------- |
| [A Brain-Inspired Causal Reasoning Model Based on Spiking Neural Networks](https://ieeexplore.ieee.org/abstract/document/9534102) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Knowledge_Representation_and_Reasoning/CRSNN | IJCNN2021 |
| [Brain Inspired Sequences Production by Spiking Neural Networks With Reward-Modulated STDP](https://www.frontiersin.org/articles/10.3389/fncom.2021.612041/full) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Knowledge_Representation_and_Reasoning/SPSNN | Frontiers in Computational Neuroscience |
| [Temporal-Sequential Learning With a Brain-Inspired Spiking Neural Network and Its Application to Musical Memory](https://www.frontiersin.org/articles/10.3389/fncom.2020.00051/full) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Knowledge_Representation_and_Reasoning/musicMemory | Frontiers in Computational Neuroscience |
| [Stylistic Composition of Melodies Based on a Brain-Inspired Spiking Neural Network](https://www.frontiersin.org/articles/10.3389/fnsys.2021.639484/full) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BRP-SNN | Frontiers in Neurorobotics |
| [Brain-inspired Graph Spiking Neural Networks for Commonsense Knowledge Representation and Reasoning](https://arxiv.org/abs/2207.05561 ) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Knowledge_Representation_and_Reasoning/CKRGSNN | Arxiv |
| [An Efficient Knowledge Transfer Strategy for Spiking Neural Networks from Static to Event Domain](https://arxiv.org/abs/2303.13077) | [https://github.com/Brain-Cog-Lab/Transfer-for-DVS](https://github.com/Brain-Cog-Lab/Transfer-for-DVS) | Proceedings of the AAAI Conference on Artificial Intelligence, 2024, 38(1): 512-520 |
### Decision Making
| Papers | Codes | Publisher |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------- | -------------------------- |
| [Nature-inspired self-organizing collision avoidance for drone swarm based on reward-modulated spiking neural network](https://www.cell.com/patterns/fulltext/S2666-3899(22)00236-7) | https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/decision_making/swarm/Collision-Avoidance.py | Cell Patterns |
| [Solving the spike feature information vanishing problem in spiking deep Q network with potential based normalization](https://www.frontiersin.org/articles/10.3389/fnins.2022.953368/full) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/decision_making/RL/sdqn | Frontiers in Neuroscience |
| [A Brain-Inspired Decision-Making Spiking Neural Network and Its Application in Unmanned Aerial Vehicle](https://www.frontiersin.org/articles/10.3389/fnbot.2018.00056/full ) | https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/decision_making/BDM-SNN/BDM-SNN-hh.py | Frontiers in Neurorobotics |
| [Multi-compartment Neuron and Population Encoding improved Spiking Neural Network for Deep Distributional Reinforcement Learning](https://arxiv.org/abs/2301.07275) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/decision_making/RL/mcs-fqf | Arxiv |
### Motor Control
| Papers | Codes | Publisher |
| ------ | ------------------------------------------------------------------------------------ | --------- |
| | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/MotorControl/experimental | |
### Social Cognition
| Papers | Codes | Publisher |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------- |
| [Toward Robot Self-Consciousness (II): Brain-Inspired Robot Bodily Self Model for Self-Recognition](https://link.springer.com/article/10.1007/s12559-017-9505-1 ) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/mirror_test](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/mirror_test) | Cognitive Computation |
| [A brain-inspired intention prediction model and its applications to humanoid robot](https://www.frontiersin.org/articles/10.3389/fnins.2022.1009237/full) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/Intention_Prediction](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/Intention_Prediction) | Frontiers in Neuroscience |
| [A Brain-Inspired Theory of Mind Spiking Neural Network for Reducing Safety Risks of Other Agents](https://www.frontiersin.org/articles/10.3389/fnins.2022.753900/full) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/ToM](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/ToM) | Frontiers in Neuroscience |
| [Brain-Inspired Affective Empathy Computational Model and Its Application on Altruistic Rescue Task](https://www.frontiersin.org/articles/10.3389/fncom.2022.784967/full) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BAE-SNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BAE-SNN) | Frontiers in Computational Neuroscience |
| [A brain-inspired robot pain model based on a spiking neural network](https://www.frontiersin.org/articles/10.3389/fnbot.2022.1025338/full) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BRP-SNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BRP-SNN) | Frontiers in Neurorobotics |
| [Brain-Inspired Theory of Mind Spiking Neural Network Elevates Multi-Agent Cooperation and Competition](https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4271099) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/MAToM-SNN ](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/MAToM-SNN ) | SSRN |
### Development and Evolution
| Papers | Codes | Publisher |
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----------------------------------------------- |
| [Developmental Plasticity-inspired Adaptive Pruning for Deep Spiking and Artificial Neural Networks](https://arxiv.org/abs/2211.12714) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/DPAP](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/DPAP) | Arxiv |
| [Adaptive Sparse Structure Development with Pruning and Regeneration for Spiking Neural Networks](https://arxiv.org/abs/2211.12219) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/SD-SNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/SD-SNN) | Arxiv |
| [Emergence of Brain-inspired Small-world Spiking Neural Network through Neuroevolution](https://arxiv.org/abs/2304.10749) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/ELSM](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/ELSM) | Arxiv |
| [Enhancing Efficient Continual Learning with Dynamic Structure Development of Spiking Neural Networks](https://arxiv.org/abs/2308.04749) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/DSD-SNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/DSD-SNN) | IJCAI |
| [Brain-inspired Evolutionary Neural Architecture Search for Spiking Neural Networks](https://ieeexplore.ieee.org/abstract/document/10542732) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structure_Evolution/EB-NAS](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structure_Evolution/EB-NAS) | IEEE Transactions on Artificial Intelligence |
| [Adaptive Structure Evolution and Biologically Plausible Synaptic Plasticity for Recurrent Spiking Neural Networks](https://www.nature.com/articles/s41598-023-43488-x) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structure_Evolution/Adaptive_lsm](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structure_Evolution/Adaptive_lsm) | Scientific Reports |
| [Brain-inspired Neural Circuit Evolution for Spiking Neural Networks](https://www.pnas.org/doi/10.1073/pnas.2218173120) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/model_zoo/NeuEvo](https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/model_zoo/NeuEvo) | Proceedings of the National Academy of Sciences |
### Safety and Security
| Papers | Codes | Publisher |
| ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- |
| [DPSNN](https://arxiv.org/abs/2205.12718) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Snn_safety/DPSNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Snn_safety/DPSNN) | Arxiv |
### Dataset
| Papers | Codes | Publisher |
| -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----------- |
| [Bullying10K: A Large-scale Neuromorphic Dataset Towards Privacy-preserving Bullying Recognition](https://proceedings.neurips.cc/paper_files/paper/2023/file/05ffe69463062b7f9fb506c8351ffdd7-Paper-Datasets_and_Benchmarks.pdf) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/datasets/bullying10k](https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/datasets/bullying10k) | Neurips2023 |
================================================
FILE: documents/Pub_brain_simulation.md
================================================
# Publications Using BrainCog
## Brain Simulation
### Funtion
| Papers | Codes | Publisher |
|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------|
| [A neural algorithm for Drosophila linear and nonlinear decision-making](https://www.nature.com/articles/s41598-020-75628-y) | [https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Brain_Cognitive_Function_Simulation/drosophila/drosophila.py](https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Brain_Cognitive_Function_Simulation/drosophila/drosophila.py) | Scientific Reports |
### Structure
| Papers | Codes | Publisher |
|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------|
| Corticothalamic minicolumn | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn) | |
| Human Brain | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Multiscale_Brain_Structure_Simulation/HumanBrain](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Multiscale_Brain_Structure_Simulation/HumanBrain) | |
| Macaque Brain | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Multiscale_Brain_Structure_Simulation/MacaqueBrain](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Multiscale_Brain_Structure_Simulation/MacaqueBrain) | |
| Mouse Brain | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Multiscale_Brain_Structure_Simulation/Mouse_brain ](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Multiscale_Brain_Structure_Simulation/Mouse_brain ) | |
| [Comparison Between Human and Rodent Neurons for Persistent Activity Performance: A Biologically Plausible Computational Investigation](https://www.frontiersin.org/articles/10.3389/fnsys.2021.628839/full) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Multiscale_Brain_Structure_Simulation/Human_PFC_Model | Frontiers in System Neuroscience |
================================================
FILE: documents/Pub_sh_codesign.md
================================================
# Publications Using BrainCog
## Software-Hardware Co-design
### Hardware Acceleration
| Papers | Codes | Publisher |
| ------------------------------------------------------------ | ---------------------------------------- | ------------------------------------------------------------ |
| [FireFly: A High-Throughput and Reconfigurable Hardware Accelerator for Spiking Neural Networks](https://ieeexplore.ieee.org/abstract/document/10143752) | https://github.com/adamgallas/FireFly-v1 | IEEE Transactions on Very Large Scale Integration (VLSI) Systems |
| [FireFly v2: Advancing Hardware Support for High-Performance Spiking Neural Network With a Spatiotemporal FPGA Accelerator](https://ieeexplore.ieee.org/abstract/document/10478105) | https://github.com/adamgallas/FireFly-v2 | IEEE Transactions on Computer-Aided Design of Integrated Circuits and Systems |
| [Revealing Untapped DSP Optimization Potentials for FPGA-Based Systolic Matrix Engines](https://ieeexplore.ieee.org/abstract/document/10705564) | https://github.com/adamgallas/SpinalDLA | 2024 34th International Conference on Field-Programmable Logic and Applications (FPL) |
| [FireFly-S: Exploiting Dual-Side Sparsity for Spiking Neural Networks Acceleration With Reconfigurable Spatial Architecture](https://ieeexplore.ieee.org/document/10754657) | | IEEE Transactions on Circuits and Systems I: Regular Papers |
| Pushing up to the Limit of Memory Bandwidth and Capacity Utilization for Efficient LLM Decoding on Embedded FPGA | | 2025 Design, Automation & Test in Europe Conference & Exhibition (DATE) |
================================================
FILE: documents/Publication.md
================================================
# Publications Using BrainCog
## 2024
| Papers | Codes | Publisher |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------- |
| [Developmental Plasticity-inspired Adaptive Pruning for Deep Spiking and Artificial Neural Networks](https://ieeexplore.ieee.org/document/10691937) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/DPAP](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/DPAP) | IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI) |
| [Multi-compartment Neuron and Population Encoding improved Spiking Neural Network for Deep Distributional Reinforcement Learning](https://www.sciencedirect.com/science/article/abs/pii/S089360802400827X?via%3Dihub) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/decision_making/RL/mcs-fqf | Neural Networks |
| [Improving Stability and Performance of Spiking Neural Networks Through Enhancing Temporal Consistency](https://www.sciencedirect.com/science/article/abs/pii/S0031320324008458) | [https://github.com/Brain-Cog-Lab/ETC](https://github.com/Brain-Cog-Lab/ETC) | Pattern Recognition |
| [Spiking Generative Adversarial Network with Attention Scoring Decoding](https://www.sciencedirect.com/science/article/abs/pii/S0893608024003472) | [https://github.com/Brain-Cog-Lab/sgad](https://github.com/Brain-Cog-Lab/sgad) | Neural Networks |
| [Temporal Knowledge Sharing Enables Spiking Neural Network Learning from Past and Future](https://ieeexplore.ieee.org/document/10462632) | [https://github.com/Brain-Cog-Lab/TKS](https://github.com/Brain-Cog-Lab/TKS) | IEEE Transactions on Artificial Intelligence (TAI) |
| [TIM: An Efficient Temporal Interaction Module for Spiking Transformer](https://www.ijcai.org/proceedings/2024/0347.pdf) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/TIM](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/TIM) | IJCAI2024 |
| [Are Conventional SNNs Really Efficient? A Perspective from Network Quantization](https://openaccess.thecvf.com/content/CVPR2024/papers/Shen_Are_Conventional_SNNs_Really_Efficient_A_Perspective_from_Network_Quantization_CVPR_2024_paper.pdf) | | 2024 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) |
| [CACE-Net: Co-guidance Attention and Contrastive Enhancement for Effective Audio-Visual Event Localization](https://dl.acm.org/doi/10.1145/3664647.3681503) | [https://github.com/Brain-Cog-Lab/CACE-Net](https://github.com/Brain-Cog-Lab/CACE-Net) | Proceedings of the 32nd ACM International Conference on Multimedia (ACM MM) |
| [Parallel Spiking Unit for Efficient Training of Spiking Neural Networks](https://ieeexplore.ieee.org/document/10650207) | [https://github.com/Brain-Cog-Lab/PSU](https://github.com/Brain-Cog-Lab/PSU) | IJCNN 2024 |
| [Spiking Neural Networks with Consistent Mapping Relations Allow High-Accuracy Inference](https://www.sciencedirect.com/science/article/abs/pii/S0020025524007369) | [https://github.com/Brain-Cog-Lab/casc](https://github.com/Brain-Cog-Lab/casc) | Information Sciences |
| [Directly training temporal Spiking Neural Network with sparse surrogate gradient](https://www.sciencedirect.com/science/article/abs/pii/S0893608024004234) | [https://github.com/Brain-Cog-Lab/msg](https://github.com/Brain-Cog-Lab/msg) | Neural Networks |
| [Brain-inspired Evolutionary Neural Architecture Search for Spiking Neural Networks](https://ieeexplore.ieee.org/abstract/document/10542732) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structure_Evolution/EB-NAS](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structure_Evolution/EB-NAS) | IEEE Transactions on Artificial Intelligence (TAI) |
| [MSAT: Biologically Inspired Multi-Stage Adaptive Threshold for Conversion of Spiking Neural Networks](https://link.springer.com/article/10.1007/s00521-024-09529-w) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/Conversion/msat_conversion](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/Conversion/msat_conversion) | Neural Computing and Applications |
| [An Efficient Knowledge Transfer Strategy for Spiking Neural Networks from Static to Event Domain](https://ojs.aaai.org/index.php/AAAI/article/view/27806/27643) | [https://github.com/Brain-Cog-Lab/Transfer-for-DVS](https://github.com/Brain-Cog-Lab/Transfer-for-DVS) | Proceedings of the AAAI Conference on Artificial Intelligence (AAAI) |
## 2023
| Papers | Codes | Publisher |
| -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------ |
| [An unsupervised STDP-based spiking neural network inspired by biologically plausible learning rules and connections](https://www.sciencedirect.com/science/article/pii/S0893608023003301) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/UnsupervisedSTDP](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/UnsupervisedSTDP) | Neural Networks |
| [EventMix: An efficient data augmentation strategy for event-based learning](https://www.sciencedirect.com/science/article/pii/S0020025523007557) | [https://github.com/BrainCog-X/Brain-Cog/blob/main/braincog/datasets/cut_mix.py](https://github.com/BrainCog-X/Brain-Cog/blob/main/braincog/datasets/cut_mix.py) | Information Sciences |
| [Challenging deep learning models with image distortion based on the abutting grating illusion](https://www.sciencedirect.com/science/article/pii/S2666389923000260) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/IllusionPerception/AbuttingGratingIllusion](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/IllusionPerception/AbuttingGratingIllusion) | Patterns |
| [Improving the Performance of Spiking Neural Networks on Event-based Datasets with Knowledge Transfer](https://arxiv.org/abs/2303.13077) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/transfer_for_dvs](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/transfer_for_dvs) | Arxiv |
| [Exploiting Nonlinear Dendritic Adaptive Computation in Training Deep Spiking Neural Networks](https://www.sciencedirect.com/science/article/pii/S0893608023006202) | [https://github.com/BrainCog-X/Brain-Cog/blob/main/braincog/base/node/node.py#L1412](https://github.com/BrainCog-X/Brain-Cog/blob/main/braincog/base/node/node.py#L1412) | Neural Networks |
| [Emergence of Brain-inspired Small-world Spiking Neural Network through Neuroevolution](https://www.sciencedirect.com/science/article/pii/S258900422400066X) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/ELSM](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/ELSM) | iScience |
| [Enhancing Efficient Continual Learning with Dynamic Structure Development of Spiking Neural Networks](https://www.ijcai.org/proceedings/2023/0334.pdf) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/DSD-SNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/DSD-SNN) | IJCAI |
| [Adaptive Structure Evolution and Biologically Plausible Synaptic Plasticity for Recurrent Spiking Neural Networks](https://www.nature.com/articles/s41598-023-43488-x) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structure_Evolution/Adaptive_lsm](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structure_Evolution/Adaptive_lsm) | Scientific Reports |
| [Brain-inspired Neural Circuit Evolution for Spiking Neural Networks](https://www.pnas.org/doi/10.1073/pnas.2218173120) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/model_zoo/NeuEvo](https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/model_zoo/NeuEvo) | Proceedings of the National Academy of Sciences (PANS) |
| [Bullying10K: A Large-scale Neuromorphic Dataset Towards Privacy-preserving Bullying Recognition](https://proceedings.neurips.cc/paper_files/paper/2023/file/05ffe69463062b7f9fb506c8351ffdd7-Paper-Datasets_and_Benchmarks.pdf) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/datasets/bullying10k](https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/datasets/bullying10k) | Neurips2023 |
## 2022
| Papers | Codes | Publisher |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------- |
| [BackEISNN: A deep spiking neural network with adaptive self-feedback and balanced excitatory–inhibitory neurons](https://www.sciencedirect.com/science/article/pii/S0893608022002520) | [https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/img_cls/bp/main_backei.py](https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/img_cls/bp/main_backei.py) | Neural Networks |
| [Efficient and Accurate Conversion of Spiking Neural Network with Burst Spikes](https://www.ijcai.org/proceedings/2022/345) | [https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/Conversion/converted_CIFAR10.py](https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/Conversion/converted_CIFAR10.py) | IJCAI |
| [Spike Calibration: Fast and Accurate Conversion of Spiking Neural Network for Object Detection and Segmentation](https://arxiv.org/abs/2207.02702) | [https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/Conversion/converted_CIFAR10.py](https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/Conversion/converted_CIFAR10.py) | Arxiv |
| [Multisensory Concept Learning Framework Based on Spiking Neural Networks](https://www.frontiersin.org/articles/10.3389/fnsys.2022.845177/full) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/MultisensoryIntegration](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/MultisensoryIntegration) | Frontiers in Systems Neuroscience |
| [Backpropagation with biologically plausible spatiotemporal adjustment for training deep spiking neural networks](https://www.sciencedirect.com/science/article/pii/S2666389922001192) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/bp](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/bp) | Patterns |
| [Spiking CapsNet: A spiking neural network with a biologically plausible routing rule between capsules](https://www.sciencedirect.com/science/article/pii/S002002552200843X) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/spiking_capsnet](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/spiking_capsnet) | Information Sciences |
| [N-Omniglot, a large-scale neuromorphic dataset for spatio-temporal sparse few-shot learning](https://www.nature.com/articles/s41597-022-01851-z) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/datasets/NOmniglot](https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/datasets/NOmniglot) | Scientific Data |
| [Brain-inspired Graph Spiking Neural Networks for Commonsense Knowledge Representation and Reasoning](https://arxiv.org/abs/2207.05561 ) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Knowledge_Representation_and_Reasoning/CKRGSNN | Arxiv |
| [Nature-inspired self-organizing collision avoidance for drone swarm based on reward-modulated spiking neural network](https://www.cell.com/patterns/fulltext/S2666-3899(22)00236-7) | https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/decision_making/swarm/Collision-Avoidance.py | Cell Patterns |
| [Solving the spike feature information vanishing problem in spiking deep Q network with potential based normalization](https://www.frontiersin.org/articles/10.3389/fnins.2022.953368/full) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/decision_making/RL/sdqn | Frontiers in Neuroscience |
| [A brain-inspired intention prediction model and its applications to humanoid robot](https://www.frontiersin.org/articles/10.3389/fnins.2022.1009237/full) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/Intention_Prediction](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/Intention_Prediction) | Frontiers in Neuroscience |
| [A Brain-Inspired Theory of Mind Spiking Neural Network for Reducing Safety Risks of Other Agents](https://www.frontiersin.org/articles/10.3389/fnins.2022.753900/full) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/ToM](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/ToM) | Frontiers in Neuroscience |
| [Brain-Inspired Affective Empathy Computational Model and Its Application on Altruistic Rescue Task](https://www.frontiersin.org/articles/10.3389/fncom.2022.784967/full) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BAE-SNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BAE-SNN) | Frontiers in Computational Neuroscience |
| [A brain-inspired robot pain model based on a spiking neural network](https://www.frontiersin.org/articles/10.3389/fnbot.2022.1025338/full) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BRP-SNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BRP-SNN) | Frontiers in Neurorobotics |
| [Brain-Inspired Theory of Mind Spiking Neural Network Elevates Multi-Agent Cooperation and Competition](https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4271099) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/MAToM-SNN ](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/MAToM-SNN ) | SSRN |
| [Adaptive Sparse Structure Development with Pruning and Regeneration for Spiking Neural Networks](https://www.sciencedirect.com/science/article/abs/pii/S0020025524013951) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/SD-SNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/SD-SNN) | Information Sciences |
| [DPSNN](https://arxiv.org/abs/2205.12718) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Snn_safety/DPSNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Snn_safety/DPSNN) | Arxiv |
## 2021
| Papers | Codes | Publisher |
| ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------- |
| [Quantum superposition inspired spiking neural network](https://www.cell.com/iscience/fulltext/S2589-0042(21)00848-8) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/QSNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/QSNN) | iScience |
| [A Brain-Inspired Causal Reasoning Model Based on Spiking Neural Networks](https://ieeexplore.ieee.org/abstract/document/9534102) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Knowledge_Representation_and_Reasoning/CRSNN | IJCNN2021 |
| [Brain Inspired Sequences Production by Spiking Neural Networks With Reward-Modulated STDP](https://www.frontiersin.org/articles/10.3389/fncom.2021.612041/full) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Knowledge_Representation_and_Reasoning/SPSNN | Frontiers in Computational Neuroscience |
| [Stylistic Composition of Melodies Based on a Brain-Inspired Spiking Neural Network](https://www.frontiersin.org/articles/10.3389/fnsys.2021.639484/full) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BRP-SNN | Frontiers in Neurorobotics |
## 2020
| Papers | Codes | Publisher |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------- |
| [GLSNN: A Multi-Layer Spiking Neural Network Based on Global Feedback Alignment and Local STDP Plasticity](https://www.frontiersin.org/articles/10.3389/fncom.2020.576841/full) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/glsnn](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/glsnn) | Frontiers in Neuroscience |
| [Temporal-Sequential Learning With a Brain-Inspired Spiking Neural Network and Its Application to Musical Memory](https://www.frontiersin.org/articles/10.3389/fncom.2020.00051/full) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Knowledge_Representation_and_Reasoning/musicMemory | Frontiers in Computational Neuroscience |
## 2018
| Papers | Codes | Publisher |
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------- |
| [A Brain-Inspired Decision-Making Spiking Neural Network and Its Application in Unmanned Aerial Vehicle](https://www.frontiersin.org/articles/10.3389/fnbot.2018.00056/full ) | https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/decision_making/BDM-SNN/BDM-SNN-hh.py | Frontiers in Neurorobotics |
| [Toward Robot Self-Consciousness (II): Brain-Inspired Robot Bodily Self Model for Self-Recognition](https://link.springer.com/article/10.1007/s12559-017-9505-1 ) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/mirror_test](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/mirror_test) | Cognitive Computation |
================================================
FILE: documents/Tutorial.md
================================================
# Tutorial
- How to Install BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/1_installation.html), [Chinese Version](https://mp.weixin.qq.com/s/fX-S3fKKDfR3NV4ISAHrZg)]
- Overall Introduction of BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/2_Overall%20introduction%20of%20BrainCog.html), [Chinese Version](https://mp.weixin.qq.com/s/ywgQ5ydQxr6W7d_Y_XqC6w)]
- Spiking Neuron Modeling with BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/3_Tutorial%20of%20Spiking%20Neuron%20Modeling.html), [Chinese Version](https://mp.weixin.qq.com/s/pCdlbkrdnMNHj7wcwi9D2Q)]
- Building Efficient Spiking Neural Networks with BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/4_Building%20Efficient%20Spiking%20Neural%20Networks%20with%20BrainCog.html), [Chinese Version](https://mp.weixin.qq.com/s/NIz7MSAOJQ79m97hkP3HVg)]
- Building Cognitive Networks with BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/5_How%20to%20build%20a%20cognitive%20network.html), [Chinese Version](https://mp.weixin.qq.com/s/K0pY6V8TupJgYXH4WvS9jg)]
- Implementing Deep Reinforcement Learning SNNs with BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/6_SDQN.html), [Chinese Version](https://mp.weixin.qq.com/s/Zt78vj_sKn5jffEyeh86Cw)]
- Quantum Superposition State Inspired Encoding With BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/7_QSSNN.html), [Chinese Version](https://mp.weixin.qq.com/s/YVNkwDwuF9FG-YqQyd3KTg)]
- Converting ANNs to SNNs with BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/8_conversion.html), [Chinese Version](https://mp.weixin.qq.com/s/4g8WBoa4SOcb_VQ24wO-Xw)]
- SNNs with Global Feedback Connections Based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/9_GLSNN.html), [Chinese Version](https://mp.weixin.qq.com/s/-AS0BihlXdFwCt1hgzFx3g)]
- Multi-brain Areas Coordinated Brain-inspired Decision-Making SNNs Based on Braincog [[English Version](http://www.brain-cog.network/docs/tutorial/10_BDMSNN.html), [Chinese Version](https://mp.weixin.qq.com/s/dltOGjhUZ9yTzSssIvkhtA)]
- Backpropagation with Spatiotemporal Adjustment for Training SNNs Based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/11_biobp.html), [Chinese Version](https://mp.weixin.qq.com/s/6ZtaWtiY9K96UhVnvhfP7w)]
- Unsupervised STDP Based SNNs with Multiple Adaptive Mechanisms Based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/12_unsup.html), [Chinese Version](https://mp.weixin.qq.com/s/v_svhQ0N3JAYo1l8NQ1W1w)]
- Symbolic Representation and Reasoning SNNs Based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/13_krr.html), [Chinese Version](https://mp.weixin.qq.com/s/UXc5QDbg8tUKVU3L6dhvxw)]
- Multisensory Integration Based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/14_multisensory.html), [Chinese Version](https://mp.weixin.qq.com/s/s9gXF5NkPUGXkcFAFtsCMw)]
- SNN-based Music Memory and Generation Based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/15_musicSNN.html), [Chinese Version](https://mp.weixin.qq.com/s/y-t9rFEXYSmI_-rnVk_usw)]
- Brain-inspired Bodily Self-perception Model Based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/16_self.html), [Chinese Version](https://mp.weixin.qq.com/s/RLbS3GmhE8ScyZTKJKl_SQ)]
- A Brain-inspired Theory of Mind Model Based on BrainCog for Reducing Other Agents’ Safety Risks [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/xzFF6IK86W4h2CMs7jt7Xw)]
- Application of the Prefrontal Cortex Column Model in Working Memory Task with BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/E6pb5W0Q4noK72NGf4SNGQ)]
- Multi-brain areas Coordinated Brain-inspired Affective Empathy Spiking Neural Network Based on Braincog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/xUnal4I5tM-Rl49fu0ZRVw)]
- Developmental Plasticity-inspired Adaptive Pruning for SNNs based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/-F0XcIovI1WY6YyoO5P58Q)]
- Dynamic structural development for SNNs incorporating constraints, pruning and regeneration based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/n1EK6lddriS0dnSgfdsa8w)]
- BrainCog Data Engine: Spatio-temporal Sequence Data N-Omniglot and Its Applications [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/VIzDb2kn_4LUyUh0qc2j4w)]
- The Implement of Object Detection and Semantic Segmentation Based on SNNs with Braincog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/fyf_cbzH_i1_ZEFWTSH81Q)]
- DVS Augmentation Strategies based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/BNMZ_oQK9Foui2zcnKWYgg)]
- SNNs with adaptive self-feedback and balanced excitatory–inhibitory neurons based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/TFEucKDQHB69wChEL_Y3Cw)]
- Spiking CapsNet based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/VIzDb2kn_4LUyUh0qc2j4w)]
- Drosophila-inspired Linear and Nonlinear Decision-Making Spiking Neural Network implemented by Braincog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/VIzDb2kn_4LUyUh0qc2j4w)]
- Reward-modulated brain-inspired spiking neural network to empower group for nature-inspired self-organized obstacle avoidance implemented by Braincog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/VIzDb2kn_4LUyUh0qc2j4w)]
- Using BrainCog to Realize Robot's Intention Prediction for Users [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/VIzDb2kn_4LUyUh0qc2j4w)]
- FPGA-based Spiking Neural Networks Acceleration Using BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/VIzDb2kn_4LUyUh0qc2j4w)]
================================================
FILE: examples/Brain_Cognitive_Function_Simulation/drosophila/README.md
================================================
# Drosophila-inspired decision-making SNN
## Run
The drosophila.py implements the core code of the Drosophila-inspired linear and non-linear decision-making in paper entitled "A Neural Algorithm for linear and non-linear Decision-making inspired by Drosophila".
The experiments includes training phase and testing phase:
* Training Phase
Training linear network and nonlinear network by reward-modulated spiking neural network: green-upright T is safe and blue-inverted T is dangerous
* Testing Phase
For linear pathway and nonlinear pathway, choose between blue-upright T and green-inverted T, and count the PI values under different color intensity
## Results
The following picture shows the linear (a) and nonlinear (b) pathways, the training and testing phases (c), and the PI values on different color intensities (d).

Differences from the original article: an improved reward-modulated STDP learning rule.
## Citation
If you find this package helpful, please consider citing the following papers:
```BibTex
@article{zhao2020neural,
title={A neural algorithm for Drosophila linear and nonlinear decision-making},
author={Zhao, Feifei and Zeng, Yi and Guo, Aike and Su, Haifeng and Xu, Bo},
journal={Scientific Reports},
volume={10},
number={1},
pages={1--16},
year={2020},
publisher={Nature Publishing Group}
}
@misc{https://doi.org/10.48550/arxiv.2207.08533,
doi = {10.48550/ARXIV.2207.08533},
url = {https://arxiv.org/abs/2207.08533},
author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},
title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},
publisher = {arXiv},
year = {2022},
}
```
================================================
FILE: examples/Brain_Cognitive_Function_Simulation/drosophila/drosophila.py
================================================
import numpy as np
import torch,os,sys
from torch import nn
from torch.nn import Parameter
import abc
import math
from abc import ABC
import numpy as np
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
import matplotlib.pyplot as plt
from braincog.base.strategy.surrogate import *
from braincog.base.node.node import IFNode
from braincog.base.learningrule.STDP import STDP,MutliInputSTDP
from braincog.base.connection.CustomLinear import CustomLinear
from braincog.model_zoo.nonlinearNet import droDMTestNet
from braincog.model_zoo.linearNet import droDMTrainNet
import copy
if __name__=="__main__":
"""
建立训练网络
"""
num_state=5
num_action=2
weight_exc=0.5
weight_inh=-0.05
trace_decay=0.8
mb_connection=[]
#input-visual
con_matrix0 = torch.eye((num_state), dtype=torch.float)
mb_connection.append(CustomLinear(weight_exc * con_matrix0,con_matrix0))
# visual-kc
con_matrix1 =torch.eye((num_state), dtype=torch.float)
mb_connection.append(CustomLinear( weight_exc * con_matrix1,con_matrix1))
# kc-mbon
con_matrix2 = torch.ones((num_state,num_action), dtype=torch.float)
mb_connection.append(CustomLinear(weight_exc * con_matrix2,con_matrix2))
# mbon-mbon
con_matrix3 = torch.ones((num_action,num_action), dtype=torch.float)
con_matrix4 = torch.eye((num_action), dtype=torch.float)
con_matrix5=con_matrix3-con_matrix4
con_matrix5=con_matrix5
mb_connection.append(CustomLinear(weight_inh * con_matrix5,con_matrix5))
MB=droDMTrainNet(mb_connection)
weight_trace_mbon=torch.zeros(con_matrix2.shape, dtype=torch.float)
"""
学习绿色正立T是安全的 蓝色倒立T是有惩罚的
"""
#learning GT
# RGB T t
GT = torch.tensor([0, 0.8, 0, 1.0, 0])
Bt = torch.tensor([0, 0, 0.8, 0, 1.0])
input = GT - Bt # input GT
input[input < 0] = 0
for i_train in range(20):
GT_out,dwkc,dwmbon=MB(input)
print("stdp:",dwkc,dwmbon)
#vis-kc STDP
MB.UpdateWeight(1, dwkc)
#kc-mbon rstdp
weight_trace_mbon *= trace_decay
weight_trace_mbon += dwmbon
if max(GT_out)>0:
r=torch.ones((num_state,num_action), dtype=torch.float)
p_action= torch.tensor([0])
r[:,p_action]=-1
dw_mbon = r * weight_trace_mbon
MB.UpdateWeight(2, dw_mbon)
print("output:",GT_out)
MB.reset()
weight_trace_mbon = torch.zeros(con_matrix2.shape, dtype=torch.float)
#learning Bt
GT = torch.tensor([0,0.8,0, 1.0, 0])
Bt = torch.tensor([0, 0, 0.8, 0, 1.0])
input = Bt - GT # input Bt
input[input < 0] = 0
for i_train in range(20):
GT_out,dwkc,dwmbon=MB(input)
#vis-kc STDP
MB.UpdateWeight(1, dwkc)
#kc-mbon rstdp
weight_trace_mbon *= trace_decay
weight_trace_mbon += dwmbon
if max(GT_out)>0:
r=torch.ones((num_state,num_action), dtype=torch.float)
p_action= torch.tensor([1])
r[:,p_action]=-1
dw_mbon = r * weight_trace_mbon
MB.UpdateWeight(2, dw_mbon)
train_weight=MB.getweight()
for i in range(len(train_weight)):
print("weight after learning:", train_weight[i].weight.data)
print("end training")
#linear test conflict decision making
test_num=12
t1=torch.zeros((test_num), dtype=torch.float)
t2=torch.zeros((test_num), dtype=torch.float)
for c in range(t1.shape[0]):
MB_test = droDMTrainNet(copy.deepcopy(train_weight))
MB_test.reset()
Gt = torch.tensor([0, (c*0.1), 0, 0, 0.5])
BT = torch.tensor([0, 0, (c*0.1), 0.5, 0])
input =Gt - BT # input Gt
input[input < 0] = 0
count=torch.zeros((num_action), dtype=torch.float)
for i_train in range(500):
GT_out,dwkc,dwmbon=MB_test(input)
count+=GT_out
t1[c]=count[0]
t2[c]=count[1]
p1=(t1-t2)/(t1+t2)
print(t1,t2,p1)
for i in range(len(train_weight)):
print("weight after learning:", train_weight[i].weight.data)
"""
建立测试网络,验证不同浓度下绿色正立T和蓝色倒立T
"""
# non-linear test conflict decision making
weight_inh_test=-0.3
num_apl=2
num_da=1
da_mb_connection=train_weight
# kc-apl
con_matrix6 = torch.ones((num_state, num_apl), dtype=torch.float)
da_mb_connection.append(CustomLinear(weight_exc * con_matrix6, con_matrix6))
# apl-kc
con_matrix7 = torch.ones((num_apl,num_state), dtype=torch.float)
da_mb_connection.append(CustomLinear(weight_inh_test * con_matrix7, con_matrix7))
# da-apl
con_matrix8 = torch.ones((num_da, num_apl), dtype=torch.float)
da_mb_connection.append(CustomLinear(weight_inh_test * con_matrix8, con_matrix8))
# apl-da
con_matrix9 = torch.ones((num_apl, num_da), dtype=torch.float)
da_mb_connection.append(CustomLinear(weight_inh_test * con_matrix9, con_matrix9))
# 1-da
con_matrix10 = torch.ones((num_da), dtype=torch.float)
da_mb_connection.append(CustomLinear(weight_exc * con_matrix10, con_matrix10))
# da-mbon
con_matrix11 = torch.ones((num_da,num_action), dtype=torch.float)
da_mb_connection.append(CustomLinear(weight_exc * con_matrix11, con_matrix11))
#0 input-vis 1 vis-kc 2 kc-mbon 3-mbon-mbon 4 kc-apl 5 apl-kc 6 da-apl 7 apl-da 8 input-da
t1 = torch.zeros((test_num), dtype=torch.float)
t2 = torch.zeros((test_num), dtype=torch.float)
for c in range(t1.shape[0]):
DA_MB_test = droDMTestNet(copy.deepcopy(da_mb_connection))
DA_MB_test.reset()
Gt = torch.tensor([0, (c * 0.1), 0, 0, 0.5])
BT = torch.tensor([0, 0, (c * 0.1), 0.5, 0])
input = Gt - BT # input Gt
input[input < 0] = 0
count = torch.zeros((num_action), dtype=torch.float)
for i_train in range(500):
if i_train<10 and i_train%2==0:
input_da = torch.tensor([0.5])
else:
input_da = torch.tensor([0.0])
GT_out, dwkc, dwapl= DA_MB_test(input,input_da)
DA_MB_test.UpdateWeight(5, dwkc)
DA_MB_test.UpdateWeight(4, dwapl)
count += GT_out
t1[c] = count[0]
t2[c] = count[1]
p2 = (t1 - t2) / (t1 + t2)
print(t1, t2, p2)
MB_test = MB.getweight()
for i in range(len(train_weight)):
print("weight after learning:", train_weight[i].weight.data)
x = torch.arange(0, test_num)
x=x*0.1
plt.figure()
A,=plt.plot(x, p1,label="linear")
B,=plt.plot(x, p2,label="non-linear")
font1 = {'family' : 'Times New Roman','weight' : 'normal','size' : 15,}
plt.xlabel("color intensity",font1)
plt.ylabel("PI",font1)
plt.legend(handles=[A,B],prop=font1)
plt.show()
================================================
FILE: examples/Embodied_Cognition/RHI/RHI_Test.py
================================================
import numpy as np
import torch,os,sys
from torch import nn
from torch.nn import Parameter
import abc
import math
from abc import ABC
import numpy as np
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
import matplotlib.pyplot as plt
from braincog.base.strategy.surrogate import *
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import random
import gc
from braincog.base.node.node import IzhNodeMU
import objgraph
from pympler import tracker
class CustomLinear(nn.Module):
def __init__(self, weight,mask=None):
super().__init__()
self.weight = nn.Parameter(weight, requires_grad=True)
self.mask=mask
def forward(self, x: torch.Tensor):
#
# ret.shape = [C]
return x.mul(self.weight) # Changed
def update(self, dw):
with torch.no_grad():
if self.mask is not None:
dw *= self.mask
self.weight.data+= dw
class M1Net(nn.Module):
def __init__(self,connection):
super().__init__()
self.node = []
self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))
self.connection = connection
def forward(self, input):
input_n = input*I_max
input_r = torch.round(input_n)
Spike = torch.zeros(num_neuron, dtype=torch.float)
self.node[0].n_reset()
if TrickID == 1:
for i in range(num_AI):
input_ri = int(input_r[i].item())
FR_i = spike_num_list[input_ri]
Spike[i] = FR_i
FR_n = Spike
else:
for t in range(Simulation_time):
self.out=self.node[0](input_r)
n_Spike = self.node[0].spike
Spike = Spike + n_Spike
FR_n = Spike/Simulation_time
return FR_n
def reset(self):
for i in range(len(self.node)):
self.node[i].n_reset()
class VNet(nn.Module):
def __init__(self,connection):
super().__init__()
self.node = []
self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))
self.connection = connection
def forward(self, input):
input_n = input*I_max
input_r = torch.round(input_n)
Spike = torch.zeros(num_neuron, dtype=torch.float)
self.node[0].n_reset()
if TrickID == 1:
for i in range(num_neuron):
input_ri = int(input_r[i].item())
FR_i = spike_num_list[input_ri]
Spike[i] = FR_i
FR_n = Spike
else:
for t in range(Simulation_time):
self.out=self.node[0](input_r)
n_Spike = self.node[0].spike
Spike = Spike + n_Spike
FR_n = Spike/Simulation_time
return FR_n
def reset(self):
for i in range(len(self.node)):
self.node[i].n_reset()
class S1Net(nn.Module):
def __init__(self,connection):
super().__init__()
self.node = []
self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))
self.connection = connection
def forward(self, input, FR, C):
FR_W = torch.zeros(num_neuron, dtype=torch.float)
if len(FR.shape) == 1:
FR_W = FR*self.connection[0].weight
else:
for i in range(FR.shape[0]):
FR_Wi = FR[i]*self.connection[i].weight
FR_W = FR_W + FR_Wi
sf = torch.tanh(FR_W)
sf = torch.where(sf<0, 0, sf)
input_n = -C * (input-sf) + input
input_n = torch.where(input_n<0, 0, input_n)
input = input_n*I_max
input_r = torch.round(input)
Spike = torch.zeros(num_S1, dtype=torch.float)
self.node[0].n_reset()
if TrickID == 1:
for i in range(num_neuron):
input_ri = int(input_r[i].item())
FR_i = spike_num_list[input_ri]
Spike[i] = FR_i
FR_n = Spike
else:
for t in range(Simulation_time):
self.out=self.node[0](input_r)
n_Spike = self.node[0].spike
Spike = Spike + n_Spike
FR_n = Spike/Simulation_time
return FR_n, input_n
def reset(self):
for i in range(len(self.node)):
self.node[i].n_reset()
class EBANet(nn.Module):
def __init__(self,connection):
super().__init__()
self.node = []
self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))
self.connection = connection
def forward(self, input, FR, C):
FR_W = torch.zeros(num_neuron, dtype=torch.float)
if len(FR.shape) == 1:
FR_W = FR*self.connection[0].weight
else:
for i in range(FR.shape[0]):
FR_Wi = FR[i]*self.connection[i].weight
FR_W = FR_W + FR_Wi
sf = torch.tanh(FR_W)
sf = torch.where(sf<0, 0, sf)
input_n = -C * (input-sf) + input
input_n = torch.where(input_n<0, 0, input_n)
input = input_n*I_max
input_r = torch.round(input)
Spike = torch.zeros(num_S1, dtype=torch.float)
self.node[0].n_reset()
if TrickID == 1:
for i in range(num_neuron):
input_ri = int(input_r[i].item())
FR_i = spike_num_list[input_ri]
Spike[i] = FR_i
FR_n = Spike
else:
for t in range(Simulation_time):
self.out=self.node[0](input_r)
n_Spike = self.node[0].spike
Spike = Spike + n_Spike
FR_n = Spike/Simulation_time
return FR_n, input_n
def reset(self):
for i in range(len(self.node)):
self.node[i].n_reset()
class TPJNet(nn.Module):
def __init__(self,connection):
super().__init__()
self.node = []
self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))
self.connection = connection
def forward(self, input, FR, C):
FR_W = torch.zeros(num_neuron, dtype=torch.float)
if len(FR.shape) == 1:
FR_W = FR*self.connection[0].weight
else:
for i in range(FR.shape[0]):
FR_Wi = FR[i]*self.connection[i].weight
FR_W = FR_W + FR_Wi
sf = torch.tanh(FR_W)
sf = torch.where(sf<0, 0, sf)
input_n = -C * (input-sf) + input
input_n = torch.where(input_n<0, 0, input_n)
input = input_n*I_max
input_r = torch.round(input)
Spike = torch.zeros(num_S1, dtype=torch.float)
self.node[0].n_reset()
if TrickID == 1:
for i in range(num_neuron):
input_ri = int(input_r[i].item())
FR_i = spike_num_list[input_ri]
Spike[i] = FR_i
FR_n = Spike
else:
for t in range(Simulation_time):
self.out=self.node[0](input_r)
n_Spike = self.node[0].spike
Spike = Spike + n_Spike
FR_n = Spike/Simulation_time
return FR_n, input_n
def reset(self):
for i in range(len(self.node)):
self.node[i].n_reset()
def UpdateWeight(self, i, W):
self.connection[i].weight.data = W
class AINet(nn.Module):
def __init__(self,connection):
super().__init__()
self.node = []
self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))
self.connection = connection
def forward(self, input, FR, C):
FR_W = torch.zeros(num_neuron, dtype=torch.float)
if len(FR.shape) == 1:
FR_W = FR*self.connection[0].weight
else:
for i in range(FR.shape[0]):
FR_Wi = FR[i]*self.connection[i].weight
FR_W = FR_W + FR_Wi
sf = torch.tanh(FR_W)
sf = torch.where(sf<0, 0, sf)
input_n = -C * (input-sf) + input
input_n = torch.where(input_n<0, 0, input_n)
input = input_n*I_max
input_r = torch.round(input)
Spike = torch.zeros(num_S1, dtype=torch.float)
self.node[0].n_reset()
if TrickID == 1:
for i in range(num_neuron):
input_ri = int(input_r[i].item())
FR_i = spike_num_list[input_ri]
Spike[i] = FR_i
FR_n = Spike
else:
for t in range(Simulation_time):
self.out=self.node[0](input_r)
n_Spike = self.node[0].spike
Spike = Spike + n_Spike
FR_n = Spike/Simulation_time
return FR_n, input_n
def reset(self):
for i in range(len(self.node)):
self.node[i].n_reset()
def UpdateWeight(self, i, W):
self.connection[i].weight.data = self.connection[i].weight.data + W
def DeltaWeight(Pre, Pre_n, Post, Post_n):
alpha = -0.0035
beta = 0.35
gamma = -0.55
T1 = alpha * (Pre_n*Post_n)
T2 = beta * (Pre_n*(Post_n-Post))
T3 = gamma * ((Pre_n-Pre)*Post_n)
dW = T1 + T2 + T3
return dW
if __name__=="__main__":
"""
Set the number of neurons, and each neuron represents unique motion information (such as angle)
"""
# number of neurons
num_neuron = 9
num_M1 = num_neuron
num_S1 = num_neuron
num_TPJ = num_neuron
num_V = num_neuron
num_EBA = num_neuron
num_AI = num_neuron
Init_Weight = 1.
param_threshold = 30.
param_a = 0.02
param_b = -0.1
param_c = -55.
param_d = 18.
param_mem = -70.
param_u = 0.
param_dt = 1.
Simulation_time = 1000
I_max = 1000
# When TrickID is set to 1, it means that the mapping relationship from input current
# to firing rate is obtained directly by loading Izh.npy,
# which can significantly reduce the program running time
TrickID = 1
if TrickID == 1:
spike_num_list=np.load('Izh.npy')
spike_num_list = spike_num_list/I_max
##############################
# M1
##############################
# M1_Input-M1
M1_connection = []
con_matrix0 = torch.ones(num_M1, dtype=torch.float)*Init_Weight
M1_connection.append(CustomLinear(con_matrix0))
M1 = M1Net(M1_connection)
##############################
# V
##############################
# V_Input-V
V_connection = []
con_matrix3 = torch.ones(num_V, dtype=torch.float)*Init_Weight
V_connection.append(CustomLinear(con_matrix3))
V = VNet(V_connection)
##############################
# S1
##############################
# M1-S1
S1_connection = []
con_matrix1 = torch.ones(num_S1, dtype=torch.float)*Init_Weight
S1_connection.append(CustomLinear(con_matrix1))
S1 = S1Net(S1_connection)
##############################
# EBA
##############################
# V-EBA
EBA_connection = []
con_matrix4 = torch.ones(num_EBA, dtype=torch.float)*Init_Weight
EBA_connection.append(CustomLinear(con_matrix4))
EBA = EBANet(EBA_connection)
##############################
# TPJ
##############################
# S1-TPJ, EBA-TPJ
TPJ_connection = []
# S1-TPJ
con_matrix2 = torch.ones(num_TPJ, dtype=torch.float)*Init_Weight*150
TPJ_connection.append(CustomLinear(con_matrix2))
# EBA-TPJ
con_matrix5 = torch.ones(num_TPJ, dtype=torch.float)*Init_Weight*150
TPJ_connection.append(CustomLinear(con_matrix5))
TPJ = TPJNet(TPJ_connection)
##############################
# AI
##############################
# S1-AI, TPJ-AI, EBA-AI
AI_connection = []
# S1-AI
con_matrix6 = torch.ones(num_AI, dtype=torch.float)*Init_Weight
AI_connection.append(CustomLinear(con_matrix6))
# TPJ-AI
con_matrix7 = torch.ones(num_AI, dtype=torch.float)*Init_Weight
AI_connection.append(CustomLinear(con_matrix7))
# EBA-AI
con_matrix8 = torch.ones(num_AI, dtype=torch.float)*Init_Weight
AI_connection.append(CustomLinear(con_matrix8))
AI = AINet(AI_connection)
AI.connection[0].weight.data = torch.from_numpy(np.load('W_S1_AI.npy'))
AI.connection[2].weight.data = torch.from_numpy(np.load('W_EBA_AI.npy'))
##############################
# Coding
##############################
S = 1
ISI = 1
JMax = int((num_neuron-1)/2)
listJ = list(range(-JMax,JMax+1))
Coding = torch.zeros([num_neuron, num_neuron], dtype=torch.float)
for i in range(len(listJ)):
e = float(listJ[i])
listY = []
for j in range(len(listJ)):
x = float(listJ[j])
y = math.exp(-(x-e)**2/(2*S**2))
Coding[i][j] = y
print(AI.connection[0].weight.data) # dW_S1AI
print(AI.connection[2].weight.data) # dW_EBAAI
##############################
# Test
##############################
Time = 300
CT = 100
Motion_Start = 1
Motion_End = Motion_Start + CT
Vision_Start = Motion_End
Vision_End = Vision_Start + CT
CM1 = 0.04
CV = 0.04
CS1 = 0.04
CEBA = 0.04
CTPJ = 0.01
CAI = 0.15
Result_List = []
Veridical_hand = int((num_neuron-1)/2)
for Disparity in range(-JMax,JMax+1):
M1_input = torch.zeros(num_M1, dtype=torch.float)
V_input = torch.zeros(num_V, dtype=torch.float)
S1_input = torch.zeros(num_S1, dtype=torch.float)
TPJ_input = torch.zeros(num_TPJ, dtype=torch.float)
EBA_input = torch.zeros(num_EBA, dtype=torch.float)
AI_input = torch.zeros(num_AI, dtype=torch.float)
FR_M1 = torch.zeros(num_M1, dtype=torch.float)
FR_V = torch.zeros(num_V, dtype=torch.float)
FR_S1 = torch.zeros(num_S1, dtype=torch.float)
FR_EBA = torch.zeros(num_EBA, dtype=torch.float)
FR_TPJ = torch.zeros(num_TPJ, dtype=torch.float)
FR_AI = torch.zeros(num_AI, dtype=torch.float)
FR_AI_List = torch.zeros([Time, num_AI], dtype=torch.float)
with torch.no_grad():
for t in range(1,Time+1):
S_M1 = torch.zeros(num_M1, dtype=torch.float)
S_V = torch.zeros(num_V, dtype=torch.float)
if t>=Motion_Start and t<=Motion_End:
S_M1 = Coding[Veridical_hand]
M1_input = (1-(1-CM1)**t)*S_M1
else:
M1_input = S_M1
if t>=Vision_Start and t<=Vision_End:
S_V = Coding[Veridical_hand+Disparity]
V_input = (1-(1-CV)**(t-CT))*S_V
else:
V_input = S_V
FR_M1_n = M1(M1_input)
FR_V_n = V(V_input)
[FR_S1_n, S1_input_n] = S1(S1_input, FR_M1_n, CS1)
[FR_EBA_n, EBA_input_n] = EBA(EBA_input, FR_V_n, CEBA)
FR_Input_TPJ_n = torch.stack((FR_S1_n, FR_EBA_n), 0)
[FR_TPJ_n, TPJ_input_n] = TPJ(TPJ_input, FR_Input_TPJ_n, CTPJ)
FR_Input_AI = torch.stack((FR_S1_n, FR_TPJ_n, FR_EBA_n), 0)
[FR_AI_n, AI_input_n] = AI(AI_input, FR_Input_AI, CAI)
FR_AI_List[t-1] = FR_AI_n
FR_M1 = FR_M1_n
FR_V = FR_V_n
FR_S1 = FR_S1_n
FR_EBA = FR_EBA_n
FR_TPJ = FR_TPJ_n
FR_AI = FR_AI_n
S1_input = S1_input_n
TPJ_input = TPJ_input_n
EBA_input = EBA_input_n
AI_input = AI_input_n
print('Test Time End')
Estimated_hand = torch.max(torch.max(FR_AI_List, 0)[0],0)[1].item()
Proprioceptive_drift = Estimated_hand - Veridical_hand
R = [Disparity, Proprioceptive_drift]
Result_List.append(R)
print(R)
print(torch.max(FR_AI_List, 0)[0])
print("----------------------------")
print(Result_List)
X = [x[0] for x in Result_List]
Y = [x[1] for x in Result_List]
S = np.polyfit(X,Y,3)
xn = np.linspace(-(num_neuron-1)/2, (num_neuron-1)/2, 1000)
yn = np.poly1d(S)
plt.plot(xn, yn(xn), X, Y, 'o')
plt.show()
================================================
FILE: examples/Embodied_Cognition/RHI/RHI_Train.py
================================================
import numpy as np
import torch,os,sys
from torch import nn
from torch.nn import Parameter
import abc
import math
from abc import ABC
import numpy as np
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
import matplotlib.pyplot as plt
from braincog.base.strategy.surrogate import *
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import random
import gc
from braincog.base.node.node import IzhNodeMU
import objgraph
from pympler import tracker
class CustomLinear(nn.Module):
def __init__(self, weight,mask=None):
super().__init__()
self.weight = nn.Parameter(weight, requires_grad=True)
self.mask=mask
def forward(self, x: torch.Tensor):
#
# ret.shape = [C]
return x.mul(self.weight)
def update(self, dw):
with torch.no_grad():
if self.mask is not None:
dw *= self.mask
self.weight.data+= dw
class M1Net(nn.Module):
def __init__(self,connection):
super().__init__()
self.node = []
self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))
self.connection = connection
def forward(self, input):
input_n = input*I_max
input_r = torch.round(input_n)
Spike = torch.zeros(num_neuron, dtype=torch.float)
self.node[0].n_reset()
if TrickID == 1:
for i in range(num_AI):
input_ri = int(input_r[i].item())
FR_i = spike_num_list[input_ri]
Spike[i] = FR_i
FR_n = Spike
else:
for t in range(Simulation_time):
self.out=self.node[0](input_r)
n_Spike = self.node[0].spike
Spike = Spike + n_Spike
FR_n = Spike/Simulation_time
return FR_n
def reset(self):
for i in range(len(self.node)):
self.node[i].n_reset()
class VNet(nn.Module):
def __init__(self,connection):
super().__init__()
self.node = []
self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))
self.connection = connection
def forward(self, input):
input_n = input*I_max
input_r = torch.round(input_n)
Spike = torch.zeros(num_neuron, dtype=torch.float)
self.node[0].n_reset()
if TrickID == 1:
for i in range(num_neuron):
input_ri = int(input_r[i].item())
FR_i = spike_num_list[input_ri]
Spike[i] = FR_i
FR_n = Spike
else:
for t in range(Simulation_time):
self.out=self.node[0](input_r)
n_Spike = self.node[0].spike
Spike = Spike + n_Spike
FR_n = Spike/Simulation_time
return FR_n
def reset(self):
for i in range(len(self.node)):
self.node[i].n_reset()
class S1Net(nn.Module):
def __init__(self,connection):
super().__init__()
self.node = []
self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))
self.connection = connection
def forward(self, input, FR, C, Fired, W_LatInh):
FR_W = torch.zeros(num_neuron, dtype=torch.float)
if len(FR.shape) == 1:
FR_W = FR*self.connection[0].weight
else:
for i in range(FR.shape[0]):
FR_Wi = FR[i]*self.connection[i].weight
FR_W = FR_W + FR_Wi
sf = torch.tanh(FR_W)
sf = torch.where(sf<0, 0, sf)
input_n = -C * (input-sf) + input
input_n = torch.where(input_n<0, 0, input_n)
input = input_n*I_max
input_r = torch.round(input)
Spike = torch.zeros(num_S1, dtype=torch.float)
self.node[0].n_reset()
if TrickID == 1:
for i in range(num_neuron):
input_ri = int(input_r[i].item())
FR_i = spike_num_list[input_ri]
Spike[i] = FR_i
FR_n = Spike
else:
for t in range(Simulation_time):
self.out=self.node[0](input_r)
n_Spike = self.node[0].spike
Spike = Spike + n_Spike
FR_n = Spike/Simulation_time
S = input_n
S = torch.where(input_n>= fire_threshold, 1, S)
S = torch.where(input_n< fire_threshold, 0, S)
if torch.sum(S) > 0:
Fired = Fired + 1;
W_LatInh = torch.tanh(W_LatInh - 2 * torch.acos(S) * torch.exp(Fired) - 1) + 1
return FR_n, input_n, Fired, W_LatInh
def reset(self):
for i in range(len(self.node)):
self.node[i].n_reset()
class EBANet(nn.Module):
def __init__(self,connection):
super().__init__()
self.node = []
self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))
self.connection = connection
def forward(self, input, FR, C, Fired, W_LatInh):
FR_W = torch.zeros(num_neuron, dtype=torch.float)
if len(FR.shape) == 1:
FR_W = FR*self.connection[0].weight
else:
for i in range(FR.shape[0]):
FR_Wi = FR[i]*self.connection[i].weight
FR_W = FR_W + FR_Wi
sf = torch.tanh(FR_W)
sf = torch.where(sf<0, 0, sf)
input_n = -C * (input-sf) + input
input_n = torch.where(input_n<0, 0, input_n)
input = input_n*I_max
input_r = torch.round(input)
Spike = torch.zeros(num_S1, dtype=torch.float)
self.node[0].n_reset()
if TrickID == 1:
for i in range(num_neuron):
input_ri = int(input_r[i].item())
FR_i = spike_num_list[input_ri]
Spike[i] = FR_i
FR_n = Spike
else:
for t in range(Simulation_time):
self.out=self.node[0](input_r)
n_Spike = self.node[0].spike
Spike = Spike + n_Spike
FR_n = Spike/Simulation_time
S = input_n
S = torch.where(input_n>= fire_threshold, 1, S)
S = torch.where(input_n< fire_threshold, 0, S)
if torch.sum(S) > 0:
Fired = Fired + 1;
W_LatInh = torch.tanh(W_LatInh - 2 * torch.acos(S) * torch.exp(Fired) - 1) + 1
return FR_n, input_n, Fired, W_LatInh
def reset(self):
for i in range(len(self.node)):
self.node[i].n_reset()
class TPJNet(nn.Module):
def __init__(self,connection):
super().__init__()
self.node = []
self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))
self.connection = connection
def forward(self, input, FR, C):
FR_W = torch.zeros(num_neuron, dtype=torch.float)
if len(FR.shape) == 1:
FR_W = FR*self.connection[0].weight
else:
for i in range(FR.shape[0]):
FR_Wi = FR[i]*self.connection[i].weight
FR_W = FR_W + FR_Wi
sf = torch.tanh(FR_W)
sf = torch.where(sf<0, 0, sf)
input_n = -C * (input-sf) + input
input_n = torch.where(input_n<0, 0, input_n)
input = input_n*I_max
input_r = torch.round(input)
Spike = torch.zeros(num_S1, dtype=torch.float)
self.node[0].n_reset()
if TrickID == 1:
for i in range(num_neuron):
input_ri = int(input_r[i].item())
FR_i = spike_num_list[input_ri]
Spike[i] = FR_i
FR_n = Spike
else:
for t in range(Simulation_time):
self.out=self.node[0](input_r)
n_Spike = self.node[0].spike
Spike = Spike + n_Spike
FR_n = Spike/Simulation_time
return FR_n, input_n
def reset(self):
for i in range(len(self.node)):
self.node[i].n_reset()
def UpdateWeight(self, i, W):
self.connection[i].weight.data = W
class AINet(nn.Module):
def __init__(self,connection):
super().__init__()
self.node = []
self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))
self.connection = connection
def forward(self, input, FR, C):
FR_W = torch.zeros(num_neuron, dtype=torch.float)
if len(FR.shape) == 1:
FR_W = FR*self.connection[0].weight
else:
for i in range(FR.shape[0]):
FR_Wi = FR[i]*self.connection[i].weight
FR_W = FR_W + FR_Wi
sf = torch.tanh(FR_W)
sf = torch.where(sf<0, 0, sf)
input_n = -C * (input-sf) + input
input_n = torch.where(input_n<0, 0, input_n)
input = input_n*I_max
input_r = torch.round(input)
Spike = torch.zeros(num_S1, dtype=torch.float)
self.node[0].n_reset()
if TrickID == 1:
for i in range(num_neuron):
input_ri = int(input_r[i].item())
FR_i = spike_num_list[input_ri]
Spike[i] = FR_i
FR_n = Spike
else:
for t in range(Simulation_time):
self.out=self.node[0](input_r)
n_Spike = self.node[0].spike
Spike = Spike + n_Spike
FR_n = Spike/Simulation_time
return FR_n, input_n
def reset(self):
for i in range(len(self.node)):
self.node[i].n_reset()
def UpdateWeight(self, i, W, WIn):
self.connection[i].weight.data = self.connection[i].weight.data + W*WIn
def DeltaWeight(Pre, Pre_n, Post, Post_n):
alpha = -0.0035
beta = 0.35
gamma = -0.55
T1 = alpha * (Pre_n*Post_n)
T2 = beta * (Pre_n*(Post_n-Post))
T3 = gamma * ((Pre_n-Pre)*Post_n)
dW = T1 + T2 + T3
return dW
if __name__=="__main__":
"""
Set the number of neurons, and each neuron represents unique motion information (such as angle)
"""
# number of neurons
num_neuron = 9
num_M1 = num_neuron
num_S1 = num_neuron
num_TPJ = num_neuron
num_V = num_neuron
num_EBA = num_neuron
num_AI = num_neuron
Init_Weight = 1.
param_threshold = 30.
param_a = 0.02
param_b = -0.1
param_c = -55.
param_d = 18.
param_mem = -70.
param_u = 0.
param_dt = 1.
Simulation_time = 1000
I_max = 1000
# When the TrickID is set to 1, it means that the mapping relationship from input current
# to firing rate is obtained directly by loading the Izh.npy,
# which can significantly reduce the program running time
TrickID = 1
if TrickID == 1:
spike_num_list=np.load('Izh.npy')
spike_num_list = spike_num_list/I_max
##############################
# M1
##############################
# M1_Input-M1
M1_connection = []
con_matrix0 = torch.ones(num_M1, dtype=torch.float)*Init_Weight
M1_connection.append(CustomLinear(con_matrix0))
M1 = M1Net(M1_connection)
##############################
# V
##############################
# V_Input-V
V_connection = []
con_matrix3 = torch.ones(num_V, dtype=torch.float)*Init_Weight
V_connection.append(CustomLinear(con_matrix3))
V = VNet(V_connection)
##############################
# S1
##############################
# M1-S1
S1_connection = []
con_matrix1 = torch.ones(num_S1, dtype=torch.float)*Init_Weight
S1_connection.append(CustomLinear(con_matrix1))
S1 = S1Net(S1_connection)
##############################
# EBA
##############################
# V-EBA
EBA_connection = []
con_matrix4 = torch.ones(num_EBA, dtype=torch.float)*Init_Weight
EBA_connection.append(CustomLinear(con_matrix4))
EBA = EBANet(EBA_connection)
##############################
# TPJ
##############################
# S1-TPJ, EBA-TPJ
TPJ_connection = []
# S1-TPJ
con_matrix2 = torch.ones(num_TPJ, dtype=torch.float)*Init_Weight*150
TPJ_connection.append(CustomLinear(con_matrix2))
# EBA-TPJ
con_matrix5 = torch.ones(num_TPJ, dtype=torch.float)*Init_Weight*150
TPJ_connection.append(CustomLinear(con_matrix5))
TPJ = TPJNet(TPJ_connection)
##############################
# AI
##############################
# S1-AI, TPJ-AI, EBA-AI
AI_connection = []
# S1-AI
con_matrix6 = torch.ones(num_AI, dtype=torch.float)*Init_Weight
AI_connection.append(CustomLinear(con_matrix6))
# TPJ-AI
con_matrix7 = torch.ones(num_AI, dtype=torch.float)*Init_Weight
AI_connection.append(CustomLinear(con_matrix7))
# EBA-AI
con_matrix8 = torch.ones(num_AI, dtype=torch.float)*Init_Weight
AI_connection.append(CustomLinear(con_matrix8))
AI = AINet(AI_connection)
##############################
# Coding
##############################
S = 1
ISI = 1
JMax = int((num_neuron-1)/2)
listJ = list(range(-JMax,JMax+1))
Coding = torch.zeros([num_neuron, num_neuron], dtype=torch.float)
for i in range(len(listJ)):
e = float(listJ[i])
listY = []
for j in range(len(listJ)):
x = float(listJ[j])
y = math.exp(-(x-e)**2/(2*S**2))
Coding[i][j] = y
##############################
# Train
##############################
MoveNum = 25
Time = 300
CT = 100
Motion_Start = 1
Motion_End = Motion_Start + CT
Vision_Start = Motion_End
Vision_End = Vision_Start + CT
CM1 = 0.04
CV = 0.04
CS1 = 0.04
CEBA = 0.04
CTPJ = 0.01
CAI = 0.15
for k in range(num_neuron):
for i in range(MoveNum):
print(i)
M1_input = torch.zeros(num_M1, dtype=torch.float)
V_input = torch.zeros(num_V, dtype=torch.float)
S1_input = torch.zeros(num_S1, dtype=torch.float)
TPJ_input = torch.zeros(num_TPJ, dtype=torch.float)
EBA_input = torch.zeros(num_EBA, dtype=torch.float)
AI_input = torch.zeros(num_AI, dtype=torch.float)
FR_M1 = torch.zeros(num_M1, dtype=torch.float)
FR_V = torch.zeros(num_V, dtype=torch.float)
FR_S1 = torch.zeros( num_S1, dtype=torch.float)
FR_EBA = torch.zeros(num_EBA, dtype=torch.float)
FR_TPJ = torch.zeros(num_TPJ, dtype=torch.float)
FR_AI = torch.zeros( num_AI, dtype=torch.float)
dW_S1TPJ = torch.zeros(num_M1, dtype=torch.float)
dW_EBATPJ = torch.zeros(num_M1, dtype=torch.float)
dW_S1AI = torch.zeros(num_M1, dtype=torch.float)
dW_EBAAI = torch.zeros(num_M1, dtype=torch.float)
fire_threshold = 0.7
W_LatInh_Init = torch.ones(num_neuron, dtype=torch.float)*Init_Weight
W_LatInh_S1_AI = W_LatInh_Init
W_LatInh_EBA_AI = W_LatInh_Init
Fired_S1 = torch.zeros(num_S1, dtype=torch.float)
Fired_EBA = torch.zeros(num_EBA, dtype=torch.float)
with torch.no_grad():
for t in range(1,Time+1):
S_M1 = torch.zeros(num_M1, dtype=torch.float)
S_V = torch.zeros(num_V, dtype=torch.float)
if t>=Motion_Start and t<=Motion_End:
S_M1 = Coding[k]
M1_input = (1-(1-CM1)**t)*S_M1
else:
M1_input = S_M1
if t>=Vision_Start and t<=Vision_End:
S_V = Coding[k]
V_input = (1-(1-CV)**(t-CT))*S_V
else:
V_input = S_V
FR_M1_n = M1(M1_input)
FR_V_n = V(V_input)
[FR_S1_n, S1_input_n, Fired_S1, W_LatInh_S1_AI] = S1(S1_input, FR_M1_n, CS1, Fired_S1, W_LatInh_S1_AI)
[FR_EBA_n, EBA_input_n, Fired_EBA, W_LatInh_EBA_AI] = EBA(EBA_input, FR_V_n, CEBA, Fired_EBA, W_LatInh_EBA_AI)
FR_Input_TPJ_n = torch.stack((FR_S1_n, FR_EBA_n), 0)
[FR_TPJ_n, TPJ_input_n] = TPJ(TPJ_input, FR_Input_TPJ_n, CTPJ)
FR_Input_AI = torch.stack((FR_S1_n, FR_TPJ_n, FR_EBA_n), 0)
[FR_AI_n, AI_input_n] = AI(AI_input, FR_Input_AI, CAI)
# Update weights
# S1-AI
ddW_S1AI = DeltaWeight(FR_S1, FR_S1_n, FR_AI, FR_AI_n)
dW_S1AI = dW_S1AI + ddW_S1AI
# EBA-AI
ddW_EBAAI = DeltaWeight(FR_EBA, FR_EBA_n, FR_AI, FR_AI_n)
dW_EBAAI = dW_EBAAI + ddW_EBAAI
FR_M1 = FR_M1_n
FR_V = FR_V_n
FR_S1 = FR_S1_n
FR_EBA = FR_EBA_n
FR_TPJ = FR_TPJ_n
FR_AI = FR_AI_n
S1_input = S1_input_n
TPJ_input = TPJ_input_n
EBA_input = EBA_input_n
AI_input = AI_input_n
AI.UpdateWeight(0, dW_S1AI, W_LatInh_S1_AI)
AI.UpdateWeight(2, dW_EBAAI, W_LatInh_EBA_AI)
print(AI.connection[0].weight.data) # dW_S1AI
print(AI.connection[2].weight.data) # dW_EBAAI
M1.reset()
V.reset()
S1.reset()
EBA.reset()
TPJ.reset()
AI.reset()
np.save('W_S1_AI.npy', AI.connection[0].weight.data)
np.save('W_EBA_AI.npy', AI.connection[2].weight.data)
print('Training End')
================================================
FILE: examples/Embodied_Cognition/RHI/ReadMe.md
================================================
================================================
FILE: examples/Hardware_acceleration/README.md
================================================
## FireFly: A High-Throughput Hardware Accelerator for Spiking Neural Networks
### Demo of Deploying SNNs on FPGA platform
This is an example of deploying an SNN model on Xilinx Zynq Ultrascale FPGA based on Braincog.
### Requirements
- Xilinx Zynq Ultrascale FPGA evaluation board Ultra96v2 or ZCU104.
- PYNQ images for the chosen evaluation boards. You can download the latest pre-compiled images from the [PYNQ website](http://www.pynq.io/board.html), or you can compile a new one following the [PYNQ Tutorial](https://pynq.readthedocs.io/en/latest/). Install the PYNQ image to the SD card, and boot the evaluation board in SD mode.
### Examples
Clone the project to fetch the necessary bitstream files and pre-processed SNN models, copy all the files to the Ultra96v2 or ZCU104 board.
```shell
git clone https://github.com/adamgallas/firefly_v1_cifar_test
```
Open a terminal in Ultra96v2 or ZCU104. Install einops on Ultra96v2 or ZCU104.
```shell
cd firefly_v1_common
pip install einops-0.6.0-py3-none-any.whl
```
Run CIFAR10 classification test on Ultra96v2:
```shell
python ultra96_test.py
```
Run CIFAR10 classification test on ZCU104:
```python
python zcu104_test.py
```
### Citation
### Citation
If you find this work helpful, please consider citing it:
```BibTex
@article{li2023firefly,
title={FireFly: A High-Throughput Hardware Accelerator for Spiking Neural Networks With Efficient DSP and Memory Optimization},
author={Li, Jindong and Shen, Guobin and Zhao, Dongcheng and Zhang, Qian and Zeng, Yi},
journal={IEEE Transactions on Very Large Scale Integration (VLSI) Systems},
year={2023},
publisher={IEEE}
}
```
================================================
FILE: examples/Hardware_acceleration/firefly_v1_schedule_on_pynq.py
================================================
import numpy as np
import tqdm
from standalone_utils import *
import math
import time
import ctypes as ct
class FireFlyV1ConvSchedule:
def __init__(
self,
ctrl_io,
allocate_method,
input_buffer_addr,
output_buffer_addr,
weight_data,
bias_data,
parallel_channel=16,
kernel_size=3,
input_channels=64,
output_channels=128,
width=32,
height=32,
enable_pooling=False,
direct_adapt=False,
winner_takes_all=False,
final_conv=False,
time_step=8,
threshold=64,
max_cnt=2048
):
self.ctrl_io = ctrl_io
self.input_buffer_addr = np.uint32(input_buffer_addr)
self.output_buffer_addr = np.uint32(output_buffer_addr)
self.weight_buffer = allocate_method(shape=weight_data.size, dtype=np.int8)
self.bias_buffer = allocate_method(shape=bias_data.size, dtype=np.int16)
self.weight_buffer_addr = np.uint32(self.weight_buffer.device_address)
self.bias_buffer_addr = np.uint32(self.bias_buffer.device_address)
self.weight_buffer[:] = np.ascontiguousarray(weight_data.flatten())
self.bias_buffer[:] = np.ascontiguousarray(bias_data.flatten())
self.weight_buffer.flush()
self.bias_buffer.flush()
self.max_cnt = max_cnt
self.parallel_channel = np.uint32(parallel_channel)
self.kernel_size = np.uint32(kernel_size)
self.input_channels = np.uint32(input_channels)
self.output_channels = np.uint32(output_channels)
self.width = np.uint32(width)
self.height = np.uint32(height)
self.enable_pooling = np.uint32(enable_pooling)
self.direct_adapt = np.uint32(direct_adapt)
self.winner_takes_all = np.uint32(winner_takes_all)
self.time_step = np.uint32(time_step)
self.threshold = np.int32(threshold)
self.out_width = np.uint32(width >> enable_pooling)
self.out_height = np.uint32(height >> enable_pooling)
self.numOfIFMs = np.uint32(input_channels / parallel_channel - 1)
self.numOfOFMs = np.uint32(output_channels / parallel_channel - 1)
self.numOfTimeSteps = np.uint32(time_step - 1)
self.numOfTimeStepIFMs = np.uint32((input_channels / parallel_channel) * time_step - 1)
self.numOfTimeStepOFMs = np.uint32((output_channels / parallel_channel) * time_step - 1)
self.weightsLength = np.uint32(input_channels - 1)
if direct_adapt:
factor = kernel_size * kernel_size
padded_length = math.ceil(input_channels / parallel_channel / factor)
self.numOfIFMs = np.uint32(padded_length - 1)
self.numOfTimeStepIFMs = np.uint32(padded_length * time_step - 1)
self.weightsLength = np.uint32(padded_length * parallel_channel - 1)
self.out_width = np.uint32(1)
self.out_height = np.uint32(1)
self.mm2s_fix_len = np.uint32(self.width * self.height * self.time_step * self.input_channels / 8)
self.s2mm_fix_len = np.uint32(self.out_width * self.out_height * self.parallel_channel / 8)
self.bias_len = np.uint32(self.output_channels * 2)
self.weight_len = np.uint32(self.output_channels * self.input_channels * self.kernel_size * self.kernel_size)
self.stride_of_channel = np.uint32(self.out_width * self.out_height * self.parallel_channel / 8)
self.stride_of_time_step = np.uint32(self.out_width * self.out_height * self.output_channels / 8)
if direct_adapt:
self.mm2s_fix_len = np.uint32(self.time_step * self.input_channels / 8)
self.weight_len = np.uint32(self.output_channels * self.input_channels)
self.stride_of_channel = np.uint32(2 * self.parallel_channel / 8)
self.stride_of_time_step = np.uint32(2 * self.output_channels / 8)
if final_conv:
self.flatten_channel = np.uint32(self.out_width * self.out_height * self.output_channels)
factor = kernel_size * kernel_size * 8 * 4
round_channel = int(math.ceil(self.flatten_channel / factor) * factor)
self.stride_of_time_step = np.uint32(round_channel / 8)
self.configReg_0x00 = np.uint32(((self.time_step - 1) << 16) + (self.numOfOFMs << 8) + self.numOfIFMs).tobytes()
self.configReg_0x04 = np.uint32((self.numOfTimeStepOFMs << 12) + self.numOfTimeStepIFMs).tobytes()
self.configReg_0x08 = np.uint32((self.threshold << 14) + self.weightsLength).tobytes()
self.configReg_0x0c = np.uint32((self.winner_takes_all << 30) + (self.direct_adapt << 29) + (
self.enable_pooling << 28) + ((self.height - 1) << 16) + self.width - 1).tobytes()
self.configReg_0x20 = np.uint32(self.stride_of_time_step).tobytes()
self.configReg_0x24 = np.uint32(self.stride_of_channel).tobytes()
self.configReg_0x28 = np.uint32(self.input_buffer_addr).tobytes()
self.configReg_0x2c = np.uint32(self.output_buffer_addr).tobytes()
self.configReg_0x30 = np.uint32(self.mm2s_fix_len).tobytes()
self.configReg_0x34 = np.uint32(self.s2mm_fix_len).tobytes()
self.configReg_0x38 = np.uint32(self.weight_buffer_addr).tobytes()
self.configReg_0x3c = np.uint32(self.weight_len).tobytes()
self.configReg_0x40 = np.uint32(self.bias_buffer_addr).tobytes()
self.configReg_0x44 = np.uint32(self.bias_len).tobytes()
self.paramCmd = np.uint32(0x00010000).tobytes()
self.inOutCmd = np.uint32(0x00000101).tobytes()
def gen_cmd(self):
cmd_list = []
cmd_list.append(np.frombuffer(self.configReg_0x00, dtype=np.uint32))
cmd_list.append(np.frombuffer(self.configReg_0x04, dtype=np.uint32))
cmd_list.append(np.frombuffer(self.configReg_0x08, dtype=np.uint32))
cmd_list.append(np.frombuffer(self.configReg_0x0c, dtype=np.uint32))
cmd_list.append(np.frombuffer(self.configReg_0x20, dtype=np.uint32))
cmd_list.append(np.frombuffer(self.configReg_0x24, dtype=np.uint32))
cmd_list.append(np.frombuffer(self.configReg_0x28, dtype=np.uint32))
cmd_list.append(np.frombuffer(self.configReg_0x2c, dtype=np.uint32))
cmd_list.append(np.frombuffer(self.configReg_0x30, dtype=np.uint32))
cmd_list.append(np.frombuffer(self.configReg_0x34, dtype=np.uint32))
cmd_list.append(np.frombuffer(self.configReg_0x38, dtype=np.uint32))
cmd_list.append(np.frombuffer(self.configReg_0x3c, dtype=np.uint32))
cmd_list.append(np.frombuffer(self.configReg_0x40, dtype=np.uint32))
cmd_list.append(np.frombuffer(self.configReg_0x44, dtype=np.uint32))
cmd_list.append(np.frombuffer(self.paramCmd, dtype=np.uint32))
cmd_list.append(np.frombuffer(self.inOutCmd, dtype=np.uint32))
return np.array(cmd_list).flatten()
def send_config(self):
self.ctrl_io.write(0x00, self.configReg_0x00)
self.ctrl_io.write(0x04, self.configReg_0x04)
self.ctrl_io.write(0x08, self.configReg_0x08)
self.ctrl_io.write(0x0c, self.configReg_0x0c)
self.ctrl_io.write(0x20, self.configReg_0x20)
self.ctrl_io.write(0x24, self.configReg_0x24)
self.ctrl_io.write(0x28, self.configReg_0x28)
self.ctrl_io.write(0x2c, self.configReg_0x2c)
self.ctrl_io.write(0x30, self.configReg_0x30)
self.ctrl_io.write(0x34, self.configReg_0x34)
self.ctrl_io.write(0x38, self.configReg_0x38)
self.ctrl_io.write(0x3c, self.configReg_0x3c)
self.ctrl_io.write(0x40, self.configReg_0x40)
self.ctrl_io.write(0x44, self.configReg_0x44)
def begin_schedule_non_blocking(self):
self.ctrl_io.write(0x48, self.paramCmd)
self.ctrl_io.write(0x48, self.paramCmd)
self.ctrl_io.write(0x48, self.inOutCmd)
def begin_schedule_blocking(self):
self.begin_schedule_non_blocking()
while self.ctrl_io.read(0x18, length=8) == 0:
continue
self.clear_schedule()
def clear_schedule(self):
self.ctrl_io.write(0x18, 0)
def run_all(self):
self.ctrl_io.write(0x00, self.configReg_0x00)
self.ctrl_io.write(0x04, self.configReg_0x04)
self.ctrl_io.write(0x08, self.configReg_0x08)
self.ctrl_io.write(0x0c, self.configReg_0x0c)
self.ctrl_io.write(0x20, self.configReg_0x20)
self.ctrl_io.write(0x24, self.configReg_0x24)
self.ctrl_io.write(0x28, self.configReg_0x28)
self.ctrl_io.write(0x2c, self.configReg_0x2c)
self.ctrl_io.write(0x30, self.configReg_0x30)
self.ctrl_io.write(0x34, self.configReg_0x34)
self.ctrl_io.write(0x38, self.configReg_0x38)
self.ctrl_io.write(0x3c, self.configReg_0x3c)
self.ctrl_io.write(0x40, self.configReg_0x40)
self.ctrl_io.write(0x44, self.configReg_0x44)
self.ctrl_io.write(0x48, self.paramCmd)
self.ctrl_io.write(0x48, self.paramCmd)
self.ctrl_io.write(0x48, self.inOutCmd)
cnt = 0
while self.ctrl_io.read(0x18, length=8) == 0:
cnt = cnt + 1
if cnt > self.max_cnt:
print("timeout, abort!")
break
end = time.time()
self.ctrl_io.write(0x18, 0)
def read_status(self):
status = np.uint32(self.ctrl_io.read(0x50, length=4)).tobytes()
busy_status = status[0]
input_status = status[1]
output_status = status[2]
param_status = status[3]
print("busy_status", busy_status)
print("input_status", input_status)
print("output_status", output_status)
print("param_status", param_status)
def create_schedule(model_config_list: list,
ctrl_io,
allocate_method,
buffer_0,
buffer_1,
image_height,
image_width,
time_step=4,
parallel_channel=16
):
schedule_list = []
curr_image_height = image_height
curr_image_width = image_width
input_buffer_addr = buffer_0.device_address
output_buffer_addr = buffer_1.device_address
for config in model_config_list:
if config["layer_type"] == "conv+IFNode":
schedule = FireFlyV1ConvSchedule(
ctrl_io=ctrl_io,
allocate_method=allocate_method,
input_buffer_addr=input_buffer_addr,
output_buffer_addr=output_buffer_addr,
weight_data=conv_weight_channel_tiling(parallel_channel, config["weight"]),
bias_data=config["bias"].flatten(),
parallel_channel=parallel_channel,
kernel_size=3,
input_channels=config["input_channel"],
output_channels=config["output_channel"],
width=curr_image_width,
height=curr_image_height,
enable_pooling=False,
time_step=time_step,
threshold=config["threshold"],
final_conv="flatten" in config
)
schedule_list.append(schedule)
input_buffer_addr, output_buffer_addr = output_buffer_addr, input_buffer_addr
elif config["layer_type"] == "conv+IFNode+maxpool":
schedule = FireFlyV1ConvSchedule(
ctrl_io=ctrl_io,
allocate_method=allocate_method,
input_buffer_addr=input_buffer_addr,
output_buffer_addr=output_buffer_addr,
weight_data=conv_weight_channel_tiling(parallel_channel, config["weight"]),
bias_data=config["bias"].flatten(),
parallel_channel=parallel_channel,
kernel_size=3,
input_channels=config["input_channel"],
output_channels=config["output_channel"],
width=curr_image_width,
height=curr_image_height,
enable_pooling=True,
time_step=time_step,
threshold=config["threshold"],
final_conv="flatten" in config
)
schedule_list.append(schedule)
input_buffer_addr, output_buffer_addr = output_buffer_addr, input_buffer_addr
curr_image_height = int(curr_image_height / 2)
curr_image_width = int(curr_image_width / 2)
elif config["layer_type"].__contains__("linear"):
input_channel = config["input_channel"]
weight = config["weight"]
if "weight_reshape" in config:
factor = 9 * 8 * 4
round_channel = int(math.ceil(input_channel / factor) * factor)
weight = rearrange(weight, "o (i p h w) -> o (i h w p)", p=parallel_channel,
h=curr_image_height, w=curr_image_width)
weight = np.pad(weight, ((0, 0), (0, round_channel - input_channel)), mode="constant")
input_channel = round_channel
weight = linear_weight_channel_tiling(parallel_channel, weight)
schedule = FireFlyV1ConvSchedule(
ctrl_io=ctrl_io,
allocate_method=allocate_method,
input_buffer_addr=input_buffer_addr,
output_buffer_addr=output_buffer_addr,
weight_data=weight,
bias_data=config["bias"].flatten(),
parallel_channel=parallel_channel,
kernel_size=3,
input_channels=input_channel,
output_channels=config["output_channel"],
width=3,
height=3,
enable_pooling=False,
time_step=time_step,
threshold=config["threshold"],
direct_adapt=config["direct_adapt"],
winner_takes_all=config["winner_take_all"]
)
schedule_list.append(schedule)
input_buffer_addr, output_buffer_addr = output_buffer_addr, input_buffer_addr
return schedule_list, input_buffer_addr
def schedule_run_all(schedule_list):
for schedule in schedule_list:
schedule.run_all()
def gen_cmd_array(schedule_list):
cmd_array = []
for schedule in schedule_list:
cmd_array.append(schedule.gen_cmd().flatten())
return np.array(cmd_array)
def init_firefly_c_lib(path, schedule_list):
cmd_arr = gen_cmd_array(schedule_list)
lib = ct.CDLL(path)
sche = lib.firefly_v1_schedule
u32Ptr = ct.POINTER(ct.c_uint32)
u32PtrPtr = ct.POINTER(u32Ptr)
ct_arr = np.ctypeslib.as_ctypes(cmd_arr)
u32PtrArr = u32Ptr * ct_arr._length_
ct_ptr = ct.cast(u32PtrArr(*(ct.cast(row, u32Ptr) for row in ct_arr)), u32PtrPtr)
sche_len = ct.c_uint8(cmd_arr.shape[0])
return sche, ct_ptr, sche_len
def init_firefly_c_lib_with_time(path, schedule_list):
cmd_arr = gen_cmd_array(schedule_list)
lib = ct.CDLL(path)
sche = lib.firefly_v1_schedule_time_it
u32Ptr = ct.POINTER(ct.c_uint32)
u32PtrPtr = ct.POINTER(u32Ptr)
ct_arr = np.ctypeslib.as_ctypes(cmd_arr)
u32PtrArr = u32Ptr * ct_arr._length_
ct_ptr = ct.cast(u32PtrArr(*(ct.cast(row, u32Ptr) for row in ct_arr)), u32PtrPtr)
sche_len = ct.c_uint8(cmd_arr.shape[0])
return sche, ct_ptr, sche_len
def firefly_v1_simulate(model_config_list, x):
for config in model_config_list:
if config["layer_type"] == "input_quant_stub":
x = np_quantize_prepare(x, config["scale"], config["zero_point"])
elif config["layer_type"] == "encoder+conv+IFNode":
x = direct_coding(x, config["weight"], config["bias"], config["time_step"], config["threshold"])
elif config["layer_type"] == "conv+IFNode":
x = conv_ifnode_forward(x, config["weight"], config["bias"], config["threshold"])
elif config["layer_type"] == "conv+IFNode+maxpool":
x = conv_ifnode_maxpool_forward(x, config["weight"], config["bias"], config["threshold"])
elif config["layer_type"] == "linear+WTA":
x = linear_wta_forward(x, config["weight"], config["bias"])
elif config["layer_type"] == "linear+IFNode":
x = linear_ifnode_forward(x, config["weight"], config["bias"], config["threshold"])
return x
def evaluate_simulate(model_config_list, sample):
correct = 0
for (image, target) in tqdm.tqdm(zip(sample[0], sample[1]), total=len(sample[0])):
sim_in = np.expand_dims(image.numpy(), axis=0)
_, sim_out = firefly_v1_simulate(model_config_list, sim_in)
correct += sim_out == target.item()
return correct / len(sample[0])
================================================
FILE: examples/Hardware_acceleration/standalone_utils.py
================================================
import math
import numpy as np
from einops import rearrange
def get_im2col_indices(x_shape, field_height, field_width, padding=1, stride=1):
N, C, H, W = x_shape
assert (H + 2 * padding - field_height) % stride == 0
assert (W + 2 * padding - field_height) % stride == 0
out_height = int((H + 2 * padding - field_height) / stride + 1)
out_width = int((W + 2 * padding - field_width) / stride + 1)
i0 = np.repeat(np.arange(field_height), field_width)
i0 = np.tile(i0, C)
i1 = stride * np.repeat(np.arange(out_height), out_width)
j0 = np.tile(np.arange(field_width), field_height * C)
j1 = stride * np.tile(np.arange(out_width), out_height)
i = i0.reshape(-1, 1) + i1.reshape(1, -1)
j = j0.reshape(-1, 1) + j1.reshape(1, -1)
k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1)
return k, i, j
def im2col_indices(x, field_height, field_width, padding=1, stride=1):
p = padding
x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding, stride)
cols = x_padded[:, k, i, j]
C = x.shape[1]
cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
return cols
def max_pool_forward_reshape(x, pool_param):
N, C, H, W = x.shape
pool_height, pool_width = pool_param['pool_height'], pool_param['pool_width']
stride = pool_param['stride']
assert pool_height == pool_width == stride, 'Invalid pool params'
assert H % pool_height == 0
assert W % pool_height == 0
x_reshaped = x.reshape(N, C, int(H / pool_height), pool_height, int(W / pool_width), pool_width)
out = x_reshaped.max(axis=3).max(axis=4)
return out
def max_pool_forward_fast(x, pool_param):
N, C, H, W = x.shape
pool_height, pool_width = pool_param['pool_height'], pool_param['pool_width']
stride = pool_param['stride']
same_size = pool_height == pool_width == stride
tiles = H % pool_height == 0 and W % pool_width == 0
if same_size and tiles:
out = max_pool_forward_reshape(x, pool_param)
else:
out = max_pool_forward_im2col(x, pool_param)
return out
def max_pool_forward_im2col(x, pool_param):
N, C, H, W = x.shape
pool_height, pool_width = pool_param['pool_height'], pool_param['pool_width']
stride = pool_param['stride']
assert (H - pool_height) % stride == 0, 'Invalid height'
assert (W - pool_width) % stride == 0, 'Invalid width'
out_height = int((H - pool_height) / stride + 1)
out_width = int((W - pool_width) / stride + 1)
x_split = x.reshape(N * C, 1, H, W)
x_cols = im2col_indices(x_split, pool_height, pool_width, padding=0, stride=stride)
x_cols_argmax = np.argmax(x_cols, axis=0)
x_cols_max = x_cols[x_cols_argmax, np.arange(x_cols.shape[1])]
out = x_cols_max.reshape(out_height, out_width, N, C).transpose(2, 3, 0, 1)
return out
def conv_forward_fast(x, w, b, pad=1, stride=1):
N, C, H, W = x.shape
# x = x.astype(np.int32)
w = w.astype(np.int32)
b = b.astype(np.int16)
num_filters, _, filter_height, filter_width = w.shape
out_height = int((H + 2 * pad - filter_height) / stride + 1)
out_width = int((W + 2 * pad - filter_width) / stride + 1)
out = np.zeros((N, num_filters, out_height, out_width), dtype=np.int32)
x_cols = im2col_indices(x, w.shape[2], w.shape[3], pad, stride)
res = w.reshape((w.shape[0], -1)).dot(x_cols) + b.reshape(-1, 1)
out = res.reshape(w.shape[0], out.shape[2], out.shape[3], x.shape[0])
out = out.transpose(3, 0, 1, 2)
return out
def spike_map_pack_to_bytes_array(spike_map, parallel):
buf_in = rearrange(spike_map, 't (c p) h w->t c h w p', p=parallel)
buf_in = np.packbits(buf_in.flatten(), bitorder='little')
return buf_in
def bytes_array_split_to_spike_map(buf_in, time_step, parallel, H, W):
unpacked = np.unpackbits(buf_in, bitorder='little')
unpacked = rearrange(unpacked, '(t c h w p)->t (c p) h w', t=time_step, p=parallel, h=H, w=W)
return unpacked
def preprocess(model_config_list, x, parallel):
time_step = model_config_list[1]["time_step"]
scale = model_config_list[0]["scale"]
zero_point = model_config_list[0]["zero_point"]
weight = model_config_list[1]["weight"]
bias = model_config_list[1]["bias"]
threshold = model_config_list[1]["threshold"]
encode_in = np_quantize_prepare(x, scale, zero_point)
encode_in = np.expand_dims(encode_in, axis=0)
firefly_in = direct_coding(encode_in, weight, bias, time_step, threshold)
packed = spike_map_pack_to_bytes_array(firefly_in, parallel)
return firefly_in, packed
def integrate_and_fire(y, threshold):
membrane = np.zeros(y.shape[1:], dtype=np.int32)
out_spike = []
for v in y:
membrane = membrane + v
o = membrane > threshold
out_spike.append(o)
membrane[o] = 0
return np.array(out_spike)
def direct_coding(x, w, b, time_step, threshold):
x = x.repeat(time_step, axis=0)
out_spike = conv_ifnode_forward(x, w, b, threshold)
return out_spike
def conv_ifnode_forward(x, w, b, threshold):
y = conv_forward_fast(x, w, b)
out_spike = integrate_and_fire(y, threshold)
return out_spike
def conv_ifnode_maxpool_forward(x, w, b, threshold):
y = conv_ifnode_forward(x, w, b, threshold)
out_spike = max_pool_forward_fast(y, {'pool_height': 2, 'pool_width': 2, 'stride': 2})
return out_spike
def linear_wta_forward(x, w, b):
x = x.astype(np.int32)
w = w.astype(np.int32)
b = b.astype(np.int32)
x = x.reshape([x.shape[0], -1])
x = np.pad(x, ((0, 0), (0, w.shape[1] - x.shape[1])), 'constant')
out = np.dot(x, w.T) + b
out_sum = out.sum(axis=0)
max_index = out_sum.argmax()
return out, max_index
def linear_ifnode_forward(x, w, b, threshold):
x = x.astype(np.int32)
w = w.astype(np.int32)
b = b.astype(np.int32)
x = x.reshape([x.shape[0], -1])
x = np.pad(x, ((0, 0), (0, w.shape[1] - x.shape[1])), 'constant')
out = np.dot(x, w.T) + b
out_spike = integrate_and_fire(out, threshold)
return out_spike
def pad_conv_weight_round_to_parallel(parallel, weight, pad_output_channel_only=False):
output_channel = weight.shape[0]
input_channel = weight.shape[1]
padded_output_channel = (parallel - (output_channel % parallel)) % parallel
padded_input_channel = (parallel - (input_channel % parallel)) % parallel
if pad_output_channel_only:
padded_input_channel = 0
new_weight = np.pad(weight, ((0, padded_output_channel), (0, padded_input_channel), (0, 0), (0, 0)), 'constant')
return new_weight
def pad_linear_weight_round_to_parallel(parallel, weight):
output_channel = weight.shape[0]
padded_output_channel = (parallel - (output_channel % parallel)) % parallel
new_weight = np.pad(weight, ((0, padded_output_channel), (0, 0)), 'constant')
return new_weight
def pad_linear_weight_round_to_factor(weight, factor):
input_channel = weight.shape[1]
round_channel = int(math.ceil(input_channel / factor) * factor)
padded_input_channel = round_channel - input_channel
new_weight = np.pad(weight, ((0, 0), (0, padded_input_channel)), 'constant')
return new_weight
def pad_bias_round_to_parallel(parallel, bias, pad_value=0):
channel = bias.shape[0]
padded_channel = (parallel - (channel % parallel)) % parallel
new_bias = np.pad(bias, (0, padded_channel), 'constant', constant_values=pad_value)
return new_bias
def np_quantize_per_tensor(x, scale, zero_point):
q_min = np.iinfo(np.int8).min
q_max = np.iinfo(np.int8).max
x = np.round(x / scale + zero_point)
x = np.clip(x, q_min, q_max)
return x.astype(np.int8)
def np_quantize_prepare(x, scale, zero_point):
x = np_quantize_per_tensor(x, scale, zero_point)
return x - zero_point
def conv_weight_channel_tiling(parallel, weight):
return rearrange(weight, '(o op) (i ip) kr kc -> (o i kr kc) ip op', op=parallel, ip=parallel)
def linear_weight_channel_tiling(parallel, weight):
return rearrange(weight, '(o op) (i ip) -> (o i) ip op', op=parallel, ip=parallel)
def conv_to_linear_weight_tiling(parallel, h, w, weight):
rearrange(weight, '(o op) (i ip h w)-> (o i h w) ip op', op=parallel, ip=parallel, h=h, w=w)
def init_input_buffer(input_spikes,
parallel=16,
stride_of_channel=8 * 1024,
stride_of_time_step=512 * 1024):
t, c, h, w = input_spikes.shape
input_spikes_rearrange = rearrange(input_spikes, 't (c p) h w -> t c (h w p)', p=parallel)
pack_spikes = np.packbits(input_spikes_rearrange, axis=-1, bitorder='little')
input_buffer = np.zeros(stride_of_time_step * t, dtype=np.uint8)
length = int(h * w * parallel / 8)
for i in range(t):
for j in range(int(c / parallel)):
addr = i * stride_of_time_step + j * stride_of_channel
input_buffer[addr:addr + length] = pack_spikes[i, j]
return input_buffer
def get_from_output_buffer(output_buffer,
t, c, h, w,
parallel=16,
stride_of_channel=8 * 1024,
stride_of_time_step=512 * 1024):
ret = []
length = int(h * w * parallel / 8)
for i in range(t):
for j in range(int(c / parallel)):
addr = i * stride_of_time_step + j * stride_of_channel
ret.append(output_buffer[addr:addr + length])
ret = np.array(ret)
ret = np.unpackbits(ret, axis=-1, bitorder='little')
ret = rearrange(ret, '(t c) (h w p) -> t (c p) h w', p=parallel, t=t, h=h, w=w)
return ret.astype(bool)
def get_output_index(buffer, parallel):
valid_data = buffer[:parallel]
if parallel == 16:
return np.unpackbits(valid_data[12:14], bitorder='little').argmax()
elif parallel == 32:
return np.unpackbits(valid_data[24:28], bitorder='little').argmax()
else:
return 0
def save_model_config_list(model_config_list, path):
np.save(path, model_config_list)
return
def load_model_config_list(path):
model_config_list = np.load(path, allow_pickle=True)
return model_config_list
================================================
FILE: examples/Hardware_acceleration/ultra96_test.py
================================================
from standalone_utils import *
from firefly_v1_schedule_on_pynq import *
from pynq import PL
from pynq import Overlay
from pynq import allocate
from pynq import MMIO
import numpy as np
import time
from einops import rearrange
ol = Overlay('firefly_v1_ultra96_bitstream/sys_wrapper.bit')
image = np.load("firefly_v1_cifar10_data/image.npy")
target = np.load("firefly_v1_cifar10_data/target.npy")
model_config_list = load_model_config_list("firefly_v1_cifar10_data/snn7_cifar10_x16.npy")
input_buffer = allocate(shape=(1<<23), dtype=np.uint8)
output_buffer = allocate(shape=(1<<23), dtype=np.uint8)
ctrl_io = MMIO(0x0400000000, 0x400)
schedule_list,output_addr=create_schedule(
model_config_list=model_config_list,
ctrl_io=ctrl_io,
allocate_method=allocate,
buffer_0=input_buffer,
buffer_1=output_buffer,
image_height=32,
image_width=32,
time_step=4,
parallel_channel=16
)
sche, ct_ptr, sche_len = init_firefly_c_lib_with_time("firefly_v1_common/firefly_v1_lib.so", schedule_list)
print(" ----------------- initialize finish!")
print(" ----------------- python schedule begin")
err_cnt = 0
for i in range(len(image)):
image_0 = image[i]
target_0 = target[i]
print(i)
start = time.time()
firefly_in, input_packed = preprocess(model_config_list, image_0, 16)
input_buffer[:input_packed.size]=input_packed
input_buffer.flush()
end = time.time()
elapsed = round((end - start) * 1000000)
# print("preprocess:", elapsed, "us")
input_buffer[:input_packed.size]=input_packed
input_buffer.flush()
start = time.time()
schedule_run_all(schedule_list)
end = time.time()
elapsed = round((end - start) * 1000000)
print("snn inference:", elapsed, "us")
output_buffer.invalidate()
test_out = output_buffer[:16]
test_index = np.unpackbits(test_out[12:14],bitorder='little').argmax()
print("test result:", test_index,"gold result", target_0)
if test_index != target_0:
err_cnt = err_cnt + 1
print("accuray: " , 1 - err_cnt/len(image))
print(" ----------------- python schedule finish")
print(" ----------------- c schedule begin")
err_cnt = 0
for i in range(len(image)):
image_0 = image[i]
target_0 = target[i]
print(i)
start = time.time()
firefly_in, input_packed = preprocess(model_config_list, image_0, 16)
input_buffer[:input_packed.size]=input_packed
input_buffer.flush()
end = time.time()
elapsed = round((end - start) * 1000000)
# print("preprocess:", elapsed, "us")
input_buffer[:input_packed.size]=input_packed
input_buffer.flush()
start = time.time()
sche(ct_ptr, sche_len)
end = time.time()
elapsed = round((end - start) * 1000000)
# print("clib inference call:", elapsed, "us")
output_buffer.invalidate()
test_out = output_buffer[:16]
test_index = np.unpackbits(test_out[12:14],bitorder='little').argmax()
print("test result:", test_index,"gold result", target_0)
if test_index != target_0:
err_cnt = err_cnt + 1
print("accuray: " , 1 - err_cnt/len(image))
print(" ----------------- c schedule finish")
================================================
FILE: examples/Hardware_acceleration/zcu104_test.py
================================================
from standalone_utils import *
from firefly_v1_schedule_on_pynq import *
from pynq import PL
from pynq import Overlay
from pynq import allocate
from pynq import MMIO
import numpy as np
import time
from einops import rearrange
ol = Overlay('firefly_v1_zcu104_bitstream/sys_wrapper.bit')
image = np.load("firefly_v1_cifar10_data/image.npy")
target = np.load("firefly_v1_cifar10_data/target.npy")
model_config_list = load_model_config_list("firefly_v1_cifar10_data/snn7_cifar10_x32.npy")
input_buffer = allocate(shape=(1<<23), dtype=np.uint8)
output_buffer = allocate(shape=(1<<23), dtype=np.uint8)
ctrl_io = MMIO(0x0400000000, 0x400)
schedule_list,output_addr=create_schedule(
model_config_list=model_config_list,
ctrl_io=ctrl_io,
allocate_method=allocate,
buffer_0=input_buffer,
buffer_1=output_buffer,
image_height=32,
image_width=32,
time_step=4,
parallel_channel=32
)
sche, ct_ptr, sche_len = init_firefly_c_lib_with_time("firefly_v1_common/firefly_v1_lib.so", schedule_list)
print(" ----------------- initialize finish!")
print(" ----------------- python schedule begin")
err_cnt = 0
for i in range(len(image)):
image_0 = image[i]
target_0 = target[i]
print(i)
start = time.time()
firefly_in, input_packed = preprocess(model_config_list, image_0, 32)
input_buffer[:input_packed.size]=input_packed
input_buffer.flush()
end = time.time()
elapsed = round((end - start) * 1000000)
# print("preprocess:", elapsed, "us")
input_buffer[:input_packed.size]=input_packed
input_buffer.flush()
start = time.time()
schedule_run_all(schedule_list)
end = time.time()
elapsed = round((end - start) * 1000000)
print("snn inference:", elapsed, "us")
output_buffer.invalidate()
test_out = output_buffer[:32]
test_index = np.unpackbits(test_out[24:28],bitorder='little').argmax()
print("test result:", test_index,"gold result", target_0)
if test_index != target_0:
err_cnt = err_cnt + 1
print("accuray: " , 1 - err_cnt/len(image))
print(" ----------------- python schedule finish")
print(" ----------------- c schedule begin")
err_cnt = 0
for i in range(len(image)):
image_0 = image[i]
target_0 = target[i]
print(i)
start = time.time()
firefly_in, input_packed = preprocess(model_config_list, image_0, 32)
input_buffer[:input_packed.size]=input_packed
input_buffer.flush()
end = time.time()
elapsed = round((end - start) * 1000000)
# print("preprocess:", elapsed, "us")
input_buffer[:input_packed.size]=input_packed
input_buffer.flush()
start = time.time()
sche(ct_ptr, sche_len)
end = time.time()
elapsed = round((end - start) * 1000000)
# print("clib inference call:", elapsed, "us")
output_buffer.invalidate()
test_out = output_buffer[:32]
test_index = np.unpackbits(test_out[24:28],bitorder='little').argmax()
print("test result:", test_index,"gold result", target_0)
if test_index != target_0:
err_cnt = err_cnt + 1
print("accuray: " , 1 - err_cnt/len(image))
print(" ----------------- c schedule finish")
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/CKRGSNN/README.md
================================================
# Commonsense Knowledge Representation SNN
(https://arxiv.org/abs/2207.05561)
This repository contains code from our paper [**Brain-inspired Graph Spiking Neural Networks for Commonsense Knowledge Representation and Reasoning**] preprint in: https://arxiv.org/abs/2207.05561 . If you use our code or refer to this project, please cite this paper.
## Requirments
* python=3.8
* numpy
* scipy
* turicreate
* pytorch >= 1.7.0
* torchvision
## Dataset
ConceptNet: https://github.com/commonsense/conceptnet5
## Run
```shell
python main.py
```
This module selects core knowledge in ConceptNet to form the sub_Concept.csv file as the input of the model. The input current, spike trains during the learning process and the network weight distribution after the learning are shown in the Results folder.
### Citation
If you find this package helpful, please consider citing the following papers:
```BibTex
@article{KRRfang2022,
title = {Brain-inspired Graph Spiking Neural Networks for Commonsense Knowledge Representation and Reasoning},
author = { Fang, Hongjian and Zeng, Yi and Tang, Jianbo and Wang, Yuwei and Liang, Yao and Liu, Xin},
journal = {arXiv preprint arXiv:2207.05561},
year = {2022}
}
@misc{https://doi.org/10.48550/arxiv.2207.08533,
doi = {10.48550/ARXIV.2207.08533},
url = {https://arxiv.org/abs/2207.08533},
author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},
title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},
publisher = {arXiv},
year = {2022},
}
```
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/CKRGSNN/main.py
================================================
import time
import numpy as np
import os
import warnings
import scipy.io as scio
import math
from matplotlib import pyplot as plt
import torch
from braincog.base.node.node import *
import turicreate as tc
from braincog.base.brainarea.BrainArea import *
from braincog.utils import *
warnings.filterwarnings('ignore')
np.set_printoptions(threshold=np.inf)
class CKRNet(BrainArea):
"""
Commonsense Knowledge Representation Net
"""
def __init__(self, w1, w2):
"""
"""
super().__init__()
self.node = [LIFNode(threshold=16, tau=15)]
self.connection = [CustomLinear(w1), CustomLinear(w2)]
self.stdp = []
self.stdp.append(MutliInputSTDP(self.node[0], [self.connection[0], self.connection[1]], decay=0.83))
self.x1 = torch.zeros(1, w2.shape[0])
def forward(self, x):
"""
x is spike train
"""
self.x1, dw1 = self.stdp[0](self.x1, x)
return self.x1, dw1
def reset(self):
self.x1 *= 0
def S_bound(S):
S[S > synapse_bound] = synapse_bound
S[S < -synapse_bound] = -synapse_bound
for i in range(N_entity):
temp1 = S[Index_E[i], :]
temp2 = temp1[:, Index_E[i]]
temp2[temp2 > inner_bound_E] = inner_bound_E
temp1[:, Index_E[i]] = temp2
S[Index_E[i], :] = temp1
for i in range(N_relation):
temp1 = S[Index_R[i], :]
temp2 = temp1[:, Index_R[i]]
temp2[temp2 > inner_bound_R] = inner_bound_R
temp1[:, Index_R[i]] = temp2
S[Index_R[i], :] = temp1
return S
if __name__ == "__main__":
print(os.getcwd())
KG = tc.SFrame.read_csv('./sub_Conceptnet.csv')
Set_R = set()
Set_E = set()
for i in range(KG.shape[0]):
Set_R.add(KG[i]['Relation'])
Set_E.add(KG[i]['Head'])
Set_E.add(KG[i]['Tail'])
List_E = sorted(Set_E)
List_R = list(Set_R)
List_R.sort()
# Network Parameter#dkenf.kejlklkelkvjlkxjel
I_syn = 5
tau_m = 30
I_t = 3 # Time duration of stimu current
I_P = 150 # Strength of input current
A_P = 0.009
certainty = 0.2
synapse_bound = 1 # The bound of all synapse
inner_bound_E = 0.6 # The bound of population inner synapse
inner_bound_R = 0.3 # The bound of population inner synapse
Ce = 20 # num of entity
Cr = 100 # num of relation
N_entity = len(List_E)
N_relation = len(List_R)
total_neurons = Ce * N_entity + Cr * N_relation
KG_No = KG.shape[0]
trail_time = 40
runtime = KG_No * trail_time
print('N_entity=', N_entity)
print('N_relation=', N_relation)
print('KG_No=', KG_No)
print('runtime=', runtime)
print('total_neurons=', total_neurons)
S = np.zeros((total_neurons, total_neurons), dtype=float) # Initial Weights
S = torch.tensor(S, dtype=torch.float32)
E = np.identity((total_neurons), dtype=float)
E = torch.tensor(E, dtype=torch.float32)
I_stimu = np.zeros((total_neurons, runtime))
ADJ = np.zeros((total_neurons, runtime)) # record the firing condition
Index_E = []
Index_R = []
for i in range(N_entity):
Index_E.append(np.arange(i * Ce, i * Ce + Ce))
for i in range(N_relation):
Index_R.append(np.arange(N_entity * Ce + i * Cr, N_entity * Ce + i * Cr + Cr))
for i in range(KG_No):
Head = KG[i]['Head']
Rela = KG[i]['Relation']
Tail = KG[i]['Tail']
Weig = KG[i]['Weight']
# print(List_E.index(Head))
# print(List_R.index(Rela))
# print(List_E.index(Rela))
# print(Index_R[List_R.index(Rela)])
I_stimu[Index_E[List_E.index(Head)], 10 + i * trail_time: 10 + I_t + i * trail_time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[Index_R[List_R.index(Rela)], 15 + i * trail_time: 15 + I_t + i * trail_time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)
I_stimu[Index_E[List_E.index(Tail)], 20 + i * trail_time: 20 + I_t + i * trail_time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
CKRGSNN = CKRNet(S, E)
for t in range(runtime):
I_input = torch.tensor(I_stimu[:, t].reshape(1, total_neurons), dtype=torch.float32)
x, dw = CKRGSNN(I_input)
S += A_P * dw[1]
S += S_bound(S) - S
ADJ[:, t] = x
print(t, 'step in >>', runtime)
img_I = plt.matshow(I_stimu)
plt.savefig("I_stimu1.jpg", dpi=500, bbox_inches='tight')
img_ADJ = plt.matshow(ADJ)
plt.savefig("ADJ1.jpg", dpi=500, bbox_inches='tight')
img_S = plt.matshow(S)
plt.colorbar()
plt.savefig("S1.jpg", dpi=500, bbox_inches='tight')
plt.show()
S = np.mat(S)
dataNew = './data_save.mat'
scio.savemat(dataNew, {'I_stimu': I_stimu, 'ADJ': ADJ, 'Weight': S})
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/CKRGSNN/sub_Conceptnet.csv
================================================
Relation,Head,Tail,Weight
antonym,ab_extra,ab_intra,1.0
antonym,ab_intra,ab_extra,1.0
antonym,abactinal,actinal,1.0
antonym,abandon,acquire,1.0
antonym,abandon,arrogate,1.0
antonym,abandon,embrace,1.0
antonym,abandon,engage,1.0
antonym,abandon,gain,1.0
antonym,abandon,join,1.0
antonym,abandon,maintain,1.0
antonym,abandon,retain,1.0
antonym,abandon,unite,1.0
atlocation,clock,department_store,1.0
atlocation,clock,desk,4.472
atlocation,clock,house,2.828
atlocation,clock,office,2.0
atlocation,crisps,table,1.0
atlocation,crisps,vending_machines,1.0
atlocation,crockery,cupboard,1.0
atlocation,crocs,united_states,0.5
atlocation,fungus,damp_spot,1.0
atlocation,fungus,damp_warm_place,1.0
atlocation,fungus,damp_wood,2.0
atlocation,fungus,damp_woods,1.0
atlocation,fungus,dank_place,1.0
atlocation,fungus,dark,1.0
atlocation,fungus,dark_and_dank_place,1.0
atlocation,fungus,dark_damp_area,2.828
atlocation,fungus,dark_damp_place,2.0
atlocation,carnival_rides,fairgrounds,1.0
atlocation,carousel,carnival,2.828
atlocation,carpet,at_hotel,1.0
capableof,adult,drive_car,1.0
capableof,adult,drive_train,1.0
capableof,adult,explain_rules_to_child,1.0
capableof,adult,feed_and_take_care_of_itself,1.0
capableof,adult,gift_knowledge_to_child,1.0
capableof,adult,hand_toy_to_child,1.0
capableof,adult,help_child,3.464
capableof,adult,keep_property,1.0
capableof,adults,act_like_infants,1.0
capableof,adults,care_for_babies,1.0
capableof,adults,carry_infants,1.0
capableof,adults,count,2.0
capableof,adults,demand_respect_from_children,1.0
capableof,adults,dress_themselves,2.828
capableof,adults,drink_beer,2.0
capableof,adults,drive_vehicle,1.0
capableof,adults,eat_sushi,2.0
capableof,adults,fail_to_pass_test,1.0
causes,competing_against,testing_yourself_against_another_person,1.0
causes,competing_against,try_hardest,1.0
causes,computing_sum,get_total_amount,1.0
causes,computing_sum,getting_answer,4.899
causes,computing_sum,getting_right_answer,1.0
causes,computing_sum,getting_total,1.0
causes,computing_sum,having_answer,2.0
causes,computing_sum,having_total,1.0
causes,computing_sum,headache,3.464
causes,computing_sum,insight,1.0
causes,computing_sum,know_total,2.0
causes,computing_sum,knowing_total,1.0
causes,computing_sum,may_get_wrong_answer,1.0
causes,computing_sum,number,1.0
causes,computing_sum,obtaining_total,1.0
causes,computing_sum,reaching_total,1.0
causes,computing_sum,receiving_total,1.0
desires,person,have_day_off,1.0
desires,person,have_diamonds,2.0
desires,person,have_easy,1.0
desires,person,have_enough_food,2.0
desires,person,have_enough_to_eat,1.0
desires,person,have_everyone_happy_with,2.0
desires,person,have_everything,1.0
desires,person,have_fast_internet_acess,1.0
desires,person,have_femily,1.0
desires,person,have_firm_body,1.0
desires,person,have_fond_memories,1.0
desires,person,have_fortune,1.0
desires,person,have_free_time,1.0
desires,person,have_friends,1.0
desires,person,have_fulfilling_life,1.0
desires,person,have_fun_in_life,2.0
desires,person,have_fun_on_weekends,1.0
desires,person,have_fun_weekend,2.0
desires,person,have_future,1.0
desires,person,have_good_bones,1.0
desires,person,have_good_day,2.828
desires,person,have_good_eyesight,1.0
desires,person,have_good_feelings,1.0
desires,person,have_good_friends,2.828
desires,person,have_good_life,1.0
desires,person,have_good_memories,2.0
desires,person,have_good_memory,1.0
desires,person,have_good_relationships_with_others,1.0
desires,person,have_good_skin,1.0
desires,person,have_great_sex,2.0
desires,person,have_happy_childhood,1.0
desires,person,have_happy_family,1.0
desires,person,have_healthy_life,1.0
desires,person,have_healthy_sex_life,1.0
desires,person,have_home_to_live_in,2.828
desires,person,have_hot_water,1.0
desires,person,have_influence,1.0
desires,person,have_inner_peace,2.828
desires,person,have_interesting_job,1.0
desires,person,have_large_vocabulary,1.0
desires,person,have_lasting_friendships,1.0
desires,person,have_long_live,2.0
desires,person,have_loving_family,1.0
desires,person,have_many_good_friends,2.0
desires,person,have_meaningful_life,2.0
desires,person,have_minimum_necessities_of_life,1.0
desires,person,have_money_to_buy_chocolate,2.0
desires,person,have_more,1.0
hascontext,bridge,card_games,1.0
hascontext,bridge,communication,1.0
hascontext,bridge,computing,1.0
hascontext,bridge,music,1.0
hascontext,bridge,wrestling,1.0
hascontext,bridge_and_tunnel,new_york_city,1.0
hascontext,bridge_and_tunnel,pejorative,1.0
hascontext,bridge_and_tunnel,slang,1.0
hassubevent,maintain_good_health,avoid_guns,1.0
hassubevent,maintain_good_health,avoid_hate,1.0
hassubevent,maintain_good_health,avoid_heavily_processed_foods,2.0
hassubevent,maintain_good_health,avoid_highly_processed_food,1.0
hassubevent,maintain_good_health,avoid_illegal_drugs,1.0
hassubevent,maintain_good_health,avoid_losing_sleep,1.0
hassubevent,maintain_good_health,avoid_marijuana,1.0
hassubevent,maintain_good_health,avoid_much_sunlight,1.0
hassubevent,maintain_good_health,avoid_poison,1.0
hassubevent,maintain_good_health,avoid_racism,1.0
hassubevent,maintain_good_health,avoid_smoking_marijuana,1.0
hassubevent,maintain_good_health,avoid_smoking_tobacco,1.0
hassubevent,maintain_good_health,avoid_tobacco,1.0
hassubevent,maintain_good_health,avoid_unpleasant_people,1.0
hassubevent,maintain_good_health,avoid_war,1.0
hassubevent,maintain_good_health,become_vegan,1.0
hassubevent,maintain_good_health,calm,1.0
hassubevent,maintain_good_health,care_about_people,1.0
hassubevent,maintain_good_health,chew_food_well,1.0
hassubevent,maintain_good_health,determine_nutritional_needs,1.0
hassubevent,maintain_good_health,do_exercise,1.0
hassubevent,maintain_good_health,dress_warmly_in_cold_weather,1.0
hassubevent,maintain_good_health,drink_very_little_alcoholic_beverages,1.0
hassubevent,maintain_good_health,drive_safely,1.0
hassubevent,maintain_good_health,eat_apple_day,1.0
hassubevent,maintain_good_health,eat_enough_and_exercise,1.0
hassubevent,maintain_good_health,eat_foods_containing_fiber,1.0
hassubevent,maintain_good_health,excersize,1.0
hassubevent,maintain_good_health,exercise,1.0
hassubevent,maintain_good_health,have_fullfilled_sex_life,1.0
hassubevent,maintain_good_health,have_spouse,1.0
hassubevent,maintain_good_health,leave_cat_alone,1.0
hassubevent,maintain_good_health,live_healthy_lifestyle,1.0
hassubevent,maintain_good_health,loving,1.0
hassubevent,maintain_good_health,monitor_health_often,1.0
hassubevent,maintain_good_health,not_eat_too_much_sna,1.0
isa,antidiarrheal,medicine,2.0
isa,antidiarrheal_therapy,drug_therapy,1.0
isa,antidiuretic,medicine,2.0
isa,antidiuretic_agent,medicine,1.0
isa,antido,artificial_language,2.0
isa,antidorcas,mammal_genus,2.0
isa,antidoron,food,0.5
isa,antidote,neutralizer,1.0
isa,antidote,remedy,2.0
isa,antiemetic,medicine,1.0
isa,antiemetic,medicine,2.0
isa,antiemetic_therapy,drug_therapy,1.0
isa,antiepileptic,anticonvulsant,1.0
isa,antiestablishmentarianism,doctrine,2.0
isa,antietam,national_cemetery_in_maryland,1.0
isa,antifeminist,bigot,2.0
isa,antiferromagnetism,magnetism,2.0
isa,antifibrinolytic_agent,medicine,1.0
isa,antiflatulent,agent,2.0
isa,antifouling_paint,paint,2.0
isa,antifreeze,automotive_fluid,1.0
isa,antifreeze,liquid,2.0
isa,antifungal,agent,2.0
isa,antigen,antigen,1.0
isa,antigen,carbohydrate,1.0
isa,antigen,tangible_thing,1.0
isa,antigen,substance,2.0
isa,antigenes,person,0.5
isa,antigenic_determinant,site,2.0
isa,antigone,play,0.5
isa,antigonia,fish_genus,2.0
isa,antigonia,fish,0.5
isa,antigorite,serpentine,1.0
isa,antigorite,serpentine,1.0
isa,gardens,often_in_yards,1.0
isa,gardens,places_where_plants_grow,1.0
isa,gardens,pleasant_outdoor_place,1.0
isa,gardens,pleasant_places_to,1.0
isa,garding,town,0.5
isa,gardnerian_wicca,wicca,1.0
isa,gardon,river,0.5
isa,garfield,fictional_character,0.5
isa,garfield,station,0.5
isa,garfield,station,0.5
isa,garfish,fish,0.5
isa,gargamel,comics_character,0.5
isa,garganey,bird,0.5
isa,garlic,food_ingredient,1.0
isa,garlic,flavorer,2.0
isa,garlic,alliaceous_plant,2.0
isa,tire_iron,rigid_portable_object,1.0
isa,tire_iron,shaped_thing,1.0
isa,tire_iron,hand_tool,2.0
isa,tire_iron,lever,2.0
isa,tire_pump,useful_tool,1.0
isa,tire_pump,mechanical_pump,1.0
isa,tire_rotation,axis_constrained_rotation,1.0
isa,tire_sealant,automotive_product,1.0
isa,tire_sealant,sealant,1.0
isa,garlic_chive,alliaceous_plant,2.0
isa,garlic_clove,bulb,1.0
partof,paragraph,textual_document,1.0
partof,paragraph,text,2.0
partof,victoria,zambezi,2.0
partof,victoria,zambia,2.0
partof,victoria,zimbabwe,2.0
partof,victoria_land,antarctica,2.0
partof,vidalia,georgia,2.0
partof,video,television,2.0
partof,vienna,austria,2.0
partof,vienne,poitou_charentes,0.5
partof,vienne,france,2.0
partof,vientiane,laos,2.0
partof,vieques,puerto_rico,2.0
partof,vietnam,indochina,2.0
partof,viewfinder,camera,1.0
partof,vigo,galicia,0.5
partof,vigo,vigo,0.5
partof,vigo,galicia,0.5
partof,villa,asturias,0.5
partof,villahermosa,tabasco,0.5
partof,villahermosa,mexico,2.0
partof,villarreal,valencian_community,0.5
partof,vilnius,dzūkija,0.5
partof,vilnius,lithuania,0.5
partof,vilnius,lithuania,2.0
partof,visual_purple,rod,2.0
partof,visual_signal,visual_communication,2.0
partof,viti_levu,fiji_islands,2.0
partof,viña_del_mar,valparaíso,0.5
partof,vladivostok,russia,2.0
partof,vocabulary,language,2.0
partof,vocal_cord,larynx,2.0
partof,voider,body_armor,2.0
partof,volapük,international_auxiliary_language,0.5
partof,volcanic_crater,volcano,2.0
partof,volcano,south_park,0.5
partof,volcano_islands,japan,2.0
partof,volcano_islands,pacific,2.0
partof,volga,russia,2.0
partof,volgograd,russia,2.0
partof,volkhov,russia,2.0
partof,parthenon,athens,2.0
receivesaction,carpet,found_on_ground,1.0
receivesaction,carpet,used_as_floor_covering,1.0
receivesaction,carpeted_floors,found_in_many_kinds_of_buildings,1.0
receivesaction,carpeting,used_in_place_of_hardwood_floors,1.0
receivesaction,carpets,bought_at_carpet_stores,1.0
receivesaction,cartoons,animated,1.0
receivesaction,case,tried_in_appeals_court,1.0
receivesaction,cases,heard_in_court_of_law,1.0
receivesaction,cash,denominated_in_dollars,1.0
receivesaction,cash,earned,1.0
receivesaction,cash,measured_in_dollars_and_cents,1.0
receivesaction,castanets,bound_together_with_leather_string,1.0
receivesaction,castanets,used_in_form_of_dance,1.0
receivesaction,casual_describes_clothing,worn_for_comfort_and_function,1.0
receivesaction,cat,attracted_to_parakeets,1.0
receivesaction,cats,thought_to_hate_dogs,1.0
receivesaction,cats_and_dogs,treated_badly,1.0
receivesaction,cats_purr_when,contented,1.0
receivesaction,cattle,fed_in_feed_lots,1.0
receivesaction,cauldron,steeped_in_magical_tradition_and_mystery,1.0
receivesaction,cauliflower_and_broccoli,combined_into_one_super_vegetable,1.0
receivesaction,cavitron,used_in_brain_surgery,1.0
receivesaction,cds,bought_in_stores,1.0
receivesaction,cds_usually,made_from_plastic,1.0
receivesaction,cedar,used_as_shingles_on_houses,1.0
receivesaction,ceilings,painted_with_brush,1.0
receivesaction,ceilings_have_color_which,painted_onto,1.0
receivesaction,celebrity,associated_with_autographs,1.0
receivesaction,celebrity,associated_with_desire,1.0
receivesaction,celebrity,associated_with_fame,1.0
receivesaction,celebrity,associated_with_fans,1.0
relatedto,penis,tarse,1.0
relatedto,penis_pump,cock_pump,1.0
relatedto,penis_worm,priapulid,1.0
relatedto,penised,bedicked,1.0
relatedto,penitence,compunction,2.0
relatedto,penitence,remorse,1.0
relatedto,penitence,repentance,1.0
relatedto,penitence,repentance,2.0
relatedto,penitent,penaunt,1.0
relatedto,penitentiary,penitential,2.0
relatedto,penitentiary,jail,1.0
relatedto,penrose_diagram,carter_penrose_diagram,1.0
relatedto,penrose_process,penrose_mechanism,1.0
relatedto,penrose_staircase,penrose_stairs,1.0
relatedto,penrose_staircase,penrose_steps,1.0
relatedto,penrose_stairs,penrose_staircase,1.0
relatedto,penrose_stairs,penrose_steps,1.0
relatedto,penrose_steps,penrose_staircase,1.0
relatedto,penrose_steps,penrose_stairs,1.0
relatedto,penrose_triangle,penrose_triangle,0.5
relatedto,pensacola,pensacola,0.5
relatedto,pension,pension,0.5
relatedto,pension,hotel,1.0
relatedto,penstemon_linarioides,narrow_leaf_penstemon,2.0
relatedto,penstemon_newberryi,mountain_pride,2.0
relatedto,penstemon_palmeri,balloon_flower,2.0
relatedto,penstemon_rupicola,rock_penstemon,2.0
relatedto,penstemon_serrulatus,cascade_penstemon,2.0
relatedto,penstock,sluice,2.0
relatedto,penstock,sluicegate,2.0
relatedto,pent,shut_up,2.0
relatedto,pent_up,repressed,2.0
relatedto,penta,quinque,1.0
relatedto,pentaborane,pentaborane,0.5
relatedto,pentabromodiphenyl_ether,pentabromodiphenyl_ether,0.5
relatedto,pentabromodiphenyl_ether,pentabromodiphenyl_oxide,1.0
relatedto,pentacene,pentacene,0.5
relatedto,pentachloronitrobenzene,pentachloronitrobenzene,0.5
relatedto,pentagon,pentagon,0.5
relatedto,pentagonal,pentangular,2.0
relatedto,pentagram,pentacle,1.0
relatedto,pentagram,pentalpha,1.0
relatedto,pentagram,pentangle,1.0
relatedto,pentagram,pentacle,2.0
relatedto,pentagraph,pentagraph,0.5
relatedto,pentail,pen_tailed_treeshrew,1.0
relatedto,pentalpha,pentagram,1.0
relatedto,pentalpha,pentangle,1.0
relatedto,pentamethylbenzene,pentamethylbenzene,0.5
relatedto,pentamethylenetetrazol,pentylenetetrazol,2.0
relatedto,pentamidine,pentamidine,0.5
relatedto,pentanal,pentanal,0.5
relatedto,pentanal,pentanaldehyde,1.0
relatedto,pentanal,valeraldehyde,1.0
relatedto,pentane,pentane,0.5
relatedto,pentangle,pentacle,2.0
relatedto,pentanoate,valerate,1.0
relatedto,pentanoic_acid,valeric_acid,2.0
relatedto,pentanol,amyl_alcohol,1.0
relatedto,pentanol,pentyl_alcohol,1.0
relatedto,pentastomid,tongue_worm,2.0
relatedto,pentastomida,pentastomida,0.5
relatedto,pentateuch,books_of_moses,1.0
relatedto,pentateuch,law,1.0
relatedto,pentateuch,torah,1.0
relatedto,pentateuch,torah,2.0
relatedto,pentatone,pentatonic_scale,2.0
relatedto,pentatonic_scale,pentatonic_scale,0.5
relatedto,pentazocine,pentazocine,0.5
relatedto,pentazole,pentazole,0.5
relatedto,pentecost,pentecost,0.5
relatedto,pentecost,feast_of_weeks,1.0
relatedto,pentecost,shavuos,1.0
relatedto,pentecost,shavuot,1.0
relatedto,pentecost,whit,1.0
relatedto,pentecost,whit_sunday,1.0
relatedto,pentecost,whitsun,1.0
relatedto,pentecost,whitsunday,1.0
relatedto,pentecost,shavous,2.0
relatedto,pentecostal,pentecostalist,2.0
relatedto,pentecostalism,pentecostalism,0.5
relatedto,pentel,pentel,0.5
relatedto,pentelic,pentelican,1.0
relatedto,pentene,pentene,0.5
usedfor,gourmet_shop,buy_weird_food,1.0
usedfor,gourmet_shop,buying_imported_foods,1.0
usedfor,gourmet_shop,customers_who_knowlegeable_about_cooking,1.0
usedfor,gourmet_shop,fine_foods,1.0
usedfor,gourmet_shop,hard_to_find_foods,1.0
usedfor,gourmet_shop,icky_foods,1.0
usedfor,lip,pleasure,1.0
usedfor,lip,pouring,1.0
usedfor,vessel,containing_highway_for_blood_flow,1.0
usedfor,vessel,float,1.0
usedfor,vessel,hold_flowers,1.0
usedfor,vessel,moving,1.0
usedfor,vessel,moving_people_around,1.0
usedfor,vessel,navigate,1.0
usedfor,vessel,piloting,1.0
usedfor,vessel,sailing_on_ocean,1.0
usedfor,vessel,ship_goods,1.0
usedfor,vessel,shipping,2.0
usedfor,vessel,store_liquid,1.0
usedfor,vessel,storing_liquids,2.0
usedfor,vessel,transporting_things,1.0
usedfor,vessel,usually_staying_above_water,1.0
usedfor,veterinarians,sick_animals,1.0
usedfor,vibrator,entertain_yourself,1.0
usedfor,vibrator,get_off,1.0
usedfor,vibrator,increase_pleasure_during_sex,1.0
usedfor,vibrator,sexually_stimulate_yourself_or_else,1.0
usedfor,viewing_video,having_fun,1.0
usedfor,viewing_video,having_good_time,1.0
usedfor,viewing_video,learning,2.828
usedfor,viewing_video,learning_language,1.0
usedfor,viewing_video,learning_new,1.0
usedfor,viewing_video,relaxation,1.0
usedfor,viewing_video,relaxing,2.0
usedfor,viewing_video,reviewing_video,1.0
usedfor,viewing_video,spending_time_with_grandson,1.0
usedfor,viewing_video,watching_film,1.0
usedfor,viewing_video,watching_home_movies,1.0
usedfor,viewing_video,watching_memories,1.0
usedfor,viewing_video,watching_movie_star,1.0
usedfor,village,learning,1.0
usedfor,village,people_to_live_in,1.0
usedfor,village,playing,1.0
usedfor,village,raise_child,1.0
usedfor,village,sedentary_living,1.0
usedfor,viola,play_song,1.0
usedfor,viola,playing,1.0
usedfor,viola,playing_music,2.0
usedfor,viola,playing_sissy_music,1.0
usedfor,viola,sing_song,1.0
usedfor,violence,kill,1.0
usedfor,violence,terrorism,1.0
usedfor,violin,create_music,1.0
usedfor,violin,creating,2.0
usedfor,violin,creating_art,1.0
usedfor,violin,entertaining,1.0
usedfor,violin,entertainment,1.0
usedfor,violin,fu_n,1.0
usedfor,violin,making_lovely_music,1.0
usedfor,violin,mkae_annoying_noises,1.0
usedfor,violin,music,1.0
usedfor,violin,play_music,2.828
usedfor,violin,playing,1.0
usedfor,violin,playing_music,5.657
usedfor,violin,playing_music_on_stringed_instrument,1.0
usedfor,visa_card,acquiring_debt,1.0
usedfor,visa_card,adults_not_kids,1.0
usedfor,visiting_museum,entertainment,1.0
usedfor,visiting_museum,feeling_young,1.0
usedfor,visiting_museum,finding_out_about_past,1.0
usedfor,visiting_museum,finding_out_about_world,1.0
usedfor,visiting_museum,fun,2.0
usedfor,visiting_museum,getting_ideas_from_past_experiences,1.0
usedfor,visiting_museum,having_fun,1.0
usedfor,visiting_museum,having_fun_day,1.0
usedfor,visiting_museum,learing_about_past,1.0
usedfor,visiting_museum,learning_about_culture,1.0
usedfor,visiting_museum,learning_about_history,1.0
usedfor,visiting_museum,learning_about_other_places,1.0
usedfor,visiting_museum,learning_history,2.0
usedfor,lip,protect,1.0
usedfor,lip,speak,2.0
usedfor,lip,suck,2.0
usedfor,lip,whistling,2.0
usedfor,lips,communicate,2.828
usedfor,lips,flapping,1.0
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/CRSNN/README.md
================================================
# Causal Reasoning SNN
(https://10.1109/IJCNN52387.2021.9534102)
This repository contains code from our paper [**A Brain-Inspired Causal Reasoning Model Based on Spiking Neural Networks
**] published in 2021 International Joint Conference on Neural Networks (IJCNN). https://ieeexplore.ieee.org/abstract/document/9534102. If you use our code or refer to this project, please cite this paper.
## Requirments
* numpy
* scipy
* pytorch >= 1.7.0
* torchvision
## Run
```shell
python main.py
```
This module builds an example of a brain-like causal inference spiking neural network model. The input causal graph is shown in figure causal_graph.png. The input current, spike trains during the learning process and the network weight distribution after the learning are shown in the Results folder.
### Citation
If you find this package helpful, please consider citing the following papers:
```BibTex
@inproceedings{fang2021CRSNN,
title={A Brain-Inspired Causal Reasoning Model Based on Spiking Neural Networks},
author={Fang, Hongjian and Zeng, Yi},
booktitle={2021 International Joint Conference on Neural Networks (IJCNN)},
pages={1--5},
year={2021},
organization={IEEE}
}
@misc{https://doi.org/10.48550/arxiv.2207.08533,
doi = {10.48550/ARXIV.2207.08533},
url = {https://arxiv.org/abs/2207.08533},
author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},
title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},
publisher = {arXiv},
year = {2022},
}
```
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/CRSNN/main.py
================================================
import time
import numpy as np
import os
import warnings
import math
from matplotlib import pyplot as plt
import torch
from braincog.base.node.node import *
from braincog.base.brainarea.BrainArea import *
from braincog.utils import *
warnings.filterwarnings('ignore')
np.set_printoptions(threshold=np.inf)
class CRNet(BrainArea):
"""
网络结构类:CRNet(Causal Reasoning Net),定义了网络的结构,继承自BrainArea基类。
:param threshold: 神经元发放脉冲需要达到的阈值
:param tau: 神经元膜电位常数,控制膜电位衰减
:param decay:STDP机制衰减常数,控制STDP机制作用强度随时间变化
:param w1:神经网络内部连接权重
:param w2:外部输入电流到每个神经元的连接
"""
def __init__(self, w1, w2):
"""
"""
super().__init__()
self.node = [LIFNode(threshold=16, tau=15)]
self.connection = [CustomLinear(w1), CustomLinear(w2)]
self.stdp = []
self.stdp.append(MutliInputSTDP(self.node[0], [self.connection[0], self.connection[1]], decay=0.8))
self.x1 = torch.zeros(1, w2.shape[0])
def forward(self, x):
"""
一次时间步的前向传播过冲函数,计算脉冲发放情况和权重改变量
:param x1:经过该时间步后的脉冲发放情况
:param dw1:STDP机制在一个时间步后带来的权重改变量
"""
self.x1, dw1 = self.stdp[0](x, self.x1)
return self.x1, dw1
def reset(self):
self.x1 *= 0
def S_bound(S):
"""
S_bound:网络权重边界控制函数,主要功能为控制全网络突触连接权不超过阈值,维持弱连接
另外,本函数还需要将神经元组内部权重控制在一定的范围之内,以防神经元组不断重复激活自身的情况发生。
:param synapse_bound: 全网络的突触连接的阈值,以维持网络整体为弱连接
:param inner_bound: 神经元组内部突触连接的阈值,防止神经元组不断重复激活自身导致网络放电紊乱
"""
S[S > synapse_bound] = synapse_bound
S[S < -synapse_bound] = -synapse_bound
temp1 = S[E1_index, :]
temp2 = temp1[:, E1_index]
temp2[temp2 > inner_bound_E] = inner_bound_E
temp1[:, E1_index] = temp2
S[E1_index, :] = temp1
temp1 = S[E2_index, :]
temp2 = temp1[:, E2_index]
temp2[temp2 > inner_bound_E] = inner_bound_E
temp1[:, E2_index] = temp2
S[E2_index, :] = temp1
temp1 = S[E3_index, :]
temp2 = temp1[:, E3_index]
temp2[temp2 > inner_bound_E] = inner_bound_E
temp1[:, E3_index] = temp2
S[E3_index, :] = temp1
temp1 = S[E4_index, :]
temp2 = temp1[:, E4_index]
temp2[temp2 > inner_bound_E] = inner_bound_E
temp1[:, E4_index] = temp2
S[E4_index, :] = temp1
temp1 = S[E4_index, :]
temp2 = temp1[:, E4_index]
temp2[temp2 > inner_bound_E] = inner_bound_E
temp1[:, E4_index] = temp2
S[E4_index, :] = temp1
temp1 = S[E5_index, :]
temp2 = temp1[:, E5_index]
temp2[temp2 > inner_bound_E] = inner_bound_E
temp1[:, E5_index] = temp2
S[E5_index, :] = temp1
temp1 = S[R1_index, :]
temp2 = temp1[:, R1_index]
temp2[temp2 > inner_bound_R] = inner_bound_R
temp1[:, R1_index] = temp2
S[R1_index, :] = temp1
temp1 = S[R2_index, :]
temp2 = temp1[:, R2_index]
temp2[temp2 > inner_bound_R] = inner_bound_R
temp1[:, R2_index] = temp2
S[R2_index, :] = temp1
return S
if __name__ == "__main__":
# Neurons Parameter
Cr = 200 # num of relation
Ce = 50 # num of entity
total_time = 2500 # Runtime in ms
tau = 100 # time constant of STDP
stdpwin = 25 # STDP windows in ms
thresh = 30 # Judge if the neurons fire or not
abs_T = 25 # The length of the ABS
Reset = 0 # Reset Potential
I_syn = 5
tau_m = 30
Rm = 10
N_entity = 5
N_relation = 2
I_t = 5 # Duration of Current
I_P = 25 # Strength of input current
certainty = 0.5
A_P = 0.01
synapse_bound = 0.2 # The bound of all synapse
inner_bound_E = 0.08 # The bound of population inner synapse
inner_bound_R = 0.06 # The bound of population inner synapse
total_neurons = Ce * N_entity + Cr * N_relation
"""
SPSNN主函数,实现网络核心主要功能
:param Cr: 因果图中节点神经元组中神经元数量
:param Ce: 因果图中因果关系神经元组中神经元数量
:param total_time: 网络总体模拟的时间步长
:param learning_times: 网络进行序列学习的次数
:param N_entity: 对网络添加外部输入电流的时间长度
:param N_relation: 对网络添加外部输入电流的强度
:param A_P: 网络在进行STDP学习后突触改变放缩量
:param certainty: 网络输入电流大小的确定度
:param total_neurons: 网络神经元总量
:param ADJ: 网络中脉冲放电情况矩阵
:param I_stimu: 网络中外部输入电流矩阵
:param S: 网络突触连接权重矩阵
:param E: 单位矩阵,用以对每个神经元引入外部电流
"""
# Initial Neurual Network
E1_index = np.linspace(0, Ce - 1, Ce, dtype=int)
E2_index = np.linspace(Ce, 2 * Ce - 1, Ce, dtype=int)
E3_index = np.linspace(2 * Ce, 3 * Ce - 1, Ce, dtype=int)
E4_index = np.linspace(3 * Ce, 4 * Ce - 1, Ce, dtype=int)
E5_index = np.linspace(4 * Ce, 5 * Ce - 1, Ce, dtype=int)
R1_index = np.linspace(5 * Ce, 5 * Ce + Cr - 1, Cr, dtype=int)
R2_index = np.linspace(5 * Ce + Cr, 5 * Ce + 2 * Cr - 1, Cr, dtype=int)
Ne = total_neurons
v = Reset * np.zeros(Ne)
firings = [] # spike timings
ADJ = np.zeros((total_time, Ne)) # record the firing condition
abs_Ne = np.zeros(Ne) # maintain the ABS of every neurons
I_stimu = np.zeros((Ne, total_time), dtype=float)
If_Memory = np.zeros((total_neurons), dtype=bool)
If_Memory[:] = True
# Pre-set synapses
S = np.zeros((total_neurons, total_neurons), dtype=float) # Initial Weights
S = S - np.diag(S) # Set the diag num to
E = np.identity((total_neurons), dtype=float)
W_set_innner = 0.7
temp = S[E1_index, :]
temp[:, E1_index] = W_set_innner * np.random.rand(Ce, Ce)
S[E1_index, :] = temp
temp = S[E2_index, :]
temp[:, E2_index] = W_set_innner * np.random.rand(Ce, Ce)
S[E2_index, :] = temp
temp = S[E3_index, :]
temp[:, E3_index] = W_set_innner * np.random.rand(Ce, Ce)
S[E3_index, :] = temp
temp = S[E4_index, :]
temp[:, E4_index] = W_set_innner * np.random.rand(Ce, Ce)
S[E4_index, :] = temp
temp = S[E5_index, :]
temp[:, E5_index] = W_set_innner * np.random.rand(Ce, Ce)
S[E5_index, :] = temp
temp = S[R1_index, :]
temp[:, R1_index] = 0.25 * W_set_innner * np.random.rand(Cr, Cr)
S[R1_index, :] = temp
temp = S[R2_index, :]
temp[:, R2_index] = 0.25 * W_set_innner * np.random.rand(Cr, Cr)
S[R2_index, :] = temp
"""
对于因果图中的因果关系,给予网络不同神经元组输入电流刺激,使其建立连接
"""
i = 1
time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E1_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E1_index, :] = temp
time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[R1_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)
I_stimu[R1_index, :] = temp
time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E2_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E2_index, :] = temp
i = 3
time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E2_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E2_index, :] = temp
time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[R2_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)
I_stimu[R2_index, :] = temp
time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E1_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E1_index, :] = temp
i = 6
time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E2_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E2_index, :] = temp
time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[R1_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)
I_stimu[R1_index, :] = temp
time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E3_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E3_index, :] = temp
i = 8
time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E3_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E3_index, :] = temp
time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[R2_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)
I_stimu[R2_index, :] = temp
time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E2_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E2_index, :] = temp
i = 11
time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E2_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E2_index, :] = temp
time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[R1_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)
I_stimu[R1_index, :] = temp
time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E4_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E4_index, :] = temp
i = 13
time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E4_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E4_index, :] = temp
time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[R2_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)
I_stimu[R2_index, :] = temp
time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E2_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E2_index, :] = temp
i = 16
time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E3_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E3_index, :] = temp
time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[R1_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)
I_stimu[R1_index, :] = temp
time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E5_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E5_index, :] = temp
i = 18
time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E5_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E5_index, :] = temp
time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[R2_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)
I_stimu[R2_index, :] = temp
time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E3_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E3_index, :] = temp
i = 21
time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E4_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E4_index, :] = temp
time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[R1_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)
I_stimu[R1_index, :] = temp
time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E5_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E5_index, :] = temp
i = 23
time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E5_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E5_index, :] = temp
time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[R2_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)
I_stimu[R2_index, :] = temp
time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)
temp = I_stimu[E4_index, :]
temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)
I_stimu[E4_index, :] = temp
S = torch.tensor(S, dtype=torch.float32)
E = torch.tensor(E, dtype=torch.float32)
CRSNN = CRNet(S, E)
for t in range(total_time):
I_input = torch.tensor(I_stimu[:, t].reshape(1, total_neurons), dtype=torch.float32)
x, dw = CRSNN(I_input)
S += A_P * dw[1]
S += S_bound(S) - S
ADJ[t] = x
plt.matshow(I_stimu)
plt.matshow(ADJ.transpose())
plt.matshow(S)
plt.colorbar()
plt.show()
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/SPSNN/README.md
================================================
# Sequence Production SNN
This repository contains code from our paper [**Brain Inspired Sequences Production by Spiking Neural Networks With Reward-Modulated STDP**] published in Frontiers in Computational Neuroscience. This module builds a Sequence Production spiking neural network model, realizing the memory and reconstruction functions for arbitrary symbol sequences. The input causal graph is shown in figure causal_graph.png. The input current, spike trains during the learning process and the network weight distribution after the learning are shown in the Results folder.
## Requirments
* numpy
* scipy
* pytorch >= 1.7.0
* torchvision
## Run
```shell
python main.py file
```
### Citation
If you find this package helpful, please consider citing the following papers:
```BibTex
@article{fang2021spsnn,
title = {Brain inspired sequences production by spiking neural networks with reward-modulated stdp},
author = {Fang, Hongjian and Zeng, Yi and Zhao, Feifei},
journal = {Frontiers in Computational Neuroscience},
volume = {15},
pages = {8},
year = {2021},
publisher = {Frontiers}
}
@misc{https://doi.org/10.48550/arxiv.2207.08533,
doi = {10.48550/ARXIV.2207.08533},
url = {https://arxiv.org/abs/2207.08533},
author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},
title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},
publisher = {arXiv},
year = {2022},
}
```
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/SPSNN/main.py
================================================
import time
import numpy as np
import os
import warnings
import math
from matplotlib import pyplot as plt
import torch
from braincog.base.node.node import *
from braincog.base.brainarea.BrainArea import *
from braincog.utils import *
warnings.filterwarnings('ignore')
np.set_printoptions(threshold=np.inf)
class SPNet(BrainArea):
"""
网络结构类:SPNet(Sequence Production Net),定义了网络的结构,继承自BrainArea基类。
:param threshold: 神经元发放脉冲需要达到的阈值
:param tau: 神经元膜电位常数,控制膜电位衰减
:param decay:STDP机制衰减常数,控制STDP机制作用强度随时间变化
:param w1:神经网络内部连接权重
:param w2:外部输入电流到每个神经元的连接
"""
def __init__(self, w1, w2 ):
"""
"""
super().__init__()
self.node = [LIFNode(threshold=10,tau=15) ]
self.connection = [CustomLinear(w1), CustomLinear(w2) ]
self.stdp = []
self.stdp.append(MutliInputSTDP(self.node[0], [self.connection[0], self.connection[1]],decay=0.745))
self.x1 = torch.zeros(1, w2.shape[0])
def forward(self, x):
"""
一次时间步的前向传播过冲函数,计算脉冲发放情况和权重改变量
:param x1:经过该时间步后的脉冲发放情况
:param dw1:STDP机制在一个时间步后带来的权重改变量
"""
self.x1, dw1 = self.stdp[0]( self.x1,x)
return self.x1, dw1
def reset(self):
self.x1 *= 0
def S_bound(S):
"""
S_bound:网络权重边界控制函数,主要功能为控制全网络突触连接权不超过阈值,维持弱连接
另外,本函数还需要将神经元组内部权重控制在一定的范围之内,以防神经元组不断重复激活自身的情况发生。
:param synapse_bound: 全网络的突触连接的阈值,以维持网络整体为弱连接
:param inner_bound: 神经元组内部突触连接的阈值,防止神经元组不断重复激活自身导致网络放电紊乱
"""
S[S > synapse_bound] = synapse_bound
S[S < -synapse_bound] = -synapse_bound
temp1 = S[l1_stimu,:]
temp2 = temp1[:, l1_stimu]
temp2 [temp2>inner_bound] = inner_bound
temp1[:, l1_stimu] = temp2
S[l1_stimu, :] = temp1
temp1 = S[l2_stimu, :]
temp2 = temp1[:, l2_stimu]
temp2[temp2 > inner_bound] = inner_bound
temp1[:, l2_stimu] = temp2
S[l2_stimu, :] = temp1
temp1 = S[l3_stimu, :]
temp2 = temp1[:, l3_stimu]
temp2[temp2 > inner_bound] = inner_bound
temp1[:, l3_stimu] = temp2
S[l3_stimu, :] = temp1
temp1 = S[l4_stimu, :]
temp2 = temp1[:, l4_stimu]
temp2[temp2 > inner_bound] = inner_bound
temp1[:, l4_stimu] = temp2
S[l4_stimu, :] = temp1
return S
if __name__ == "__main__":
## Neurons Parameter
C = 50; # constant:the number of neurons of a symbol
runtime = 1000; # Runtime in ms
thresh = 30; # Judge if the neurons fire or not
I_syn = 5;
tau_m = 30;
Rm = 10;
learning_times = 3;
Sym_size = 6
I_t = 5 # Time duration of stimu current
I_P = 130 # Strength of input current
A_P = 0.024
certainty = 0.35
synapse_bound = 10 # The bound of all synapse
inner_bound = 1 # The bound of population inner synapse
"""
SPSNN主函数,实现网络核心主要功能
:param C: 神经元组中神经元数量
:param runtime: 网络总体模拟的时间步长
:param learning_times: 网络进行序列学习的次数
:param I_t: 对网络添加外部输入电流的时间长度
:param I_P: 对网络添加外部输入电流的强度
:param A_P: 网络在进行STDP学习后突触改变放缩量
:param certainty: 网络输入电流大小的确定度
:param total_neurons: 网络神经元总量
:param ADJ: 网络中脉冲放电情况矩阵
:param I_stimu: 网络中外部输入电流矩阵
:param S: 网络突触连接权重矩阵
:param E: 单位矩阵,用以对每个神经元引入外部电流
"""
# Initial Neurual Network
Net1 = [C,Sym_size*C,Sym_size*C,Sym_size*C,1] #memory
Net2 = [Sym_size,Sym_size,Sym_size] #action
current_end = 0
total_neurons = int(sum(Net1)+sum(Net2))
index1 = np.linspace(1,sum(Net1),sum(Net1),dtype=int)-1
index2 = np.linspace (sum(Net1)+1,total_neurons,sum(Net2),dtype=int)-1
Ne = total_neurons
firings= [] # spike timings
ADJ = np.zeros((runtime,Ne)) # record the firing condition
abs_Ne = np.zeros(Ne) # maintain the ABS of every neurons
I_stimu = np.zeros((Ne,runtime));
P = np.zeros((runtime,3)); # potential of neuron 1
# logical vector to differ if the neuron is belong to memory part
If_Memory = np.zeros((total_neurons),dtype=bool)
If_Memory[index1[:]] = True
# logical vector to differ if the neuron is belong to action part
If_Action = np.zeros((total_neurons),dtype=bool)
If_Action[index2[:]] = True
# Pre-set synapses
S = np.zeros((total_neurons,total_neurons),dtype=float) # Initial Weights
S = S - np.diag(S) # Set the diag num to
E = np.identity((total_neurons),dtype=float)
W_r2a = 0.3
# Memory to Action
for i in range (C,sum(Net1)-2 ):
S[ int(index2[int(i/C)-1]), i] =W_r2a
# Learning Process
I_stimu = np.zeros((Ne,runtime))
seq = np.array([6,3,4])
l1_stimu = np.arange (0,C)
l2_stimu = np.arange(C+(seq[0]-1)*C,C+(seq[0])*C)
l3_stimu = np.arange(C+ 6*C+(seq[1]-1)*C,C+ 6*C+(seq[1])*C)
l4_stimu = np.arange(C+12*C+(seq[2]-1)*C,C+12*C+(seq[2])*C)
l5_stimu = Ne-1
np.linspace(20 + i * 100 ,20+I_t+i*100-1,I_t ,dtype=int )
for i in range(learning_times):
"""
对网络添加输入电流
"""
temp = I_stimu [l1_stimu,:]
I_stimu [l1_stimu, 10 + i * 100 :10+I_t+i*100] = certainty*I_P + I_P * np.random.rand(C,I_t)
I_stimu [l2_stimu, 25 + i * 100 :25+I_t+i*100] = certainty*I_P + I_P * np.random.rand(C,I_t)
I_stimu [l3_stimu, 40 + i * 100 :40+I_t+i*100] = certainty*I_P + I_P * np.random.rand(C,I_t)
I_stimu [l4_stimu, 55 + i * 100 :55+I_t+i*100] = certainty*I_P + I_P * np.random.rand(C,I_t)
I_stimu [l5_stimu, 70 + i * 100 :70+I_t+i*100] = certainty*I_P + I_P * np.random.rand(1,I_t)
I_stimu[l1_stimu,700:700+I_t] = I_P * np.random.rand(C,I_t)
S = torch.tensor(S,dtype=torch.float32)
E = torch.tensor(E,dtype=torch.float32)
SPSNN = SPNet(S,E)
for t in range (runtime):
I_input = torch.tensor( I_stimu[:,t].reshape(1,total_neurons),dtype=torch.float32)
x,dw = SPSNN( I_input )
S += A_P*dw[1]
S += S_bound(S) - S
ADJ[t] = x
plt.matshow(I_stimu)
plt.matshow(ADJ)
plt.matshow(S)
plt.colorbar()
plt.show()
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Areas/apac.py
================================================
'''
Created on 2016.7.7
@author: liangqian
'''
from Modal.note import Note
from Modal.cluster import Cluster
from conf.conf import configs
from Modal.pitch import Pitch
class APAC():
'''
anterior primary auditory cortex,encoding the musical notes
'''
def __init__(self):
'''
Constructor
'''
self.notes = []
#self.cluster = Cluster()
def encodingNote(self,NoteID):
NoteName = configs.notesMap.get(int(NoteID))
n = Pitch()
n.name = NoteName
n.frequence = int(NoteID)
self.notes.append(n)
return n
def encodingMIDINote(self,p):
NoteName = configs.notesMap.get(p.frequence)
p.name = NoteName
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Areas/cortex.py
================================================
'''
Created on 2016.7.6
@author: liangqian
'''
from Areas.pfc import PFC
from Areas.pac import PAC
from conf.conf import *
from Modal.synapse import Synapse
import random
import numpy as np
from Areas.apac import APAC
from Areas.pac import Music_Sequence_Mem
class Cortex():
'''
This class is used to control areas in the cortex, just cortex controlling
'''
def __init__(self, neutype, dt):
self.neutype = neutype
self.msm = Music_Sequence_Mem(neutype)
self.pfc = PFC(self.neutype)
self.dt = dt
def addSubGoalToPFC(self, goalname):
self.pfc.addNewSubGoal(goalname)
''' tt = np.arange(0, 5, self.dt)
for t in tt:
self.pfc.doRemebering(goalname, self.dt, t)'''
def addComposerToPFC(self, composername):
self.pfc.addNewComposer(composername)
'''tt = np.arange(0, 5, self.dt)
for t in tt:
self.pfc.doRememberingComposer(composername, self.dt, t)'''
def addGenreToPFC(self, genrename):
self.pfc.addNewGenre(genrename)
'''tt = np.arange(0, 5, self.dt)
for t in tt:
self.pfc.doRememberingGenre(genrename, self.dt, t)'''
def musicSequenceMemroyInit(self):
self.msm.createActionSequenceMem(1, self.neutype)
def rememberANote(self, goalname, noteName, order):
note = self.apac.encodingNote(noteName)
if (order > len(self.msm.sequenceLayers.get(1).groups)):
self.msm.sequenceLayers.get(1).addNewGroups(GroupID=order, layerID=1, neunum=128)
tt = np.arange((order - 1) * 5, order * 5, self.dt)
for t in tt:
self.ips.doRemebering(goalname, self.dt, t)
self.msm.doRemembering_note_only(note, order, self.dt, t)
self.msm.doConnectToGoal(self.ips.goals.groups.get(goalname), order)
dic = {}
if (configs.flag_experiments == False):
dic["GoalSpike"] = self.ips.goals.groups.get(goalname).writeSpikeInfoToJson()
dic["MSMSpike"] = self.msm.sequenceLayers.get(1).groups.get(order).writeSpikeInfoToJson()
dic["Neuron"] = self.msm.sequenceLayers.get(1).groups.get(order).writeSelfInfoToJson("MSM")
dic["GroupNum"] = order
return dic
def connectKeyAndNotesUsingKSModel(self,keys, Notes):
noteNeurons = Notes.neurons
for tone in keys.groups.values():
tneurons = tone.neurons
for nindex, noteneu in enumerate(noteNeurons[1:]):
i = nindex%12
syn = Synapse(tneurons[i],noteneu) # from tone to pitch
if(tneurons[i].importance < 0):
syn.excitability = 0
syn.weight = -100 # 这个地方应该用KS pitch profile,稍等下再改
else:
syn.excitability = 1
syn.weight = 20+random.uniform(20,30)
syn.type = 3
noteneu.synapses.append(syn)
noteneu.pre_neurons.append(tneurons[i])
# syn1 = Synapse(noteneu, tneurons[i]) # from pitch to tone
# syn1.excitability = 1
# syn1.type = 3
# tneurons[i].synapses.append(syn1)
# tneurons[i].pre_neurons.append(noteneu)
def connectKeyAndNotes(self,keys, Notes):
noteNeurons = Notes.neurons
for tone in keys.groups.values():
tneurons = tone.neurons
for nindex, noteneu in enumerate(noteNeurons[1:]):
i = nindex%12
syn = Synapse(tneurons[i],noteneu)
if(tneurons[i].importance < 0):
syn.excitability = 0
syn.weight = -100
else:
syn.excitability = 1
syn.weight = 20+random.uniform(20,30)
syn.type = 3
noteneu.synapses.append(syn)
noteneu.pre_neurons.append(tneurons[i])
def connectKeyAndNotesUsingKSModel(self,keys, Notes):
noteNeurons = Notes.neurons
for tone in keys.groups.values():
tneurons = tone.neurons
for nindex, noteneu in enumerate(noteNeurons[1:]):
i = nindex%12
syn = Synapse(tneurons[i],noteneu) # from tone to pitch
if(tneurons[i].importance < 0):
syn.excitability = 0
syn.weight = -100 # 这个地方应该用KS pitch profile,稍等下再改
else:
syn.excitability = 1
syn.weight = 20+random.uniform(20,30)
syn.type = 3
noteneu.synapses.append(syn)
noteneu.pre_neurons.append(tneurons[i])
# syn1 = Synapse(noteneu, tneurons[i]) # from pitch to tone
# syn1.excitability = 1
# syn1.type = 3
# tneurons[i].synapses.append(syn1)
# tneurons[i].pre_neurons.append(noteneu)
def rememberANoteWithKnowledge(self,goalname, composername, genrename, emo, keyName, trackIndex, noteIndex, tinterval,order, xmldata):
instrumentTrack = self.msm.sequenceLayers.get(trackIndex)
if (order > len(instrumentTrack.get("N").groups)):
instrumentTrack.get("N").addNewGroups(GroupID=order, layerID=trackIndex, neunum=129)
instrumentTrack.get("T").addNewGroups(GroupID=order, layerID=trackIndex, neunum=64)
self.connectKeyAndNotesUsingKSModel(self.pfc.keys,instrumentTrack.get("N").groups.get(order))# prior knowledge
tt = np.arange((order - 1) * 5, order * 5, self.dt)
for t in tt:
self.pfc.doRemebering(goalname, self.dt, t)
self.pfc.doRememberingComposer(composername, self.dt, t)
self.pfc.doRememberingGenre(genrename, self.dt, t)
self.pfc.doRememberingKey(keyName,self.dt,t)
self.pfc.doRememberingMode(keyName%2, self.dt,t)
self.msm.doRemembering(trackIndex, noteIndex, order, self.dt, t, tinterval)
self.msm.doConnectToTitle(self.pfc.goals.groups.get(goalname), instrumentTrack, order)
self.msm.doConnectToComposer(self.pfc.composers.groups.get(composername), instrumentTrack, order)
self.msm.doConnectToGenre(self.pfc.genres.groups.get(genrename), instrumentTrack, order)
self.msm.doConnectToKey(self.pfc.keys.groups.get(keyName), instrumentTrack, order, noteIndex)
self.msm.doConnectToMode(self.pfc.modes.groups.get(keyName%2), keyName, instrumentTrack, order, noteIndex)
dic = {}
dic["GoalSpike"] = self.pfc.goals.groups.get(goalname).writeSpikeInfoToJson()
dic["ComposerSpike"] = self.pfc.composers.groups.get(composername).writeSpikeInfoToJson()
#dic["EmotionSpike"] = self.amy.emos.groups.get(emo).writeSpikeInfoToJson()
dic["KeySpike"] = self.pfc.keys.groups.get(keyName).writeSpikeInfoToJson()
dic["ModeSpike"] = self.pfc.modes.groups.get(keyName%2).writeSpikeInfoToJson()
dic["MSMNSpike"] = instrumentTrack.get("N").groups.get(order).writeSpikeInfoToJson()
dic["MSMTSpike"] = instrumentTrack.get("T").groups.get(order).writeSpikeInfoToJson()
return(dic)
def rememberANoteandTempo(self, goalname, composername, genrename, trackIndex, noteIndex, order, tinterval):
instrumentTrack = self.msm.sequenceLayers.get(trackIndex)
if (order > len(instrumentTrack.get("N").groups)):
instrumentTrack.get("N").addNewGroups(GroupID=order, layerID=trackIndex, neunum=129)
instrumentTrack.get("T").addNewGroups(GroupID=order, layerID=trackIndex, neunum=64)
tt = np.arange((order - 1) * 5, order * 5, self.dt)
for t in tt:
self.pfc.doRemebering(goalname, self.dt, t)
self.pfc.doRememberingComposer(composername, self.dt, t)
self.pfc.doRememberingGenre(genrename, self.dt, t)
self.msm.doRemembering(trackIndex, noteIndex, order, self.dt, t, tinterval)
self.msm.doConnectToTitle(self.pfc.goals.groups.get(goalname), instrumentTrack, order)
self.msm.doConnectToComposer(self.pfc.composers.groups.get(composername), instrumentTrack, order)
self.msm.doConnectToGenre(self.pfc.genres.groups.get(genrename), instrumentTrack, order)
dic = {}
ngraph = {}
if (configs.RunTimeState == 1):
dic["GoalSpike"] = self.pfc.goals.groups.get(goalname).writeSpikeInfoToJson()
dic["ComposerSpike"] = self.pfc.composers.groups.get(composername).writeSpikeInfoToJson()
dic["MSMSpike"] = instrumentTrack.get("N").groups.get(order).writeSpikeInfoToJson()
dic["MSMTSpike"] = instrumentTrack.get("T").groups.get(order).writeSpikeInfoToJson()
temp = {}
temp[1] = instrumentTrack.get("N").groups.get(order).writeSelfInfoToJson("NMSM")
temp[2] = instrumentTrack.get("T").groups.get(order).writeSelfInfoToJson("TMSM")
# dic["GroupNum"] = order
Nodes = []
Edges = []
for key, td in temp.items():
nlist = td.get("Neuron")
for n in nlist:
# if(len(n.get('synapses')) > 0):
d = {}
d['id'] = n.get('area') + '_' + str(n.get('TrackID')) + '_' + str(n.get('GroupID')) + '_' + str(
n.get('Index'))
d['area'] = n.get('area')
if (n.get('area') == 'NMSM'):
d['label'] = configs.notesMap.get(n.get('Index') - 2)
else:
d['label'] = str(n.get('Index') * 60)
Nodes.append(d)
synlist = n.get('synapses')
for syn in synlist:
e = {}
# e['id'] = syn.get('Sarea') + '_'+str(syn.get('SgroupID'))+'_'+str(syn.get('Sindex')) + '_' +syn.get('Tarea') + '_'+str(syn.get('TgroupID')) + '_'+str(syn.get('Tindex'))
e['weight'] = str(syn.get('weight'))
e['source'] = syn.get('Sarea') + '_' + str(syn.get('StrackID')) + '_' + str(
syn.get('SgroupID')) + '_' + str(syn.get('Sindex'))
e['target'] = syn.get('Tarea') + '_' + str(syn.get('TtrackID')) + '_' + str(
syn.get('TgroupID')) + '_' + str(syn.get('Tindex'))
Edges.append(e)
ngraph["node"] = Nodes
ngraph["edge"] = Edges
#print(dic)
return dic, ngraph
def actionSequenceMemoryInit(self):
self.asm.createActionSequenceMem(1, self.neutype, 16)
def recallMusicPFC(self, goalName):
self.pfc.setTestStates()
self.msm.setTestStates()
result = self.pfc.doRecalling2(goalName, self.msm)
return result
def recallMusicByEpisode(self, episodeNotes): # using time window search episode
self.pfc.setTestStates()
self.msm.setTestStates()
# sl = self.msm.sequenceLayers.get(1)
# for index,group in sl.groups.items():
# strs = "group_"+str(index)+":"
# for n in group.neurons:
# if(n.preActive == True):
# strs +=" neu_Index:"+str(n.index)+","
# print(strs)
result = self.msm.recallByEpisode2(episodeNotes, self.pfc.goals)
return result
def generateEx_Nihilo(self, firstNote, durations, length):
'''
this function is used to generate the main melody, only one track
'''
self.pfc.setTestStates()
self.msm.setTestStates()
result = {}
track1 = []
for i in range(length):
dic = {}
tt = np.arange(i * 5, (i + 1) * 5, self.dt)
for t in tt:
self.pfc.inhibiteGoals(self.dt, t)
self.msm.generateEx_Nihilo(firstNote, durations, i, self.dt, t)
panneu = []
maxrate = 0.0
maxneu = None
for ni, neu in enumerate(self.msm.sequenceLayers.get(1).get("N").groups.get(i + 1).neurons):
if (neu.preActive == True):
panneu.append(neu)
if (len(neu.spiketime) > 0):
# dic['N'] = neu.selectivity
if (len(neu.spiketime) > maxrate):
maxrate = len(neu.spiketime)
maxneu = neu
print(maxneu.I)
if (dic.get('N') == None):
j = random.randint(0, len(panneu) - 1)
neu = panneu[j]
neu.I = 20
for t in tt:
neu.update_normal(self.dt, t)
dic['N'] = neu.selectivity
else: # chose the neuron which has the max firing rate
dic['N'] = maxneu.selectivity
patneu = []
maxrate = 0.0
maxneu = None
for neu in self.msm.sequenceLayers.get(1).get("T").groups.get(i + 1).neurons:
if (neu.preActive == True): patneu.append(neu)
if (len(neu.spiketime) > 0):
if (len(neu.spiketime) > maxrate):
maxrate = len(neu.spiketime)
maxneu = neu
if (dic.get('T') == None):
j = random.randint(0, len(patneu) - 1)
neu = patneu[j]
neu.I = 20
for t in tt:
neu.update_normal(self.dt, t)
dic['T'] = neu.selectivity
else:
dic['T'] = neu.selectivity
track1.append(dic)
result[1] = track1
#print(result)
return result
def generateEx_Nihilo2(self, firstNote, durations, length):
'''
this function is used to generate the main melody, only one track
'''
self.pfc.setTestStates()
self.msm.setTestStates()
result = {}
track1 = []
for i in range(length):
dic = {}
tt = np.arange(i * 5, (i + 1) * 5, self.dt)
for t in tt:
self.pfc.inhibiteGoals(self.dt, t)
self.msm.generateEx_Nihilo(firstNote, durations, i, self.dt, t)
panneu = []
for ni, neu in enumerate(self.msm.sequenceLayers.get(1).get("N").groups.get(i + 1).neurons):
if (neu.preActive == True):
panneu.append(neu)
if (len(neu.spiketime) > 0):
dic['N'] = neu.selectivity
if (dic.get('N') == None):
j = random.randint(0, len(panneu) - 1)
neu = panneu[j]
neu.I = 20
for t in tt:
neu.update_normal(self.dt, t)
dic['N'] = neu.selectivity
patneu = []
for neu in self.msm.sequenceLayers.get(1).get("T").groups.get(i + 1).neurons:
if (neu.preActive == True): patneu.append(neu)
if (len(neu.spiketime) > 0):
dic['T'] = neu.selectivity
if (dic.get('T') == None):
j = random.randint(0, len(patneu) - 1)
neu = patneu[j]
neu.I = 20
for t in tt:
neu.update_normal(self.dt, t)
dic['T'] = neu.selectivity
track1.append(dic)
result[1] = track1
return result
def generateEx_NihiloAccordingToGenre(self, genreName, firstNote, durations, length):
genreName = genreName.title()
self.pfc.setTestStates()
self.msm.setTestStates()
result = {}
track1 = []
for i in range(length):
dic = {}
tt = np.arange(i * 5, (i + 1) * 5, self.dt)
for t in tt:
self.pfc.inhibiteGoals(self.dt, t)
self.pfc.inhibitComposers(self.dt, t)
self.pfc.doRememberingGenre(genreName, self.dt, t)
self.msm.generateEx_Nihilo(firstNote, durations, i, self.dt, t)
panneu = []
for ni, neu in enumerate(self.msm.sequenceLayers.get(1).get("N").groups.get(i + 1).neurons):
if (neu.preActive == True):
panneu.append(neu)
if (len(neu.spiketime) > 0):
dic['N'] = neu.selectivity
if (dic.get('N') == None):
j = random.randint(0, len(panneu) - 1)
neu = panneu[j]
neu.I = 20
for t in tt:
neu.update_normal(self.dt, t)
dic['N'] = neu.selectivity
patneu = []
for neu in self.msm.sequenceLayers.get(1).get("T").groups.get(i + 1).neurons:
if (neu.preActive == True): patneu.append(neu)
if (len(neu.spiketime) > 0):
dic['T'] = neu.selectivity
if (dic.get('T') == None):
j = random.randint(0, len(patneu) - 1)
neu = patneu[j]
neu.I = 20
for t in tt:
neu.update_normal(self.dt, t)
dic['T'] = neu.selectivity
track1.append(dic)
result[1] = track1
return result
def generateEx_NihiloAccordingToComposer(self, composerName, firstNote, durations, length):
composerName = composerName.title()
self.pfc.setTestStates()
self.msm.setTestStates()
result = {}
track1 = []
for i in range(length):
dic = {}
tt = np.arange(i * 5, (i + 1) * 5, self.dt)
for t in tt:
self.pfc.inhibiteGoals(self.dt, t)
self.pfc.doRememberingComposer(composerName, self.dt, t)
self.msm.generateEx_Nihilo(firstNote, durations, i, self.dt, t)
panneu = []
maxrate = 0.0
maxneu = None
for ni, neu in enumerate(self.msm.sequenceLayers.get(1).get("N").groups.get(i + 1).neurons):
if (neu.preActive == True):
panneu.append(neu)
if (len(neu.spiketime) > 0):
# dic['N'] = neu.selectivity
#print(str(neu.selectivity) + ":" + str(len(neu.spiketime)))
if (len(neu.spiketime) > maxrate):
maxrate = len(neu.spiketime)
maxneu = neu
if (dic.get('N') == None):
j = random.randint(0, len(panneu) - 1)
neu = panneu[j]
neu.I = 20
for t in tt:
neu.update_normal(self.dt, t)
dic['N'] = neu.selectivity
else: # chose the neuron which has the max firing rate
dic['N'] = maxneu.selectivity
patneu = []
maxrate = 0.0
maxneu = None
for neu in self.msm.sequenceLayers.get(1).get("T").groups.get(i + 1).neurons:
if (neu.preActive == True): patneu.append(neu)
if (len(neu.spiketime) > 0):
if (len(neu.spiketime) > maxrate):
maxrate = len(neu.spiketime)
maxneu = neu
if (dic.get('T') == None):
j = random.randint(0, len(patneu) - 1)
neu = patneu[j]
neu.I = 20
for t in tt:
neu.update_normal(self.dt, t)
dic['T'] = neu.selectivity
else:
dic['T'] = neu.selectivity
track1.append(dic)
result[1] = track1
#print(result)
return result
def generateMelodyWithKey(self, key, firstNotes, durations, length):
result = {}
self.pfc.setTestStates()
self.msm.setTestStates()
row,col = firstNotes.shape
for i in range(length):
time = np.arange(i*5,(i+1)*5,self.dt)
for t in time:
self.pfc.inhibiteGoals(self.dt,t)
self.pfc.inhibitComposers(self.dt,t)
self.pfc.inhibitGenres(self.dt,t)
self.pfc.doRememberingKey(key,self.dt,t)
self.msm.generateMelodyWithTone(firstNotes[:,i] if i < col else None, durations[:,i] if durations else None,key,i+1,self.dt,t)
# find the max firing rate
for j,sl in self.msm.sequenceLayers.items():
if j > 4: break;
#print("***********************this is part " + str(j) + "****************************")
dic = {}
if i < col:
dic["N"] = firstNotes[j-1][i]
dic["T"] = durations[j-1][i] if durations else 1.0
else:
maxrate = 0
for nn in sl.get("N").groups.get(i+1).neurons:
if nn.preActive:
# print("------order: "+str(i+1)+"-------")
# print(nn.selectivity)
# print(nn.I)
# print(nn.I_upper)
# print(nn.I_lower)
# print(len(nn.spiketime))
if len(nn.spiketime) > maxrate:
maxrate = len(nn.spiketime)
#print("top1:"+str(nn.selectivity))
dic["N"] = nn.selectivity
if durations is not None:
dic["T"] = 1
else:
maxrate = 0
for tn in sl.get("T").groups.get(i + 1).neurons:
if tn.preActive:
if len(tn.spiketime) > maxrate:
maxrate = len(tn.spiketime)
#print(tn.selectivity)
dic["N"] = nn.selectivity
if dic.get("T") is None:
dic["T"] = 1.0
if result.get(j) is None:
part = []
part.append(dic)
result[j] = part
else:
result.get(j).append(dic)
return result
def recallActionIPS(self, goalName):
self.pfc.setTestStates()
self.pfc.setTestStates()
self.pfc.doRecalling(goalName, self.asm)
def generate2TrackMusic(self, firstNotes, durations, lengths):
self.pfc.setTestStates()
self.msm.setTestStates()
result = {}
for k, notes in firstNotes.items():
track1 = []
for i in range(lengths[k - 1]):
dic = {}
tt = np.arange(i * 5, (i + 1) * 5, self.dt)
for t in tt:
self.pfc.inhibiteGoals(self.dt, t)
# self.msm.generateSimgleTrackNotes(j+1, firstNotes[j], durations[j], i, self.dt, t)
self.msm.generateSimgleTrackNotes(k, notes, durations.get(k), i, self.dt, t)
panneu = []
for ni, neu in enumerate(self.msm.sequenceLayers.get(k).get("N").groups.get(i + 1).neurons):
if (neu.preActive == True):
panneu.append(neu)
if (len(neu.spiketime) > 0):
dic['N'] = neu.selectivity
if (dic.get('N') == None):
j = random.randint(0, len(panneu) - 1)
neu = panneu[j]
neu.I = 20
for t in tt:
neu.update_normal(self.dt, t)
dic['N'] = neu.selectivity
patneu = []
for neu in self.msm.sequenceLayers.get(k).get("T").groups.get(i + 1).neurons:
if (neu.preActive == True): patneu.append(neu)
if (len(neu.spiketime) > 0):
dic['T'] = neu.selectivity
if (dic.get('T') == None):
j = random.randint(0, len(patneu) - 1)
neu = patneu[j]
neu.I = 20
for t in tt:
neu.update_normal(self.dt, t)
dic['T'] = neu.selectivity
track1.append(dic)
result[k] = track1
print(result)
return result
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Areas/pac.py
================================================
'''
Primary Auditory Area
'''
from braincog.base.brainarea.BrainArea import BrainArea
from Modal.sequencememory import SequenceMemory
from Modal.notesequencelayer import NoteSequenceLayer
from Modal.temposequencelayer import TempoSequenceLayer
from Modal.synapse import Synapse
import numpy as np
import math
from conf.conf import *
class PAC(BrainArea,SequenceMemory):
'''
the planum polare, anterior to PAC, as well as in the left planum temporale,posterior to PAC.
'''
def __init__(self, neutype):
'''
Constructor
'''
SequenceMemory.__init__(self, neutype)
def forward(self, x):
pass
def createActionSequenceMem(self, layernum, neutype):
sl = NoteSequenceLayer(neutype)
tl = TempoSequenceLayer(neutype)
instrumentTrack = {}
instrumentTrack["N"] = sl
instrumentTrack["T"] = tl
self.sequenceLayers[layernum] = instrumentTrack
print(len(self.sequenceLayers))
def doRemembering_note_only(self, note, order, dt, t):
# remember note
sl = self.sequenceLayers.get(1)
sgroup = sl.groups.get(order)
dt = 0.1
for n in sgroup.neurons:
n.I_ext = note.frequence
n.computeFilterCurrent()
n.update(dt, t, 'Learn')
def doRemembering(self, trackIndex, noteIndex, order, dt, t, tinterval=0):
# remember note
iTrack = self.sequenceLayers.get(trackIndex)
sl = iTrack.get("N")
sgroup = sl.groups.get(order)
dt = 0.1
for n in sgroup.neurons:
n.I_ext = noteIndex
n.computeFilterCurrent()
n.update(dt, t, 'Learn')
# remember tempo
tl = iTrack.get("T")
tgroup = tl.groups.get(order)
dt = 0.1
for n in tgroup.neurons:
n.I_ext = tinterval
n.computeFilterCurrent()
n.update(dt, t, 'Learn')
def doConnectToTitle(self, title, track, order):
for sl in track.values():
self.doConnecting(title, sl, order)
def doConnectToComposer(self, composer, track, order):
for sl in track.values():
self.doConnecting(composer, sl, order)
def doConnectToGenre(self, genre, track, order):
for sl in track.values():
self.doConnecting(genre, sl, order)
def generateEx_Nihilo(self, firstNote, durations, order, dt, t):
ns = self.sequenceLayers.get(1).get("N")
ts = self.sequenceLayers.get(1).get("T")
nneurons = ns.groups.get(order + 1).neurons
tneurons = ts.groups.get(order + 1).neurons
# firstNotes specify the beginning notes to trigger the following notes
if (order < len(firstNote)): # beginning notes
i = firstNote[order]
nneu = nneurons[i + 1]
nneu.I = 20
nneu.update_normal(dt, t)
d = int(durations[order] / 0.125) - 1
tneu = tneurons[d]
tneu.I = 20
tneu.update_normal(dt, t)
else: # generate next note
for nn in nneurons:
nn.updateCurrentOfLowerAndUpperLayer(t)
nn.update(dt, t, 'test')
for tn in tneurons:
tn.updateCurrentOfLowerAndUpperLayer(t)
tn.update(dt, t, 'test')
def generateSimgleTrackNotes(self, trackIndex, firstNote, durations, order, dt, t):
ns = self.sequenceLayers.get(trackIndex).get("N")
ts = self.sequenceLayers.get(trackIndex).get("T")
nneurons = ns.groups.get(order + 1).neurons
tneurons = ts.groups.get(order + 1).neurons
# firstNotes specify the beginning notes to trigger the following notes
if (order < len(firstNote)): # beginning notes
i = firstNote[order]
nneu = nneurons[i + 1]
nneu.I = 20
nneu.update_normal(dt, t)
d = int(durations[order] / 0.125) - 1
tneu = tneurons[d]
tneu.I = 20
tneu.update_normal(dt, t)
else: # generate next note
for nn in nneurons:
nn.updateCurrentOfLowerAndUpperLayer(t)
nn.update(dt, t, 'test')
# if(neu.spike == True):
# print(neu.selectivity)
for tn in tneurons:
tn.updateCurrentOfLowerAndUpperLayer(t)
tn.update(dt, t, 'test')
class Music_Sequence_Mem(SequenceMemory):
'''
the planum polare, anterior to PAC, as well as in the left planum temporale,posterior to PAC.
'''
def __init__(self, neutype):
'''
Constructor
'''
SequenceMemory.__init__(self, neutype)
def createActionSequenceMem(self, layernum, neutype):
sl = NoteSequenceLayer(neutype)
tl = TempoSequenceLayer(neutype)
instrumentTrack = {}
instrumentTrack["N"] = sl
instrumentTrack["T"] = tl
self.sequenceLayers[layernum] = instrumentTrack
print(len(self.sequenceLayers))
def doRemembering_note_only(self, note, order, dt, t):
# remember note
sl = self.sequenceLayers.get(1)
sgroup = sl.groups.get(order)
dt = 0.1
for n in sgroup.neurons:
n.I_ext = note.frequence
n.computeFilterCurrent()
n.update(dt, t, 'Learn')
def doRemembering(self, trackIndex, noteIndex, order, dt, t, tinterval=0):
# remember note
iTrack = self.sequenceLayers.get(trackIndex)
sl = iTrack.get("N")
sgroup = sl.groups.get(order)
dt = 0.1
for n in sgroup.neurons:
n.I_ext = noteIndex
n.computeFilterCurrent()
n.update(dt, t, 'Learn')
# remember tempo
tl = iTrack.get("T")
tgroup = tl.groups.get(order)
dt = 0.1
for n in tgroup.neurons:
n.I_ext = tinterval
n.computeFilterCurrent()
n.update(dt, t, 'Learn')
def recallByEpisode(self, episodeNotes, goals):
dt = 0.1
sl = self.sequenceLayers.get(1)
firstresult = {}
firstNote = episodeNotes[0]
tt = np.arange(0, 5, dt)
# find first note and activate goal neurons
for t in tt:
for id, group in sl.groups.items():
neuid = firstNote - 15;
if (group.neurons[neuid - 1].preActive == False): continue
group.neurons[neuid - 1].I_ext = 20
group.neurons[neuid - 1].updateCurrentOfLowerAndUpperLayer(t)
group.neurons[neuid - 1].update(dt, t, 'test')
if (group.neurons[neuid - 1].spike == True):
firstresult[id] = 1
for name, g in goals.groups.items():
for neu in g.neurons:
neu.updateCurrentOfLowerAndUpperLayer(t)
neu.update(dt, t)
# find rest episode notes
restResult = {}
# for i in range(1,len(episodeNotes)):
goalchecked = {}
for groupID, value in firstresult.items():
tmp = {}
for i in range(1, len(episodeNotes)):
fre = episodeNotes[i]
tt = np.arange(i * 5, (i + 1) * 5, dt)
g = sl.groups.get(groupID + i)
for t in tt:
# update memory
neuid = fre - 15;
if (g.neurons[neuid - 1].preActive == False): continue
g.neurons[neuid - 1].I_ext = 20
g.neurons[neuid - 1].updateCurrentOfLowerAndUpperLayer(t)
if (g.neurons[neuid - 1].I_lower == 0): break
g.neurons[neuid - 1].update(dt, t, 'test')
if (g.neurons[neuid - 1].spike == True):
tmp[i] = g.id
# update goals' neurons
for name, gg in goals.groups.items():
for neu in gg.neurons:
neu.updateCurrentOfLowerAndUpperLayer(t)
neu.update(dt, t)
# if(neu.spike == True):print(name)
# find
maxFiringRate = 0
maxGoalName = ''
maxGoal = {} # an episode may be mapped to more than one songs
for name, gg in goals.groups.items():
if (goalchecked.get(name) == None):
averageFiringRate = 0
for neu in gg.neurons:
averageFiringRate = averageFiringRate + len(neu.spiketime)
averageFiringRate = float(averageFiringRate) / float(len(gg.neurons))
gg.averageFiringRate = averageFiringRate
if (averageFiringRate > maxFiringRate):
maxFiringRate = averageFiringRate
maxGoalName = name
maxGoal[maxGoalName] = 1
goalchecked[maxGoalName] = 1
for name, gg in goals.groups.items():
if (gg.averageFiringRate == maxFiringRate and goalchecked.get(name) == None):
maxGoal[name] = 1
goalchecked[name] = 1
tmp['goal'] = maxGoal
restResult[groupID] = tmp
print(restResult)
episodeResult = []
for key, value in firstresult.items():
tmp = {}
tmp[0] = key
dic = restResult.get(key)
for i in range(1, len(episodeNotes)):
if (dic.get(i) != None):
tmp[i] = dic.get(i)
if (len(tmp) == len(episodeNotes)):
tmp['goal'] = dic.get('goal')
episodeResult.append(tmp)
print(episodeResult)
for res in episodeResult:
msmgroupID = res.get(len(episodeNotes) - 1) + 1
for gname, value in res.get('goal').items():
gg = goals.groups.get(gname)
def recallByEpisode2(self, episodeNotes, goals):
dt = 0.1
sl = self.sequenceLayers.get(1).get("N")
tl = self.sequenceLayers.get(1).get("T")
print(len(sl.groups))
firstresult = {}
firstNote = episodeNotes[0]
tt = np.arange(0, 5, dt)
# find first note and activate goal neurons
for t in tt:
for id, group in sl.groups.items():
neuid = firstNote;
if (group.neurons[neuid + 1].preActive == False): continue
# group.neurons[neuid+1].I_ext = 20
# group.neurons[neuid+1].updateCurrentOfLowerAndUpperLayer(t)
group.neurons[neuid + 1].I = 20
group.neurons[neuid + 1].update(dt, t, 'test')
if (group.neurons[neuid + 1].spike == True):
firstresult[id] = 1
# find rest episode notes
restResult = {}
# for i in range(1,len(episodeNotes)):
goalchecked = {}
for groupID in firstresult.keys():
#
sl.setTestStates()
tmp = {}
tt = np.arange(0, 5, dt)
g = sl.groups.get(groupID)
neuid = firstNote
# g.neurons[neuid+1].I_ext = 20
g.neurons[neuid + 1].I = 20
for t in tt:
# g.neurons[neuid+1].updateCurrentOfLowerAndUpperLayer(t)
g.neurons[neuid + 1].update(dt, t, 'test')
for i in range(1, len(episodeNotes)):
fre = episodeNotes[i]
tt = np.arange(i * 5, (i + 1) * 5, dt)
if (groupID + i > len((sl.groups))): continue
g = sl.groups.get(groupID + i)
for t in tt:
# update memory
neuid = fre;
if (g.neurons[neuid + 1].preActive == False): continue
g.neurons[neuid + 1].I_ext = 20
g.neurons[neuid + 1].updateCurrentOfLowerAndUpperLayer(t)
if (g.neurons[neuid + 1].I_lower == 0): break
g.neurons[neuid + 1].update(dt, t, 'test')
if (g.neurons[neuid + 1].spike == True):
tmp[i] = g.id
restResult[groupID] = tmp
# print(restResult)
episodeResult = []
for key in firstresult.keys():
tmp = {}
tmp[0] = key
dic = restResult.get(key)
for i in range(1, len(episodeNotes)):
if (dic.get(i) != None):
tmp[i] = dic.get(i)
if (len(tmp) == len(episodeNotes)):
episodeResult.append(tmp)
# print(episodeResult)
# begin remembering
finalResult = []
for i, res in enumerate(episodeResult):
goals.setTestStates()
self.setTestStates()
for i, fre in enumerate(episodeNotes):
tt = np.arange(i * 5, (i + 1) * 5, dt)
neuid = fre
g = sl.groups.get(res.get(i))
for t in tt:
# g.neurons[neuid+1].I_ext = 20
# g.neurons[neuid+1].updateCurrentOfLowerAndUpperLayer(t)
g.neurons[neuid + 1].I = 20
g.neurons[neuid + 1].update(dt, t, 'test')
for name, gg in goals.groups.items():
for neu in gg.neurons:
neu.updateCurrentOfLowerAndUpperLayer(t)
neu.update(dt, t)
# find goal
maxFiringRate = 0
maxGoalName = ''
maxGoal = {} # an episode may be mapped to more than one songs
for name, gg in goals.groups.items():
averageFiringRate = 0
for neu in gg.neurons:
averageFiringRate = averageFiringRate + len(neu.spiketime)
averageFiringRate = float(averageFiringRate) / float(len(gg.neurons))
gg.averageFiringRate = averageFiringRate
if (averageFiringRate > maxFiringRate):
maxFiringRate = averageFiringRate
maxGoalName = name
maxGoal[maxGoalName] = 1
for name, gg in goals.groups.items():
if (gg.averageFiringRate == maxFiringRate and goalchecked.get(name) == None):
maxGoal[name] = 1
# print(maxGoal)
# recall the rest song
for goalname, value in maxGoal.items():
# reset State
goals.setTestStates()
nextGroupId = res.get(len(episodeNotes) - 1) + 1
for i in range(nextGroupId, len(sl.groups) + 1):
sl.groups.get(i).setTestStates()
tl.groups.get(i).setTestStates()
restSpikeResult = {}
count = 0
gg = goals.groups.get(goalname)
for i in range(nextGroupId, len(sl.groups) + 1):
order = len(episodeNotes) + count
if (order == 3):
print("debug")
tt = np.arange(order * 5, (order + 1) * 5, dt)
msmgroup = sl.groups.get(i)
msmtgroup = tl.groups.get(i)
tdic = {}
for t in tt:
for n in gg.neurons:
# n.updateCurrentOfLowerAndUpperLayer(t)
n.I = 30
n.update_normal(dt, t)
for neu in msmgroup.neurons:
neu.updateCurrentOfLowerAndUpperLayer(t)
neu.update(dt, t, 'test')
if (neu.spike == True and restSpikeResult.get(order) == None):
# restSpikeResult[int(order)] = neu.selectivity
tdic["N"] = neu.selectivity
for neu in msmtgroup.neurons:
neu.updateCurrentOfLowerAndUpperLayer(t)
neu.update(dt, t, 'test')
if (neu.spike == True and restSpikeResult.get(order) == None):
# restSpikeResult[int(order)] = neu.selectivity
tdic["T"] = neu.selectivity
if (tdic):
restSpikeResult[int(order)] = tdic
count += 1
# print(restSpikeResult)
dic = {}
dic['goal'] = goalname
dic['rest'] = restSpikeResult
finalResult.append(dic)
return (finalResult)
def doConnectToTitle(self, title, track, order):
for sl in track.values():
self.doConnecting(title, sl, order)
def doConnectToComposer(self, composer, track, order):
for sl in track.values():
self.doConnecting(composer, sl, order)
def doConnectToGenre(self, genre, track, order):
for sl in track.values():
self.doConnecting(genre, sl, order)
def doConnectToEmotion(self, emo, track, order):
for sl in track.values():
self.doConnecting(emo, sl, order)
def doConnectToKey(self, key, track, order, noteIndex):
group = track.get("N").groups.get(order)
if group == None: return
tb = (order - 1) * group.timeWindow
te = (order) * group.timeWindow
sp1_goal = {}
sp2 = []
kneu = key.neurons[noteIndex % 12]
sp = []
for st in kneu.spiketime:
if (st < te and st >= tb):
sp.append(st)
sp1_goal[kneu.index] = sp
n = group.neurons[noteIndex + 1]
if (len(n.spiketime) > 0):
for index, sp in sp1_goal.items():
temp = 0
for sp1 in n.spiketime: # spike times of group
for sp2 in sp:
if (abs(sp1 - sp2) <= n.timeWindow):
temp += 1
if (temp >= 2): # super threshold, create a new synapse between goal and neurons of sequence group
# syn = Synapse(goal.neurons[index - 1], n)
# syn.type = 2
# syn.weight = 30
# n.pre_neurons.append(goal.neurons[index - 1])
# n.synapses.append(syn)
# add reverse synapse to neurons of the goal
syn2 = Synapse(n, kneu)
syn2.type = 3
syn2.weight = 0
kneu.synapses.append(syn2)
def doConnectToMode(self, mode, keyName, track, order, noteIndex): # 只连接对应调式的音级(这是两个神经元之间的连接)
noteScales = np.where((configs.keyscales.get(keyName % 2)[keyName // 2]) == noteIndex % 12)[0][0]
group = track.get("N").groups.get(order)
if (group == None): return
tb = (order - 1) * group.timeWindow
te = (order) * group.timeWindow
sp1_goal = {}
sp2 = []
modeneu = mode.neurons[noteScales]
sp = []
for st in modeneu.spiketime:
if (st < te and st >= tb):
sp.append(st)
sp1_goal[modeneu.index] = sp
n = group.neurons[noteIndex + 1]
if (len(n.spiketime) > 0):
for index, sp in sp1_goal.items():
temp = 0
for sp1 in n.spiketime: # spike times of group
for sp2 in sp:
if (abs(sp1 - sp2) <= n.timeWindow):
temp += 1
if (temp >= 2): # super threshold, create a new synapse between goal and neurons of sequence group
# syn = Synapse(goal.neurons[index - 1], n)
# syn.type = 2
# syn.weight = 30
# n.pre_neurons.append(goal.neurons[index - 1])
# n.synapses.append(syn)
# add reverse synapse to neurons of the goal
syn2 = Synapse(n, modeneu)
syn2.type = 3
syn2.weight = 0
modeneu.synapses.append(syn2)
def generateEx_Nihilo(self, firstNote, durations, order, dt, t):
ns = self.sequenceLayers.get(1).get("N")
ts = self.sequenceLayers.get(1).get("T")
nneurons = ns.groups.get(order + 1).neurons
tneurons = ts.groups.get(order + 1).neurons
# firstNotes specify the beginning notes to trigger the following notes
if (order < len(firstNote)): # beginning notes
i = firstNote[order]
nneu = nneurons[i + 1]
nneu.I = 20
nneu.update_normal(dt, t)
d = int(durations[order] / 0.125) - 1
tneu = tneurons[d]
tneu.I = 20
tneu.update_normal(dt, t)
else: # generate next note
for nn in nneurons:
nn.updateCurrentOfLowerAndUpperLayer(t)
nn.update(dt, t, 'test')
for tn in tneurons:
tn.updateCurrentOfLowerAndUpperLayer(t)
tn.update(dt, t, 'test')
def generateMelodyWithTone(self, firstNote, duration, tone, order, dt, t):
for i, part in self.sequenceLayers.items():
if i > 4: break
ns = part.get("N")
ts = None
if duration is not None:
ts = part.get("T") # 不输入值的话就都生成四分音符
# pitches updating
if firstNote is not None:
for nneu in ns.groups.get(order).neurons:
nneu.I_ext = firstNote[i - 1]
nneu.computeFilterCurrent()
nneu.update_normal(dt, t)
else:
for nneu in ns.groups.get(order).neurons:
nneu.updateCurrentOfLowerAndUpperLayer(t)
nneu.update_normal(dt, t)
# durations updating
if ts is not None:
if duration is not None:
for tneu in ts.groups.get(order).neurons:
tneu.I_ext = duration[i - 1]
tneu.computeFilterCurrent()
tneu.update_normal(dt, t)
else:
for tneu in ts.groups.get(order).neurons:
tneu.updateCurrentOfLowerAndUpperLayer(t)
tneu.update_normal(dt, t)
def generateSimgleTrackNotes(self, trackIndex, firstNote, durations, order, dt, t):
ns = self.sequenceLayers.get(trackIndex).get("N")
ts = self.sequenceLayers.get(trackIndex).get("T")
nneurons = ns.groups.get(order + 1).neurons
tneurons = ts.groups.get(order + 1).neurons
# firstNotes specify the beginning notes to trigger the following notes
if (order < len(firstNote)): # beginning notes
i = firstNote[order]
nneu = nneurons[i + 1]
nneu.I = 20
nneu.update_normal(dt, t)
d = int(durations[order] / 0.125) - 1
tneu = tneurons[d]
tneu.I = 20
tneu.update_normal(dt, t)
else: # generate next note
for nn in nneurons:
nn.updateCurrentOfLowerAndUpperLayer(t)
nn.update(dt, t, 'test')
# if(neu.spike == True):
# print(neu.selectivity)
for tn in tneurons:
tn.updateCurrentOfLowerAndUpperLayer(t)
tn.update(dt, t, 'test')
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Areas/pfc.py
================================================
import numpy as np
import math
from braincog.base.brainarea.PFC import PFC
from Modal.synapse import Synapse
from Modal.titlelayer import TitleLayer
from Modal.composerlayer import ComposerLayer
from Modal.genrelayer import GenreLayer
from conf.conf import *
from Modal.layer import *
class PFC(PFC):
'''
This area is used to store the sub-Goal of the task
'''
def __init__(self, neutype):
'''
Constructor
'''
super().__init__()
self.neutype = neutype
self.goals = TitleLayer(self.neutype) # store the musical titles
self.composers = ComposerLayer(self.neutype) # store composers
self.keys = KeyLayer(self.neutype)
self.modes = ModeLayer(self.neutype)
self.genres = GenreLayer(self.neutype)
self.chords = ChordLayer(self.neutype)
def addNewKey(self):
row, col = configs.key_matrix.shape
for i in range(row):
#print(configs.key_matrix[i,:])
self.keys.addNewGroups(i+1, 1, col, configs.key_matrix[i,:])
def addNewSubGoal(self, goalname):
if (self.goals.groups.get(goalname) == None):
self.goals.addNewGroups(len(self.goals.groups) + 1, 1, 1, goalname)
def addNewComposer(self, composername):
if (self.composers.groups.get(composername) == None):
self.composers.addNewGroups(len(self.composers.groups) + 1, 1, 1, composername)
def addNewGenre(self, genrename):
if (self.genres.groups.get(genrename) == None):
self.genres.addNewGroups(len(self.genres.groups) + 1, 1, 1, genrename)
def addNewMode(self):
for i, m in configs.index2mode.items():
self.modes.addNewGroups(i + 1, 1, 12, m)
# 与调式网络相连
scales = configs.keyscales.get(i)
for k in range(12): # key neurons project to mode neurons
for j, index in enumerate(scales[k, :]):
pre = self.keys.groups.get(k).neurons[index]
post = self.modes.groups.get(i).neurons[j]
syn = Synapse(pre, post)
syn.excitability = 1
syn.type = 3
post.synapses.append(syn)
post.pre_neurons.append(pre)
# syn1 = Synapse(post, pre) # mode neurons project to key neurons
# syn1.type = 2
# syn1.excitability = 1
# syn1.weight = 10 # 这个地方应该设置成KS model
# pre.synapses.append(syn1)
# pre.pre_neurons.append(post)
def addNewKey(self):
row, col = configs.key_matrix.shape
for i in range(row):
#print(configs.key_matrix[i,:])
self.keys.addNewGroups(i+1, 1, col, configs.key_matrix[i,:])
def addNewChord(self):
for i in range(0, 7): # 暂时先存储7个三和弦
self.chords.addNewGroups(i + 1, 1, 1)
# 与调式网络相连
# 先连接T,S,D和弦
for t, c in configs.chordsMap.items():
# print(t)
# print(configs.keyIndexMap.get(t))
# print(c[i, :])
# print('---------')
for k in c[i, :]:
pre = self.chords.groups.get(i).neurons[0]
post = self.keys.groups.get(configs.keyIndexMap.get(t)).neurons[k]
syn = Synapse(pre, post)
syn.excitability = 1
post.synapses.append(syn)
post.pre_neurons.append(pre)
# 建立和弦内部连接
r = np.argwhere(configs.chordsMatrix >= 1)
for i in range(len(r)):
# print(r[i][0])
# print(r[i][1])
pre = self.chords.groups.get(r[i][0]).neurons[0]
post = self.chords.groups.get(r[i][1]).neurons[0]
syn = Synapse(pre, post)
post.synapses.append(syn)
post.pre_neurons.append(pre)
def setTestStates(self):
self.goals.setTestStates()
self.composers.setTestStates()
self.genres.setTestStates()
self.keys.setTestStates()
self.modes.setTestStates()
self.chords.setTestStates()
def doRecalling(self, goalname, asm):
goal = self.goals.groups.get(goalname)
# print(goal.name)
# print(goal.id)
result = {}
sequences = asm.sequenceLayers.get(1).groups
dt = 0.1
time = np.arange(0, len(sequences) * 5, dt)
for t in time:
order = math.floor(t / 5) + 1
for neu in goal.neurons:
neu.I = 30
neu.update_normal(dt, t)
sg = sequences.get(order)
for neu in sg.neurons:
neu.updateCurrentOfLowerAndUpperLayer(t)
neu.update(dt, t, 'test')
# if(neu.spike == True):
# print(neu.index)
if (neu.spike == True and result.get(order) == None):
result[int(order)] = neu.selectivity
return result
def doRecalling2(self, goalname, asm):
goal = self.goals.groups.get(goalname)
# print(goal.name)
# print(goal.id)
result = {}
for tindex, strack in asm.sequenceLayers.items():
nsequences = strack.get("N").groups
tsequences = strack.get("T").groups
dic = {}
ndic = {}
tdic = {}
dt = 0.1
time = np.arange(0, len(nsequences) * 5, dt)
for t in time:
order = math.floor(t / 5) + 1
for neu in goal.neurons:
neu.I = 30
neu.update_normal(dt, t)
nsg = nsequences.get(order)
for neu in nsg.neurons:
# print(neu.selectivity)
neu.updateCurrentOfLowerAndUpperLayer(t)
neu.update(dt, t, 'test')
# if(neu.I > 0):
# print(neu.I)
if (neu.spike == True and ndic.get(order) == None):
ndic[int(order)] = neu.selectivity
tsg = tsequences.get(order)
for neu in tsg.neurons:
neu.updateCurrentOfLowerAndUpperLayer(t)
neu.update(dt, t, 'test')
if (neu.spike == True and tdic.get(order) == None):
tdic[int(order)] = neu.selectivity
dic["N"] = ndic
dic["T"] = tdic
result[tindex] = dic
return result
def doRecalling3(self,goalname,asm):
goal = self.goals.groups.get(goalname)
# print(goal.name)
# print(goal.id)
result = {}
for tindex, strack in asm.sequenceLayers.items():
nsequences = strack.get("N").groups
tsequences = strack.get("T").groups
part = []
ndic = {}
tdic = {}
dt = 0.1
# for order in range(len(nsequences)):
# nns = nsequences.get(order+1).neurons
# for n in nns:
# if n.preActive:
# print('-------order:'+str(order)+', selectivity: '+str(n.selectivity)+'--------------------')
# for syn in n.synapses:
# if syn.weight > 0:
# print(syn.weight)
#time = np.arange(0, len(nsequences) * 5, dt)
print(len(nsequences))
for i in range(0,len(nsequences)):
order = i+1
tmp = {}
time = np.arange(i*5,(i+1)*5,dt)
for t in time:
for neu in goal.neurons:
neu.I = 30
neu.update_normal(dt, t)
nsg = nsequences.get(order)
for neu in nsg.neurons:
neu.updateCurrentOfLowerAndUpperLayer(t)
neu.update_normal(dt, t)
# if (neu.I > 0):
# print('order: ' + str(order))
# print(neu.I)
# print(neu.selectivity)
# print(neu.I_lower)
if (neu.spike == True and ndic.get(order) == None):# 这里用的是first spike的理念,我觉得最好改成max firingrate
ndic[int(order)] = neu.selectivity
tmp["N"] = neu.selectivity
tsg = tsequences.get(order)
for neu in tsg.neurons:
neu.updateCurrentOfLowerAndUpperLayer(t)
neu.update_normal(dt, t)
if (neu.spike == True and tdic.get(order) == None):#这个地方bug太邪乎了,等一会儿改
tdic[int(order)] = neu.selectivity
tmp["T"] = neu.selectivity
part.append(tmp)
result[tindex] = part
print(len(result))
return result
def doRemebering(self, goalname, dt, t):
# storing the title information
goal_group = self.goals.groups.get(goalname)
for neu in goal_group.neurons:
neu.I = 50
neu.update_normal(dt, t)
def doRememberingComposer(self, composername, dt, t):
composer_group = self.composers.groups.get(composername)
for neu in composer_group.neurons:
neu.I = 50
neu.update_normal(dt, t)
def doRememberingGenre(self, genrename, dt, t):
genre_group = self.genres.groups.get(genrename)
for neu in genre_group.neurons:
neu.I = 50
neu.update_normal(dt, t)
def doRememberingKey(self, key, dt, t):
key_group = self.keys.groups.get(key)
for neu in key_group.neurons:
neu.I = 50 if neu.importance > 0 else -100
neu.update_learn(dt, t)
def doRememberingMode(self,mode, dt,t):
mode_group = self.modes.groups.get(mode)
for neu in mode_group.neurons:
neu.I = 50
neu.update_learn(dt,t)
def innerLearning(self, goalname, composer, genre):
g = self.goals.groups.get(goalname)
c = self.composers.groups.get(composer)
gre = self.genres.groups.get(genre)
if (g != None and c != None):
for n1 in c.neurons:
if (len(n1.spiketime) > 0):
for n2 in g.neurons:
if (len(n2.spiketime) > 0):
temp = 0
for sp1 in n1.spiketime:
for sp2 in n2.spiketime:
if (abs(sp1 - sp2) <= n1.tau_ref):
temp += 1
if (temp >= 3):
syn = Synapse(n1, n2)
syn.type = 2
syn.weight = 5
n2.synapses.append(syn)
n2.pre_neurons.append(n1)
if (gre != None):
for n1 in gre.neurons:
if (len(n1.spiketime) > 0):
if (c != None):
for n2 in c.neurons:
if (len(n2.spiketime) > 0):
temp = 0
for sp1 in n1.spiketime:
for sp2 in n2.spiketime:
if (abs(sp1 - sp2) <= n1.tau_ref):
temp += 1
if (temp >= 4):
syn = Synapse(n1, n2)
syn.type = 2
syn.weight = 5
n2.synapses.append(syn)
n2.pre_neurons.append(n1)
if (g != None):
for n2 in g.neurons:
if (len(n2.spiketime) > 0):
temp = 0
for sp1 in n1.spiketime:
for sp2 in n2.spiketime:
if (abs(sp1 - sp2) <= n1.tau_ref):
temp += 1
if (temp >= 4):
syn = Synapse(n1, n2)
syn.type = 2
syn.weight = 5
n2.synapses.append(syn)
n2.pre_neurons.append(n1)
def inhibitGenres(self,dt,t):
gen_group = self.goals.groups
for g in gen_group.values():
for neu in g.neurons:
neu.I = -100
neu.update_normal(dt, t)
def inhibiteGoals(self, dt, t):
goal_group = self.goals.groups
for g in goal_group.values():
for neu in g.neurons:
neu.I = -100
neu.update(dt, t)
def inhibitComposers(self, dt, t):
com_group = self.composers.groups
for g in com_group.values():
for neu in g.neurons:
neu.I = -100
neu.update(dt, t)
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/PAC.py
================================================
'''
Primary Auditory Cortex
'''
import torch
from braincog.base.node.node import *
from braincog.base.brainarea.BrainArea import *
from braincog.base.connection import CustomLinear
from braincog.base.learningrule.STDP import *
class PAC(BrainArea):
def __int__(self,w,mask):
self.noteNetworks = NoteLIFNode()
self.connection = [CustomLinear(w,mask),CustomLinear(w2,mask2)]
self.stdp = []
self.internalinputs = torch.zeros(640,640)
self.stdp.append(MutliInputSTDP(self.noteNetworks, self.connection))
def forward(self, x):
self.internalinputs,dw = self.stdp[0](x,self.internalinputs)
return self.internalinputs, dw
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/cluster.py
================================================
from .lifneuron import LIFNeuron
from .synapse import Synapse
from Modal.izhikevichneuron import *
class Cluster():
'''
classdocs
'''
def __init__(self, neutype='LIF', neunum=10):
'''
Constructor
'''
self.id = 0 # starting with 1
self.name = '' # name of this group
self.neutype = neutype
self.neunum = neunum
self.neurons = []
self.timeWindow = 5 # ms
def createClusterNetwork(self):
# create neurons
for i in range(0, self.neunum):
if (self.neutype == 'LIF'):
node = LIFNeuron()
node.index = i + 1
node.setPreference()
self.neurons.append(node)
if (self.neutype == 'Izhikevich'):
node = IzhikevichNeuron()
node.index = i
self.neurons.append(node)
if (self.neutype == 'Gaussian'):
node = GaussianNeuron()
node.index = i + 1
self.neurons.append(node)
if (self.neutype == 'HH'):
node = HHNeuron()
node.index = i
self.neurons.append(node)
def setInhibitoryNeurons(self, ratio_inhneuron):
for i in range(int(self.neunum * (1 - ratio_inhneuron)), self.neunum):
self.neurons[i].type = 'inh'
def setPropertiesofNeurons(self, groupID, layerType, layerID):
for n in self.neurons:
n.layerType = layerType
n.groupIndex = groupID
n.layerIndex = layerID
def setTestStates(self):
for neu in self.neurons:
neu.setTestStates()
def createFullConnections(self): # all in all connections
for i in range(0, self.neunum):
neu = self.neurons[i]
for j in range(0, self.neunum):
if (j != i):
syn = Synapse(self.neurons[j], neu) # this neuron is considered as post_synapse neuron
syn.type = 0
neu.synapses.append(syn)
neu.pre_neurons.append(self.neurons[j])
def createInhibitoryConnections(self): # all in all inhibitory connections
for i in range(0, self.neunum):
neu = self.neurons[i]
for j in range(0, self.neunum):
if (j != i):
syn = Synapse(self.neurons[j], neu) # this neuron is considered as post_synapse neuron
syn.type = 0
syn.excitability = 0
syn.weight = 20
neu.synapses.append(syn)
neu.pre_neurons.append(self.neurons[j])
def writeSelfInfoToJson(self):
dic = {}
nlist = []
for neu in self.neurons:
if (len(neu.spiketime) <= 0): continue
ndic = neu.writeBasicInfoToJson()
nlist.append(ndic)
dic["GroupID"] = self.id
dic["Name"] = self.name
dic["Neuron"] = nlist
return dic
def writeSpikeInfoToJson(self):
nlist = []
for neu in self.neurons:
if (len(neu.spiketime) > 0):
tmp = {}
tmp["GroupID"] = neu.groupIndex
tmp["Index"] = neu.index
tmp["SpikeTime"] = neu.writeSpikeTimeToJson()
nlist.append(tmp)
return nlist
class ModeCluster(Cluster):
def __init__(self, neutype, neunum):
Cluster.__init__(self, neutype,neunum)
def createClusterNetwork(self, areaName):
for i in range(self.neunum): # 暂时先不考虑importance
if (self.neutype == 'LIF'):
node = ModeLIFNeuron()
node.index = i + 1
node.areaName = areaName
node.selectivity = i+1
self.neurons.append(node)
if (self.neutype == 'Izhikevich'):
self.neutype = 'Izhikevich'
node = ModeIzhikevichNeuron()
node.index = i + 1
node.areaName = areaName
node.selectivity = i+1
self.neurons.append(node)
class KeyCluster(Cluster):
def __init__(self, neutype, neunum):
'''
Constructor
'''
Cluster.__init__(self, neutype, neunum)
def createClusterNetwork(self,tone,areaName):
for i in range(0, self.neunum):
if (self.neutype == 'LIF'):
node = KeyLIFNeuron()
node.index = i + 1
node.areaName = areaName
node.selectivity = i
node.importance = tone[i]
self.neurons.append(node)
if(self.neutype == 'Izhikevich'):
self.neutype = 'Izhikevich'
a=0
b=0
c=0
d=0
if tone[i] == 2:
a = 0.02
b = 0.2
c = -55
d = 4
if tone[i] == -1:
a=0.1
b = 0.2
c = -65
d = 2
if(tone[i] == 1):
a = 0.02
b = 0.2
c = -65
d = 8
node = KeyIzhikevichNeuron(a,b,c,d)
node.index = i + 1
node.areaName = areaName
node.selectivity = i
node.importance = tone[i]
self.neurons.append(node)
class ChordCluster(Cluster):
def __init__(self, neutype,neunum):
Cluster.__init__(self, neutype, neunum)
def createClusterNetwork(self):
for i in range(self.neunum):
if (self.neutype == 'LIF'):
node = ChordLIFNeuron()
node.index = i + 1
node.areaName = 'Chord'
node.selectivity = i
node.importance = 1
self.neurons.append(node)
if self.neutype == 'Izhikevich':
node = ChordIzhikevichNeuron()
node.index = i + 1
node.areaName = 'Chord'
node.selectivity = i
node.importance = 1
self.neurons.append(node)
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/composercluster.py
================================================
from .cluster import Cluster
from .composerlifneuron import ComposerLIFNeuron
class ComposerCluster(Cluster):
'''
classdocs
'''
def __init__(self, neutype, neunum):
'''
Constructor
'''
Cluster.__init__(self, neutype, neunum)
def createClusterNetwork(self):
for i in range(0, self.neunum):
if (self.neutype == 'LIF'):
node = ComposerLIFNeuron()
node.index = i + 1
node.areaName = 'Composer'
self.neurons.append(node)
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/composerlayer.py
================================================
from .layer import Layer
from .composercluster import ComposerCluster
class ComposerLayer(Layer):
'''
This layer defines the information of composer name. One neuron corresponds to a composer
'''
def __init__(self, neutype='LIF'):
self.neutype = neutype
self.groups = {}
def setTestStates(self):
for id, g in self.groups.items():
g.setTestStates()
def addNewGroups(self, groupID, layerID, neunum, composername):
g = ComposerCluster(self.neutype, neunum)
g.id = groupID
g.name = composername
g.createClusterNetwork()
g.setPropertiesofNeurons(groupID, 'G', layerID)
self.groups[composername] = g
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/composerlifneuron.py
================================================
from .lifneuron import LIFNeuron
class ComposerLIFNeuron(LIFNeuron):
'''
classdocs
'''
def __init__(self, tau_ref=0, vthresh=5, Rm=2, Cm=0.2):
'''
Constructor
'''
LIFNeuron.__init__(self, tau_ref, vthresh, Rm, Cm)
def update(self, dt, t):
self.spike = False
# self.updateCurrentOfLowerAndUpperLayer(t)
if (t >= self.t_rest):
self.mem += dt * (-self.mem + self.I * self.Rm) / self.tau_m
if (self.mem > self.vth):
self.spike = True
self.spiketime.append(t)
self.mem = 0
self.t_rest = t + self.tau_ref
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/genrecluster.py
================================================
from .cluster import Cluster
from .genrelifneuron import GenreLIFNeuron
class GenreCluster(Cluster):
'''
classdocs
'''
def __init__(self, neutype, neunum):
'''
Constructor
'''
Cluster.__init__(self, neutype, neunum)
def createClusterNetwork(self):
for i in range(0, self.neunum):
if (self.neutype == 'LIF'):
node = GenreLIFNeuron()
node.index = i + 1
node.areaName = 'Genre'
self.neurons.append(node)
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/genrelayer.py
================================================
from .layer import Layer
from .genrecluster import GenreCluster
class GenreLayer(Layer):
'''
This layer defines the information of composer name. One neuron corresponds to a composer
'''
def __init__(self, neutype='LIF'):
self.neutype = neutype
self.groups = {}
def setTestStates(self):
for id, g in self.groups.items():
g.setTestStates()
def addNewGroups(self, groupID, layerID, neunum, genrename):
g = GenreCluster(self.neutype, neunum)
g.id = groupID
g.name = genrename
g.createClusterNetwork()
g.setPropertiesofNeurons(groupID, 'G', layerID)
self.groups[genrename] = g
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/genrelifneuron.py
================================================
from .lifneuron import LIFNeuron
class GenreLIFNeuron(LIFNeuron):
'''
classdocs
'''
def __init__(self, tau_ref=0, vthresh=5, Rm=2, Cm=0.2):
'''
Constructor
'''
LIFNeuron.__init__(self, tau_ref, vthresh, Rm, Cm)
def update(self, dt, t):
self.spike = False
# self.updateCurrentOfLowerAndUpperLayer(t)
if (t >= self.t_rest):
self.mem += dt * (-self.mem + self.I * self.Rm) / self.tau_m
if (self.v > self.vth):
self.spike = True
self.spiketime.append(t)
self.mem = 0
self.t_rest = t + self.tau_ref
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/izhikevichneuron.py
================================================
'''
Created on 2016.4.8
@author: liangqian
'''
#from modal.izhikevich import Izhikevich
from braincog.base.node import IzhNodeMU
import math
import random
import numpy as np
class IzhikevichNeuron(IzhNodeMU):
'''
classdocs
'''
def __init__(self, a = 0.1,b = 0.2,c = -65,d = 8,vthresh = 30, dt=0.1):
'''
Constructor
'''
super().__init__(threshold=vthresh, a=a, b=b, c=c, d=d, dt=dt)
self.layerType = 'S' # S:sequenceLayer, G: goal layer
self.layerIndex = 0 # the layer in which the neuron situated
self.groupIndex = 0 # the group in which the neuron situated
self.index = 0 # starting with 1
self.areaName = ''
self.synapses = [] #this neuron is considered as post-synaptic neuron
self.spiketime = []
self.pre_neurons = []
self.I_syn_lower = 0
self.I_syn_upper = 0
self.I_ext = 0
self.I_lower = 0
self.I_upper = 0
self.I_ext = -100
self.timeWindow = 5 # ms
self.I_bg = random.randint(0, 10)
# self.state = 'Learn' # else test
self.selectivity = 0
self.importance = 0
self.preActive = False
self.I = 0
self.v = -65
self.u = b * self.v
self.vthresh = vthresh
self.spike = False
self.type = 'exc'
def update_old(self,dt,t):
self.spike = 0
self.updateSynapses(t)
self.updateCurrentOfLowerAndUpperLayer(t)
self.I = self.I_ext + self.I_syn_lower + self.I_syn_upper
self.v += dt * (0.04*self.v * self.v + 5 * self.v + 140 - self.u + self.I)
self.u += dt * self.a * (self.b *self.v - self.u)
#self.synapseWeightsDepression()
if(self.v >= 30):
self.spike = 1
self.v = self.c
self.u += self.d
self.spiketime.append(t)
def update(self,dt,t,state):
if (state == 'Learn'):
self.update_learn(dt, t)
if(state == 'test'):
self.update_test(dt,t)
def update_learn(self,dt,t):
self.spike = False
self.v += dt * (0.04 * self.v * self.v + 5 * self.v + 140 - self.u + self.I)
self.u += dt * self.a * (self.b * self.v - self.u)
if self.v > self.vthresh:
self.spike = True
self.v = self.c
self.u += self.d
self.preActive = True
self.spiketime.append(t)
self.updateSynapses(t)
def update_test(self,dt,t):
self.spike = False
self.v += dt * (0.04 * self.v * self.v + 5 * self.v + 140 - self.u + self.I)
self.u += dt * self.a * (self.b * self.v - self.u)
if self.v > self.vthresh:
self.spike = True
self.v = self.c
self.u += self.d
def update_normal(self,dt,t):
self.spike = False
self.v += dt * (0.04 * self.v * self.v + 5 * self.v + 140 - self.u + self.I)
self.u += dt * self.a * (self.b * self.v - self.u)
if self.v >= self.vthresh:
self.spike = True
self.v = self.c
self.u += self.d
self.spiketime.append(t)
def updateSynapses(self,t):
for syn in self.synapses:
syn.computeWeight(t)
def updateCurrentOfLowerAndUpperLayer(self,t):
I_inh = 0
I_ext = 0
I_exc_ext = 0
for syn in self.synapses:
# compute the alpha value of all spikes before this time t
alpha_value = 0
for st in syn.pre.spiketime:
temp = 0
if(t - st >= 0): temp = 6*(t/1000)*math.exp(-0.03*(t - st)/1000)
else:temp = 0
alpha_value += temp
if(syn.type == 0): # from the same group
if(syn.pre.type == 'inh'):
I_inh += syn.weight * (self.v+80) * alpha_value
if(syn.pre.type == 'exc'):
I_ext += syn.weight * self.v * alpha_value
if(syn.type == 1):# from other modules in the same layer
I_exc_ext += syn.weight * self.v * alpha_value
if(syn.type == 2): # from the upper layer
self.I_syn_upper += self.weight * self.v * alpha_value
self.I_syn_lower = -I_inh + I_ext + I_exc_ext
def setTestStates(self):
self.t_rest = 0
self.spiketime = []
self.v = -65
self.u = self.b*self.v
self.I = 0
self.I_ext = 0
for syn in self.synapses:
syn.strength = 0
def writeBasicInfoToJson(self):
dic = {}
dic["TrackID"] = self.layerIndex
dic["GroupID"] = self.groupIndex
dic["Index"] = self.index
dic["selectivity"] = self.selectivity
dic["area"] = self.areaName
slist = []
for syn in self.synapses:
if (syn.weight <= 0): continue
tmp = {}
tmp["type"] = syn.type
tmp["StrackID"] = syn.pre.layerIndex
tmp["SgroupID"] = syn.pre.groupIndex
tmp["Sindex"] = syn.pre.index
tmp["Sarea"] = syn.pre.areaName
tmp["pre-selectivity"] = syn.pre.selectivity
tmp["weight"] = syn.weight
slist.append(tmp)
dic["synapses"] = slist
return dic
def writeSpikeTimeToJson(self):
slist = []
for i, t in enumerate(self.spiketime):
dic = {}
dic[i + 1] = round(t, 2)#两位有效数字
slist.append(dic)
return slist
class NoteIzhikevichNeuron(IzhikevichNeuron):
def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8,vthresh = 30):
IzhikevichNeuron.__init__(self,a,b,c,d,vthresh)
def setPreference(self):
self.selectivity = self.index - 2
def computeFilterCurrent(self):
if(self.I_ext == self.selectivity):
self.I = 30
def updateCurrentOfLowerAndUpperLayer(self, t):
self.I_lower = 0
self.I_upper = 0
for syn in self.synapses:
syn.computeShortTermFacilitation2(t)
if (syn.type == 0): # the same group
if (syn.excitability == 0):
self.I_lower -= syn.weight * syn.strength
if (self.I_lower < -20): self.I_lower = -20
if (syn.type == 1): # pre and post neurons come from the same layer but not the same group
#if(syn.weight > 0):
# print('pre_neuron_group id:'+str(syn.pre.groupIndex) + ' neuron index:'+str(syn.pre.index))
# print('post_neuron_group id:'+str(self.groupIndex) + ' neuron index:'+str(self.index))
# print('syn.strength=' + str(syn.strength))
# print('syn.weight='+ str(syn.weight))
self.I_lower += syn.weight * syn.strength
# print('syn.strength='+str(syn.strength))
# print('syn.weight='+ str(syn.weight))
if (syn.type >= 2): # pre and post neurons come from the different layers
# print(syn.pre.groupIndex)
self.I_upper += syn.weight * syn.strength
self.I = self.I_lower + self.I_upper
class TempoIzhikevichNeuron(IzhikevichNeuron):
def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8,vthresh = 30):
IzhikevichNeuron.__init__(self,a,b,c,d,vthresh)
def setPreference(self):
self.selectivity = self.index * 0.125
def computeFilterCurrent(self):
if(self.I_ext <= self.selectivity + 0.0625 and self.I_ext >= self.selectivity - 0.0625 ):
self.I = 30
def updateCurrentOfLowerAndUpperLayer(self, t):
self.I_lower = 0
self.I_upper = 0
for syn in self.synapses:
syn.computeShortTermFacilitation2(t)
if (syn.type == 0): # the same group
if (syn.excitability == 0):
self.I_lower -= syn.weight * syn.strength
if (self.I_lower < -20): self.I_lower = -20
if (syn.type == 1): # pre and post neurons come from the same layer but not the same group
# if(syn.weight > 0):
#
# print('pre_neuron_group id:'+str(syn.pre.groupIndex) + ' neuron index:'+str(syn.pre.index))
# print('post_neuron_group id:'+str(self.groupIndex) + ' neuron index:'+str(self.index))
# print(syn.weight)
self.I_lower += syn.weight * syn.strength
# print('syn.strength='+str(syn.strength))
if (syn.type == 2): # pre and post neurons come from the different layers
# print(syn.pre.groupIndex)
# self.I_lower = syn.weight * syn.strength
self.I_upper += syn.weight * syn.strength
# if(self.I_upper == 0):
# self.I = self.I_ext
else:
self.I = self.I_lower + self.I_upper
class TitleIzhikevichNeuron(IzhikevichNeuron):
def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8,vthresh=30):
IzhikevichNeuron.__init__(self,a,b,c,d,vthresh)
class ComposerIzhikevichNeuron(IzhikevichNeuron):
def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8,vthresh=30):
IzhikevichNeuron.__init__(self, a,b,c,d,vthresh)
class GenreIzhikevichNeuron(IzhikevichNeuron):
def __init__(self, a = 0.1,b = 0.2,c = -65,d = 8, vthresh=30):
IzhikevichNeuron.__init__(self, a, b, c, d, vthresh)
class AmyIzhikevichNeuron(IzhikevichNeuron):
def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8,vthresh=30):
IzhikevichNeuron.__init__(self, a,b,c,d,vthresh)
class DirectionIzhikevichNeuron(IzhikevichNeuron):
def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8,vthresh=30):
IzhikevichNeuron.__init__(self,a,b,c,d,vthresh)
def setPreference(self):
# self.selectivity = 2 * math.pi/240 * self.index - math.pi/240
self.selectivity = (self.index + 1) * math.pi / 120
def computeFilterCurrent(self, input):
if (input < self.selectivity + math.pi / 240 and input >= self.selectivity - math.pi / 240):
self.I = self.I_ext = 30
def updateCurrentOfLowerAndUpperLayer(self, t):
self.I_lower = 0
self.I_upper = 0
for syn in self.synapses:
syn.computeShortTermFacilitation2(t)
if (syn.type == 0): # the same group
if (syn.excitability == 0):
self.I_lower -= syn.weight * syn.strength
if (self.I_lower < -20): self.I_lower = -20
if (syn.type == 1): # pre and post neurons come from the same layer but not the same group
#if(syn.weight > 0):In t
# print('pre_neuron_group id:'+str(syn.pre.groupIndex) + ' neuron index:'+str(syn.pre.index))
# print('post_neuron_group id:'+str(self.groupIndex) + ' neuron index:'+str(self.index))
# print('syn.strength=' + str(syn.strength))
# print('syn.weight='+ str(syn.weight))
self.I_lower += syn.weight * syn.strength
# print('syn.strength='+str(syn.strength))
# print('syn.weight='+ str(syn.weight))
if (syn.type >= 2): # pre and post neurons come from the different layers
# print(syn.pre.groupIndex)
self.I_upper += syn.weight * syn.strength
self.I = self.I_lower + self.I_upper
class GridIzhikevichCell(IzhikevichNeuron):
def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8,vthresh = 30):
IzhikevichNeuron.__init__(self, a,b,c,d,vthresh)
class KeyIzhikevichNeuron(IzhikevichNeuron):
def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8,vthresh = 30):
IzhikevichNeuron.__init__(self,a,b,c,d,vthresh)
class ModeIzhikevichNeuron(IzhikevichNeuron):
def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8,vthresh = 30):
IzhikevichNeuron.__init__(self,a,b,c,d,vthresh)
class ChordIzhikevichNeuron(IzhikevichNeuron):
def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8, vthresh = 30):
IzhikevichNeuron.__init__(self, a,b,c,d,vthresh)
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/layer.py
================================================
from abc import ABCMeta,abstractmethod
from conf.conf import configs
from Modal.cluster import *
class Layer():
'''
classdocs
'''
_metaclass_ = ABCMeta
def __init__(self, neutype):
'''
Constructor
'''
self.neutype = neutype
self.groups = {}
@abstractmethod
def resetProperties(self):
raise NotImplementedError
def addNewGroups(self, layerID, neunum):
raise NotImplementedError
class ModeLayer(Layer):
def __init__(self, neutype = 'LIF'):
self.neutype = neutype
self.groups = {}
def setTestStates(self):
for id, g in self.groups.items():
g.setTestStates()
def addNewGroups(self, groupID, layerID, neunum, modeName):
g = ModeCluster('Izhikevich', neunum)
g.id = groupID
g.name = modeName
g.createClusterNetwork(g.name)
g.setPropertiesofNeurons(groupID,'Mode',layerID)
self.groups[groupID-1] = g
class KeyLayer(Layer):
def __init__(self, neutype='LIF'):
self.neutype = neutype
self.groups = {}
def setTestStates(self):
for id, g in self.groups.items():
g.setTestStates()
def addNewGroups(self, groupID, layerID, neunum, key):
g = KeyCluster('Izhikevich', neunum)
g.id = groupID
g.name = configs.index2key.get(groupID-1)
g.createClusterNetwork(key,g.name)
g.setPropertiesofNeurons(groupID, 'Key', layerID)
self.groups[groupID-1] = g
class ChordLayer(Layer):
def __init__(self, neutype = 'LIF'):
Layer.__init__(self,neutype)
def setTestStates(self):
for id, g in self.groups.items():
g.setTestStates()
def addNewGroups(self,groupID, layerID, neunum):
g = ChordCluster('Izhikevich', neunum)
g.id = groupID
g.name = groupID
g.createClusterNetwork()
g.setPropertiesofNeurons(groupID, 'Chord', layerID)
self.groups[groupID - 1] = g
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/lifneuron.py
================================================
import torch
import random
from braincog.base.node import node
import numpy as np
class LIFNeuron(node.LIFNode):
def __init__(self, tau_ref = 0, vthresh = 5, Rm = 2, Cm = 0.2,dt = 0.1,*args, **kwargs):
super().__init__(threshold=vthresh, tau=Rm*Cm, dt=dt, *args, **kwargs)
self.layerType = 'S' # S:sequenceLayer, G: goal layer
self.layerIndex = 0 # the layer in which the neuron situated
self.groupIndex = 0 # the group in which the neuron situated
self.index = 0 # starting with 1
self.areaName = ''
self.pre_neurons = []
self.synapses = []
self.spiketime = []
self.type = 'exc'
self.tau_ref = tau_ref
self.tau_m = Rm*Cm
self.vth = vthresh
self.Rm = Rm
self.Cm = Cm
self.t_rest = 0
self.I = 0
self.spike = False
self.firingrate = 0 # Hz
self.I_ower = 0
self.I_upper = 0
self.I_ext = -100
self.timeWindow = 5 # ms
self.I_bg = random.randint(0, 10)
# self.state = 'Learn' # else test
self.selectivity = 0
self.preActive = False
def update(self, dt, t, state): # state = 'learn' or state = 'test'
if (state == 'Learn'):
self.update_learn(dt, t)
'''
#------------Gaussian selectivity--------------#
self.I = math.exp(-((self.I_ext-math.pi/16)/0.24)**2)
self.I = self.I if self.I >= 0.5 else 0
self.I *= 10
'''
elif (state == 'test'):
self.update_test(dt, t)
def update_learn(self, dt, t):
self.spike = False
# self.computeFilterCurrent()
'''
#------------Gaussian selectivity--------------#
self.I = math.exp(-((self.I_ext-math.pi/16)/0.24)**2)
self.I = self.I if self.I >= 0.5 else 0
self.I *= 10
'''
if (t >= self.t_rest):
self.mem += dt * (-self.mem + self.I * self.Rm) / self.tau_m
if (self.mem > self.vth):
self.spike = True
self.preActive = True
# print("groupID:"+ str(self.groupIndex) + ", neuronID:"+str(self.index))
self.spiketime.append(t)
self.mem = 0
self.t_rest = t + self.tau_ref
self.updateSynapses(t)
def update_test(self, dt, t):
self.spike = False
# self.updateCurrentOfLowerAndUpperLayer(t)
if (t >= self.t_rest):
self.mem += dt * (-self.mem + self.I * self.Rm) / self.tau_m
if (self.mem > self.vth):
self.spike = True
self.spiketime.append(t)
self.mem = 0
self.t_rest = t + self.tau_ref
def update_normal(self, dt, t):
self.spike = False
# self.I = self.I_ext
if (t >= self.t_rest):
self.mem += dt * (-self.mem + self.I * self.Rm) / self.tau_m
if (self.mem > self.vth):
self.spike = True
self.mem = 0
self.t_rest = t + self.tau_ref
self.spiketime.append(t)
def updateSynapses(self, t):
for syn in self.synapses:
syn.computeWeight(t)
def setTestStates(self):
self.t_rest = 0
self.spiketime = []
self.mem = 0
self.I = 0
self.I_ext = 0
for syn in self.synapses:
syn.strength = 0
# print('I=' + str(self.I))
def computeFilterCurrent(self):
pass
def setPreference(self): # set preference of a neuron or called selectivity
# self.selectivity = 2 * math.pi/16 * self.index - math.pi/16 # the mean of the Gaussian funtion
pass
def writeBasicInfoToJson(self, areaName):
dic = {}
dic["TrackID"] = self.layerIndex
dic["GroupID"] = self.groupIndex
dic["Index"] = self.index
dic["area"] = areaName
slist = []
for syn in self.synapses:
if (syn.weight <= 0): continue
tmp = {}
tmp["StrackID"] = syn.pre.layerIndex
tmp["SgroupID"] = syn.pre.groupIndex
tmp["Sindex"] = syn.pre.index
tmp["Sarea"] = syn.pre.areaName
tmp["TtrackID"] = self.layerIndex
tmp["TgroupID"] = self.groupIndex
tmp["Tindex"] = self.index
tmp["Tarea"] = self.areaName
tmp["type"] = syn.type
tmp["weight"] = syn.weight
slist.append(tmp)
dic["synapses"] = slist
return dic
def writeSpikeTimeToJson(self):
slist = []
for i, t in enumerate(self.spiketime):
dic = {}
dic[i + 1] = t
slist.append(dic)
return slist
# neu = LIFNeuron()
# dt = 0.001
# T = 1
# time = np.arange(0,T,dt)
# spikes = np.zeros(len(time))
# for i in range(0,len(time)):
# if(i == 22):
# print("debug")
# neu.I = 84.49
# neu.update_normal(dt, time[i])
# if(neu.spike == True):
# spikes[i] = 1
# #spikes[i] = neu.mem
# print(len(neu.spiketime))
# pl.plot(time,spikes)
# pl.show()
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/note.py
================================================
'''
Created on 2016.7.6
@author: liangqian
'''
from Modal.pitch import Pitch
class Note():
'''
Because a chord consist of more than two pitches at the same time, so using
arrays to record the chord
'''
def __init__(self):
self.pitches = []
# self.startTime = []
# self.endTime = []
self.lastTime = []
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/notecluster.py
================================================
from .cluster import Cluster
from .notelifneuron import NoteLIFNeuron
from Modal.izhikevichneuron import *
class NoteCluster(Cluster):
'''
classdocs
'''
def __init__(self, neutype, neunum):
'''
Constructor
'''
Cluster.__init__(self, neutype, neunum)
def createClusterNetwork(self):
for i in range(0, self.neunum):
if (self.neutype == 'LIF'):
node = NoteLIFNeuron()
node.index = i + 1
node.areaName = 'NMSM'
node.setPreference()
self.neurons.append(node)
if (self.neutype == 'Izhikevich'):
node = NoteIzhikevichNeuron()
node.index = i + 1
node.areaName = 'NMSM'
node.setPreference()
self.neurons.append(node)
# if(self.neutype == 'Izhi'):
# node = IzhikevichNeuron(a = 0.02,b = 0.2,c = -65,d = 8,vthresh = 30)
# node.index = i
# self.neurons.append(node)
# if(self.neutype == 'Gaussian'):
# node = GaussianNeuron()
# node.index = i+1
# self.neurons.append(node)
# if(self.neutype == 'HH'):
# node = HHNeuron()
# node.index = i
# self.neurons.append(node)
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/notelifneuron.py
================================================
from .lifneuron import LIFNeuron
class NoteLIFNeuron(LIFNeuron):
'''
classdocs
'''
def __init__(self, tau_ref=0.5, vthresh=5, Rm=2, Cm=0.2):
'''
Constructor
'''
LIFNeuron.__init__(self, tau_ref, vthresh, Rm, Cm)
def setPreference(self):
self.selectivity = self.index - 2
def computeFilterCurrent(self):
if (self.I_ext == self.selectivity):
self.I = 10
def updateCurrentOfLowerAndUpperLayer(self, t):
self.I_lower = 0
self.I_upper = 0
for syn in self.synapses:
syn.computeShortTermFacilitation(t)
if (syn.type == 0): # the same group
if (syn.excitability == 0):
self.I_lower -= syn.weight * syn.strength
if (self.I_lower < -20): self.I_lower = -20
if (syn.type == 1): # pre and post neurons come from the same layer but not the same group
# if(syn.weight > 0):
# print('pre_neuron_group id:'+str(syn.pre.groupIndex) + ' neuron index:'+str(syn.pre.index))
# print('post_neuron_group id:'+str(self.groupIndex) + ' neuron index:'+str(self.index))
# print('syn.strength=' + str(syn.strength))
# print('syn.weight='+ str(syn.weight))
self.I_lower += 0.001 * syn.weight * syn.strength
# print('syn.strength='+str(syn.strength))
# print('syn.weight='+ str(syn.weight))
if (syn.type == 2): # pre and post neurons come from the different layers
# print(syn.pre.groupIndex)
self.I_upper += 0.001 * syn.weight * syn.strength
# if(self.I_upper == 0):
# self.I = self.I_ext
self.I = 0.4 * self.I_lower + 0.6 * self.I_upper
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/notesequencelayer.py
================================================
from .sequencelayer import SequenceLayer
from .notecluster import NoteCluster
from .synapse import Synapse
class NoteSequenceLayer(SequenceLayer):
'''
classdocs
'''
def __init__(self, neutype):
'''
Constructor
'''
SequenceLayer.__init__(self, neutype)
def addNewGroups(self, GroupID, layerID, neunum):
g = NoteCluster(self.neutype, neunum)
g.createClusterNetwork()
# g.createInhibitoryConnections()
g.id = GroupID
g.setPropertiesofNeurons(g.id, 'S', layerID)
self.groups[g.id] = g
# create full connection with the former group
if (len(self.groups) > 1):
s = 0
if (g.id <= 5):
s = 1
else:
s = g.id - 4
# for i in range(1,g.id)[::-1]:
for i in range(s, g.id)[::-1]:
pre_g = self.groups.get(i)
for n1 in pre_g.neurons:
for n2 in g.neurons:
if (n1.type == 'inh' or n2.type == 'inh'): continue;
syn = Synapse(n1, n2)
syn.type = 1
syn.delay = g.id - pre_g.id
n2.pre_neurons.append(n1)
n2.synapses.append(syn)
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/pitch.py
================================================
'''
Created on 2018.8.29
@author: liangqian
'''
class Pitch():
'''
classdocs
'''
def __init__(self):
'''
Constructor
'''
self.name = ''
self.frequence = 0 #midi index number just now
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/sequencelayer.py
================================================
from .layer import Layer
from .cluster import Cluster
from .synapse import Synapse
class SequenceLayer(Layer):
'''
This class mainly stores the musical sequential elements, including pitches and durations
'''
def __init__(self, neutype='LIF'):
'''
Constructor
'''
self.type = ""
self.neutype = neutype
self.groups = {}
def addNewGroups(self, GroupID, layerID, neunum):
g = Cluster(self.neutype, neunum)
g.createClusterNetwork()
g.id = GroupID
g.setPropertiesofNeurons(g.id, 'S', layerID)
self.groups[g.id] = g
# create full connection with the former group
if (len(self.groups) > 1):
for i in range(1, g.id)[::-1]:
pre_g = self.groups.get(i)
for n1 in pre_g.neurons:
for n2 in g.neurons:
if (n1.type == 'inh' or n2.type == 'inh'): continue;
syn = Synapse(n1, n2)
syn.type = 1
syn.delay = g.id - pre_g.id
n2.pre_neurons.append(n1)
n2.synapses.append(syn)
def setTestStates(self):
for gid, g in self.groups.items():
g.setTestStates()
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/sequencememory.py
================================================
from .synapse import Synapse
from Modal.sequencelayer import SequenceLayer
import numpy as np
from Modal.synapse import Synapse
class SequenceMemory():
'''
classdocs
'''
def __init__(self, neutype):
'''
Constructor
'''
self.neutype = neutype
self.sequenceLayers = {}
def createActionSequenceMem(self, layernum, neutype, neunumpergroup):
pass
def doRemembering(self):
pass
def doConnecting(self, goal, sl, order):
# the goal and the group always generate spikes in a limit time window,create a synapse between them.
group = sl.groups.get(order)
if (group == None): return
tb = (order - 1) * group.timeWindow
te = (order) * group.timeWindow
sp1_goal = {}
sp2 = []
for n in goal.neurons:
sp = []
for st in n.spiketime:
if (st < te and st >= tb):
sp.append(st)
sp1_goal[n.index] = sp
for n in group.neurons:
if (len(n.spiketime) > 0):
for index, sp in sp1_goal.items():
temp = 0
for sp1 in n.spiketime: # spike times of group
for sp2 in sp:
if (abs(sp1 - sp2) <= n.tau_ref):
temp += 1
if (
temp >= 4): # super threshold, create a new synapse between goal and neurons of sequence group
syn = Synapse(goal.neurons[index - 1], n)
syn.type = 2
syn.weight = 3
n.pre_neurons.append(goal.neurons[index - 1])
n.synapses.append(syn)
# add reverse synapse to neurons of the goal
syn2 = Synapse(n, goal.neurons[index - 1])
syn2.type = 2
syn2.weight = 1
goal.neurons[index - 1].synapses.append(syn2)
# clear the goal's spike time
''' *************************************************************
I have forgot why neurons here needs to be cleaned, but this must be important, mark here
for n in goal.neurons:
n.spiketime = []
******************************************************************
'''
# def doConnectToGoal(self,goal,track,order): # connect to the goal in the time window
#
# for sl in track.values():
# group = sl.groups.get(order)
#
# if(group == None): continue
# # the goal and the group always generate spikes in a limit time window,create a synapse between them.
# tb = (order-1)*group.timeWindow
# te = (order) * group.timeWindow
# sp1_goal = {}
# sp2 = []
#
# for n in goal.neurons:
# sp = []
# for st in n.spiketime:
# if(st < te and st >= tb ):
# sp.append(st)
# sp1_goal[n.index] = sp
#
# for n in group.neurons:
# if(len(n.spiketime) > 0):
# for index,sp in sp1_goal.items():
# temp = 0
# for sp1 in n.spiketime: #spike times of group
# for sp2 in sp:
# if(abs(sp1-sp2) <= n.tau_ref):
# temp += 1
# if(temp >= 4): # super threshold, create a new synapse between goal and neurons of sequence group
# syn = Synapse(goal.neurons[index-1],n)
# syn.type = 2
# syn.weight = 3
# n.pre_neurons.append(goal.neurons[index-1])
# n.synapses.append(syn)
#
# #add reverse synapse to neurons of the goal
# syn2 = Synapse(n,goal.neurons[index-1])
# syn2.type = 2
# syn2.weight = 1
# goal.neurons[index-1].synapses.append(syn2)
#
# #clear the goal's spike time
# ''' *************************************************************
# I have forgot why neurons here needs to be cleaned, but this must be important, mark here
# for n in goal.neurons:
# n.spiketime = []
# ******************************************************************
# '''
#
# def doConnectToComposer(self, composer, track, order):
# for sl in track.values():
# group = sl.groups.get(order)
#
# if(group == None): continue
# # the goal and the group always generate spikes in a limit time window,create a synapse between them.
# tb = (order-1)*group.timeWindow
# te = (order) * group.timeWindow
# sp1_composer = {}
# sp2 = []
#
# for n in composer.neurons:
# sp = []
# for st in n.spiketime:
# if(st < te and st >= tb ):
# sp.append(st)
# sp1_composer[n.index] = sp
#
# for n in group.neurons:
# if(len(n.spiketime) > 0):
# for index,sp in sp1_composer.items():
# temp = 0
# for sp1 in n.spiketime: #spike times of group
# for sp2 in sp:
# if(abs(sp1-sp2) <= n.tau_ref):
# temp += 1
# if(temp >= 4): # super threshold, create a new synapse between composer and neurons of sequence group
# syn = Synapse(composer.neurons[index-1],n)
# syn.type = 2
# syn.weight = 3
# n.pre_neurons.append(composer.neurons[index-1])
# n.synapses.append(syn)
#
# #add reverse synapse to neurons of the goal
# # syn2 = Synapse(n,goal.neurons[index-1])
# # syn2.type = 2
# # syn2.weight = 1
# # goal.neurons[index-1].synapses.append(syn2)
#
#
# def doConnectToGenre(self, genre, track, order):
# for sl in track.values():
# group = sl.groups.get(order)
#
# if(group == None): continue
# # the goal and the group always generate spikes in a limit time window,create a synapse between them.
# tb = (order-1)*group.timeWindow
# te = (order) * group.timeWindow
# sp1_genre = {}
# sp2 = []
#
# for n in genre.neurons:
# sp = []
# for st in n.spiketime:
# if(st < te and st >= tb ):
# sp.append(st)
# sp1_genre[n.index] = sp
#
# for n in group.neurons:
# if(len(n.spiketime) > 0):
# for index,sp in sp1_genre.items():
# temp = 0
# for sp1 in n.spiketime: #spike times of group
# for sp2 in sp:
# if(abs(sp1-sp2) <= n.tau_ref):
# temp += 1
# if(temp >= 4): # super threshold, create a new synapse between composer and neurons of sequence group
# syn = Synapse(genre.neurons[index-1],n)
# syn.type = 2
# syn.weight = 3
# n.pre_neurons.append(genre.neurons[index-1])
# n.synapses.append(syn)
def setTestStates(self):
for itrack in self.sequenceLayers.values():
for sl in itrack.values():
sl.setTestStates()
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/synapse.py
================================================
import math
class Synapse():
'''
classdocs
'''
def __init__(self, pre, post):
'''
Constructor
'''
self.type = 0 # 0: within a group; 1: different groups in the same layer; 2: other layer
self.pre = pre
self.post = post
self.weight = 0
self.excitability = 1 # 1:excited connection; 0:inhibitory connection
self.strength = 0 # short term depression and facilitation factor
self.delay = 0 # time delay of transmission between pre and post
def computeWeight(self, t):
if (self.type == 0): # pre and post neurons are in the same group
pass
elif (self.type == 1): # pre and post neurons are in the same layer but different groups
for st in self.pre.spiketime:
s = t - st - (self.delay-1)*self.post.timeWindow
temp = 0
if (self.post.groupIndex - self.pre.groupIndex == self.delay): # compute weight according to time delay
# using STDP rules
if (s >= 0):
temp = math.exp(-s / 5)
else:
# print(self.pre.groupIndex)
# print(self.post.groupIndex)
temp = -math.exp(s / 5)
self.weight += temp
elif (self.type == 2): # pre and post neurons are in the different layers
pass
# computing the STDP to update the weight within the time window
elif (self.type == 3):
# pass #fixed weight
for st in self.pre.spiketime:
s = t - st - (self.delay - 1) * self.post.timeWindow
temp = 0
# using STDP rules
if (s >= 0):
temp = 5 * math.exp(-s / 5)
else:
# print(self.pre.groupIndex)
# print(self.post.groupIndex)
temp = -5 * math.exp(s / 5)
self.weight += temp
def computeShortTermFacilitation(self, t):
if (self.type == 1):
for st in self.pre.spiketime[::-1]:
at = st + self.delay
# if (at <= t and at >= t - self.post.tau_ref): # between current time and time minus refractory period
if (at <= t and at >= t):
temp = (self.strength + 1) * 0.2
self.strength += temp
elif (self.type == 2):
# print(self.pre.areaName)
# print(self.pre.index)
# print(self.pre.groupIndex)
for st in self.pre.spiketime:
if (st <= t and st >= t - self.post.tau_ref):
temp = (self.strength + 1) * 0.5
self.strength = self.strength + temp
elif (self.type == 0):
if (self.excitability == 0):
for st in self.pre.spiketime:
self.strength += (self.strength + 1) * 0.8
def computeShortTermFacilitation2(self, t):
if (self.type == 1):
for st in self.pre.spiketime[::-1]:
at = st + self.delay
# if ( at <= t and at >= t - self.post.tau_ref):
if (at <= t): # between current time and time minus refractory period
self.strength = 1
elif (self.type >= 2):
# print(self.pre.areaName)
# print(self.pre.index)
# print(self.pre.groupIndex)
for st in self.pre.spiketime:
# if (st <= t and st >= t - self.post.tau_ref):
if (st <= t):
self.strength = 1
def computeShortTermReduction(self, t):
if (self.type == 2):
for st in self.pre.spiketime[::-1]:
if (t - st > self.post.timeWindow):
self.strength -= (self.strength + 1) * 0.5
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/tempocluster.py
================================================
from .cluster import Cluster
from .tempolifneuron import TempoLIFNeuron
from Modal.izhikevichneuron import *
class TempoCluster(Cluster):
'''
classdocs
'''
def __init__(self, neutype, neunum):
'''
Constructor
'''
Cluster.__init__(self, neutype, neunum)
def createClusterNetwork(self):
for i in range(0, self.neunum):
if (self.neutype == 'LIF'):
node = TempoLIFNeuron()
node.index = i + 1
node.areaName = 'TMSM'
node.setPreference()
self.neurons.append(node)
if (self.neutype == 'Izhikevich'):
node = TempoIzhikevichNeuron()
node.index = i + 1
node.areaName = 'TMSM'
node.setPreference()
self.neurons.append(node)
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/tempolifneuron.py
================================================
from .lifneuron import LIFNeuron
import math
class TempoLIFNeuron(LIFNeuron):
'''
classdocs
'''
def __init__(self, tau_ref=0.5, vthresh=5, Rm=2, Cm=0.2):
'''
Constructor
'''
LIFNeuron.__init__(self, tau_ref, vthresh, Rm, Cm)
def setPreference(self):
# Gaussian function to set selectivity
self.selectivity = self.index * 0.125 #
def computeFilterCurrent(self):
if (self.I_ext <= self.selectivity + 0.0625 and self.I_ext >= self.selectivity - 0.0625):
self.I = 10
def updateCurrentOfLowerAndUpperLayer(self, t):
self.I_lower = 0
self.I_upper = 0
for syn in self.synapses:
syn.computeShortTermFacilitation(t)
if (syn.type == 0): # the same group
if (syn.excitability == 0):
self.I_lower -= 0.001 * syn.weight * syn.strength
if (self.I_lower < -20): self.I_lower = -20
if (syn.type == 1): # pre and post neurons come from the same layer but not the same group
# if(syn.weight > 0):
#
# print('pre_neuron_group id:'+str(syn.pre.groupIndex) + ' neuron index:'+str(syn.pre.index))
# print('post_neuron_group id:'+str(self.groupIndex) + ' neuron index:'+str(self.index))
# print(syn.weight)
self.I_lower += 0.001 * syn.weight * syn.strength
# print('syn.strength='+str(syn.strength))
if (syn.type == 2): # pre and post neurons come from the different layers
# print(syn.pre.groupIndex)
# self.I_lower = syn.weight * syn.strength
self.I_upper += 0.001 * syn.weight * syn.strength
# if(self.I_upper == 0):
# self.I = self.I_ext
else:
self.I = 0.4 * self.I_lower + 0.6 * self.I_upper
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/temposequencelayer.py
================================================
from .sequencelayer import SequenceLayer
from .tempocluster import TempoCluster
from .synapse import Synapse
class TempoSequenceLayer(SequenceLayer):
'''
classdocs
'''
def __init__(self, neutype):
'''
Constructor
'''
SequenceLayer.__init__(self, neutype)
def addNewGroups(self, GroupID, layerID, neunum):
g = TempoCluster(self.neutype, neunum)
g.createClusterNetwork()
# g.createInhibitoryConnections()
g.id = GroupID
g.setPropertiesofNeurons(g.id, 'S', layerID)
self.groups[g.id] = g
# create full connection with the former group
if (len(self.groups) > 1):
s = 0
if (g.id <= 5):
s = 1
else:
s = g.id - 4
for i in range(s, g.id)[::-1]:
pre_g = self.groups.get(i)
for n1 in pre_g.neurons:
for n2 in g.neurons:
if (n1.type == 'inh' or n2.type == 'inh'): continue;
syn = Synapse(n1, n2)
syn.type = 1
syn.delay = g.id - pre_g.id
n2.pre_neurons.append(n1)
n2.synapses.append(syn)
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/titlecluster.py
================================================
from .cluster import Cluster
from .titlelifneuron import TitleLIFNeuron
class TitleCluster(Cluster):
'''
classdocs
'''
def __init__(self, neutype, neunum):
'''
Constructor
'''
Cluster.__init__(self, neutype, neunum)
self.averageFiringRate = 0
def createClusterNetwork(self):
for i in range(0, self.neunum):
if (self.neutype == 'LIF'):
node = TitleLIFNeuron()
node.index = i + 1
node.areaName = 'IPS'
self.neurons.append(node)
# if(self.neutype == 'Izhi'):
# node = IzhikevichNeuron(a = 0.02,b = 0.2,c = -65,d = 8,vthresh = 30)
# node.index = i
# self.neurons.append(node)
# if(self.neutype == 'Gaussian'):
# node = GaussianNeuron()
# node.index = i+1
# self.neurons.append(node)
# if(self.neutype == 'HH'):
# node = HHNeuron()
# node.index = i
# self.neurons.append(node)
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/titlelayer.py
================================================
from .layer import Layer
from .titlecluster import TitleCluster
class TitleLayer(Layer):
'''
classdocs
'''
def __init__(self, neutype='LIF'):
'''
Constructor
'''
self.neutype = neutype
self.groups = {}
def setTestStates(self):
for id, g in self.groups.items():
g.setTestStates()
def addNewGroups(self, groupID, layerID, neunum, goalname):
g = TitleCluster(self.neutype, neunum)
g.id = groupID
g.name = goalname
g.createClusterNetwork()
g.setPropertiesofNeurons(groupID, 'G', layerID)
self.groups[goalname] = g
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/titlelifneuron.py
================================================
from .lifneuron import LIFNeuron
import math
class TitleLIFNeuron(LIFNeuron):
'''
classdocs
'''
def __init__(self, tau_ref=0, vthresh=5, Rm=2, Cm=0.2):
'''
Constructor
'''
LIFNeuron.__init__(self, tau_ref, vthresh, Rm, Cm)
def updateCurrentOfLowerAndUpperLayer(self, t):
self.I_lower = 0
self.I_upper = 0
for syn in self.synapses:
syn.computeShortTermFacilitation(t)
syn.computeShortTermReduction(t)
if (syn.type == 2): # pre and post neurons come from the different layers
self.I_lower += syn.weight * syn.strength
if (self.I_lower <= 0):
self.I = self.I_lower
else:
self.I = math.log(self.I_lower)
def update(self, dt, t):
self.spike = False
# self.updateCurrentOfLowerAndUpperLayer(t)
if (t >= self.t_rest):
self.mem += dt * (-self.mem + self.I * self.Rm) / self.tau_m
if (self.mem > self.vth):
self.spike = True
self.spiketime.append(t)
self.mem = 0
self.t_rest = t + self.tau_ref
def computeFiringRate(self):
if (self.I == 0):
self.firingrate = 0
else:
self.firingrate = 1 / (self.tau_ref + self.Rm * self.Cm * math.log(self.I / (self.I - self.vth)))
self.firingrate *= 1000
self.firingrate = round(self.firingrate)
# print(self.firingrate)
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/README.md
================================================
# Music Memory and stylistic composition
This repository contains code from our paper:
- [**Temporal-Sequential Learning With a Brain-Inspired Spiking Neural Network and Its Application to Musical Memory**](https://www.cell.com/patterns/fulltext/S2666-3899(22)00119-2) published in Frontiers in Computational Neuroscience. **https://www.cell.com/patterns/fulltext/S2666-3899(22)00119-2**,
- [**Stylistic composition of melodies based on a brain-inspired spiking neural network**](https://www.frontiersin.org/articles/10.3389/fnsys.2021.639484/full) published in Frontiers in Systems Neuroscience **https://www.frontiersin.org/articles/10.3389/fnsys.2021.639484/full**.
- [**Mode-conditioned music learning and composition: a spiking neural network inspired by neuroscience and psychology**](https://arxiv.org/pdf/2411.14773) preprint in arXiv **https://arxiv.org/pdf/2411.14773**.
## Requirments
* numpy
* scipy
* pytorch >= 1.7.0
* pretty_midi >= 0.2.9
* music21
## Data preparation
The dataset used here can be referred to the website http://www.piano-midi.de/.
The dataset used in mode-conditioned music learning can be referred to the website https://github.com/lqnankai/Music-Dataset.
## Run
* Run the script *task/musicMemory.py* to memorize and recall the musical melodies, the result will be recorded in a midi file.
* Run the script *task/musicGeneration.py* to learn and generate melodies with different styles, the result will be recorded in a midi file.
* Run the script *task/mode-conditioned learning.py* to learn and generate melodies with different mode and keys, the result will be recorded in a midi file.
The API and details can be found in these scripts.
## Citation
If you find this package helpful, please consider citing the following papers:
```BibTex
@article{LQ2020,
author = {Liang, Qian and Zeng, Yi and Xu, Bo},
year = {2020},
month = {07},
pages = {51},
title = {Temporal-Sequential Learning With a Brain-Inspired Spiking Neural Network and Its Application to Musical Memory},
volume = {14},
journal = {Frontiers in Computational Neuroscience}
}
@article{LQ2021,
title = {Stylistic composition of melodies based on a brain-inspired spiking neural network},
author = {Liang, Qian and Zeng, Yi},
journal = {Frontiers in systems neuroscience},
volume = {15},
pages = {21},
year = {2021},
publisher = {Frontiers}
}
@misc{liang2024modeconditionedmusiclearningcomposition,
title={Mode-conditioned music learning and composition: a spiking neural network inspired by neuroscience and psychology},
author={Qian Liang and Yi Zeng and Menghaoran Tang},
year={2024},
eprint={2411.14773},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2411.14773},
}
```
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/api/music_engine_api.py
================================================
from conf.conf import *
from Areas.cortex import Cortex
import pretty_midi
import math
import json
import music21 as m21
class EngineAPI():
'''
'''
def __init__(self):
'''
Constructor
'''
self.cortex = Cortex(configs.neuron_type, configs.dt)
def cortexInit(self):
self.cortex.musicSequenceMemroyInit()
self.cortex.pfc.addNewKey()
self.cortex.pfc.addNewMode()
self.cortex.pfc.addNewChord()
def rememberMusic(self, muiscName, composerName="None"):
'''
:param muiscName: the name of the melody
:param composerName: the composer
:return:
'''
muiscName = muiscName.title()
composerName = composerName.title()
self.cortex.pfc.setTestStates()
self.cortex.msm.setTestStates()
self.cortex.addSubGoalToPFC(muiscName)
self.cortex.addComposerToPFC(composerName)
genreName = str(configs.GenreMap.get(composerName))
self.cortex.addGenreToPFC(genreName)
self.cortex.pfc.innerLearning(muiscName, composerName, genreName)
goaldic = {}
composerdic = {}
genredic = {}
if (configs.RunTimeState == 1):
g = self.cortex.pfc.titles.groups.get(muiscName)
c = self.cortex.pfc.composers.groups.get(composerName)
gre = self.cortex.pfc.genres.groups.get(genreName)
goaldic = g.writeSelfInfoToJson("IPS")
composerdic = c.writeSelfInfoToJson("Composer")
genredic = gre.writeSelfInfoToJson("Genre")
return goaldic, composerdic
def learnFourPartMusic(self,xmldata, musicName, composerName="None"):
musicName = musicName.title()
composerName = composerName.title()
genreName = "None"
toneName = configs.keyIndexMap.get(str(xmldata.analyze('key')))
print(musicName + " learning...")
emo = "None"
for i, part in enumerate(xmldata.parts):
if (self.cortex.msm.sequenceLayers.get(i + 1) == None):
self.cortex.msm.createActionSequenceMem(i + 1, self.cortex.neutype)
self.rememberPartNotes(musicName, composerName, genreName, emo, toneName, i + 1, part)
def rememberPartNotes(self,musicName, composerName, genreName, emo, keyName, partIndex, part):
print("Learning the part "+str(partIndex))
for i,note in enumerate(part.flat.notes[:20]):
p = 0
dur = 0
if note.isChord:
dur = note.duration.quarterLength
for chord_note in note:
p = m21.pitch.Pitch(chord_note.pitch).midi
else:
dur = note.duration.quarterLength
p = m21.pitch.Pitch(note.pitch).midi
if dur == 0.0:
dur = 0.125
if keyName == 'None':
self.cortex.rememberANoteandTempo(musicName, composerName, genreName, emo, partIndex, p, i+1, dur)
else:
self.cortex.rememberANoteWithKnowledge(musicName, composerName, genreName, emo, keyName, partIndex, p, dur, i+1, part)
def rememberMIDIMusic(self, musicName, composerName, noteLength, fileName):
'''
:param musicName: the name of the piece of music
:param composerName: the composer who writes this melody
:param fileName: the name of this midi file
:return: none
'''
musicName = musicName.title()
composerName = composerName.title()
print(musicName + " processing...")
pm = pretty_midi.PrettyMIDI(fileName)
genreName = str(configs.GenreMap.get(composerName))
for i, ins in enumerate(pm.instruments):
if (i >= 1): break;
if (self.cortex.msm.sequenceLayers.get(i + 1) == None):
# create a new layer to store the track
self.cortex.msm.createActionSequenceMem(i + 1, self.cortex.neutype)
self.rememberTrackNotes(musicName, composerName, genreName, i + 1, ins, pm, noteLength)
print(musicName + " finished!")
def rememberTrackNotes(self, musicName, composerName, genreName, trackIndex, track, pm, noteLength):
r_notes = []
r_intervals = []
total_dic = {}
print(track)
if(noteLength == "ALL"):
noteLength = len(track.notes)
order = 1
i = 0
#while (i < len(track.notes)):
while (i < noteLength):
#if (i >= rl): break;
note = track.notes[i]
start = pm.time_to_tick(note.start)
end = pm.time_to_tick(note.end)
pitches = []
durations = []
restFlag = False
# this part recognizes a rest
if (i == 0): # determine whether the first note is a rest
if (start >= 30):
pitches.append(-1) # -1 represents a rest
durations.append(start / pm.resolution)
restFlag = True
else:
lastend = pm.time_to_tick(track.notes[i - 1].end)
if (start - lastend >= 50):
pitches.append(-1)
durations.append((start - lastend) / pm.resolution)
restFlag = True
if (restFlag == True):
dic, g = self.rememberANote(musicName, composerName, genreName, trackIndex, pitches[0], order,
durations[0], True)
if (configs.RunTimeState == 1):
jstr = json.dumps(g)
self.conn.send('/Queue/SampleQueue', jstr)
#print(str(order) + ":(-1," + str(durations[0]) + ")")
order = order + 1
pitches = []
durations = []
# this part recognizes a chord
pitches.append(note.pitch)
durations.append((end - start) / pm.resolution)
j = i + 1
while (j < len(track.notes)):
nextstart = pm.time_to_tick(track.notes[j].start)
nextend = pm.time_to_tick(track.notes[j].end)
# if(start == nextstart or end > nextstart):
if (math.fabs(start - nextstart) <= 30 or end - nextstart >= 30):
pitches.append(track.notes[j].pitch)
durations.append((nextend - nextstart) / pm.resolution)
j = j + 1
else:
break
i = j
if (i < noteLength):
dic, g = self.rememberANote(musicName, composerName, genreName, trackIndex, pitches[0], order,
durations[0], True)
str1 = str(order) + ":("
for t in range(len(pitches)):
str1 += str(pitches[t]) + "," + str(durations[t]) + ";"
#print(str1 + ")")
order = order + 1
if (configs.RunTimeState == 1):
jstr = json.dumps(g)
self.conn.send('/Queue/SampleQueue', jstr)
nlist = dic.get('MSMSpike')
ns = []
for l in nlist:
n = l.get('Index')
ns.append(n)
r_notes.append(ns)
tlist = dic.get('MSMTSpike')
ts = []
for l in tlist:
t = l.get('Index')
ts.append(t * 60)
r_intervals.append(ts)
return total_dic
def rememberNotes(self, MusicName, notes, intervals, tempo=True):
jStr = ''
# print(intervals)
notesStr = notes.split(",")
intervalsStr = intervals.split(",")
intervaltimes = []
for i in range(len(intervalsStr) - 1):
intervaltimes.append(int(intervalsStr[i]))
print(intervaltimes)
for i, note in enumerate(notesStr):
note = int(note)
if (i < len(notesStr) - 1):
tinterval = intervalsStr[i]
tinterval = int(intervalsStr[i])
self.rememberANote(MusicName, note, i + 1, tinterval, tempo)
return jStr
def rememberANote(self, MusicName, ComposerName, genreName, TrackIndex, NoteIndex, order, tinterval, tempo=False):
if (tempo == False):
dic = self.cortex.rememberANote(MusicName, NoteIndex, order)
jsonStr = json.dumps(dic)
return jsonStr
else:
dic, g = self.cortex.rememberANoteandTempo(MusicName, ComposerName, genreName, TrackIndex, NoteIndex, order,
tinterval)
return dic, g
def memorizing(self,MusicName, ComposerName, noteLength, fileName):
'''
:param musicName: the name of the piece of music
:param composerName: the composer who writes this melody
:param noteLength: the number of notes to be trained(integer), if you want to learn all the notes of a musical work, the value should be specified as "ALL"
:param fileName: the path and the name of this midi file
:return: none
'''
self.rememberMusic(MusicName, ComposerName)
self.rememberMIDIMusic(MusicName,ComposerName,noteLength, fileName)
def recallMusic(self, musicName):
print("Recall the " + musicName + " ......")
musicName = musicName.title()
result = self.cortex.recallMusicPFC(musicName)
#print(result)
noteResult = {}
for tindex,track in result.items():
ns = track.get('N')
ts = track.get('T')
tmp = []
for key in ns.keys():
dic = {}
dic['N']=ns.get(key)
dic['T']=ts.get(key)
tmp.append(dic)
noteResult[tindex] = tmp
self.writeMidiFile(musicName+"_recall",noteResult)
print("Recall " + musicName + " finished!")
return noteResult
def generateEx_Nihilo(self, firstNote, durations, length,gen_fName):
'''
parameters:
fistNote: Specify the beginning notes to generate a note
durations: Specify the duration of the beginning notes
length: the length of the generated music, less than 50 notes
'''
print("Generate melody with no style............")
result = self.cortex.generateEx_Nihilo2(firstNote, durations, length)
self.writeMidiFile(gen_fName,result)
print("Generating finished!")
return result
def generateEx_NihiloAccordingToGenre(self, genreName, firstNote, durations, length,gen_fName):
'''
parameters:
genreName:Specify the style of genre of the generated melody, for example: Baroque,Classical,Romantic
fistNote: Specify the beginning notes to generate a note
durations: Specify the duration of the beginning notes
length: the length of the generated music, less than 50 notes
'''
print("Generate melody with "+ genreName+"\'s style............")
result = self.cortex.generateEx_NihiloAccordingToGenre(genreName, firstNote, durations, length)
self.writeMidiFile(gen_fName,result)
print("Generating finished!")
return result
def generateEx_NihiloAccordingToComposer(self, composerName, firstNote, durations, length,gen_fName):
'''
parameters:
composerName:Specify the style of composer of the generated melody, for example: Bach, Mozart and etc.
fistNote: Specify the beginning notes to generate a note
durations: Specify the duration of the beginning notes
length: the length of the generated music, less than 50 notes
'''
print("Generate melody with " + composerName + "'s style............")
result = self.cortex.generateEx_NihiloAccordingToComposer(composerName, firstNote, durations, length)
self.writeMidiFile(gen_fName,result)
print("Generating finished!")
return result
def generate2TrackMusic(self, firstNotes, durations, lengths):
result = self.cortex.generate2TrackMusic(firstNotes, durations, lengths)
return result
def generateMelodyWithKey(self,tone, firstNotes,durations = None,length = 8):
result = self.cortex.generateMelodyWithKey(tone, firstNotes,durations,length)
return result
def writeMidiFile(self,fileName, mudic):
'''
mudic format description:
mudic = {1:[{'N':71,'T':0.5}.....],
2:[{'N':60,'T':0.25}.....],
....
}
'''
fileName += ".mid"
pm = pretty_midi.PrettyMIDI()
# Create an Instrument instance for a cello instrument
for values in mudic.values():
piano = pretty_midi.Instrument(program=0)
# Iterate over note names, which will be converted to note number later
start = 0
end = 0
for i, n in enumerate(values):
# Retrieve the MIDI note number for this note name
# note_number = pretty_midi.note_name_to_number(note_name)
# Create a Note instance, starting at 0s and ending at .5s
end = start + n.get('T')
note_name = n.get('N')
if (note_name == -1):
note = pretty_midi.Note(
velocity=0, pitch=0, start=start, end=end)
else:
note = pretty_midi.Note(
velocity=100, pitch=note_name, start=start, end=end)
# Add it to our cello instrument
piano.notes.append(note)
start = end
# Add the cello instrument to the PrettyMIDI object
pm.instruments.append(piano)
# Write out the MIDI data
pm.write(fileName)
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/conf/GenreData.txt
================================================
Baroque:Bach
Classical:Haydn,Mozart,Beethoven,Schubert,Clementi
Romantic:Mendelssohn,Liszt,Chopin,Schumann,Brahms,Burgmueller,Debussy,Godowsky,Moszkowski,Mussorgsky,Rachmaninov,Ravel,Tchaikovsky,Albéniz,Balakirew,Borodin,Granados,Grieg,Sinding
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/conf/MIDIData.txt
================================================
-1:rest
0:C3
1:C sharp3/D flat3
2:D3
3:D sharp3/E flat3
4:E3
5:F3
6:F sharp3/G flat3
7:G3
8:G sharp3/A flat3
9:A3
10:A sharp3/B flat3
11:B3
12:C2
13:C sharp2/D flat2
14:D2
15:D sharp2/E flat2
16:E2
17:F2
18:F sharp2/G flat2
19:G2
20:G sharp2/A flat2
21:A2
22:A sharp2/B flat2
23:B2
24:C1
25:C sharp1/D flat1
26:D1
27:D sharp1/E flat1
28:E1
29:F1
30:F sharp1/G flat1
31:G1
32:G sharp1/A flat1
33:A1
34:A sharp1/B flat1
35:B1
36:C
37:C sharp/D flat
38:D
39:D sharp/E flat
40:E
41:F
42:F sharp/G flat
43:G
44:G sharp/A flat
45:A
46:A sharp/B flat
47:B
48:c
49:c sharp/d flat
50:d
51:d sharp/e flat
52:e
53:f
54:f sharp/g flat
55:g
56:g sharp/a flat
57:a
58:a sharp/b flat
59:b
60:c1
61:c sharp1/d flat1
62:d1
63:d sharp1/e flat1
64:e1
65:f1
66:f sharp1/g flat1
67:g1
68:g sharp1/a flat1
69:a1
70:a sharp1/b flat1
71:b1
72:c2
73:c sharp2/d flat2
74:d2
75:d sharp2/e flat2
76:e2
77:f2
78:f sharp2/g flat2
79:g2
80:g sharp2/a flat2
81:a2
82:a sharp2/b flat2
83:b2
84:c3
85:c sharp3/d flat3
86:d3
87:d sharp3/e flat3
88:e3
89:f3
90:f sharp3/g flat3
91:f3
92:g sharp3/a flat3
93:a3
94:a sharp3/b flat3
95:b3
96:c4
97:c sharp4/d flat4
98:d4
99:d sharp4/e flat4
100:e4
101:f4
102:f sharp4/g flat4
103:g4
104:g sharp4/a flat4
105:a4
106:a sharp4/b flat4
107:b4
108:c5
109:c sharp5/d flat5
110:d5
111:d sharp5/e flat5
112:e5
113:f5
114:f sharp5/g flat5
115:g5
116:g sharp5/a flat5
117:a5
118:a sharp5/b flat5
119:b5
120:c6
121:c sharp6/d flat6
122:d6
123:d sharp6/e flat6
124:e6
125:f6
126:f sharp6/g flat6
127:g6
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/conf/conf.py
================================================
import numpy
import numpy as np
import pandas as pd
class Conf():
'''
classdocs
'''
def __init__(self, neutype="LIF", task="MusicLearning", dt=0.1):
'''
Constructor
'''
self.neuron_type = neutype
self.task = task
self.dt = dt
self.notesMap = {}
self.GenreMap = {}
self.emoMap = {}
self.key_matrix = []
self.keysMap = {}
self.index2key = {}
self.index2mode = {}
self.keyIndexMap = {}
self.keyscales = {}
self.chordsMap = {}
self.chordsMatrix = np.zeros((7, 7))
self.RunTimeState = 0 # 0: GUI, 1: Bigdata experiments 2: other
def readNoteFiles(self):
# f = open("./Data.txt","r")
f = open("../inputs/MIDIData.txt", "r")
while (True):
line = f.readline()
if not line:
break
else:
strs = line.split(":")
index = int(strs[0])
self.notesMap[index] = strs[1].strip()
f.close()
def readGenreFils(self):
f = open("../inputs/GenreData.txt", "r")
while (True):
line = f.readline()
if not line:
break
else:
strs = line.split(":")
g = strs[0].strip()
ns = strs[1].split(",")
for n in ns:
self.GenreMap[(n.strip()).title()] = g.title()
f.close()
def readEmotionFiles(self):
f = open("../inputs/information.csv", "r")
while (True):
line = f.readline()
if not line:
break
else:
strs = line.split(",")
mn = strs[0].strip()
e = strs[3].strip()
self.emoMap[mn.title()] = e.title()
f.close()
def readKeysFile(self):
f = open("../inputs/keyIndex.csv", "r")
while (True):
line = (f.readline()).strip()
if not line:
break
else:
strs = line.split(",")
toneName = strs[0].strip()
self.keysMap[toneName] = int(strs[1].strip())
# print(self.keysMap)
self.index2key = dict(zip(self.keysMap.values(), self.keysMap.keys()))
# print(self.index2key)
self.key_matrix = np.array(pd.read_excel("../inputs/keys.xlsx", sheet_name='keys'))
self.keyscales = {0: np.array(pd.read_excel("../inputs/keys.xlsx", sheet_name='major')),
1: np.array(pd.read_excel("../inputs/keys.xlsx", sheet_name='minor'))}
def readKeys2IndexFile(self):
f = open("../inputs/keyIndex.csv", "r")
while (True):
line = f.readline().strip()
if not line:
break
else:
strs = line.split(",")
self.keyIndexMap[strs[0].strip()] = int(strs[1].strip())
# print(self.keyIndexMap)
def readChordsFile(self):
tmp = pd.read_excel("../inputs/chords.xlsx", sheet_name=None)
for key, chords in tmp.items():
chords = np.array(chords)
self.chordsMap[key.strip()] = chords
# print(self.chordsMap)
# 暂时先连接主,下属,属和弦
self.chordsMatrix = np.array([[1, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 1, 1, 0, 0],
[1, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0]])
def readModesFile(self):
f = open("../inputs/modeindex.csv", "r")
while (True):
line = (f.readline()).strip()
if not line:
break
else:
strs = line.split(",")
self.index2mode[int(strs[0].strip())] = strs[1].strip()
configs = Conf(neutype = 'Izhikevich')
configs.readNoteFiles()
configs.readGenreFils()
configs.readEmotionFiles()
configs.readKeysFile()
configs.readKeys2IndexFile()
configs.readChordsFile()
configs.readModesFile()
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/1.txt
================================================
1:A-B2
2:A#-B2
3:B-B2
4:C-B1
5:C#-B1
6:D-B1
7:D#-B1
8:E-B1
9:F-B1
10:F#-B1
11:G-B1
12:G#-B1
13:A-B1
14:A#-B1
15:B-B1
16:C-B
17:C#-B
18:D-B
19:D#-B
20:E-B
21:F-B
22:F#-B
23:G-B
24:G#-B
25:A-B
26:A#-B
27:B-B
28:C-S
29:C#-S
30:D-S
31:D#-S
32:E-S
33:F-S
34:F#-S
35:G-S
36:G#-S
37:A-S
38:A#-S
39:B-S
40:C-S1
41:C#-S1
42:D-S1
43:D#-S1
44:E-S1
45:F-S1
46:F#-S1
47:G-S1
48:G#-S1
49:A-S1
50:A#-S1
51:B-S1
52:C-S2
53:C#-S2
54:D-S2
55:D#-S2
56:E-S2
57:F-S2
58:F#-S2
59:G-S2
60:G#-S2
61:A-S2
62:A#-S2
63:B-S2
64:C-S3
65:C#-S3
66:D-S3
67:D#-S3
68:E-S3
69:F-S3
70:F#-S3
71:G-S3
72:G#-S3
73:A-S3
74:A#-S3
75:B-S3
76:C-S4
77:C#-S4
78:D-S4
79:D#-S4
80:E-S4
81:F-S4
82:F#-S4
83:G-S4
84:G#-S4
85:A-S4
86:A#-S4
87:B-S4
88:C-S5
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/Data.txt
================================================
1:A2
2:A#2
3:B2
4:C1
5:C#1
6:D1
7:D#1
8:E1
9:F1
10:F#1
11:G1
12:G#1
13:A1
14:A#1
15:B1
16:C
17:C#
18:D
19:D#
20:E
21:F
22:F#
23:G
24:G#
25:A
26:A#
27:B
28:c
29:c#
30:d
31:d#
32:e
33:f
34:f#
35:g
36:g#
37:a
38:a#
39:b
40:c1
41:c#1
42:d1
43:d#1
44:e1
45:f1
46:f#1
47:g1
48:g#1
49:a1
50:a#1
51:b1
52:c2
53:c#2
54:d2
55:d#2
56:e2
57:f2
58:f#2
59:g2
60:g#2
61:a2
62:a#2
63:b2
64:c3
65:c#3
66:d3
67:d#3
68:e3
69:f3
70:f#3
71:f3
72:g#3
73:a3
74:a#3
75:b3
76:c4
77:c#4
78:d4
79:d#4
80:e4
81:f4
82:f#4
83:g4
84:g#4
85:a4
86:a#4
87:b4
88:c5
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/GenreData.txt
================================================
Baroque:Bach
Classical:Haydn,Mozart,Beethoven,Schubert,Clementi
Romantic:Mendelssohn,Liszt,Chopin,Schumann,Brahms,Burgmueller,Debussy,Godowsky,Moszkowski,Mussorgsky,Rachmaninov,Ravel,Tchaikovsky,Albéniz,Balakirew,Borodin,Granados,Grieg,Sinding
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/MIDIData.txt
================================================
-1:rest
0:C3
1:C sharp3/D flat3
2:D3
3:D sharp3/E flat3
4:E3
5:F3
6:F sharp3/G flat3
7:G3
8:G sharp3/A flat3
9:A3
10:A sharp3/B flat3
11:B3
12:C2
13:C sharp2/D flat2
14:D2
15:D sharp2/E flat2
16:E2
17:F2
18:F sharp2/G flat2
19:G2
20:G sharp2/A flat2
21:A2
22:A sharp2/B flat2
23:B2
24:C1
25:C sharp1/D flat1
26:D1
27:D sharp1/E flat1
28:E1
29:F1
30:F sharp1/G flat1
31:G1
32:G sharp1/A flat1
33:A1
34:A sharp1/B flat1
35:B1
36:C
37:C sharp/D flat
38:D
39:D sharp/E flat
40:E
41:F
42:F sharp/G flat
43:G
44:G sharp/A flat
45:A
46:A sharp/B flat
47:B
48:c
49:c sharp/d flat
50:d
51:d sharp/e flat
52:e
53:f
54:f sharp/g flat
55:g
56:g sharp/a flat
57:a
58:a sharp/b flat
59:b
60:c1
61:c sharp1/d flat1
62:d1
63:d sharp1/e flat1
64:e1
65:f1
66:f sharp1/g flat1
67:g1
68:g sharp1/a flat1
69:a1
70:a sharp1/b flat1
71:b1
72:c2
73:c sharp2/d flat2
74:d2
75:d sharp2/e flat2
76:e2
77:f2
78:f sharp2/g flat2
79:g2
80:g sharp2/a flat2
81:a2
82:a sharp2/b flat2
83:b2
84:c3
85:c sharp3/d flat3
86:d3
87:d sharp3/e flat3
88:e3
89:f3
90:f sharp3/g flat3
91:f3
92:g sharp3/a flat3
93:a3
94:a sharp3/b flat3
95:b3
96:c4
97:c sharp4/d flat4
98:d4
99:d sharp4/e flat4
100:e4
101:f4
102:f sharp4/g flat4
103:g4
104:g sharp4/a flat4
105:a4
106:a sharp4/b flat4
107:b4
108:c5
109:c sharp5/d flat5
110:d5
111:d sharp5/e flat5
112:e5
113:f5
114:f sharp5/g flat5
115:g5
116:g sharp5/a flat5
117:a5
118:a sharp5/b flat5
119:b5
120:c6
121:c sharp6/d flat6
122:d6
123:d sharp6/e flat6
124:e6
125:f6
126:f sharp6/g flat6
127:g6
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/chords.csv
================================================
C,,,
1,C,E,G
4,F,A,C
5,G,B,D
,,,
a,,,
1,A,C,E
4,D,F,A
5,E,G#,B
,,,
G,,,
1,G,B,D
4,C,E,G
5,D,F#,A
,,,
F,,,
1,F,A,C
4,B-,D,F
5,C,E,G
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/information.csv
================================================
chpn_op35_2.mid,Chopin,Romantic,unclear
chpn_op33_4.mid,Chopin,Romantic,unclear
chpn-p14.mid,Chopin,Romantic,unclear
chpn-p17.mid,Chopin,Romantic,soft
chpn_op33_2.mid,Chopin,Romantic,happy
chpn_op53.mid,Chopin,Romantic,unclear
chpn_op35_3.mid,Chopin,Romantic,unclear
chpn-p1.mid,Chopin,Romantic,unclear
chpn_op7_2.mid,Chopin,Romantic,soft
chpn_op25_e11.mid,Chopin,Romantic,passionate
chpn-p16.mid,Chopin,Romantic,unclear
chpn-p20.mid,Chopin,Romantic,depressed
chpn-p13.mid,Chopin,Romantic,soft
chpn_op27_2.mid,Chopin,Romantic,soft
chpn-p6.mid,Chopin,Romantic,depressed
chpn-p9.mid,Chopin,Romantic,unclear
chpn-p15.mid,Chopin,Romantic,depressed
chpn_op10_e12.mid,Chopin,Romantic,passionate
chpn-p22.mid,Chopin,Romantic,unclear
chpn_op23.mid,Chopin,Romantic,unclear
chpn-p24.mid,Chopin,Romantic,passionate
chpn-p18.mid,Chopin,Romantic,happy
chpn_op27_1.mid,Chopin,Romantic,depressed
chpn_op10_e05.mid,Chopin,Romantic,passionate
chpn_op25_e4.mid,Chopin,Romantic,unclear
chpn_op66.mid,Chopin,Romantic,unclear
chpn_op25_e2.mid,Chopin,Romantic,soft
chpn_op10_e01.mid,Chopin,Romantic,passionate
chpn-p11.mid,Chopin,Romantic,happy
chpn-p7.mid,Chopin,Romantic,soft
chp_op18.mid,Chopin,Romantic,unclear
chpn-p23.mid,Chopin,Romantic,soft
chpn-p5.mid,Chopin,Romantic,unclear
chpn_op35_4.mid,Chopin,Romantic,unclear
chpn_op25_e3.mid,Chopin,Romantic,unclear
chpn-p12.mid,Chopin,Romantic,passionate
chpn-p21.mid,Chopin,Romantic,soft
chpn-p4.mid,Chopin,Romantic,depressed
chpn-p2.mid,Chopin,Romantic,depressed
chpn_op35_1.mid,Chopin,Romantic,unclear
chp_op31.mid,Chopin,Romantic,unclear
chpn_op25_e1.mid,Chopin,Romantic,unclear
chpn-p10.mid,Chopin,Romantic,unclear
chpn-p8.mid,Chopin,Romantic,passionate
chpn_op25_e12.mid,Chopin,Romantic,unclear
chpn_op7_1.mid,Chopin,Romantic,happy
chpn-p3.mid,Chopin,Romantic,unclear
chpn-p19.mid,Chopin,Romantic,soft
scn15_5.mid,Schumann,Romantic,unclear
scn15_7.mid,Schumann,Romantic,unclear
scn15_12.mid,Schumann,Romantic,unclear
scn15_6.mid,Schumann,Romantic,unclear
scn15_13.mid,Schumann,Romantic,unclear
scn68_12.mid,Schumann,Romantic,unclear
scn15_1.mid,Schumann,Romantic,unclear
scn15_3.mid,Schumann,Romantic,unclear
scn15_2.mid,Schumann,Romantic,unclear
scn16_3.mid,Schumann,Romantic,unclear
scn16_2.mid,Schumann,Romantic,unclear
scn68_10.mid,Schumann,Romantic,unclear
scn16_6.mid,Schumann,Romantic,unclear
scn16_5.mid,Schumann,Romantic,unclear
scn16_1.mid,Schumann,Romantic,unclear
scn15_9.mid,Schumann,Romantic,unclear
scn15_8.mid,Schumann,Romantic,unclear
scn15_11.mid,Schumann,Romantic,unclear
scn16_7.mid,Schumann,Romantic,unclear
scn16_8.mid,Schumann,Romantic,unclear
scn16_4.mid,Schumann,Romantic,unclear
scn15_10.mid,Schumann,Romantic,unclear
schum_abegg.mid,Schumann,Romantic,unclear
scn15_4.mid,Schumann,Romantic,unclear
ty_august.mid,Tchaikovsky,Romantic,unclear
ty_april.mid,Tchaikovsky,Romantic,unclear
ty_juli.mid,Tchaikovsky,Romantic,unclear
ty_oktober.mid,Tchaikovsky,Romantic,unclear
ty_juni.mid,Tchaikovsky,Romantic,unclear
ty_november.mid,Tchaikovsky,Romantic,unclear
ty_januar.mid,Tchaikovsky,Romantic,unclear
ty_dezember.mid,Tchaikovsky,Romantic,unclear
ty_maerz.mid,Tchaikovsky,Romantic,unclear
ty_mai.mid,Tchaikovsky,Romantic,unclear
ty_februar.mid,Tchaikovsky,Romantic,unclear
ty_september.mid,Tchaikovsky,Romantic,unclear
debussy_cc_3.mid,Debussy,Romantic,happy
DEB_CLAI.MID,Debussy,Romantic,soft
debussy_cc_2.mid,Debussy,Romantic,unclear
debussy_cc_6.mid,Debussy,Romantic,unclear
deb_menu.mid,Debussy,Romantic,unclear
DEB_PASS.MID,Debussy,Romantic,unclear
deb_prel.mid,Debussy,Romantic,happy
debussy_cc_1.mid,Debussy,Romantic,unclear
debussy_cc_4.mid,Debussy,Romantic,unclear
alb_se2.mid,Albeniz,Romantic,unclear
alb_se8.mid,Albeniz,Romantic,unclear
alb_se1.mid,Albeniz,Romantic,unclear
alb_esp6.mid,Albeniz,Romantic,unclear
alb_esp4.mid,Albeniz,Romantic,unclear
alb_esp2.mid,Albeniz,Romantic,unclear
alb_se4.mid,Albeniz,Romantic,unclear
alb_esp1.mid,Albeniz,Romantic,unclear
alb_esp5.mid,Albeniz,Romantic,unclear
alb_se6.mid,Albeniz,Romantic,unclear
alb_se3.mid,Albeniz,Romantic,unclear
alb_esp3.mid,Albeniz,Romantic,unclear
alb_se5.mid,Albeniz,Romantic,unclear
alb_se7.mid,Albeniz,Romantic,unclear
liz_et_trans8.mid,Liszt,Romantic,unclear
liz_liebestraum.mid,Liszt,Romantic,soft
liz_rhap10.mid,Liszt,Romantic,unclear
liz_rhap15.mid,Liszt,Romantic,unclear
liz_et2.mid,Liszt,Romantic,unclear
liz_rhap09.mid,Liszt,Romantic,unclear
liz_rhap02.mid,Liszt,Romantic,unclear
liz_rhap12.mid,Liszt,Romantic,unclear
liz_et4.mid,Liszt,Romantic,unclear
liz_et1.mid,Liszt,Romantic,unclear
liz_et_trans4.mid,Liszt,Romantic,unclear
liz_et5.mid,Liszt,Romantic,unclear
liz_donjuan.mid,Liszt,Romantic,unclear
liz_et3.mid,Liszt,Romantic,soft
liz_et6.mid,Liszt,Romantic,unclear
liz_et_trans5.mid,Liszt,Romantic,unclear
pathetique_1.mid,Beethoven,Classical,unclear
beethoven_hammerklavier_4.mid,Beethoven,Classical,unclear
beethoven_hammerklavier_2.mid,Beethoven,Classical,unclear
beethoven_opus22_4.mid,Beethoven,Classical,happy
waldstein_2.mid,Beethoven,Classical,unclear
beethoven_hammerklavier_3.mid,Beethoven,Classical,depressed
beethoven_les_adieux_3.mid,Beethoven,Classical,happy
beethoven_opus22_2.mid,Beethoven,Classical,unclear
waldstein_1.mid,Beethoven,Classical,happy
appass_2.mid,Beethoven,Classical,unclear
beethoven_opus10_3.mid,Beethoven,Classical,unclear
beethoven_opus22_1.mid,Beethoven,Classical,unclear
beethoven_hammerklavier_1.mid,Beethoven,Classical,unclear
pathetique_3.mid,Beethoven,Classical,passionate
beethoven_opus90_2.mid,Beethoven,Classical,unclear
beethoven_opus22_3.mid,Beethoven,Classical,happy
beethoven_opus90_1.mid,Beethoven,Classical,unclear
waldstein_3.mid,Beethoven,Classical,happy
beethoven_opus10_1.mid,Beethoven,Classical,unclear
appass_1.mid,Beethoven,Classical,unclear
beethoven_opus10_2.mid,Beethoven,Classical,unclear
mond_2.mid,Beethoven,Classical,unclear
beethoven_les_adieux_1.mid,Beethoven,Classical,depressed
elise.mid,Beethoven,Classical,soft
appass_3.mid,Beethoven,Classical,passionate
pathetique_2.mid,Beethoven,Classical,soft
mond_3.mid,Beethoven,Classical,passionate
beethoven_les_adieux_2.mid,Beethoven,Classical,depressed
mond_1.mid,Beethoven,Classical,depressed
muss_3.mid,Mussorgsky,Romantic,unclear
muss_5.mid,Mussorgsky,Romantic,unclear
muss_1.mid,Mussorgsky,Romantic,unclear
muss_8.mid,Mussorgsky,Romantic,unclear
muss_6.mid,Mussorgsky,Romantic,unclear
muss_2.mid,Mussorgsky,Romantic,unclear
muss_7.mid,Mussorgsky,Romantic,unclear
muss_4.mid,Mussorgsky,Romantic,unclear
mendel_op30_4.mid,Mendelssohn,Romantic,unclear
mendel_op19_2.mid,Mendelssohn,Romantic,unclear
mendel_op19_5.mid,Mendelssohn,Romantic,unclear
mendel_op19_4.mid,Mendelssohn,Romantic,soft
mendel_op19_6.mid,Mendelssohn,Romantic,depressed
mendel_op19_1.mid,Mendelssohn,Romantic,soft
mendel_op30_5.mid,Mendelssohn,Romantic,unclear
mendel_op30_2.mid,Mendelssohn,Romantic,unclear
mendel_op62_3.mid,Mendelssohn,Romantic,depressed
mendel_op30_3.mid,Mendelssohn,Romantic,unclear
mendel_op62_4.mid,Mendelssohn,Romantic,unclear
mendel_op19_3.mid,Mendelssohn,Romantic,unclear
mendel_op53_5.mid,Mendelssohn,Romantic,unclear
mendel_op30_1.mid,Mendelssohn,Romantic,soft
mendel_op62_5.mid,Mendelssohn,Romantic,unclear
gra_esp_2.mid,Granados,Romantic,unclear
gra_esp_3.mid,Granados,Romantic,unclear
gra_esp_4.mid,Granados,Romantic,unclear
fruehlingsrauschen.mid,Sinding,Romantic,unclear
rac_op32_1.mid,Rachmaninov,Romantic,unclear
rac_op23_3.mid,Rachmaninov,Romantic,unclear
rac_op33_8.mid,Rachmaninov,Romantic,unclear
rac_op33_6.mid,Rachmaninov,Romantic,unclear
rac_op33_5.mid,Rachmaninov,Romantic,unclear
rac_op23_7.mid,Rachmaninov,Romantic,unclear
rac_op32_13.mid,Rachmaninov,Romantic,unclear
rac_op23_2.mid,Rachmaninov,Romantic,unclear
rac_op3_2.mid,Rachmaninov,Romantic,unclear
rac_op23_5.mid,Rachmaninov,Romantic,unclear
god_alb_esp2.mid,Godowsky,Romantic,unclear
god_chpn_op10_e01.mid,Godowsky,Romantic,unclear
mz_545_1.mid,Mozart,Classical,unclear
mz_570_2.mid,Mozart,Classical,unclear
mz_311_2.mid,Mozart,Classical,unclear
mz_333_3.mid,Mozart,Classical,unclear
mz_331_3.mid,Mozart,Classical,passionate
mz_332_3.mid,Mozart,Classical,unclear
mz_331_2.mid,Mozart,Classical,unclear
mz_332_1.mid,Mozart,Classical,unclear
mz_545_2.mid,Mozart,Classical,unclear
mz_330_1.mid,Mozart,Classical,unclear
mz_545_3.mid,Mozart,Classical,unclear
mz_333_2.mid,Mozart,Classical,unclear
mz_330_2.mid,Mozart,Classical,unclear
mz_333_1.mid,Mozart,Classical,unclear
mz_311_3.mid,Mozart,Classical,unclear
mz_331_1.mid,Mozart,Classical,happy
mz_570_3.mid,Mozart,Classical,unclear
mz_332_2.mid,Mozart,Classical,unclear
mz_570_1.mid,Mozart,Classical,unclear
mz_330_3.mid,Mozart,Classical,unclear
mz_311_1.mid,Mozart,Classical,unclear
rav_scarbo.mid,Ravel,Romantic,unclear
rav_ondi.mid,Ravel,Romantic,unclear
rav_eau.mid,Ravel,Romantic,unclear
rav_gib.mid,Ravel,Romantic,unclear
ravel_miroirs_1.mid,Ravel,Romantic,unclear
clementi_opus36_4_1.mid,Clementi,Classical,unclear
clementi_opus36_2_2.mid,Clementi,Classical,unclear
clementi_opus36_6_1.mid,Clementi,Classical,happy
clementi_opus36_1_3.mid,Clementi,Classical,happy
clementi_opus36_3_3.mid,Clementi,Classical,unclear
clementi_opus36_5_1.mid,Clementi,Classical,happy
clementi_opus36_1_2.mid,Clementi,Classical,soft
clementi_opus36_4_3.mid,Clementi,Classical,unclear
clementi_opus36_1_1.mid,Clementi,Classical,happy
clementi_opus36_2_3.mid,Clementi,Classical,unclear
clementi_opus36_4_2.mid,Clementi,Classical,unclear
clementi_opus36_3_2.mid,Clementi,Classical,unclear
clementi_opus36_5_3.mid,Clementi,Classical,unclear
clementi_opus36_5_2.mid,Clementi,Classical,unclear
clementi_opus36_3_1.mid,Clementi,Classical,unclear
clementi_opus36_6_2.mid,Clementi,Classical,unclear
clementi_opus36_2_1.mid,Clementi,Classical,happy
mos_op36_6.mid,Moszkowski,Romantic,unclear
haydn_7_2.mid,Haydn,Classical,unclear
haydn_35_2.mid,Haydn,Classical,unclear
haydn_7_3.mid,Haydn,Classical,unclear
haydn_8_4.mid,Haydn,Classical,unclear
haydn_9_1.mid,Haydn,Classical,unclear
haydn_8_3.mid,Haydn,Classical,unclear
haydn_35_3.mid,Haydn,Classical,unclear
haydn_33_3.mid,Haydn,Classical,unclear
haydn_8_1.mid,Haydn,Classical,unclear
haydn_8_2.mid,Haydn,Classical,unclear
haydn_7_1.mid,Haydn,Classical,unclear
haydn_9_2.mid,Haydn,Classical,unclear
haydn_43_1.mid,Haydn,Classical,unclear
hay_40_2.mid,Haydn,Classical,unclear
haydn_9_3.mid,Haydn,Classical,unclear
haydn_35_1.mid,Haydn,Classical,unclear
haydn_33_1.mid,Haydn,Classical,unclear
haydn_43_2.mid,Haydn,Classical,unclear
hay_40_1.mid,Haydn,Classical,unclear
haydn_33_2.mid,Haydn,Classical,unclear
haydn_43_3.mid,Haydn,Classical,unclear
bach_850.mid,Bach,Baroque,happy
bach_846.mid,Bach,Baroque,soft
bach_847.mid,Bach,Baroque,unclear
grieg_halling.mid,Grieg,Romantic,unclear
grieg_wedding.mid,Grieg,Romantic,unclear
grieg_brooklet.mid,Grieg,Romantic,unclear
grieg_butterfly.mid,Grieg,Romantic,happy
grieg_wanderer.mid,Grieg,Romantic,unclear
grieg_zwerge.mid,Grieg,Romantic,unclear
grieg_march.mid,Grieg,Romantic,unclear
grieg_album.mid,Grieg,Romantic,unclear
grieg_elfentanz.mid,Grieg,Romantic,unclear
grieg_spring.mid,Grieg,Romantic,unclear
grieg_waechter.mid,Grieg,Romantic,unclear
grieg_kobold.mid,Grieg,Romantic,unclear
grieg_berceuse.mid,Grieg,Romantic,unclear
grieg_voeglein.mid,Grieg,Romantic,unclear
grieg_walzer.mid,Grieg,Romantic,unclear
grieg_once_upon_a_time.mid,Grieg,Romantic,unclear
br_im2.mid,Brahms,Romantic,unclear
br_rhap.mid,Brahms,Romantic,unclear
brahms_opus1_1.mid,Brahms,Romantic,unclear
brahms_opus1_3.mid,Brahms,Romantic,unclear
brahms_opus117_1.mid,Brahms,Romantic,unclear
brahms_opus1_2.mid,Brahms,Romantic,unclear
BR_IM6.MID,Brahms,Romantic,unclear
brahms_opus1_4.mid,Brahms,Romantic,unclear
brahms_opus117_2.mid,Brahms,Romantic,unclear
br_im5.mid,Brahms,Romantic,unclear
burg_geschwindigkeit.mid,Burgmueller,Romantic,unclear
burg_perlen.mid,Burgmueller,Romantic,unclear
burg_trennung.mid,Burgmueller,Romantic,unclear
burg_agitato.mid,Burgmueller,Romantic,unclear
burg_sylphen.mid,Burgmueller,Romantic,unclear
burg_spinnerlied.mid,Burgmueller,Romantic,unclear
burg_quelle.mid,Burgmueller,Romantic,unclear
burg_erwachen.mid,Burgmueller,Romantic,unclear
burg_gewitter.mid,Burgmueller,Romantic,unclear
schuim-4.mid,Schubert,Classical,unclear
schumm-6.mid,Schubert,Classical,unclear
schu_143_2.mid,Schubert,Classical,unclear
schubert_D850_3.mid,Schubert,Classical,unclear
schub_d960_1.mid,Schubert,Classical,unclear
schubert_D935_4.mid,Schubert,Classical,unclear
schu_143_1.mid,Schubert,Classical,unclear
schubert_D850_2.mid,Schubert,Classical,unclear
schumm-2.mid,Schubert,Classical,unclear
schubert_D850_4.mid,Schubert,Classical,unclear
schub_d960_3.mid,Schubert,Classical,unclear
schub_d760_3.mid,Schubert,Classical,unclear
schumm-5.mid,Schubert,Classical,unclear
schumm-3.mid,Schubert,Classical,unclear
schumm-1.mid,Schubert,Classical,unclear
schub_d760_4.mid,Schubert,Classical,unclear
schubert_D935_2.mid,Schubert,Classical,unclear
schumm-4.mid,Schubert,Classical,unclear
schu_143_3.mid,Schubert,Classical,unclear
schub_d960_2.mid,Schubert,Classical,unclear
schuim-1.mid,Schubert,Classical,unclear
schub_d760_1.mid,Schubert,Classical,unclear
schubert_D850_1.mid,Schubert,Classical,unclear
schubert_D935_3.mid,Schubert,Classical,unclear
schub_d960_4.mid,Schubert,Classical,unclear
schuim-3.mid,Schubert,Classical,unclear
schub_d760_2.mid,Schubert,Classical,unclear
schubert_D935_1.mid,Schubert,Classical,unclear
schuim-2.mid,Schubert,Classical,unclear
islamei.mid,Balakirew,Romantic,unclear
bor_ps2.mid,Borodin,Romantic,unclear
bor_ps1.mid,Borodin,Romantic,unclear
bor_ps7.mid,Borodin,Romantic,unclear
bor_ps4.mid,Borodin,Romantic,unclear
bor_ps6.mid,Borodin,Romantic,unclear
bor_ps3.mid,Borodin,Romantic,unclear
bor_ps5.mid,Borodin,Romantic,unclear
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/keyIndex.csv
================================================
C major,0
a minor,1
G major,2
e minor,3
D major,4
b minor,5
A major,6
f# minor,7
E major,8
c# minor,9
B major,10
g# minor,11
F major,12
d minor,13
B- major,14
g minor,15
E- major,16
c minor,17
A- major,18
f minor,19
D- major,20
b- minor,21
G- major,22
e- minor,23
C# major,20
a# minor,21
F# major,22
d# minor,23
C- major,10
a- minor,11
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/keys.csv
================================================
1,-1,2,-1,3,4,-1,5,-1,6,-1,7
3,-1,4,-1,5,6,-1,-1,7,1,-1,2
4,-1,5,-1,6,-1,7,1,-1,2,-1,3
6,-1,-1,7,1,-1,2,3,-1,4,-1,5
-1,7,1,-1,2,-1,3,4,-1,5,-1,6
-1,2,3,-1,4,-1,5,6,-1,-1,7,1
-1,3,4,-1,5,-1,6,-1,7,1,-1,2
-1,5,6,-1,-1,7,1,-1,2,3,-1,4
-1,6,-1,7,1,-1,2,-1,3,4,-1,5
7,1,-1,2,3,-1,4,-1,5,6,-1,-1
-1,2,-1,3,4,-1,5,-1,6,-1,7,1
-1,4,-1,5,6,-1,-1,7,1,-1,2,3
5,-1,6,-1,7,1,-1,2,-1,3,4,-1
-1,7,1,-1,2,3,-1,4,-1,5,6,-1
2,-1,3,4,-1,5,-1,6,-1,7,1,-1
4,-1,5,6,-1,-1,7,1,-1,2,3,-1
6,-1,7,1,-1,2,-1,3,4,-1,5,-1
1,-1,2,3,-1,4,-1,5,6,-1,-1,7
3,4,-1,5,-1,6,-1,7,1,-1,2,-1
5,6,-1,-1,7,1,-1,2,3,-1,4,-1
7,1,-1,2,-1,3,4,-1,5,-1,6,-1
2,3,-1,4,-1,5,6,-1,-1,7,1,-1
-1,5,-1,6,-1,7,1,-1,2,-1,3,4
-1,-1,7,1,-1,2,3,-1,4,-1,5,6
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/modeindex.csv
================================================
0,major
1,minor
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/pitch2midi.csv
================================================
C,0,12,24,36,48,60,72,84,96,108,120
C#,1,13,25,37,49,61,73,85,97,109,121
C-,11,23,35,47,59,71,83,95,107,119,
D,2,14,26,38,50,62,74,86,98,110,122
D#,3,15,27,39,51,63,75,87,99,111,123
D-,1,13,25,37,49,61,73,85,97,109,121
E,4,16,28,40,52,64,76,88,100,112,124
E#,5,17,29,41,53,65,77,89,101,113,125
E-,3,15,27,39,51,63,75,87,99,111,123
F,5,17,29,41,53,65,77,89,101,113,125
F#,6,18,30,42,54,66,78,90,102,114,126
F-,4,16,28,40,52,64,76,88,100,112,124
G,7,19,31,43,55,67,79,91,103,115,127
G#,8,20,32,44,56,68,80,92,104,116,
G-,6,18,30,42,54,66,78,90,102,114,126
A,9,21,33,45,57,69,81,93,105,117,
A#,10,22,34,46,58,70,82,94,106,118,
A-,8,20,32,44,56,68,80,92,104,116,
B,11,23,35,47,59,71,83,95,107,119,
B+,0,12,24,36,48,60,72,84,96,108,120
B-,10,22,34,46,58,70,82,94,106,118,
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/tones2.csv
================================================
,0,1,2,3,4,5,6,7,8,9,10,11
C major,2,-1,1,-1,1,1,-1,1,-1,1,-1,1
a minor,1,-1,1,-1,1,1,-1,-1,2,2,-1,1
G major,1,-1,1,-1,1,-1,2,2,-1,1,-1,1
e minor,1,-1,-1,2,2,-1,2,1,-1,1,-1,1
D major,-1,2,2,-1,1,-1,2,1,-1,1,-1,1
b minor,-1,2,1,-1,1,-1,2,1,-1,-1,2,2
A major,-1,2,1,-1,1,-1,2,-1,2,2,-1,1
f# minor,-1,2,1,-1,-1,2,2,-1,2,1,-1,1
E major,-1,2,-1,2,2,-1,2,-1,2,1,-1,1
c# minor,2,2,-1,2,1,-1,2,-1,2,1,-1,-1
B major,-1,2,-1,2,1,-1,2,-1,2,-1,2,2
g# minor,-1,2,-1,2,1,-1,-1,2,2,-1,2,1
F major,1,-1,1,-1,1,2,-1,1,-1,1,2,-1
d minor,-1,2,2,-1,1,1,-1,1,-1,1,2,-1
B- major,1,-1,1,2,-1,1,-1,1,-1,1,2,-1
g minor,1,-1,1,2,-1,-1,2,2,-1,1,2,-1
E- major,1,-1,1,2,-1,1,-1,1,2,-1,2,-1
c minor,1,-1,1,2,-1,1,-1,1,2,-1,-1,2
A- major,1,2,-1,2,-1,1,-1,1,2,-1,2,-1
f minor,1,2,-1,-1,2,1,-1,1,2,-1,2,-1
D- major,1,2,-1,2,-1,1,2,-1,2,-1,2,-1
b- minor,1,2,-1,2,-1,1,2,-1,-1,2,2,-1
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/task/mode-conditioned learning.py
================================================
import sys
import os
import time
sys.path.append("../../../../")
sys.path.append("../")
import numpy as np
import music21 as m21
from conf.conf import *
from api.music_engine_api import EngineAPI
if __name__=="__main__":
musicEngine = EngineAPI()
musicEngine.cortexInit()
#------------Bach dataset learning----------------#
paths = m21.corpus.getComposer('bach')
print(len(paths))
for path in paths:
musicName = (str(path).split('\\'))[-1]
print(musicName)
if musicName.split('.')[-1] != 'mxl': continue
xmldata = m21.corpus.parse(path)
musicEngine.rememberMusic(musicName, "None")
musicEngine.learnFourPartMusic(xmldata, musicName, "None")
#------------generation test----------------#
key = 'C major'
firstnotes = np.array([[m21.pitch.Pitch('E5').midi],
[m21.pitch.Pitch('G4').midi],
[m21.pitch.Pitch('C4').midi],
[m21.pitch.Pitch('C3').midi]])
result = musicEngine.generateMelodyWithKey(configs.keyIndexMap.get(key),firstnotes,None,4)
steam1 = m21.stream.Stream()
for i,part in result.items():
pt = m21.stream.Stream()
for v in part:
p = v.get("N")
d = v.get("T")
n = m21.note.Note(p)
n.quarterLength = d
pt.append(n)
steam1.insert(0,pt)
opath = '../result_output/tone learning/'
nowtime = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
t2 = ''.join([x for x in nowtime if x.isdigit()])
steam1.write('midi', fp=opath+key+"_"+t2+'.mid')
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/task/musicGeneration.py
================================================
import sys
sys.path.append("../../../../")
sys.path.append("../")
from api.music_engine_api import EngineAPI
import os
if __name__=="__main__":
#----------------------------------Init------------------------------#
musicEngine = EngineAPI()
musicEngine.cortexInit()
#----------------------------Learning process------------------------#
input_path = "../testData/"
for composerName in os.listdir(input_path):
dpath = os.path.join(input_path,composerName)
if os.path.isdir(dpath):
for musicName in os.listdir(dpath):
fileName = (os.path.join(dpath,musicName))
musicEngine.memorizing(musicName,composerName,20,fileName)
#-------------------------Generation Process------------------------#
beginnotes = {1:[-1,67],
2:[-1]}
begindurs = {1:[0.5,0.25],
2:[0.5]}
lengths = [10,8]
genreName = "Classical"
composerName = "Bach"
#Generate a piece of music melody#
musicEngine.generateEx_Nihilo(beginnotes.get(2),begindurs.get(2),20,"melody_generated")
#Generate a piece of melody with a composer style
musicEngine.generateEx_NihiloAccordingToComposer(composerName,beginnotes.get(2),begindurs.get(2),15,"Bach_generated")
#Generate a piece of melody with a genre style
musicEngine.generateEx_NihiloAccordingToGenre(genreName,beginnotes.get(1),begindurs.get(1),15,"Classical_generated")
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/task/musicMemory.py
================================================
import sys
sys.path.append("../")
sys.path.append("../../../../")
from api.music_engine_api import EngineAPI
import os
if __name__=="__main__":
musicEngine = EngineAPI()
musicEngine.cortexInit()
input_path = "../testData/"
#--------------------learning process---------------#
for composerName in os.listdir(input_path):
dpath = os.path.join(input_path,composerName)
if os.path.isdir(dpath):
for musicName in os.listdir(dpath):
fileName = (os.path.join(dpath,musicName))
#Here is the training function, the first and the second parameters refer to the title and composer of a melody.
#The third parameter indicates the number of notes you want to learn. If you want to learn all the notes of melodies, this value is "ALL".
musicEngine.memorizing(musicName, composerName, 20, fileName)
# recall the music based on the name of a music
musicEngine.recallMusic("Sonate C Major.Mid")
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/__init__.py
================================================
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/generateData.py
================================================
'''
Created on 2016.4.27
@author: liangqian
'''
import json
import random
class joint():
def __init__(self):
self.x = 0
self.y = 0
self.z = 0
def self2dic(self):
return {'x':self.x,
'y':self.y,
'z':self.z}
class data():
def __init__(self):
self.T = 0
self.position = 0
self.joints = {}
def self2dic(self):
return {'T':self.T,
'position':self.position
}
Data_List = []
# for i in range(1,419):
# d = data()
# d.T = i
# d.position = random.randint(0,100)
# dic = d.self2dic()
# Data_List.append(d)
#
# dic = {}
# dic['Data'] = Data_List
#
# # using json to write file
# strs = (json.dumps(dic))
#
# fout = open('position.txt','w')
# fout.write(strs)
# fout.close()
#
#
# fin = open('position.txt','r')
# data_json = fin.read()
#
# print(data_json)
f = open('../Data.txt','r')
line = f.readline()
while(len(line) != 0):
if('T=' in line):
ss = line.split('=')
d = data()
d.T = int(ss[1].strip())
d.position = random.randint(0,100)
line = f.readline()
while(len(line.strip()) != 0):
sss = line.split(' ')
jj = joint()
s = sss[1].split('=')
jj.x = float(s[1].strip())
s = sss[2].split('=')
jj.y = float(s[1].strip())
s = sss[3].split('=')
jj.z = float(s[1].strip())
d.joints[sss[0]] = jj
line = f.readline()
Data_List.append(d)
print(len(d.joints))
line = f.readline()
print(len(Data_List))
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/hamonydataset_test.py
================================================
import os
import sys
import music21 as m21
import numpy as np
firstnotes = np.array([[m21.pitch.Pitch('b-4').midi],
[m21.pitch.Pitch('d4').midi],
[m21.pitch.Pitch('f3').midi],
[m21.pitch.Pitch('B2').midi]])
print(firstnotes)
input_path = "../xmlfiles/hamony dataset/"
for ch in os.listdir(input_path):
if ch == '.DS_Store': continue
fileName = os.path.join(input_path, ch)
for fName in os.listdir(fileName):
if fName == 'four':
if fName == '.DS_Store': continue
fn = os.path.join(fileName+'/four')
for f in os.listdir(fn):
if f == '.DS_Store': continue
print(f)
musicName = f.split("_")[0]
fTone = (f.split("_")[1]).split(".")[0]
print(musicName)
print(fTone)
print('---')
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/msg.py
================================================
'''
Created on 2016.6.8
@author: liangqian
'''
import time
import sys
import stomp
class MyListener(object):
def on_error(self, headers, message):
print('received an error : %s' % message)
def on_message(self, headers, message):
print('%s' % message)
conn = stomp.Connection([('159.226.19.16',61613)])
#conn = stomp.Connection([('10.10.10.106',61613)])
conn.set_listener('', MyListener())
conn.start()
print('hh')
conn.connect(wait=True,headers={'client-id':'LXYNB','non_persistent':'true'})
conn.subscribe(destination='/topic/TEST2.FOO',id='LX', ack='auto',headers={'activemq.subscriptionName':'LXYNB'})
#conn.send(body='hello,garfield!', destination='/topic/myTopic.messages')
while(True):
pass
conn.disconnect()
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/msgq.py
================================================
'''
Created on 2018.9.7
@author: liangqian
'''
import time
import sys
import stomp
def createMSQ():
queue_name = '/queue/SampleQueue'
conn = stomp.Connection([('localhost',61613)])
#conn.start()
print("building connection to activemq......")
conn.connect()
#return conn
# for i in range (10):
# msg = 'this is the '+ str(i) + 'th messages'
# conn.send(queue_name,msg)
# print(msg)
# conn.disconnect()
return conn
#createMSQ()
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/oscillations.py
================================================
'''
Created on 2016.5.13
@author: liangqian
'''
from modal.lifneuron import LIFNeuron
from modal.cluster import Cluster
from modal.synapse import Synapse
c = Cluster('LIF')
c.createClusterNetwork()
c.setInhibitoryNeurons(0.2)
for i in range(0,c.neunum):
for j in range(0,c.neunum):
if(i != j):
node = c.neurons[j]
node.pre_neurons.append(c.neurons[i])
syn = Synapse(c.neurons[i],node)
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/position.txt
================================================
{"1": {"position": 58, "T": 1}, "2": {"position": 96, "T": 2}, "3": {"position": 36, "T": 3}, "4": {"position": 28, "T": 4}, "5": {"position": 43, "T": 5}, "6": {"position": 46, "T": 6}, "7": {"position": 91, "T": 7}, "8": {"position": 44, "T": 8}, "9": {"position": 59, "T": 9}, "10": {"position": 83, "T": 10}, "11": {"position": 17, "T": 11}, "12": {"position": 50, "T": 12}, "13": {"position": 46, "T": 13}, "14": {"position": 53, "T": 14}, "15": {"position": 9, "T": 15}, "16": {"position": 88, "T": 16}, "17": {"position": 23, "T": 17}, "18": {"position": 28, "T": 18}, "19": {"position": 17, "T": 19}, "20": {"position": 16, "T": 20}, "21": {"position": 43, "T": 21}, "22": {"position": 93, "T": 22}, "23": {"position": 13, "T": 23}, "24": {"position": 0, "T": 24}, "25": {"position": 92, "T": 25}, "26": {"position": 56, "T": 26}, "27": {"position": 2, "T": 27}, "28": {"position": 38, "T": 28}, "29": {"position": 17, "T": 29}, "30": {"position": 13, "T": 30}, "31": {"position": 95, "T": 31}, "32": {"position": 86, "T": 32}, "33": {"position": 45, "T": 33}, "34": {"position": 30, "T": 34}, "35": {"position": 42, "T": 35}, "36": {"position": 87, "T": 36}, "37": {"position": 14, "T": 37}, "38": {"position": 56, "T": 38}, "39": {"position": 76, "T": 39}, "40": {"position": 55, "T": 40}, "41": {"position": 5, "T": 41}, "42": {"position": 98, "T": 42}, "43": {"position": 5, "T": 43}, "44": {"position": 17, "T": 44}, "45": {"position": 73, "T": 45}, "46": {"position": 55, "T": 46}, "47": {"position": 86, "T": 47}, "48": {"position": 56, "T": 48}, "49": {"position": 80, "T": 49}, "50": {"position": 87, "T": 50}, "51": {"position": 47, "T": 51}, "52": {"position": 51, "T": 52}, "53": {"position": 62, "T": 53}, "54": {"position": 3, "T": 54}, "55": {"position": 65, "T": 55}, "56": {"position": 90, "T": 56}, "57": {"position": 76, "T": 57}, "58": {"position": 35, "T": 58}, "59": {"position": 81, "T": 59}, "60": {"position": 0, "T": 60}, "61": {"position": 32, "T": 61}, "62": {"position": 93, "T": 62}, "63": {"position": 100, "T": 63}, "64": {"position": 70, "T": 64}, "65": {"position": 10, "T": 65}, "66": {"position": 37, "T": 66}, "67": {"position": 87, "T": 67}, "68": {"position": 72, "T": 68}, "69": {"position": 77, "T": 69}, "70": {"position": 65, "T": 70}, "71": {"position": 21, "T": 71}, "72": {"position": 45, "T": 72}, "73": {"position": 81, "T": 73}, "74": {"position": 27, "T": 74}, "75": {"position": 55, "T": 75}, "76": {"position": 8, "T": 76}, "77": {"position": 2, "T": 77}, "78": {"position": 66, "T": 78}, "79": {"position": 80, "T": 79}, "80": {"position": 76, "T": 80}, "81": {"position": 49, "T": 81}, "82": {"position": 49, "T": 82}, "83": {"position": 24, "T": 83}, "84": {"position": 51, "T": 84}, "85": {"position": 43, "T": 85}, "86": {"position": 31, "T": 86}, "87": {"position": 58, "T": 87}, "88": {"position": 84, "T": 88}, "89": {"position": 58, "T": 89}, "90": {"position": 95, "T": 90}, "91": {"position": 91, "T": 91}, "92": {"position": 97, "T": 92}, "93": {"position": 53, "T": 93}, "94": {"position": 33, "T": 94}, "95": {"position": 11, "T": 95}, "96": {"position": 8, "T": 96}, "97": {"position": 96, "T": 97}, "98": {"position": 28, "T": 98}, "99": {"position": 54, "T": 99}, "100": {"position": 37, "T": 100}, "101": {"position": 60, "T": 101}, "102": {"position": 71, "T": 102}, "103": {"position": 6, "T": 103}, "104": {"position": 7, "T": 104}, "105": {"position": 4, "T": 105}, "106": {"position": 39, "T": 106}, "107": {"position": 75, "T": 107}, "108": {"position": 47, "T": 108}, "109": {"position": 30, "T": 109}, "110": {"position": 100, "T": 110}, "111": {"position": 30, "T": 111}, "112": {"position": 40, "T": 112}, "113": {"position": 11, "T": 113}, "114": {"position": 3, "T": 114}, "115": {"position": 83, "T": 115}, "116": {"position": 45, "T": 116}, "117": {"position": 91, "T": 117}, "118": {"position": 85, "T": 118}, "119": {"position": 56, "T": 119}, "120": {"position": 33, "T": 120}, "121": {"position": 74, "T": 121}, "122": {"position": 100, "T": 122}, "123": {"position": 56, "T": 123}, "124": {"position": 91, "T": 124}, "125": {"position": 16, "T": 125}, "126": {"position": 0, "T": 126}, "127": {"position": 17, "T": 127}, "128": {"position": 73, "T": 128}, "129": {"position": 40, "T": 129}, "130": {"position": 10, "T": 130}, "131": {"position": 6, "T": 131}, "132": {"position": 63, "T": 132}, "133": {"position": 23, "T": 133}, "134": {"position": 47, "T": 134}, "135": {"position": 79, "T": 135}, "136": {"position": 97, "T": 136}, "137": {"position": 13, "T": 137}, "138": {"position": 8, "T": 138}, "139": {"position": 39, "T": 139}, "140": {"position": 72, "T": 140}, "141": {"position": 59, "T": 141}, "142": {"position": 55, "T": 142}, "143": {"position": 91, "T": 143}, "144": {"position": 22, "T": 144}, "145": {"position": 62, "T": 145}, "146": {"position": 23, "T": 146}, "147": {"position": 14, "T": 147}, "148": {"position": 70, "T": 148}, "149": {"position": 86, "T": 149}, "150": {"position": 10, "T": 150}, "151": {"position": 72, "T": 151}, "152": {"position": 91, "T": 152}, "153": {"position": 29, "T": 153}, "154": {"position": 29, "T": 154}, "155": {"position": 94, "T": 155}, "156": {"position": 69, "T": 156}, "157": {"position": 58, "T": 157}, "158": {"position": 0, "T": 158}, "159": {"position": 11, "T": 159}, "160": {"position": 24, "T": 160}, "161": {"position": 70, "T": 161}, "162": {"position": 58, "T": 162}, "163": {"position": 58, "T": 163}, "164": {"position": 7, "T": 164}, "165": {"position": 25, "T": 165}, "166": {"position": 32, "T": 166}, "167": {"position": 92, "T": 167}, "168": {"position": 58, "T": 168}, "169": {"position": 55, "T": 169}, "170": {"position": 13, "T": 170}, "171": {"position": 17, "T": 171}, "172": {"position": 4, "T": 172}, "173": {"position": 30, "T": 173}, "174": {"position": 58, "T": 174}, "175": {"position": 91, "T": 175}, "176": {"position": 53, "T": 176}, "177": {"position": 41, "T": 177}, "178": {"position": 45, "T": 178}, "179": {"position": 69, "T": 179}, "180": {"position": 82, "T": 180}, "181": {"position": 44, "T": 181}, "182": {"position": 48, "T": 182}, "183": {"position": 16, "T": 183}, "184": {"position": 75, "T": 184}, "185": {"position": 91, "T": 185}, "186": {"position": 16, "T": 186}, "187": {"position": 71, "T": 187}, "188": {"position": 59, "T": 188}, "189": {"position": 89, "T": 189}, "190": {"position": 25, "T": 190}, "191": {"position": 85, "T": 191}, "192": {"position": 34, "T": 192}, "193": {"position": 83, "T": 193}, "194": {"position": 36, "T": 194}, "195": {"position": 30, "T": 195}, "196": {"position": 5, "T": 196}, "197": {"position": 7, "T": 197}, "198": {"position": 5, "T": 198}, "199": {"position": 74, "T": 199}, "200": {"position": 29, "T": 200}, "201": {"position": 40, "T": 201}, "202": {"position": 48, "T": 202}, "203": {"position": 69, "T": 203}, "204": {"position": 0, "T": 204}, "205": {"position": 67, "T": 205}, "206": {"position": 7, "T": 206}, "207": {"position": 79, "T": 207}, "208": {"position": 72, "T": 208}, "209": {"position": 68, "T": 209}, "210": {"position": 97, "T": 210}, "211": {"position": 40, "T": 211}, "212": {"position": 84, "T": 212}, "213": {"position": 20, "T": 213}, "214": {"position": 91, "T": 214}, "215": {"position": 16, "T": 215}, "216": {"position": 12, "T": 216}, "217": {"position": 100, "T": 217}, "218": {"position": 89, "T": 218}, "219": {"position": 33, "T": 219}, "220": {"position": 60, "T": 220}, "221": {"position": 25, "T": 221}, "222": {"position": 82, "T": 222}, "223": {"position": 28, "T": 223}, "224": {"position": 53, "T": 224}, "225": {"position": 6, "T": 225}, "226": {"position": 63, "T": 226}, "227": {"position": 93, "T": 227}, "228": {"position": 14, "T": 228}, "229": {"position": 47, "T": 229}, "230": {"position": 42, "T": 230}, "231": {"position": 94, "T": 231}, "232": {"position": 61, "T": 232}, "233": {"position": 88, "T": 233}, "234": {"position": 17, "T": 234}, "235": {"position": 3, "T": 235}, "236": {"position": 97, "T": 236}, "237": {"position": 38, "T": 237}, "238": {"position": 30, "T": 238}, "239": {"position": 84, "T": 239}, "240": {"position": 37, "T": 240}, "241": {"position": 99, "T": 241}, "242": {"position": 36, "T": 242}, "243": {"position": 100, "T": 243}, "244": {"position": 53, "T": 244}, "245": {"position": 44, "T": 245}, "246": {"position": 37, "T": 246}, "247": {"position": 80, "T": 247}, "248": {"position": 8, "T": 248}, "249": {"position": 79, "T": 249}, "250": {"position": 94, "T": 250}, "251": {"position": 82, "T": 251}, "252": {"position": 84, "T": 252}, "253": {"position": 47, "T": 253}, "254": {"position": 27, "T": 254}, "255": {"position": 22, "T": 255}, "256": {"position": 22, "T": 256}, "257": {"position": 26, "T": 257}, "258": {"position": 58, "T": 258}, "259": {"position": 83, "T": 259}, "260": {"position": 60, "T": 260}, "261": {"position": 16, "T": 261}, "262": {"position": 54, "T": 262}, "263": {"position": 65, "T": 263}, "264": {"position": 7, "T": 264}, "265": {"position": 57, "T": 265}, "266": {"position": 15, "T": 266}, "267": {"position": 63, "T": 267}, "268": {"position": 54, "T": 268}, "269": {"position": 58, "T": 269}, "270": {"position": 16, "T": 270}, "271": {"position": 45, "T": 271}, "272": {"position": 57, "T": 272}, "273": {"position": 72, "T": 273}, "274": {"position": 93, "T": 274}, "275": {"position": 55, "T": 275}, "276": {"position": 77, "T": 276}, "277": {"position": 74, "T": 277}, "278": {"position": 20, "T": 278}, "279": {"position": 57, "T": 279}, "280": {"position": 89, "T": 280}, "281": {"position": 68, "T": 281}, "282": {"position": 74, "T": 282}, "283": {"position": 87, "T": 283}, "284": {"position": 11, "T": 284}, "285": {"position": 74, "T": 285}, "286": {"position": 69, "T": 286}, "287": {"position": 95, "T": 287}, "288": {"position": 76, "T": 288}, "289": {"position": 23, "T": 289}, "290": {"position": 6, "T": 290}, "291": {"position": 42, "T": 291}, "292": {"position": 71, "T": 292}, "293": {"position": 24, "T": 293}, "294": {"position": 18, "T": 294}, "295": {"position": 76, "T": 295}, "296": {"position": 97, "T": 296}, "297": {"position": 41, "T": 297}, "298": {"position": 60, "T": 298}, "299": {"position": 13, "T": 299}, "300": {"position": 84, "T": 300}, "301": {"position": 93, "T": 301}, "302": {"position": 75, "T": 302}, "303": {"position": 81, "T": 303}, "304": {"position": 76, "T": 304}, "305": {"position": 93, "T": 305}, "306": {"position": 2, "T": 306}, "307": {"position": 34, "T": 307}, "308": {"position": 27, "T": 308}, "309": {"position": 37, "T": 309}, "310": {"position": 71, "T": 310}, "311": {"position": 84, "T": 311}, "312": {"position": 97, "T": 312}, "313": {"position": 56, "T": 313}, "314": {"position": 100, "T": 314}, "315": {"position": 41, "T": 315}, "316": {"position": 88, "T": 316}, "317": {"position": 51, "T": 317}, "318": {"position": 80, "T": 318}, "319": {"position": 62, "T": 319}, "320": {"position": 34, "T": 320}, "321": {"position": 34, "T": 321}, "322": {"position": 32, "T": 322}, "323": {"position": 56, "T": 323}, "324": {"position": 58, "T": 324}, "325": {"position": 25, "T": 325}, "326": {"position": 29, "T": 326}, "327": {"position": 15, "T": 327}, "328": {"position": 55, "T": 328}, "329": {"position": 87, "T": 329}, "330": {"position": 20, "T": 330}, "331": {"position": 18, "T": 331}, "332": {"position": 25, "T": 332}, "333": {"position": 21, "T": 333}, "334": {"position": 69, "T": 334}, "335": {"position": 50, "T": 335}, "336": {"position": 12, "T": 336}, "337": {"position": 30, "T": 337}, "338": {"position": 41, "T": 338}, "339": {"position": 59, "T": 339}, "340": {"position": 100, "T": 340}, "341": {"position": 20, "T": 341}, "342": {"position": 20, "T": 342}, "343": {"position": 72, "T": 343}, "344": {"position": 30, "T": 344}, "345": {"position": 94, "T": 345}, "346": {"position": 31, "T": 346}, "347": {"position": 26, "T": 347}, "348": {"position": 13, "T": 348}, "349": {"position": 5, "T": 349}, "350": {"position": 2, "T": 350}, "351": {"position": 97, "T": 351}, "352": {"position": 97, "T": 352}, "353": {"position": 3, "T": 353}, "354": {"position": 75, "T": 354}, "355": {"position": 98, "T": 355}, "356": {"position": 97, "T": 356}, "357": {"position": 35, "T": 357}, "358": {"position": 12, "T": 358}, "359": {"position": 46, "T": 359}, "360": {"position": 91, "T": 360}, "361": {"position": 40, "T": 361}, "362": {"position": 48, "T": 362}, "363": {"position": 98, "T": 363}, "364": {"position": 70, "T": 364}, "365": {"position": 7, "T": 365}, "366": {"position": 75, "T": 366}, "367": {"position": 35, "T": 367}, "368": {"position": 48, "T": 368}, "369": {"position": 77, "T": 369}, "370": {"position": 91, "T": 370}, "371": {"position": 96, "T": 371}, "372": {"position": 9, "T": 372}, "373": {"position": 2, "T": 373}, "374": {"position": 76, "T": 374}, "375": {"position": 25, "T": 375}, "376": {"position": 95, "T": 376}, "377": {"position": 72, "T": 377}, "378": {"position": 59, "T": 378}, "379": {"position": 4, "T": 379}, "380": {"position": 87, "T": 380}, "381": {"position": 100, "T": 381}, "382": {"position": 91, "T": 382}, "383": {"position": 73, "T": 383}, "384": {"position": 63, "T": 384}, "385": {"position": 64, "T": 385}, "386": {"position": 16, "T": 386}, "387": {"position": 20, "T": 387}, "388": {"position": 51, "T": 388}, "389": {"position": 56, "T": 389}, "390": {"position": 81, "T": 390}, "391": {"position": 16, "T": 391}, "392": {"position": 88, "T": 392}, "393": {"position": 68, "T": 393}, "394": {"position": 65, "T": 394}, "395": {"position": 54, "T": 395}, "396": {"position": 29, "T": 396}, "397": {"position": 37, "T": 397}, "398": {"position": 88, "T": 398}, "399": {"position": 82, "T": 399}, "400": {"position": 93, "T": 400}, "401": {"position": 71, "T": 401}, "402": {"position": 33, "T": 402}, "403": {"position": 72, "T": 403}, "404": {"position": 83, "T": 404}, "405": {"position": 93, "T": 405}, "406": {"position": 72, "T": 406}, "407": {"position": 3, "T": 407}, "408": {"position": 98, "T": 408}, "409": {"position": 99, "T": 409}, "410": {"position": 39, "T": 410}, "411": {"position": 62, "T": 411}, "412": {"position": 60, "T": 412}, "413": {"position": 48, "T": 413}, "414": {"position": 32, "T": 414}, "415": {"position": 66, "T": 415}, "416": {"position": 40, "T": 416}, "417": {"position": 33, "T": 417}, "418": {"position": 68, "T": 418}}
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/readjson.py
================================================
'''
Created on 2016.5.24
@author: liangqian
'''
import json
import os
def readjsonFile(filename):
#print(os.path.abspath(os.curdir))
f = open(filename,'r')
jsonstrs = f.read()
#print(jsonstrs)
jdata = json.loads(jsonstrs)
return jdata
#readjsonFile('../jsondata.txt')
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/testSound.py
================================================
'''
Created on 2016.6.29
@author: liangqian
'''
import pygame,sys
pygame.init()
pygame.mixer.init()
pygame.time.delay(1000)
pygame.mixer.music.load("do.wav")
pygame.mixer.music.play()
while 1:
for event in pygame.event.get():
if event.type==pygame.QUIT:
sys.exit()
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/testmusic21.py
================================================
from music21 import *
#s = converter.parse('../xmlfiles/four_part_hamony/ch4-03_A-major.xml')
s = corpus.parse('bach/bwv65.2.xml')
s.analyze('key')
print(len(s.parts))
s.show()
for i,part in enumerate(s.parts):
print(i)
print('-----------')
for ns in part.flat.notes:
print(ns.pitch)
print(ns.duration.quarterLength)
note1 = note.Note("D5")
note2 = note.Note("F#5")
note2.duration.quarterLength = 0.5
note3 = note.Note("A5")
stream1 = stream.Stream()
stream1.append(note1)
stream1.append(note2)
stream1.append(note3)
print(note2.offset)
sout = stream1.getElementsByOffset(0,2)
sBach = corpus.parse('bach/bwv57.8')
s = sBach.chordify()
#cs = s.getElementsByClass('Chord')
s1 = s.flatten()
chords = s1.getElementsByClass('Chord')
# cMinor = chord.Chord(["A4","F4","D5"])
# print(cMinor.inversion())
# print(cMinor.isMinorTriad())
keyA = key.Key('B-')
for c in chords:
rn = roman.romanNumeralFromChord(c, keyA)
c.addLyric(str(rn.figure))
chords.show()
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/testopengl.py
================================================
'''
Created on 2018.8.31
@author: liangqian
'''
from OpenGL.GL import *
from OpenGL.GLU import *
from OpenGL.GLUT import *
def drawFunc():
glClear(GL_COLOR_BUFFER_BIT)
#glRotatef(1, 0, 1, 0)
glutWireTeapot(0.5)
glFlush()
glutInit()
glutInitDisplayMode(GLUT_SINGLE | GLUT_RGBA)
glutInitWindowSize(400, 400)
glutCreateWindow(b"First")
glutDisplayFunc(drawFunc)
#glutIdleFunc(drawFunc)
glutMainLoop()
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/testwave.py
================================================
import wave
import struct
import os
import numpy as np
f = 440
framerate = 44100.0
fw = wave.open("sine.wav","wb")
fw.setnchannels(1)
fw.setframerate(framerate)
fw.setsampwidth(2)
tt = np.arange(0, 1, 1.0/framerate)
data = [2000*(np.sin(2*np.pi*f*t)+np.sin(2*np.pi*2*f*t)+np.sin(2*np.pi*3*f*t)) for t in tt]
print(data)
for d in data:
fw.writeframes(struct.pack('h',int(d)))
fw.close()
================================================
FILE: examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/xmlParser.py
================================================
import librosa
import music21 as m21
import pandas as pd
import os
'''
This function parses MusicXML file and extracts necessary score information as CSV.
'''
def readXmlAsCsv(xmlPath='xml/'):
for subfolder in os.listdir(xmlPath):
if subfolder.startswith('.'):
continue
subfolder_path = os.path.join(xmlPath, subfolder)
for item in os.listdir(subfolder_path):
if item.endswith('xml'):
item_path = os.path.join(subfolder_path, item)
xml_data = m21.converter.parse(item_path)
print("Converting ", item_path)
score = []
for part in xml_data.parts:
for note in part.flat.notes:
if note.isChord:
print('note is chord: ', note)
measureNo = note.measureNumber
start = note.offset
duration = note.quarterLength
for chord_note in note:
pitch = chord_note.pitch
articulations = note.articulations
expressions = note.expressions
spanners = note.getSpannerSites()
gliss = []
for spanner in spanners:
if 'Glissando' in spanner.classes:
if spanner.isFirst(chord_note):
gliss.append('slide start')
if spanner.isLast(chord_note):
gliss.append('slide last')
score.append(
[measureNo, start, duration, pitch, m21.pitch.Pitch(pitch).frequency,
articulations, expressions, gliss, spanners])
else:
measureNo = note.measureNumber
start = note.offset
duration = note.quarterLength
pitch = note.pitch
articulations = note.articulations
expressions = note.expressions
spanners = note.getSpannerSites()
gliss = []
for spanner in spanners:
if 'Glissando' in spanner.classes:
if spanner.isFirst(note):
gliss.append('slide start')
if spanner.isLast(note):
gliss.append('slide last')
score.append(
[measureNo, start, duration, pitch, m21.pitch.Pitch(pitch).frequency,
articulations, expressions, gliss, spanners])
score = sorted(score, key=lambda x: (x[0], x[1], x[2]))
df = pd.DataFrame(score,
columns=['MeasureNumber', 'Start', 'Duration', 'Pitch', 'f0',
'Articulations', 'Expressions', 'Glissando', 'Spanner'])
df.to_csv(os.path.join(path, 'csv', subfolder, os.path.splitext(item)[0] + '.csv'))
================================================
FILE: examples/MotorControl/experimental/README.md
================================================
# Experimental works for motor control with different Brain Aeras.
The project is still immature and under continuous development...
## Citation
If you find the code and dataset useful in your research, please consider citing:
```
@misc{https://doi.org/10.48550/arxiv.2207.08533,
doi = {10.48550/ARXIV.2207.08533},
url = {https://arxiv.org/abs/2207.08533},
author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},
title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},
publisher = {arXiv},
year = {2022},
}
```
================================================
FILE: examples/MotorControl/experimental/brain_area.py
================================================
import torch
import numpy as np
import torch.nn as nn
from braincog.base.node.node import *
class MoColumnPOP(nn.Module):
def __init__(self,
input_dims: int,
pop_num: int = 16,
embedding_dim: int = 64,
time_window: int = 16) -> None:
super().__init__()
self._threshold = 1.0
self.v_reset = 0.0
self._time_window = time_window
self._pop_num = pop_num
self._node = LIFNode
self.column_net = nn.ModuleList(
[nn.Sequential(
nn.Linear(input_dims, embedding_dim),
self._node(threshold=self._threshold, v_reset=self.v_reset))
for _ in range(pop_num)
]
)
self.decode = nn.Linear(embedding_dim, 64)
def reset(self):
for mod in self.modules():
if hasattr(mod, 'n_reset'):
mod.n_reset()
def _emb_decode(self, x):
pop_emb_decode = []
for net in self.column_net:
emb = net(x)
pop_emb_decode.append(self.decode(emb))
return pop_emb_decode
def forward(self, inputs):
pop_emb_decode = self._emb_decode(inputs)
out = sum(pop_emb_decode) / self._pop_num
return out
class MotorCortex(nn.Module):
def __init__(self,
input_dims: int,
out_dims: int = 128,
time_window: int = 16) -> None:
super().__init__()
self._threshold = 1.0
self.v_reset = 0.0
self._time_window = time_window
self._node = LIFNode
self.pfc_net = nn.Sequential(
nn.Linear(input_dims, 512),
self._node(threshold=self._threshold, v_reset=self.v_reset)
)
self.sma_net = nn.Sequential(
nn.Linear(input_dims, 512),
self._node(threshold=self._threshold, v_reset=self.v_reset)
)
self.ganglia_net = nn.Sequential(
nn.Linear(512, 128),
self._node(threshold=self._threshold, v_reset=self.v_reset)
)
self.pmc_net = nn.Sequential(
nn.Linear(512, 512),
self._node(threshold=self._threshold, v_reset=self.v_reset)
)
self.motor_net = nn.Sequential(
nn.Linear(512+128, 128),
self._node(threshold=self._threshold, v_reset=self.v_reset)
)
self.motor_emb = MoColumnPOP(input_dims=128, embedding_dim=out_dims, time_window=time_window)
def reset(self):
for mod in self.modules():
if hasattr(mod, 'n_reset'):
mod.n_reset()
def _compute_motor_out(self, inputs):
sma_out = self.sma_net(inputs)
ganglia_out = self.ganglia_net(sma_out)
motor_in = torch.concat([ganglia_out, sma_out], dim=-1)
motor_out = self.motor_net(motor_in)
# pop coding
return motor_out
def forward(self, inputs):
self.reset()
outs = []
for step in range(self._time_window):
motor_out = self._compute_motor_out(inputs)
m_emb = self.motor_emb(motor_out) # [Batch, 128]
outs.append(m_emb)
return outs
class Celebellum(nn.Module):
def __init__(self,
input_dims: int = 512,
out_dims: int = 7,
time_window: int = 16,
) -> None:
super().__init__()
self._threshold = 1.0
self.v_reset = 0.0
self._time_window = time_window
self._node = LIFNode
self.gc_layer = nn.Sequential(
nn.Linear(input_dims, 512),
self._node(threshold=self._threshold, v_reset=self.v_reset)
)
self.pc_layer = nn.Sequential(
nn.Linear(512, 512),
self._node(threshold=self._threshold, v_reset=self.v_reset)
)
self.dcn_layer = nn.Sequential(
nn.Linear(input_dims + 512, 512),
self._node(threshold=self._threshold, v_reset=self.v_reset),
nn.Linear(512, out_dims)
)
def reset(self):
for mod in self.modules():
if hasattr(mod, 'n_reset'):
mod.n_reset()
def forward(self, x):
self.reset()
outs = []
for step in range(self._time_window):
gc = self.gc_layer(x[step])
pc = self.pc_layer(gc)
dcn_in = torch.concat([x[step], pc], dim=-1)
dcn = self.dcn_layer(dcn_in)
outs.append(dcn)
cel_out = sum(outs) / self._time_window
return cel_out
if __name__ == '__main__':
motor = MotorCortex(input_dims=1024)
for mod in motor.modules():
# print('mod: ', mod)
if hasattr(mod, 'n_reset'):
print('mod: ', mod)
================================================
FILE: examples/MotorControl/experimental/main.py
================================================
import torch
import numpy as np
import torch.nn as nn
from model import Motion
import tqdm
import argparse
from torch.nn import functional as F
parser = argparse.ArgumentParser(description='Motor Parameters')
parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
parser.add_argument('--time-window', type=int, default=8, help="Number of timesteps to do.")
parser.add_argument('--device', type=str, default='0', help="CUDA device")
parser.add_argument('--log-path', type=str, default='./logs/out.txt', help="Log path")
args = parser.parse_args()
print(args)
device = torch.device('cuda:'+args.device)
LABELS = {
'position_group_0': (-0.337, -0.020, -0.077, -0.031164, 0.999496, -0.005979, 2.850154),
'position_group_1': (-0.337, 0.007, -0.077, -0.039668, 0.999161, -0.010174, 2.894892),
'position_group_2': (-0.337, 0.030, -0.076, -0.031164, 0.999496, -0.005979, 2.850154),
'position_group_3': (-0.337, 0.052, -0.076, -0.031164, 0.999496, -0.005979, 2.850154),
'position_group_4': (-0.339, 0.074, -0.076, 0.016057, 0.999842, -0.007643, 2.804204),
'position_group_5': (-0.339, 0.096, -0.078, 0.016057, 0.999842, -0.007643, 2.804204),
'position_group_6': (-0.339, 0.123, -0.079, 0.016057, 0.999842, -0.007643, 2.804204),
'position_group_7': (-0.337, 0.139, -0.080, 0.076912, 0.997035, -0.0021723, 2.799101),
'position_group_8': (-0.337, 0.163, -0.0770, 0.076912, 0.997035, -0.0021723, 2.799101),
'position_group_9': (-0.338, 0.188, -0.075, 0.076912, 0.997035, -0.002172, 2.799101),
'position_group_10': (-0.338, 0.212, -0.075, 0.087103, 0.995757, -0.029681, 2.785759),
'position_group_11': (-0.338, 0.235, -0.070, 0.087103, 0.995757, -0.029681, 2.785759),
'position_group_12': (-0.338, 0.259, -0.073, 0.087103, 0.995757, -0.029681, 2.785759),
'position_group_13': (-0.339, 0.273, -0.065, 0.202020, 0.979225, 0.017483, 2.764647),
'position_group_14': (-0.336, 0.290, -0.066, 0.244628, 0.963147, -0.111827, 2.740450),
}
position_num = 15
position_dims = 7
TARGETS = []
for i in range(position_num):
TARGETS.append(np.array(LABELS['position_group_'+str(i)], dtype=np.float32))
TARGETS = np.stack(TARGETS, axis=0)
t_factors = np.array([10.0, 10.0, 100.0, 10.0, 1.0, 100.0, 1.0], dtype=np.float32)
TARGETS_FAC = TARGETS * t_factors[np.newaxis, :]
KEYS = {
'c1': 0,
'd2': 1,
'e1': 2,
'f1': 3,
'g1': 4,
'a1': 5,
'b1': 6,
'c2': 7,
'd2': 8,
'e2': 9,
'f2': 10,
'g2': 11,
'a2': 12,
'b2': 13,
'c3': 14,
'd3': 15,
'e3': 16
}
finger_num = 3
finger_pop_num = 10
key_num = 17
key_pop_num = 5
def creat_key_finger_emb():
key_value = np.zeros((key_num, key_num*key_pop_num), dtype=np.float32)
finger_value = np.zeros((finger_num, finger_num*finger_pop_num), dtype=np.float32)
for i in range(key_num):
key_value[i, i*key_pop_num: (i+1)*key_pop_num] += 1.0
for i in range(finger_num):
finger_value[i, i*finger_pop_num: (i+1)*finger_pop_num] += 1.0
return (key_value, finger_value)
def mse_loss(pred, target):
mse = F.mse_loss(pred, target)
def main():
key_embs, finger_emb = creat_key_finger_emb()
in_dims = key_embs.shape[1] + finger_emb.shape[1]
model = Motion(in_dims=in_dims, out_dims=position_dims, time_window=args.time_window).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.MSELoss().to(device)
T = 100
batch_size = 32
EPOCHS = 200
with open(args.log_path, 'a+') as f:
argsDict = args.__dict__
f.writelines('------------------ start ------------------' + '\n')
for eachArg, value in argsDict.items():
f.writelines(eachArg + ' : ' + str(value) + '\n')
f.writelines('------------------- end -------------------'+ '\n')
for epoch in range(EPOCHS):
# train
for step in tqdm.tqdm(range(T)):
key_idxs = np.random.choice(key_num, size=batch_size)
finger_idxs = np.random.choice(finger_num, size=batch_size)
labels = np.clip(key_idxs - finger_idxs, a_min=0, a_max=position_num-1)
in_emb = np.concatenate([key_embs[key_idxs], finger_emb[finger_idxs]], axis=-1)
x = torch.from_numpy(in_emb).to(device)
y = torch.from_numpy(TARGETS_FAC[labels]).to(device)
y_pred = model(x)
loss = criterion(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# test
loss_record = []
f.writelines('\n')
f.writelines('Epoch:[{epoch}/{total_eps}]:\n'.format(epoch=epoch, total_eps=EPOCHS))
for key in range(key_num):
for fin in range(finger_num):
in_emb = np.concatenate([key_embs[key], finger_emb[fin]], axis=-1)
x = torch.from_numpy(in_emb).to(device)
with torch.no_grad():
pred = model(x)
y = max(min(key - fin, position_num-1), 0)
target = torch.from_numpy(TARGETS_FAC[y]).to(device)
loss = F.mse_loss(pred, target, reduction='sum')
loss_record.append(loss.cpu().item())
real_pred = pred.cpu().numpy() / t_factors
distant = np.sum((TARGETS[y] - real_pred)**2)**0.5
f.writelines(' Predict position {y}: {pred}\n'.format(y=y, pred=real_pred.tolist()))
f.writelines('==> Epoch:[{epoch}/{total_eps}][validation stage]: loss: {loss}, distant {dis}\n'.format(
epoch=epoch, total_eps=EPOCHS, loss=sum(loss_record)/len(loss_record), dis=distant))
print('==> Epoch:[{epoch}/{total_eps}][validation stage]: loss: {loss}, distant {dis}\n'.format(
epoch=epoch, total_eps=EPOCHS, loss=sum(loss_record)/len(loss_record), dis=distant))
if __name__ == '__main__':
main()
================================================
FILE: examples/MotorControl/experimental/model.py
================================================
import torch
import numpy as np
import torch.nn as nn
from brain_area import Celebellum, MotorCortex
class Motion(nn.Module):
def __init__(self, in_dims: int, out_dims: int=17, time_window: int=8, emb_size: int = 128) -> None:
super().__init__()
self._time_window = time_window
self.in_emb = nn.Linear(in_dims, emb_size)
self.motor_cotex = MotorCortex(input_dims=emb_size, out_dims=64, time_window=self._time_window)
self.cele = Celebellum(input_dims=64, out_dims=out_dims, time_window=self._time_window)
# self.opti = torch.optim.Adam(net.parameters(), lr=0.001)
def forward(self, x):
in_emb = self.in_emb(x)
motor_out = self.motor_cotex(in_emb)
out = self.cele(motor_out)
return out
def learn(self):
pass
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/README.md
================================================
# Corticothalamic minicolumn
## Description
The anatomical data is saved in the "tool" package. The **main.py** create the network of minicolumn deppending on the anatomical data.
A file named **"fire.csv"** will be generated to record the firing result of neurons in each time step.
## Requirments
* numpy
* scipy
* pytorch >= 1.7.0
```shell
python main.py
```
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/data/__init__.py
================================================
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/data/globaldata.py
================================================
'''
Created on 2014.11.25
@author: Liang Qian
'''
#from tools import dbconnection as DB
from tools import exdata as Data
data = Data.EXDATA()
curneuronindex = 0
SynapseNumberPerDendrite = 40
ProximalDendriteNumerPerNeuron = 1
f = open('fire.csv','w')
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/main.py
================================================
import sys
sys.path.append("../")
import time
from model import cortex_thalamus
if __name__ == '__main__':
starttime = time.time()
myCortex = cortex_thalamus.Cortex_Thalamus(1000) # create a cortex object and specify the neuron number scale
myCortex.CreateCortexNetwork() # create cortex-thalamus network by the cortical object
myCortex.run()
print(len(myCortex.synapse))
totaltime = (time.time() - starttime)
print("totaltime:" + str(totaltime) + "s")
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/model/__init__.py
================================================
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/model/cortex.py
================================================
'''
Created on 2015.5.27
@author: Liang Qian
'''
from data import globaldata as Global
from .layer import Layer
from .synapse import Synapse
class Cortex():
'''
This class defines properties and funtions of cortex
'''
def __init__(self,neuronnumscale):
'''
Constructor
'''
self.neuronnumscale = neuronnumscale
self.neuronsNumber = 0
self.synapsesNumber = 0
self.neurons = [] # a list storing all neurons of the whole cortex
self.layers = {} # a dictionary storing layers of cortex
self.synapses = [] # a list storing all synapses of the whole cortex
self.minicolumns = [] # a list storing information per mini-column
self.neurontoindex = {} # a dictionary storing name to index of neuronlist
self.totaldata = Global.data.getCortexData()
def setNeuronToIndex(self,node):
name = node.name
if(self.neurontoindex.get(name) == None):
self.neurontoindex[name] = len(self.neurons) - 1
def setLayers(self):
layerdic = Global.data.getLayerData()
for i,info in layerdic.items():
layer = Layer()
layer.name = info.get('name')
layer.neuronnum = self.neuronnumscale * float(info.get('neuronnum'))/100
print(layer.name + ' neuron number:'+str(layer.neuronnum))
# layer.synapsenum = self.synapsenum * info.get('synapsenum')
# print(layer.name + ' synapse number:'+str(layer.synapsenum))
self.layers[layer.name] = layer
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/model/cortex_thalamus.py
================================================
'''
Created on 2014.11.13
@author: Liang Qian
'''
import sys
from .dendrite import Dendrite
sys.path.append('../')
from data import globaldata as Global
from .thalamus import Thalamus
from .cortex import Cortex
from .layer import Layer
from .synapse import Synapse
from braincog.base.node.node import *
class Cortex_Thalamus():
'''
cortex class is used to build human brain cortex
members are as follows:
@param neuronnum: total neuron number
@param layer: layer list
at the aspect of coding or algorithm,cortex is a huge Graph,neurons can be seen as nodes,synapses can be seen as edges,
but there may be plenty of edges between any two given nodes,so we can not use the to express an edge,
the edge should be defined as an object,if we use adjacency list to store this huge graph,the form is just like
node->map(node,list)->(map(another node,list))
'''
def __init__(self, neuronnumscale):
self.neuronnumscale = neuronnumscale
self.cortex = Cortex(neuronnumscale)
self.thalamus = Thalamus()
self.minicolumns = [] # a list storing information per mini-column
self.synapsenum = 0
self.neuronname = [] # a list storing neuron name in all cortex
self.neuron = [] # a list storing all neuron in cortex
self.synapse = [] # a list storing all synapses in thalamocortex
self.neurontoindex = {} # a dictionary to storing the neuron mapping to index in neuron list
self.totaldata = Global.data.getCortexData()
def setSynapseNum(self):
num = 0
for name,info in self.totaldata.items():
num += self.neuronnumscale * (info.get('synnum')*info.get('neunum')/100.0)
self.neuronname.append(info.get('neuronname'))
self.synapsenum = num
def setLayer(self):
layerdic = Global.data.getLayerData()
for i,info in layerdic.items():
layer = Layer()
layer.name = info.get('name')
layer.neuronnum = self.neuronnumscale * info.get('neuronnum')/100
print(layer.name + ' neuron number:'+str(layer.neuronnum))
layer.synapsenum = self.synapsenum * info.get('synapsenum')
print(layer.name + ' synapse number:'+str(layer.synapsenum))
self.layer[layer.name] = layer
def setNeuronsDendritesAndSynapes(self):
neurondic = Global.data.getNeuronData()
index = 0
for s in self.neuronname:
neuroninfo = neurondic.get(s)
self.neurontoindex[s] = len(self.neuron)
num = self.neuronnumscale * float(neuroninfo.get('percent'))/100.0
#using synapse numbers to compute dendrites numbers
synapsedic = Global.data.getSynapseData(s)
dendic = {}
totaldennum = 0
for r,item in synapsedic.items():
synum = int(item.get('synapsenum'))
dennum = (synum+Global.SynapseNumberPerDendrite-1)//Global.SynapseNumberPerDendrite # a dendrite contains no more than 40 synapses
loc = item.get('locationlayer')
dendic[loc] = dennum
for i in range(int(num)):
#init
node = CTIzhNode(morphology = neuroninfo.get('morphology'),
name = neuroninfo.get('name'),
excitability = neuroninfo.get('excitability'),
spiketype = neuroninfo.get('spiketype'),
synnum = synum,
locationlayer = neuroninfo.get('location layer'),
totalindex = index,
Gup = float(neuroninfo.get('Gup')),
Gdown = float(neuroninfo.get('Gdown')),
Vr = float(neuroninfo.get('Vr')),
Vt = float(neuroninfo.get('Vt')),
Vpeak = float(neuroninfo.get('Vpeak')),
a = float(neuroninfo.get('a')),
b = float(neuroninfo.get('b')),
c = float(neuroninfo.get('Csoma')),
d = float(neuroninfo.get('d')),
capacitance = float(neuroninfo.get('capacitance')),
k = float(neuroninfo.get('k')),
)
# set dendrites
count = 0; flag = False
for dlocatelayer,dnum in dendic.items():
for j in range(dnum):
den = Dendrite()
den.locationlayer = dlocatelayer
if(dlocatelayer != node.locationlayer or flag):
den.postion = 'distal'
node.distal_dendrites.append(den)
else:
if(count < Global.ProximalDendriteNumerPerNeuron):
den.postion = 'proximal'
node.proximal_dendrites.append(den)
count += 1
if(count >= Global.ProximalDendriteNumerPerNeuron):
flag = True
#node.getDendritesInfo()
if(node.locationlayer == 'T'):
self.thalamus.neuronsNumber += 1
self.thalamus.neurons.append(node)
self.thalamus.setNeuronToIndex(node)
else:
self.cortex.neuronsNumber += 1
self.cortex.neurons.append(node)
self.cortex.setNeuronToIndex(node)
la = self.cortex.layers.get(node.locationlayer)
la.neuronlist.append(node)
if node.name not in la.neuronname:
la.neuronname.append(node.name)
self.neuron.append(node)
index += 1
# set synapses
for postindex,node in enumerate(self.neuron): # this step can be optimized
synapsedic = Global.data.getSynapseData(node.name)
# get synapse_pre neuron and relative numbers of synapses
count = 0
for r,item in synapsedic.items():
#print(item)
for s in self.neuronname:
#print("pre_neuron_name:"+s)
totalsynapsenum = round(item.get('synapsenum') * item.get(s)/100.0)
if(postindex == 0):count += totalsynapsenum
#print("pre_neuron_name_synapse_num:"+str(totalsynapsenum))
if(totalsynapsenum > 0): #if this neuron connect to synapse_pre neuron s,distribute the synapse to these neurons
info = self.totaldata.get(s)
preneuronnum = round(self.neuronnumscale * int(info.get('neunum'))/100)
#print("pre_neuron_name_number:" + str(preneuronnum))
avgnum = lastnum = 0
if(preneuronnum == 1):
avgnum = lastnum = totalsynapsenum
else:
avgnum = totalsynapsenum // (preneuronnum-1)
lastnum = totalsynapsenum - (avgnum*(preneuronnum-1))
preneuronindex = self.neurontoindex.get(s)
for j in range(preneuronindex,preneuronindex+preneuronnum):
if(j == postindex):lastnum += avgnum;continue # not connect to itself
preneuron = self.neuron[j]
synapselist = []
if(preneuron.adjneuronlist.get(node) != None):
synapselist = preneuron.adjneuronlist.get(node)
if(j == preneuronindex+preneuronnum-1): #the last neuron
for t in range(lastnum):
synapse = Synapse(self.neuron[j],node,item.get('locationlayer'))
synapselist.append(synapse)
self.synapse.append(synapse)
if(item.get('locationlayer') != 'T'):
layerinfo = self.cortex.layers.get(item.get('locationlayer'))
layerinfo.synapselist.append(synapse)
if(node.locationlayer == 'T'):
self.thalamus.synapses.append(synapse)
self.thalamus.synapsesNumber += 1
else:
self.cortex.synapses.append(synapse)
self.cortex.synapsesNumber += 1
else:
for t in range(avgnum):
synapse = Synapse(self.neuron[j],node,item.get('locationlayer'))
synapselist.append(synapse)
self.synapse.append(synapse)
if(item.get('locationlayer') != 'T'):
layerinfo = self.cortex.layers.get(item.get('locationlayer'))
layerinfo.synapselist.append(synapse)
if(node.locationlayer == 'T'):
self.thalamus.synapses.append(synapse)
self.thalamus.synapsesNumber += 1
else:
self.cortex.synapses.append(synapse)
self.cortex.synapsesNumber += 1
if(preneuron.adjneuronlist.get(node) == None and len(synapselist) > 0):
preneuron.adjneuronlist[node] = synapselist
# set these synapses to dendrite list
#self.setSynapsesToDendrites()
def setSynapsesToDendrites(self):
for node in self.neuron:
if node.name == 'TCs': # synapses from TCs to ss4(L4) must be located in proximal dendrites of ss4(L4)
for post,synlist in node.adjneuronlist.items():
if(post.name == 'ss4(L4)'):
for syn in synlist:
flag = post.addSynapseToDendrite('proximal',syn)
if(not flag):flag = post.addSynapseToDendrite('distal',syn)
if(not flag):
print('all dendrites are full in neuron' + post.name + '_'+str(node.totalindex))
else:
for syn in synlist:
flag = False
if(post.locationlayer == syn.locationlayer):
flag = post.addSynapseToDendrite('proximal',syn)
if(not flag):
flag = post.addSynapseToDendrite('distal',syn)
else:
flag = post.addSynapseToDendrite('distal',syn)
if(not flag): print('all dendrites are full in neuron' + post.name + '_'+str(node.totalindex))
else:
for post,synlist in node.adjneuronlist.items():
for syn in synlist:
flag = False
if(post.locationlayer == syn.locationlayer):
flag = post.addSynapseToDendrite('proximal',syn)
if(not flag):
flag = post.addSynapseToDendrite('distal',syn)
if(not flag): print('all dendrites are full in neuron' + node.name + '_'+str(node.totalindex))
# f = open('dendrites_info.csv','w')
# for node in self.neuron:
# if(node.totalindex == 0):
# node.getDendriteSynapsesInfo(f)
# f.close()
def setCortexProperties(self):
self.cortex.setCortexProperties()
def setThalamusProperties(self):
self.thalamus.setThalamusProperties()
def CreateCortexNetwork(self):
self.setSynapseNum()
self.cortex.setLayers()
self.setNeuronsDendritesAndSynapes()
self.setThalamusProperties()
#----------API Of the whole network--------------#
def getTotalNeuronNumber(self):
return len(self.neuron)
def getTotalSynapseNumber(self):
return len(self.synapse)
def getCortexNeuronNumber(self):
return self.corticalneuronnumber
def getThalamoNeuronNumber(self):
return self.thamlamoneuronnumber
def getSpecifiedNeuronNumber(self,name):
result = {}
if name in self.neuronname:
info = self.totaldata.get(name)
num = self.neuronnumscale * info.get('neunum')/100
result[name] = num
elif name == 'all':
for r,info in self.totaldata.items():
num = self.neuronnumscale * info.get('neunum')
result[r] = num
return result
def getNeuronTypesNumber(self):
return len(self.neuronname)
def getNeuronTypes(self):
return self.neuronname
def getCorticalSynapseNumber(self):
return len(self.corticalsynapse)
def getThalamoSynapseNumber(self):
return len(self.corticalsynapse)
def getPreAndPostNeuronsOfSynapse(self,index):
if(index >= 0 or index <= len(self.synapse -1)):
return self.synapse[index].pre,self.synapse[index].post
else: return None
#--------------API of the layer----------------------#
def getCortexLayerNeuronNumber(self,layername):
layerinfo = self.layer.get(layername)
return layerinfo.getLayerNeuronNumber()
def getCortexLayerSynapseNumber(self,layername):
layerinfo = self.layer.get(layername)
return layerinfo.getLayerSynapseNumber()
def getCortexLayerNeuronTypes(self,layername):
layerinfo = self.layer.get(layername)
if(layerinfo == None):
print(layername +" is not in Cortex!")
return None
return layerinfo.neuronname
def getCortexLayerPreAndPostNeuronsOfSynapse(self,layername,index):
layerinfo = self.layer.get(layername)
if(index >= 0 and index < len(layerinfo.synapselist)):
return layerinfo.synapselist[index].pre, layerinfo.synapselist[index].post
def getNeuronAllPreNeuronsTypes(self,index):
if(index >= 0 and index < len(self.neuron)):
node = self.neuron[index]
return node.getWholePreSynapseNeuronType()
def outputNeuronInfo(self):
f = open('neuron.csv','w')
f.write('index,name,morphology,excitability,locationlayer\n')
for node in self.neuron:
flag = 'No'
if(node.excitability == "TRUE"):
flag = 'Yes'
f.write(str(node.totalindex)+','+node.name+','+node.morphology+','+flag+','+node.locationlayer+'\n')
f.close()
def outputConnectionMatrix(self):
M = len(self.neuron)
matrix = [[0 for col in range(M)] for row in range(M)]
for node in self.neuron:
for post,list in node.adjneuronlist.items():
weight = len(list)
row = node.totalindex
col = post.totalindex
matrix[row][col] = weight
f = open('connection.csv','w')
name = ''
for node in self.neuron:
name += node.name + ','
f.write(name+'\n')
for row in range(M):
line = ''
for col in range(M):
line += str(matrix[row][col])+','
f.write(line+'\n')
f.close()
def outputsynapspercent(self,namelist):
totalcount = 0
slist = {'L1':0,'L2/3':0,'L4':0,'L5':0,'L6':0,'T':0}
for name in namelist:
for node in self.neuron:
for pre,list in node.adjneuronlist.items():
if pre.name == name:
totalcount = totalcount+len(list)
loclayer = node.locationlayer
value = slist.get(loclayer) + len(list)
slist[loclayer] = value
print(slist)
#-----------------------runing the whole network---------------------------#
def run(self):
'''
run the cortical system
'''
#s1:stimulate the neuron in L4
L = self.cortex.layers.get('L4')
L.stimulateNeuronInLayer4_BFS(30,self.neuron)
def outputSpikeThreashold(self):
f = open('threashold.csv','w')
for node in self.neuron:
f.write(str(node.Vpeak)+'\n')
f.close()
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/model/dendrite.py
================================================
'''
Created on 2015.5.19
@author: liangqian
'''
import sys
import sys
from data.globaldata import *
class Dendrite():
'''
This class defines dendrite structure of a neuron.
A dendrite contains no more than 40 synapses
'''
def __init__(self):
'''
Constructor
'''
self.synapses = [] # synapses list which this dendrite contains
self.locationlayer = '' # layer this dendrite locates in
self.postion = '' # the distance from soma, proximal or distal
def setSynapse(self,syn):
'''
This function is going to insert the Synapse syn to this dendrite
if the number of synapse of this dendrite is more than theshold, the current synapse
can not be inserted to the dendrite.
'''
if(len(self.synapses) >= SynapseNumberPerDendrite):
return False
else:
self.synapses.append(syn)
return True
def getSynapseInfo(self,f,nodename,denpos):
for syns in self.synapses:
syns.getInfo(f,nodename,denpos,self.locationlayer)
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/model/fire.csv
================================================
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/model/layer.py
================================================
'''
Created on 2014.11.13.
@author: Liang Qian
'''
import sys
sys.setrecursionlimit(1000000)
from data import globaldata as Global
class Layer():
'''
layer class is used to build brain cortical layer of cortex
class members are as follows:
@param name: layer name(L1,L2,L3,.etc)
@param neuronnum: total neural numbers of this layer
@param neuraltype: neural type list of this layer
@param neuronlist: different types of neuron list in this layer
'''
def __init__(self):
self.name = ''
self.neuronnum = 0
self.synapsenum = 0
self.neuronlist = []
self.synapselist = []
self.neuronname = []
def getLayerNeuronNumber(self):
return len(self.neuronlist)
def getLayerSynapseNumber(self):
return len(self.synapselist)
def getLayerNeuronTypes(self):
return self.neuronname
def stimulateNeuronInLayer4_BFS(self, T, neulist):
for node in self.neuronlist:
if(node.name == 'ss4(L2/3)'):
break;
step = int(T*1000/1.0)
dc = 0
for i in range(step):
#print(i)
strs = str(i)+','
if(i > 1):
for n in neulist:
if(n.totalindex == node.totalindex):continue
if(n.dc > 0):
n.integral(n.dc)
n.calc_spike()
if(n.spike == 1):
strs +=n.name+':'+str(n.totalindex)+','
if(i < 10 or i > 25000):
dc = 0
else:
dc = 400
node.integral(dc)
node.calc_spike()
if(node.spike == 1):
strs +=node.name+':'+str(node.totalindex)+','
Global.f.write(strs+'\n')
Global.f.close()
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/model/synapse.py
================================================
'''
Created on 2014.11.13
@author: Liang Qian
'''
class Synapse():
'''
synapsis class is used to create a synapsis structure
members are as follows:
@param pre: pre-synapsis neuron
@param post: post-synapsis neuron
@param locationlayer: layer where this synapse locate in
'''
def __init__(self, pre,post,locationlayer):
self.pre = pre
self.post = post
self.locationlayer = locationlayer
self.I = 0
self.weight = 0 if(pre.name == 'p2/3' and post.name == 'p2/3') else -1
# self.tauAMPA = 5
# self.tauNMDA = 150
# self.tauGABAA = 6
# self.tauGABAB = 150
# self.STDPA_pos = 1
# self.STDPA_neg = 2
# self.tau_pos = 20
# self.tau_neg = 20
def getInfo(self,f,nodename,denpos,denlayer):
f.write('neuron:'+nodename+','+'dendrite:'+denpos+','+'den_layer:'+denlayer+','
+'syn_Layer:'+ self.locationlayer+','
+ 'pre_neuron:'+self.pre.name+','+'pre_neuron_index:'+str(self.pre.totalindex)+','
+ 'post_neuron:'+self.post.name+','+'post_neuron_index:'+str(self.post.totalindex)+','
+ 'weight:'+str(self.weight)+'\n')
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/model/thalamus.py
================================================
'''
Created on 2015.5.27
@author: Liang Qian
'''
class Thalamus():
'''
This class defines the basic functions and properties of thalamus
'''
def __init__(self):
'''
Constructor
'''
self.neuronsNumber = 0
self.synapsesNumber = 0
self.neurons = []
self.synapses = []
self.neurontoindex = {}
def setNeuronToIndex(self,node):
name = node.name
if(self.neurontoindex.get(name) == None):
self.neurontoindex[name] = len(self.neurons)-1
def setThalamusProperties(self):
print(len(self.synapses))
for node in self.neurons:
if(node.name == 'TCs'):
for post,synlist in node.adjneuronlist.items():
if(post.name == 'ss4(L2/3)'):
for syn in synlist:
syn.weight = 0
if(node.name == 'TCn'):
for post,synlist in node.adjneuronlist.items():
if(post.name == 'p6(L5/6)'):
for syn in synlist:
syn.weight = 0
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/tools/__init__.py
================================================
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/tools/cortical.csv
================================================
neuronname,neuronnum,synmapsenum,area
nb1,1.5,8890,cortex
p2/3,26,7106,cortex
b2/3,3.1,3854,cortex
ss4(L4),9.2,5792,cortex
ss4(L2/3),9.2,4989,cortex
p4,9.2,6703,cortex
b4,5.4,3230,cortex
nb4,1.5,3688,cortex
p5(L2/3),4.8,5196,cortex
p5(L5/6),1.3,13075,cortex
b5,0.6,2981,cortex
nb5,0.8,2981,cortex
p6(L4),13.6,6363,cortex
p6(L5/6),4.5,6421,cortex
b6,2,3220,cortex
nb6,2,3220,cortex
nb2/3,4.2,3307,cortex
TCs,0.5,4000,thalamus
TCn,0.5,4000,thalamus
TIs,0.1,3000,thalamus
TIn,0.1,3000,thalamus
TRN,0.5,4000,thalamus
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/tools/exdata.py
================================================
import os
import csv
#import pandas as pd
class EXDATA():
def __int__(self):
pass
def getCortexData(self):
neurondic = {}
f = open("./tools/cortical.csv","r")
line = f.readline()
count = 0
while(True):
line = (f.readline()).strip()
if(len(line) <= 0): break;
strs = line.split(",")
info = {}
info['neuronname'] = strs[0].strip()
info['neunum'] = float(strs[1])
info['synnum'] = int(strs[2])
info['area'] = str(strs[3])
neurondic[strs[0].strip()] = info
count += 1
f.close()
print(neurondic)
return neurondic
def getCortexData2(self):
neurondic = {}
data = pd.read_csv("../tools/cortical.csv")
print('debug')
def getLayerData(self):
layerdic = {}
f = open("./tools/layer.csv","r")
strs = (f.readline()).strip()
str = strs.split(",")
count = 0
while(True):
info = {}
line = (f.readline()).strip()
if(len(line) <= 0):break
v = line.split(",")
for i in range(len(str)):
if(i > 0):
v[i] = float(v[i])
info[str[i]] = v[i]
layerdic[count] = info
count += 1
f.close()
return layerdic
def getNeuronData(self):
neurondic = {}
f = open("./tools/neuron.csv","r")
strs = (f.readline()).strip()
str = strs.split(",")
while (True):
info = {}
line = (f.readline()).strip()
if (len(line) <= 0): break
v = line.split(",")
for i in range(len(str)):
if(len(v[i]) <= 0): break
if(i > 4): v[i] = float(v[i])
info[str[i].strip()] = v[i]
neurondic[v[0].strip()] = info
f.close()
return neurondic
def getSynapseData(self, postneuron):
synapsemap = {}
f = open("./tools/synapse.csv","r")
fields = (f.readline()).strip()
fields = fields.split(",")
while(True):
line = (f.readline()).strip()
if(len(line) <= 0):break
strs = line.split(",")
info = {}
for i,v in enumerate(fields):
if(i > 1): strs[i] = float(strs[i])
info[v] = strs[i]
if(synapsemap.get(strs[0]) == None):
syndic = {}
syndic[len(syndic)] = info
synapsemap[strs[0]] = syndic
else:
syndic = synapsemap.get(strs[0])
syndic[len(syndic)] = info
f.close()
return synapsemap.get(postneuron)
#tmp = EXDATA()
#tmp.getCortexData()
#tmp.getLayerData()
#tmp.getNeuronData()
#result = tmp.getSynapseData('p4')
#print(result)
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/tools/layer.csv
================================================
name,neuronnum,synapsenum,nb1,p2/3,b2/3,nb2/3,ss4(L4),ss4(L2/3),p4,b4,nb4,p5(L2/3),p5(L5/6),b5,nb5,p6(L5/6),b6,p6(L4),nb6
L1,1.5,10.86,1.5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
L2/3,33.3,32.86,0,26,3.1,4.2,0,0,0,0,0,0,0,0,0,0,0,0,0
L4,34.5,34.04,0,0,0,0,9.2,9.2,9.2,5.4,1.5,0,0,0,0,0,0,0,0
L5,7.5,8.1,0,0,0,0,0,0,0,0,0,4.8,1.3,0.6,0.8,0,0,0,0
L6,22.1,14.14,0,0,0,0,0,0,0,0,0,0,0,0,0,4.5,2,13.6,2
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/tools/neuron.csv
================================================
name,morphology,location layer,spiketype,excitability,percent,capacitance,k ,Vr ,Vt ,Vpeak ,Gup,Gdown,a ,b ,Csoma,Cdendr,d
nb1,non-basket,L1,LS,FALSE,1.5,20,0.3,-66,-40,30,0.6,2.5,0.17,5,-45,-45,100
p2/3,pyramidal,L2/3,RS,TRUE,26,100,3,-60,-50,50,3,5,0.01,5,-60,-55,400
b2/3,basket,L2/3,FS,FALSE,3.1,20,1,-55,-40,25,0.5,1,0.15,8,-55,-55,200
nb2/3,non-basket,L2/3,LTS,FALSE,4.2,100,1,0,-42,40,1,1,0.03,8,-50,-50,20
ss4(L4),spiny stell,L4,RS,TRUE,9.2,100,3,-60,-50,50,3,5,0.01,5,-60,-50,400
ss4(L2/3),spiny stell,L4,RS,TRUE,9.2,100,3,-60,-50,50,3,5,0.01,5,-60,-50,400
p4,pyramidal,L4,RS,TRUE,9.2,100,3,-60,-50,50,3,5,0.01,5,-60,-50,400
b4,basket,L4,FS,FALSE,5.4,20,1,-55,-40,25,0.5,1,0.15,8,-55,-55,200
b5,basket,L5,FS,FALSE,0.6,20,1,-55,-40,25,0.5,1,0.15,8,-55,-55,200
nb4,non-basket,L4,LTS,FALSE,1.5,100,1,-55,-42,40,1,1,0.03,8,-50,-50,20
p5(L2/3),pyramidal,L5,RS,TRUE,4.8,100,3,-60,-50,50,3,5,0.01,5,-60,-50,400
nb5,non-basket,L5,LTS,FALSE,0.8,100,1,-55,-42,40,1,1,0.03,8,-50,-50,20
p5(L5/6),pyramidal,L5,RS,TRUE,1.3,100,3,-60,-50,50,3,5,0.01,5,-60,-50,400
p6(L4),pyramidal,L6,RS,TRUE,13.6,100,3,-60,-50,50,3,5,0.01,5,-60,-50,400
b6,basket,L6,FS,FALSE,2,20,1,-55,-40,25,0.5,1,0.15,8,-55,-55,200
nb6,non-basket,L6,FS,FALSE,2,100,1,0,-42,40,1,1,0.03,8,-50,-50,20
TCs,TC in specific nucleus,T,TC,TRUE,0.5,200,1.6,-60,-50,40,2,2,0.1,15,-60,-60,10
TCn,TC in non specific nucleus,T,TC,TRUE,0.5,200,1.6,-60,-50,40,2,2,0.1,15,-60,-60,10
TIs,thalamic in specific nucleus,T,TI,FALSE,0.1,20,0.5,-60,-50,20,5,5,0.05,7,-65,-65,50
TIn,thalamic in non specific nucleus,T,TI,FALSE,0.1,20,0.5,-60,-50,20,5,5,0.05,7,-65,-65,50
TRN,GABAergic,T,TRN,TRUE,0.5,40,0.25,-65,-45,0,5,5,0.015,10,-55,-55,50
p6(L5/6),pyramidal,L6,RS,TRUE,4.5,100,3,-60,-50,50,3,5,0.01,5,-60,-50,400
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/tools/synapse.csv
================================================
postneuron,locationlayer,synapsenum,nb1,p2/3,b2/3,nb2/3,ss4(L4),ss4(L2/3),p4,b4,nb4,p5(L2/3),p5(L5/6),b5,nb5,p6(L4),p6(L5/6),b6,nb6,TCs,TCn,TIs,TIn,TRN
nb1,L1,8890,10.1,6.3,0.6,1.1,0,0,0.1,0,0,0.1,0,0,0,0,0,0,0,0,4.1,0,0,0
p2/3,L2/3,5800,0,59.9,9.1,4.4,0.6,6.9,7.7,0,0.8,7.4,0,0,0,2.3,0,0,0.8,0,0,0,0,0
p2/3,L1,1306,10.2,6.3,0.1,1.1,0,0,0.1,0,0,0.1,0,0,0,0,0,0,0,0,4.1,0,0,0
b2/3,L2/3,3854,1.3,51.6,10.6,3.4,0.5,5.8,6.6,0,0.8,6.3,0,0,0,2.1,0,0,0.7,0,0.5,0,0,0
nb2/3,L2/3,3307,1.7,48.6,11.4,3.3,0.5,5.5,6.2,0,0.8,5.9,0,0,0,1.8,0,0,0.6,0,0.7,0,0,0
ss4(L4),L4,5792,0,2.7,0.2,0.6,11.9,3.7,4.1,7.1,2,0.8,0.1,0,0,32.7,0,0,5.8,1.7,1.3,0,0,0
ss4(L2/3),L4,4989,0,5.6,0.4,0.8,11.3,3.8,4.3,7.2,2.1,1.1,0.1,0,0,31.1,0,0,5.5,1.7,1.3,0,0,0
p4,L4,5031,0,4.3,0.2,0.6,11.5,3.6,4.2,7.2,2.1,1.2,0.1,0,0,31.4,0.1,0,5.9,1.7,1.3,0,0,0
p4,L2/3,866,0,63.1,5.1,4.1,0.6,7.2,8.1,0,0.6,7.8,0,0,0,2.5,0,0,0.8,0,0,0,0,0
p4,L1,806,10.2,6.3,0.1,1.1,0,0,0.1,0,0,0.1,0,0,0,0,0,0,0,0,4.1,0,0,0
b4,L4,3230,0,5.8,0.5,0.8,11,3.8,4.2,8.4,2.4,1.1,0,0,0,30.3,0,0,5.4,1.6,1.2,0,0,0
nb4,L4,3688,0,2.7,0.2,0.6,11.7,3.6,4,8.2,2.3,0.8,0.1,0,0,32.2,0,0,5.7,1.7,1.3,0,0,0
p5(L2/3),L5,4316,0,45.9,1.8,0.3,3.3,2,7.5,0,0.9,11.7,1,0.8,1.1,2.3,2.1,0,11.5,0.1,0.4,0,0,0
p5(L2/3),L4,283,0,2.8,0.1,0.7,12.2,3.8,4.2,5.2,1.5,0.8,0.1,0,0,33.7,0,0,5.9,1.8,1.4,0,0,0
p5(L2/3),L2/3,412,0,63.1,5.1,4.1,0.6,7.2,8.1,0,0.6,7.8,0,0,0,2.5,0,0,0.8,0,0,0,0,0
p5(L2/3),L1,185,10.2,6.3,0.1,1.1,0,0,0.1,0,0,0.1,0,0,0,0,0,0,0,0,4.1,0,0,0
p5(L5/6),L5,5101,0,44.3,1.7,0.2,3.2,2,7.3,0,0.8,11.3,1.2,0.8,1.1,2.3,2.5,0.3,11.3,0.2,0.5,0,0,0
p5(L5/6),L4,949,0,2.8,0.1,0.7,12.2,3.8,4.2,5.2,1.5,0.8,0.1,0,0,33.7,0,0,5.9,1.8,1.4,0,0,0
p5(L5/6),L2/3,1367,0,63.1,5.1,4.1,0.6,7.2,8.1,0,0.6,7.8,0,0,0,2.5,0,0,0.8,0,0,0,0,0
p5(L5/6),L1,5658,10.2,6.3,0.1,1.1,0,0,0.1,0,0,0.1,0,0,0,0,0,0,0,0,4.1,0,0,0
b5,L5,2981,0,45.5,2.3,0.2,3.3,2,7.5,0,1.1,11.6,1,0.9,1.3,2.3,2,0,11.4,0.1,0.4,0,0,0
nb5,L5,2981,0,45.5,2.3,0.2,3.3,2,7.5,0,1.1,11.6,1,0.9,1.3,2.3,2,0,11.4,0.1,0.4,0,0,0
p6(L4),L6,3261,0,2.5,0.1,0.1,0.7,0.9,1.3,0,0.1,0.1,4.9,0,0.3,1.2,13.2,7.7,7.7,1.6,2.9,0,0,0
p6(L4),L5,1066,0,46.8,0.8,0.3,3.4,2.1,7.7,0,0.6,11.9,1,0.6,0.8,2.3,2.1,0,11.7,0.1,0.4,0,0,0
p6(L4),L4,1915,0,2.8,0.1,0.7,12.2,3.8,4.2,5.2,1.5,0.8,0.1,0,0,33.7,0,0,5.9,1.8,1.4,0,0,0
p6(L4),L2/3,121,0,63.1,5.1,4.1,0.6,7.2,8.1,0,0.6,7.8,0,0,0,2.5,0,0,0.8,0,0,0,0,0
p6(L5/6),L6,5573,0,2.5,0.1,0.1,0.7,0.9,1.3,0,0.1,0.1,4.9,0,0.3,1.2,13.2,7.8,7.8,0,2.9,0,0,0
p6(L5/6),L5,257,0,46.8,0.8,0.3,3.4,2.1,7.7,0,0.6,11.9,1,0.6,0.8,2.3,2.1,0,11.7,0.6,0.4,0,0,0
p6(L5/6),L4,243,0,2.8,0.1,0.7,12.2,3.8,4.2,5.2,1.5,0.8,0.1,0,0,33.7,0,0,5.9,0,1.4,0,0,0
p6(L5/6),L2/3,286,0,63.1,5.1,4.1,0.6,7.2,8.1,0,0.6,7.8,0,0,0,2.5,0,0,0.8,0,0,0,0,0
p6(L5/6),L1,62,10.2,6.3,0.1,1.1,0,0,0.1,0,0,0.1,0,0,0,0,0,0,0,0,4.1,0,0,0
b6,L6,3220,0,2.5,0.1,0.1,0.7,0.9,1.3,0,0.1,0.1,4.9,0,0.4,1.2,13.2,7.7,7.7,0.6,2.9,0,0,0
nb6,L6,3220,0,2.5,0.1,0.1,0.7,0.9,1.3,0,0.1,0.1,4.9,0,0.4,1.2,13.2,7.7,7.7,0.6,2.9,0,0,0
TCs,T,4000,0,0,0,0,0,0,0,0,0,0,0,0,0,23,8,0,0,0,0,5,0,25.9
TCn,T,4000,0,0,0,0,0,0,0,0,0,14,3.8,0,0,0,13.2,0,0,0,0,0,5,25.9
TIs,T,3000,0,0,0,0,0,0,0,0,0,0,0,0,0,9.8,3.3,0,0,0.4,0,24.4,0,0
TIn,T,3000,0,0,0,0,0,0,0,0,0,5.8,1.6,0,0,0,5.4,0,0,0,0.6,0,24.4,0
TRN,T,4000,0,0,0,0,0,0,0,0,0,0,0,0,0,30,0,0,0,10,10,0,0,10
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/Corticothalamic_Brain_Model/Bioinformatics_propofol_circle.py
================================================
import numpy as np
import random
import math
import matplotlib.pyplot as plt
import scipy.io as scio
import pandas as pd
import torch
device = 'cpu'
trail = 0
class brain_model_91():
def __init__(self, W, D):
self.weight_matrix = W.to(device)
# for i in range(len(self.weight_matrix)):
# self.weight_matrix[i][29] = 0
# self.weight_matrix[29][i] = 0
self.distance_matrix = D.to(device)
self.distance_matrix = torch.vstack((self.distance_matrix,
torch.mean(self.distance_matrix, dim=1)))
self.distance_matrix = torch.hstack((self.distance_matrix,
torch.zeros((len(self.distance_matrix), 1), device=device)))
temp = self.distance_matrix.clone()
self.distance_matrix[0:29, 29] = temp[29, 0:29]
# self.distance_matrix = torch.zeros_like(self.distance_matrix, device=device)
self.speed = 1.5
self.decay = torch.ceil(self.distance_matrix / self.speed)
self.t_window = int(torch.max(self.decay)) + 1
self.V_th = torch.tensor([10., 4., 10., 4., 4.], device=device)
self.tau_v = torch.tensor([40., 10., 200., 20., 40.], device=device)
self.Tsig = torch.tensor([12., 10., 12., 10., 10.], device=device)
self.beta = torch.tensor([0., 4.5, 0., 4.5, 4.5], device=device)
self.alpha_ad = torch.tensor([0., -2., 0., -2., -2.], device=device)
self.tau_ad = 20
self.tau_I = 10
# Simulation parameters
self.NR = len(W)
self.NN = 500
self.NType = self.NN * torch.tensor([0.79, 0.20, 0.01, 0.005, 0.002], device=device)
self.NE = int(self.NType[0])
self.NI = int(self.NType[1])
self.NTC = int(self.NR * self.NType[2])
self.NTI = int(self.NR * self.NType[3])
self.NTRN = int(self.NR * self.NType[4])
self.NC = self.NE + self.NI
self.NSum = int((self.NR - 1) * (self.NE + self.NI) + self.NTC + self.NTI + self.NTRN)
self.Ncycle = 1
self.dt = 1
self.T = 20000
self.Delta_T = 0.5
# self.refrac = 5 / self.dt
# self.ref = self.refrac*torch.zeros((self.NN, 1)).squeeze(1)
self.gamma_c = 0.1
self.g_m = 1
self.Gama_c = self.g_m * self.gamma_c / (1 - self.gamma_c)
self.GammaII = 15
self.GammaIE = -10
self.GammaEE = 15
self.GammaEI = 15
self.TEmean = 0.5 * self.V_th[0] # Mean current to excitatory neurons
self.TTCmean = 0.5 * self.V_th[2] # Mean current to TC neurons
self.TImean = -5 * self.V_th[1]
self.TTImean = -5 * self.V_th[3]
self.TTRNmean = -5 * self.V_th[4]
self.v = torch.zeros(self.NSum, device=device)
self.vt = torch.zeros(self.NSum, device=device)
self.c_m = torch.zeros(self.NSum, device=device)
self.alpha_w = torch.zeros(self.NSum, device=device)
self.beta_ad = torch.zeros(self.NSum, device=device)
self.delta = torch.ones(self.NSum, device=device)
self.ad = torch.zeros(self.NSum, device=device)
self.vv = torch.zeros(self.NSum, device=device)
self.Iback = torch.zeros(self.NSum, device=device)
self.Ieff = torch.zeros(self.NSum, device=device)
self.Nmean = torch.zeros(self.NSum, device=device)
self.Nsig = torch.zeros(self.NSum, device=device)
self.Igap = torch.zeros(self.NSum, device=device)
self.Ichem = torch.zeros(self.NSum, device=device)
self.Ieeg = torch.zeros(self.NSum, device=device)
self.vm1 = torch.zeros(self.NSum, device=device)
self.E_range = []
self.I_range = []
self.TC_range = []
self.TI_range = []
self.TRN_range = []
self.divide_point_E = []
self.divide_point_I = []
for n in range(self.NR):
if n < self.NR - 1:
self.divide_point_E.append(list(range(n * self.NC, n * self.NC + self.NE)))
self.divide_point_I.append(list(range(n * self.NC + self.NE, n * self.NC + self.NE + self.NI)))
self.E_range = self.E_range + list(range(n * self.NC, n * self.NC + self.NE))
self.I_range = self.I_range + list(
range(n * self.NC + self.NE, n * self.NC + self.NE + self.NI))
else:
self.TC_range = self.TC_range + list(range((self.NR - 1) * self.NC,
(self.NR - 1) * self.NC + self.NTC))
self.TI_range = self.TI_range + list(range((self.NR - 1) * self.NC + self.NTC,
(self.NR - 1) * self.NC + self.NTC + self.NTI))
self.TRN_range = self.TRN_range + list(range((self.NR - 1) * self.NC + self.NTC + self.NTI,
(self.NR - 1) * self.NC + self.NTC + self.NTI + self.NTRN))
self.divide_point_E = torch.tensor(self.divide_point_E, device=device)
self.divide_point_I = torch.tensor(self.divide_point_I, device=device)
self.divide_point_CR = torch.concat((self.divide_point_E, self.divide_point_I), dim=1)
torch.save({'divide_point_E': self.divide_point_E, 'divide_point_I': self.divide_point_I,
'TC_range': self.TC_range, 'TI_range': self.TI_range, 'TRN_range': self.TRN_range},
'./neuron_divide.pt')
self.c_m[self.E_range] = self.tau_v[0] * self.g_m + 5 * torch.randn(len(self.E_range), device=device)
self.c_m[self.TC_range] = self.tau_v[2] * self.g_m
self.c_m[self.I_range] = self.tau_v[1] * (self.g_m + self.Gama_c)
self.c_m[self.TI_range] = self.tau_v[3] * (self.g_m + self.Gama_c)
self.c_m[self.TRN_range] = self.tau_v[4] * (self.g_m + self.Gama_c)
self.alpha_w[self.E_range] = self.alpha_ad[0] * self.g_m
self.alpha_w[self.TC_range] = self.alpha_ad[2] * self.g_m + self.Gama_c
self.alpha_w[self.I_range] = self.alpha_ad[1] * (self.g_m + self.Gama_c)
self.alpha_w[self.TI_range] = self.alpha_ad[3] * (self.g_m + self.Gama_c)
self.alpha_w[self.TRN_range] = self.alpha_ad[4] * (self.g_m + self.Gama_c)
self.beta_ad[self.E_range] = self.beta[0]
self.beta_ad[self.TC_range] = self.beta[2]
self.beta_ad[self.I_range] = self.beta[1]
self.beta_ad[self.TI_range] = self.beta[3]
self.beta_ad[self.TRN_range] = self.beta[4]
self.vt[self.E_range] = self.V_th[0]
self.vt[self.TC_range] = self.V_th[2]
self.vt[self.I_range] = self.V_th[1]
self.vt[self.TI_range] = self.V_th[3]
self.vt[self.TRN_range] = self.V_th[4]
self.Nmean[self.E_range] = self.TEmean * self.g_m
self.Nmean[self.TC_range] = self.TTCmean * self.g_m
self.Nmean[self.I_range] = self.TImean * (self.g_m + self.Gama_c)
self.Nmean[self.TI_range] = self.TTImean * (self.g_m + self.Gama_c)
self.Nmean[self.TRN_range] = self.TTRNmean * (self.g_m + self.Gama_c)
self.Nsig[self.E_range] = self.Tsig[0] * self.g_m
self.Nsig[self.TC_range] = self.Tsig[2] * self.g_m
self.Nsig[self.I_range] = self.Tsig[1] * (self.g_m + self.Gama_c)
self.Nsig[self.TI_range] = self.Tsig[3] * (self.g_m + self.Gama_c)
self.Nsig[self.TRN_range] = self.Tsig[4] * (self.g_m + self.Gama_c)
def simulation(self):
range_E = self.E_range + self.TC_range
range_I = self.I_range + self.TI_range + self.TRN_range
Vgap = self.Gama_c
weight_matrix = self.weight_matrix
for i in range(self.Ncycle):
I_total = torch.zeros((self.Ncycle, self.T), device=device)
V_total = torch.zeros((self.Ncycle, self.T), device=device)
V = torch.zeros(self.T, device=device)
I_subregion = torch.zeros((self.NR, self.T), device=device)
I_subregion_E = torch.zeros((self.NR, self.T), device=device)
I_subregion_I = torch.zeros((self.NR, self.T), device=device)
Vsubregion = torch.zeros((self.NR, self.T), device=device)
EEG = torch.zeros((self.T), device=device)
Iraster = []
vv_sumE = torch.zeros((self.NR, self.t_window), device=device)
vv_sumI = torch.zeros((self.NR, self.t_window), device=device)
phase = self.T / 4
for t in range(self.T):
#
if t < phase:
tau_vI = 20
self.GammaII = 15
self.GammaIE = -10
elif phase <= t < 3 * phase:
tau_vI = 20 + 20 * (t - phase) / phase
self.GammaII = 30 + 10 * (t - phase) / phase
self.GammaIE = -20 - 10 * (t - phase) / phase
elif 3 * phase <= t < 4 * phase:
tau_vI = 60
self.GammaII = 50
self.GammaIE = -40
elif 4 * phase <= t < 6 * phase:
tau_vI = 60 - 20 * (t - 4 * phase) / phase
self.GammaII = 50 - 10 * (t - 4 * phase) / phase
self.GammaIE = -40 + 20 * (t - 4 * phase) / phase
elif t > 6 * phase:
tau_vI = 20
self.GammaII = 15
self.GammaIE = -10
self.c_m[range_I] = tau_vI * (self.g_m + self.Gama_c)
WII = self.GammaII * torch.mean(self.c_m[self.I_range])
WEE = self.GammaEE * torch.mean(self.c_m[self.E_range])
WEI = self.GammaEI * torch.mean(self.c_m[self.I_range])
WIE = self.GammaIE * torch.mean(self.c_m[self.E_range])
self.Iback = self.Iback + self.dt / self.tau_I * (-self.Iback + torch.randn(self.NSum, device=device))
self.Ieff = self.Iback / math.sqrt(1 / (2 * (self.tau_I / self.dt))) * self.Nsig + self.Nmean
temp = vv_sumE.clone()
vv_sumE[:, 0:self.t_window - 1] = temp[:, 1:self.t_window]
vv_sumE[:, self.t_window - 1] = torch.cat((torch.mean(self.vv[self.divide_point_E], dim=1),
torch.mean(self.vv[self.TC_range]).unsqueeze(0)))
temp = vv_sumI.clone()
vv_sumI[:, 0:self.t_window - 1] = temp[:, 1:self.t_window]
vv_sumI[:, self.t_window - 1] = torch.cat((torch.mean(self.vv[self.divide_point_I], dim=1),
torch.mean(self.vv[self.TI_range + self.TRN_range]).unsqueeze(0)))
v_sum = torch.cat((torch.mean(self.v[self.divide_point_I], dim=1),
torch.mean(self.v[self.TI_range + self.TRN_range]).unsqueeze(0)))
v_sum_CR = v_sum[:self.NR - 1].reshape(-1, 1) * \
torch.ones((self.NR - 1, self.NI), device=device)
v_sum_CR = v_sum_CR.reshape(-1, 1).squeeze(1)
v_sum_TN = v_sum[self.NR - 1] * \
torch.ones(self.NTI + self.NTRN, device=device)
v_sum = torch.cat((v_sum_CR, v_sum_TN))
time_decay = torch.concat(
(torch.concat([torch.arange(30, device=device).unsqueeze(0)] * 30, dim=0).unsqueeze(0),
self.t_window - 1 - self.decay.unsqueeze(0)), dim=0)
time_decay = list(time_decay.long())
v_E = torch.sum(weight_matrix * vv_sumE[time_decay], dim=1)
v_I = torch.sum(weight_matrix * vv_sumI[time_decay], dim=1)
v_E_CR = v_E[:self.NR - 1].reshape(-1, 1) * \
torch.ones((self.NR - 1, self.NC), device=device)
v_I_CR = v_I[:self.NR - 1].reshape(-1, 1) * \
torch.ones((self.NR - 1, self.NC), device=device)
v_E_CR = v_E_CR.reshape(-1, 1).squeeze(1)
v_I_CR = v_I_CR.reshape(-1, 1).squeeze(1)
v_E_TN = v_E[self.NR - 1] * \
torch.ones(self.NTC + self.NTI + self.NTRN, device=device)
v_I_TN = v_I[self.NR - 1] * \
torch.ones(self.NTC + self.NTI + self.NTRN, device=device)
v_E = torch.cat((v_E_CR, v_E_TN))
v_I = torch.cat((v_I_CR, v_I_TN))
self.Ichem[range_E] = self.Ichem[range_E] + self.dt / self.tau_I * \
(-self.Ichem[range_E] + WEE * v_E[range_E]
+ WIE * v_I[range_E])
self.Ichem[range_I] = self.Ichem[range_I] + self.dt / self.tau_I * \
(-self.Ichem[range_I] + WII * v_I[range_I]
+ WEI * v_E[range_I])
self.Igap[range_I] = Vgap * (
v_sum - self.v[range_I])
self.v = self.v + self.dt / self.c_m * (-self.g_m * self.v +
self.alpha_w * self.ad + self.Ieff + self.Ichem + self.Igap)
self.ad = self.ad + self.dt / self.tau_ad * (-self.ad + self.beta_ad * self.v)
self.vv = (self.v >= self.vt).float() * (self.vm1 < self.vt).float()
self.vm1 = self.v
Isp = torch.where(self.vv == 1)[0]
Iraster.append(torch.stack((t * torch.ones((len(Isp)), device=device), Isp), dim=1))
I_CR = torch.mean(self.Ichem[self.divide_point_CR], dim=1)
I_TN = torch.mean(self.Ichem[self.TC_range + self.TI_range + self.TRN_range]).unsqueeze(0)
I_subregion[:, t] = torch.cat((I_CR, I_TN), dim=0)
print('over')
torch.save(I_subregion.cpu(), f'./result/I_subregion_2_delay_{trail}.pt')
Iraster = torch.cat(Iraster, dim=0).cpu()
torch.save(Iraster, f'./result/raster_2_delay_{trail}.pt')
W = torch.tensor(torch.load('./FLNe.pt')['W'], dtype=torch.float32, device=device)
W = W + torch.eye(len(W), device=device)
D = torch.load('./distance.pt')
simulation_model = brain_model_91(W, D)
simulation_model.simulation()
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/Corticothalamic_Brain_Model/Readme.md
================================================
The code for corticothalamic brain model. The connection matrix and simulation results are available in the follow link:
https://drive.google.com/drive/folders/1oOAb-X_ag5feV8Q09_oFZbuoEd7uxIIo?usp=sharing
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/Corticothalamic_Brain_Model/spectrogram.py
================================================
import scipy.io as scio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from mpl_toolkits.mplot3d import Axes3D
from scipy.fftpack import fft,ifft
from scipy import signal
from scipy.fft import fftshift
trail = 4
version = 2
Iraster = torch.load(f'./result/raster_{version}_delay_{trail}.pt').cpu()
time = Iraster[:, 0]
mask = (time >= 3000) & (time < 10000)
indices = torch.where(mask)
spike = Iraster[indices[0]]
plt.figure(figsize=(20, 12))
plt.scatter(spike[:, 0], spike[:, 1], s=0.1)
plt.xlabel('time [ms]', fontsize=20)
plt.ylabel('Neuron index', fontsize=20)
plt.show(dpi=600)
data = np.array(torch.load(f'./result/I_subregion_{version}_delay_{trail}.pt').cpu())
print(data.shape)
b, a = signal.butter(2, [0.002, 0.06], 'bandpass') #配置滤波器 8 表示滤波器的阶数
data = signal.filtfilt(b, a, data) #data为要过滤的信号
fs = 1000
time_window = 1024
# divide = torch.load('./neuron_divide.pt')
# divide_E = divide['divide_point_E']
# print(divide_E)
brain_map = ['2','5','24c','46d','7A','7B','7m','8B','8l',
'8m','9/46d','9/46v','10','DP','F1','F2','F5','F7',
'MT','PBr','ProM','STPc','STPi','STPr','TEO','TEpd',
'V1','V2','V4','TH']
def region_sxx(region):
# plt.figure()
# plt.plot(data[region])
# plt.show()
plt.figure(figsize=(16, 8))
f, t, sxx = signal.stft(data[region], fs=fs, nperseg=time_window, noverlap=time_window / 2)
print(sxx.shape)
cm = plt.cm.get_cmap('jet')
#plt.pcolormesh(t, f[2:10], np.abs(sxx[2:10]), cmap=cm, shading='auto')
plt.contourf(t, f[0:30], np.abs(sxx[0:30]), cmap=cm, levels=200)
plt.colorbar()
plt.xlabel('time/min', fontsize=20)
plt.ylabel('Frequency/Hz', fontsize=20)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.show()
def global_sxx():
plt.figure(figsize=(16,8))
global_eeg = np.mean(data, axis=0)
f, t, sxx = signal.stft(global_eeg, fs=fs, nperseg=time_window, noverlap=time_window / 2)
print(sxx.shape)
cm = plt.cm.get_cmap('jet')
#plt.pcolormesh(t, f[2:10], np.abs(sxx[2:10]), cmap=cm, shading='auto')
plt.contourf(t, f[0:30], np.abs(sxx[0:30]), cmap=cm, levels=200)
plt.colorbar()
plt.xlabel('time/min', fontsize=20)
plt.ylabel('Frequency/Hz', fontsize=20)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.show()
def compare_sxx():
f, t, sxx = signal.stft(data[0], fs=fs, nperseg=time_window, noverlap=time_window / 2)
f_band = range(0, 30)
sm = np.max(np.abs(sxx[f_band]), axis=0)
for col in range(1, 30):
f, t, sxx = signal.stft(data[col], fs=fs, nperseg=time_window, noverlap=time_window / 2)
sm = np.vstack((sm, np.max(np.abs(sxx[f_band]), axis=0)))
cm = plt.cm.get_cmap('jet')
plt.pcolormesh(t, brain_map, np.abs(sm), cmap=cm, shading='auto')
#plt.pcolormesh(t, f, sxx[5:50,:],cmap=cm)
plt.colorbar()
plt.ylabel('Brain Regions', fontsize=10)
plt.xlabel('Time [min]', fontsize=10)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.show()
return np.abs(sm)
region_sxx(7)
# global_sxx()
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/HumanBrain/README.md
================================================
# Human Brain Simulation
## Description
Human Brain Simulation is a large scale brain modeling framework depending on braincog framework.
## Requirements:
* numpy >= 1.21.2
* scipy >= 1.8.0
* h5py >= 3.6.0
* torch >= 1.10
* torchvision >= 0.12.0
* torchaudio >= 0.11.0
* timm >= 0.5.4
* matplotlib >= 3.5.1
* einops >= 0.4.1
* thop >= 0.0.31
* pyyaml >= 6.0
* loris >= 0.5.3
* pandas >= 1.4.2
* tonic (special)
* pandas >= 1.4.2
## Input:
The 88 regions' connectivity matrix can be obtained from the following link:
[https://drive.google.com/file/d/1f8fpXgR8X07HrJ7G9DwMAl8K0naPcxJC/view?usp=sharing](https://drive.google.com/file/d/1tLHxCtm2kawKVvJ1BhAbkFKeyxcrJwnO/view?usp=sharing)
The source of this connectivity matrix is in the following link:
https://www.nitrc.org/frs/?group_id=432
## Example:
```shell
cd ~/examples/Multi-scale Brain Structure Simulation/HumanBrain/
python human_brain.py
```
## Parameters:
The parameters are similar to mouse brain simulation
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/HumanBrain/human_brain.py
================================================
import time
import numpy as np
import scipy.io as scio
import torch
from torch import nn
from braincog.base.node.node import *
from braincog.base.brainarea.BrainArea import *
import pandas as pd
import matplotlib.pyplot as plt
from numpy import genfromtxt
device = 'cuda:0'
class Syn(nn.Module):
def __init__(self, syn, weight, neuron_num, tao_d, tao_r, dt, device):
super().__init__()
self.pre = syn[1]
self.post = syn[0]
self.syn_num = len(syn)
self.w = torch.sparse_coo_tensor(syn.t(), weight,
size=(neuron_num, neuron_num))
self.tao_d = tao_d
self.tao_r = tao_r
self.dt = dt
self.lamda_d = self.dt / self.tao_d
self.lamda_r = self.dt / self.tao_r
self.s = torch.zeros(neuron_num, device=device)
self.r = torch.zeros(neuron_num, device=device)
self.dt = dt
def forward(self, neuron):
neuron.Iback = neuron.Iback + neuron.dt_over_tau * (
torch.randn(neuron.neuron_num, device=device, requires_grad=False) - neuron.Iback)
neuron.Ieff = neuron.Iback / neuron.sqrt_coeff * neuron.sig + neuron.mu
self.s = self.s + self.lamda_r * (-self.s + 1 / self.tao_d * neuron.spike)
self.r = self.r - self.lamda_d * self.r + self.dt * self.s
self.I = torch.sparse.mm(self.w, self.r.unsqueeze(-1)).squeeze() + neuron.Ieff
return self.I
class brain(nn.Module):
def __init__(self, syn, weight, neuron_model, p_neuron, dt, device):
super().__init__()
if neuron_model == 'HH':
self.neurons = HHNode(p_neuron, dt, device)
elif neuron_model == 'aEIF':
self.neurons = aEIF(p_neuron, dt, device)
self.neuron_num = len(p_neuron[0])
self.syns = Syn(syn, weight, self.neuron_num, 3, 6, dt, device)
def forward(self, inputs):
I = self.syns(self.neurons)
self.neurons(I)
def brain_region(neuron_num):
region = []
start = 0
end = 0
for i in range(len(neuron_num)):
end += neuron_num[i].item()
region.append([start, end])
start = end
return torch.tensor(region)
def neuron_type(neuron_num, ratio, regions):
neuron_num = neuron_num.reshape(-1, 1)
neuron_type = torch.floor(ratio * neuron_num).int() + regions[:, 0].reshape(-1, 1)
return neuron_type
def syn_within_region(syn_num, region):
start = 1
for neurons in region:
if start:
syn = torch.randint(neurons[0], neurons[1],
size=((neurons[1]-neurons[0]) * syn_num, 2), device=device)
start = 0
else:
syn = torch.concatenate((syn, torch.randint(neurons[0], neurons[1],
size=((neurons[1]-neurons[0]) * syn_num, 2), device=device)))
return syn
def syn_cross_region(weight_matrix, region):
start = 1
for i in range(len(weight_matrix)):
for j in range(len(weight_matrix)):
if weight_matrix[i][j] < 10:
continue
else:
pre = torch.randint(region[j][0], region[j][1],
size=(weight_matrix[i][j], 1), device=device)
post = torch.randint(region[i][0], region[i][1],
size=(weight_matrix[i][j], 1), device=device)
if start:
syn = torch.concatenate((post, pre), dim=1)
start = 0
else:
syn = torch.concatenate((syn, torch.concatenate((post, pre), dim=1)))
return syn
size = 500
neuron_model = 'HH'
weight_matrix = np.load('./IIT_connectivity_matrix.npy')
weight_matrix = torch.from_numpy(weight_matrix)
NR = len(weight_matrix)
data = size * np.ones(NR)
neuron_num = np.array(data).astype(np.int32)
neuron_num = torch.from_numpy(neuron_num)
regions = brain_region(neuron_num)
ratio = torch.tensor([[0.7, 0.9, 1.0] * NR]).reshape(NR, 3)
neuron_types = neuron_type(neuron_num, ratio, regions)
syn_1 = syn_within_region(10, regions)
syn_2 = syn_cross_region(weight_matrix, regions)
syn = torch.concatenate((syn_1, syn_2))
print(syn.shape)
weight = -torch.ones(len(syn), device=device, requires_grad=False)
if neuron_model == 'aEIF':
threshold = torch.zeros(regions[-1][1], device=device, requires_grad=False)
v_reset = torch.zeros(regions[-1][1], device=device, requires_grad=False)
c_m = torch.zeros(regions[-1][1], device=device, requires_grad=False)
tao_w = torch.zeros(regions[-1][1], device=device, requires_grad=False)
alpha_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)
beta_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)
elif neuron_model == 'HH':
threshold = torch.zeros(regions[-1][1], device=device, requires_grad=False)
for i in range(len(neuron_types)):
pre = syn[:, 0]
mask = (pre >= regions[i][0]) & (pre < neuron_types[i][0])
indices = torch.where(mask)
weight[indices] = 1.5
if neuron_model == 'aEIF':
if i < 177:
threshold[regions[i][0]:neuron_types[i][0]] = -50
threshold[neuron_types[i][0]:neuron_types[i][1]] = -44
threshold[neuron_types[i][1]:neuron_types[i][2]] = -45
v_reset[regions[i][0]:neuron_types[i][0]] = -110
v_reset[neuron_types[i][0]:neuron_types[i][1]] = -110
v_reset[neuron_types[i][1]:neuron_types[i][2]] = -66
c_m[regions[i][0]:neuron_types[i][0]] = 10
c_m[neuron_types[i][0]:neuron_types[i][1]] = 10
c_m[neuron_types[i][1]:neuron_types[i][2]] = 8.5
tao_w[regions[i][0]:neuron_types[i][0]] = 1
tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2
tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2
alpha_ad[regions[i][0]:neuron_types[i][0]] = 0
alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2
alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2
beta_ad[regions[i][0]:neuron_types[i][0]] = 0
beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45
beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45
else:
threshold[regions[i][0]:neuron_types[i][0]] = -50
threshold[neuron_types[i][0]:neuron_types[i][1]] = -50
threshold[neuron_types[i][1]:neuron_types[i][2]] = -45
v_reset[regions[i][0]:neuron_types[i][0]] = -60
v_reset[neuron_types[i][0]:neuron_types[i][1]] = -60
v_reset[neuron_types[i][1]:neuron_types[i][2]] = -65
c_m[regions[i][0]:neuron_types[i][0]] = 20
c_m[neuron_types[i][0]:neuron_types[i][1]] = 2
c_m[neuron_types[i][1]:neuron_types[i][2]] = 4
tao_w[regions[i][0]:neuron_types[i][0]] = 1
tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2
tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2
alpha_ad[regions[i][0]:neuron_types[i][0]] = 0
alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2
alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2
beta_ad[regions[i][0]:neuron_types[i][0]] = 0
beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45
beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45
elif neuron_model == 'HH':
threshold[regions[i][0]:neuron_types[i][0]] = 50
threshold[neuron_types[i][0]:neuron_types[i][1]] = 60
threshold[neuron_types[i][1]:neuron_types[i][2]] = 60
if neuron_model == 'aEIF':
p_neuron = [threshold, v_reset, c_m, tao_w, alpha_ad, beta_ad]
dt = 1
T = 300
elif neuron_model == 'HH':
p_neuron = [threshold, 120, 36, 0.3, 115, -12, 10.6, 1]
dt = 0.01
T = 10000
model = brain(syn, weight, neuron_model, p_neuron, dt, device)
Iraster = []
for t in range(T):
model(0)
print(torch.sum(model.neurons.spike))
Isp = torch.nonzero(model.neurons.spike)
print(len(Isp))
if (len(Isp) != 0):
left = t * torch.ones((len(Isp)), device=device, requires_grad=False)
left = left.reshape(len(left), 1)
mide = torch.concatenate((left, Isp), dim=1)
if (len(Isp) != 0) and (len(Iraster) != 0):
Iraster = torch.concatenate((Iraster, mide), dim=0)
if (len(Iraster) == 0) and (len(Isp) != 0):
Iraster = mide
Iraster = torch.tensor(Iraster).transpose(0, 1)
torch.save(Iraster, "./human.pt")
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/HumanBrain/human_multi.py
================================================
import time
import numpy as np
import scipy.io as scio
import torch
from torch import nn
from braincog.base.node.node import *
from braincog.base.brainarea.BrainArea import *
import pandas as pd
import matplotlib.pyplot as plt
from numpy import genfromtxt
device_ids = [0,2,3,4,5,7,8,9]
device = 'cuda:0'
class MultiCompartmentaEIF(BaseNode):
"""
双房室神经元模型
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param tau: 胞体膜电位时间常数, 用于控制胞体膜电位衰减
:param tau_basal: 基底树突膜电位时间常数, 用于控制基地树突胞体膜电位衰减
:param tau_apical: 远端树突膜电位时间常数, 用于控制远端树突胞体膜电位衰减
:param comps: 神经元不同房室, 例如["apical", "soma"]
:param act_fun: 脉冲梯度代理函数
"""
def __init__(self,
p,
dt,
tau=2.0,
tau_basal=2.0,
tau_apical=2.0,
act_fun=AtanGrad, *args, **kwargs):
g_B = 0.6
g_L = 0.05
super().__init__(threshold=p[0], *args, **kwargs)
self.neuron_num = len(p[0])
self.tau = 2.0
self.tau_basal = 20.0
self.tau_apical = 2.0
self.spike = torch.zeros(self.neuron_num, device=device, requires_grad=False)
self.v_reset = p[1] # membrane potential reset to v_reset after fire spike
# Initialize membrane potentials
self.tau_I = 3.0
self.sig = 12.0
self.mu = 10.0
self.dt=dt
self.dt_over_tau = self.dt / self.tau_I
self.mems = {}
self.mems['soma'] = torch.ones(self.neuron_num, device=device) * self.v_reset
self.mems['apical'] = torch.ones(self.neuron_num, device=device) * self.v_reset
self.act_fun = act_fun(alpha=self.tau, requires_grad=False)
self.Iback = torch.zeros(self.neuron_num, device=device, requires_grad=False)
self.Ieff = torch.zeros(self.neuron_num, device=device, requires_grad=False)
self.sqrt_coeff = math.sqrt(1 / (2 * (1 / self.dt_over_tau)))
def integral(self,apical_inputs):
'''
Params:
inputs torch.Tensor: Inputs for basal dendrite
'''
self.mems['apical'] = (self.mems['apical'] + apical_inputs) / self.tau_apical
self.mems['soma'] = self.mems['soma'] + (self.mems['apical'] - self.mems['soma']) / self.tau
def calc_spike(self):
self.spike = self.act_fun(self.mems['soma'] - self.threshold)
self.mems['soma'] = self.mems['soma'] * (1. - self.spike.detach())
self.mems['apical'] = self.mems['apical'] * (1. - self.spike.detach())
def forward(self, inputs):
# aeifnode_cuda.forward(self.threshold, self.c_m, self.alpha_w, self.beta_ad, inputs, self.ref, self.ad, self.mem, self.spike)
self.integral(inputs)
self.calc_spike()
return self.spike, self.mems['soma']
class aEIF(BaseNode):
"""
The adaptive Exponential Integrate-and-Fire model (aEIF)
This class define the membrane, spike, current and parameters of a neuron group of a specific type
:param args: Other parameters
:param kwargs: Other parameters
"""
def __init__(self, p, dt, device, *args, **kwargs):
"""
p:[threshold, v_reset, c_m, tao_w, alpha_ad, beta_ad]
c_m: Membrane capacitance
alpha_w: Coupling of the adaptation variable
beta_ad: Conductance of the adaptation variable
mu: Mean of back current
sig: Variance of back current
if_IN: if the neuron type is inhibitory neuron, it has gap-junction
neuron_num: number of neurons in this group
W: connection weight for the neuron groups connected to this group
type_index: the index of this type of neuron group in the brain region
"""
super().__init__(threshold=p[0], requires_fp=False, *args, **kwargs)
self.neuron_num = len(p[0])
self.g_m = 0.1 # neuron conduction
self.dt = dt
self.tau_I = 3 # Time constant to filter the synaptic inputs
self.Delta_T = 0.5 # parameter
self.v_reset = p[1] # membrane potential reset to v_reset after fire spike
self.c_m = p[2]
self.tau_w = p[3] # Time constant of adaption coupling
self.alpha_ad = p[4]
self.beta_ad = p[5]
self.refrac = 5 / self.dt # refractory period
self.dt_over_tau = self.dt / self.tau_I
self.sqrt_coeff = math.sqrt(1 / (2 * (1 / self.dt_over_tau)))
self.mem = self.v_reset
self.spike = torch.zeros(self.neuron_num, device=device, requires_grad=False)
self.ad = torch.zeros(self.neuron_num, device=device, requires_grad=False)
self.ref = torch.randint(0, int(self.refrac + 1), (1, self.neuron_num), device=device, requires_grad=False).squeeze(
0) # refractory counter
self.ref = self.ref.float()
self.mu = 10
self.sig = 12
self.Iback = torch.zeros(self.neuron_num, device=device, requires_grad=False)
self.Ieff = torch.zeros(self.neuron_num, device=device, requires_grad=False)
def integral(self, inputs):
self.mem = self.mem + (self.ref > self.refrac) * self.dt / self.c_m * \
(-self.g_m * (self.mem - self.v_reset) + self.g_m * self.Delta_T *
torch.exp((self.mem - self.threshold) / self.Delta_T) +
self.alpha_ad * self.ad + inputs)
self.ad = self.ad + (self.ref > self.refrac) * self.dt / self.tau_w * \
(-self.ad + self.beta_ad * (self.mem - self.v_reset))
def calc_spike(self):
self.spike = (self.mem > self.threshold).float()
self.ref = self.ref * (1 - self.spike) + 1
self.ad = self.ad + self.spike * 30
self.mem = self.spike * self.v_reset + (1 - self.spike.detach()) * self.mem
def forward(self, inputs):
# aeifnode_cuda.forward(self.threshold, self.c_m, self.alpha_w, self.beta_ad, inputs, self.ref, self.ad, self.mem, self.spike)
self.integral(inputs)
self.calc_spike()
return self.spike, self.mem
class HHNode(BaseNode):
"""
简单版本的HH模型
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param dt: 时间步长
:param step: 仿真步
:param tau: 膜电位时间常数, 用于控制膜电位衰减
:param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``
:param args: 其他的参数
:param kwargs: 其他的参数
"""
def __init__(self, p, dt, device, act_fun=AtanGrad, *args, **kwargs):
super().__init__(threshold=p[0], *args, **kwargs)
if isinstance(act_fun, str):
act_fun = eval(act_fun)
'''
I = Cm dV/dt + g_k*n^4*(V_m-V_k) + g_Na*m^3*h*(V_m-V_Na) + g_l*(V_m - V_L)
'''
self.neuron_num = len(p[0])
self.act_fun = act_fun(alpha=2., requires_grad=False)
self.tau_I = 3
self.g_Na = torch.tensor(p[1])
self.g_K = torch.tensor(p[2])
self.g_l = torch.tensor(p[3])
self.V_Na = torch.tensor(p[4])
self.V_K = torch.tensor(p[5])
self.V_l = torch.tensor(p[6])
self.C = torch.tensor(p[7])
self.m = 0.05 * torch.ones(self.neuron_num, device=device, requires_grad=False)
self.n = 0.31 * torch.ones(self.neuron_num, device=device, requires_grad=False)
self.h = 0.59 * torch.ones(self.neuron_num, device=device, requires_grad=False)
self.v_reset = 0
self.dt = dt
self.dt_over_tau = self.dt / self.tau_I
self.sqrt_coeff = math.sqrt(1 / (2 * (1 / self.dt_over_tau)))
self.mu = 10
self.sig = 12
self.mem = torch.tensor(self.v_reset, device=device, requires_grad=False)
self.mem_p = self.mem
self.spike = torch.zeros(self.neuron_num, device=device, requires_grad=False)
self.Iback = torch.zeros(self.neuron_num, device=device, requires_grad=False)
self.Ieff = torch.zeros(self.neuron_num, device=device, requires_grad=False)
def integral(self, inputs):
self.alpha_n = (0.1 - 0.01 * self.mem) / (torch.exp(1 - 0.1 * self.mem) - 1)
self.alpha_m = (2.5 - 0.1 * self.mem) / (torch.exp(2.5 - 0.1 * self.mem) - 1)
self.alpha_h = 0.07 * torch.exp(-self.mem / 20.0)
self.beta_n = 0.125 * torch.exp(-self.mem / 80.0)
self.beta_m = 4.0 * torch.exp(-self.mem / 18.0)
self.beta_h = 1 / (torch.exp(3 - 0.1 * self.mem) + 1)
self.tau_n = 1.0 / (self.alpha_n + self.beta_n)
self.inf_n = self.alpha_n * self.tau_n
self.tau_m = 1.0 / (self.alpha_m + self.beta_m)
self.inf_m = self.alpha_m * self.tau_m
self.tau_h = 1.0 / (self.alpha_h + self.beta_h)
self.inf_h = self.alpha_h * self.tau_h
self.n = (1 - self.dt / self.tau_n) * self.n + (self.dt / self.tau_n) * self.inf_n
self.m = (1 - self.dt / self.tau_m) * self.m + (self.dt / self.tau_m) * self.inf_m
self.h = (1 - self.dt / self.tau_h) * self.h + (self.dt / self.tau_h) * self.inf_h
# self.n = self.n + self.dt * (self.alpha_n * (1 - self.n) - self.beta_n * self.n)
# self.m = self.m + self.dt * (self.alpha_m * (1 - self.m) - self.beta_m * self.m)
# self.h = self.h + self.dt * (self.alpha_h * (1 - self.h) - self.beta_h * self.h)
self.I_Na = torch.pow(self.m, 3) * self.g_Na * self.h * (self.mem - self.V_Na)
self.I_K = torch.pow(self.n, 4) * self.g_K * (self.mem - self.V_K)
self.I_L = self.g_l * (self.mem - self.V_l)
self.mem_p = self.mem
self.mem = self.mem + self.dt * (inputs - self.I_Na - self.I_K - self.I_L) / self.C
# self.mem = self.mem + self.dt * (inputs - self.I_K - self.I_L) / self.C
def calc_spike(self):
self.spike = (self.threshold > self.mem_p).float() * (self.mem > self.threshold).float()
def forward(self, inputs):
self.integral(inputs)
self.calc_spike()
return self.spike, self.mem
def requires_activation(self):
return False
class Syn(nn.Module):
def __init__(self, syn, weight, neuron_num, tao_d, tao_r, dt, device):
super().__init__()
self.pre = syn[1]
self.post = syn[0]
self.syn_num = len(syn)
self.w = torch.sparse_coo_tensor(syn.t(), weight,
size=(neuron_num, neuron_num))
self.tao_d = tao_d
self.tao_r = tao_r
self.dt = dt
self.lamda_d = self.dt / self.tao_d
self.lamda_r = self.dt / self.tao_r
self.s = torch.zeros(neuron_num, device=device)
self.r = torch.zeros(neuron_num, device=device)
self.dt = dt
def forward(self, neuron):
neuron.Iback = neuron.Iback + neuron.dt_over_tau * (
torch.randn(neuron.neuron_num, device=device, requires_grad=False) - neuron.Iback)
neuron.Ieff = neuron.Iback / neuron.sqrt_coeff * neuron.sig + neuron.mu
self.s = self.s + self.lamda_r * (-self.s + 1 / self.tao_d * neuron.spike)
self.r = self.r - self.lamda_d * self.r + self.dt * self.s
self.I = torch.sparse.mm(self.w, self.r.unsqueeze(-1)).squeeze() + neuron.Ieff
return self.I
class brain(nn.Module):
def __init__(self, syn, weight, neuron_model, p_neuron, dt, device):
super().__init__()
if neuron_model == 'HH':
self.neurons = HHNode(p_neuron, dt, device)
elif neuron_model == 'aEIF':
self.neurons = aEIF(p_neuron, dt, device)
elif neuron_model == 'MultiCompartmentaEIF':
self.neurons = MultiCompartmentaEIF(p_neuron,dt,device)
self.neuron_num = len(p_neuron[0])
self.syns = Syn(syn, weight, self.neuron_num, 3.0, 6.0, dt, device)
def forward(self, inputs):
I = self.syns(self.neurons)
self.neurons(I)
def brain_region(neuron_num):
region = []
start = 0
end = 0
for i in range(len(neuron_num)):
end += neuron_num[i].item()
region.append([start, end])
start = end
return torch.tensor(region)
def neuron_type(neuron_num, ratio, regions):
neuron_num = neuron_num.reshape(-1, 1)
neuron_type = torch.floor(ratio * neuron_num).int() + regions[:, 0].reshape(-1, 1)
return neuron_type
def syn_within_region(syn_num, region):
start = 1
for neurons in region:
if start:
syn = torch.randint(neurons[0], neurons[1],
size=((neurons[1]-neurons[0]) * syn_num, 2), device=device)
start = 0
else:
syn = torch.concat((syn, torch.randint(neurons[0], neurons[1],
size=((neurons[1]-neurons[0]) * syn_num, 2), device=device)))
return syn
def syn_cross_region(weight_matrix, region):
start = 1
for i in range(len(weight_matrix)):
for j in range(len(weight_matrix)):
if weight_matrix[i][j] < 10:
continue
else:
pre = torch.randint(region[j][0], region[j][1],
size=(weight_matrix[i][j], 1), device=device)
post = torch.randint(region[i][0], region[i][1],
size=(weight_matrix[i][j], 1), device=device)
if start:
syn = torch.concat((post, pre), dim=1)
start = 0
else:
syn = torch.concat((syn, torch.concat((post, pre), dim=1)))
return syn
size = 100
neuron_model = 'MultiCompartmentaEIF'
weight_matrix = torch.from_numpy(np.load("IIT_connectivity_matrix.npy")[0:84,0:84])
weight_matrix = weight_matrix.int() * 10
# weight_matrix = np.load('./IIT_connectivity_matrix.npy')
# weight_matrix = torch.from_numpy(weight_matrix)
NR = len(weight_matrix)
data = size * np.ones(NR)
neuron_num = np.array(data).astype(np.int32)
neuron_num = torch.from_numpy(neuron_num)
print(torch.sum(neuron_num))
regions = brain_region(neuron_num)
ratio = torch.tensor([[0.7, 0.9, 1.0] * NR]).reshape(NR, 3)
neuron_types = neuron_type(neuron_num, ratio, regions)
syn_1 = syn_within_region(10, regions)
syn_2 = syn_cross_region(weight_matrix, regions)
syn = torch.concat((syn_1, syn_2))
print(len(syn_2))
print(syn.shape)
weight = -torch.ones(len(syn), device=device, requires_grad=False)
if neuron_model == 'aEIF':
threshold = torch.zeros(regions[-1][1], device=device, requires_grad=False)
v_reset = torch.zeros(regions[-1][1], device=device, requires_grad=False)
c_m = torch.zeros(regions[-1][1], device=device, requires_grad=False)
tao_w = torch.zeros(regions[-1][1], device=device, requires_grad=False)
alpha_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)
beta_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)
elif neuron_model == 'HH':
threshold = torch.zeros(regions[-1][1], device=device, requires_grad=False)
if neuron_model == 'MultiCompartmentaEIF':
threshold = torch.zeros(regions[-1][1], device=device, requires_grad=False)
v_reset = torch.zeros(regions[-1][1], device=device, requires_grad=False)
c_m = torch.zeros(regions[-1][1], device=device, requires_grad=False)
tao_w = torch.zeros(regions[-1][1], device=device, requires_grad=False)
alpha_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)
beta_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)
for i in range(len(neuron_types)):
pre = syn[:, 0]
mask = (pre >= regions[i][0]) & (pre < neuron_types[i][0])
indices = torch.where(mask)
weight[indices] = 1.5
if neuron_model == 'aEIF':
if i < 70:
threshold[regions[i][0]:neuron_types[i][0]] = -50
threshold[neuron_types[i][0]:neuron_types[i][1]] = -44
threshold[neuron_types[i][1]:neuron_types[i][2]] = -45
v_reset[regions[i][0]:neuron_types[i][0]] = -110
v_reset[neuron_types[i][0]:neuron_types[i][1]] = -110
v_reset[neuron_types[i][1]:neuron_types[i][2]] = -110
c_m[regions[i][0]:neuron_types[i][0]] = 10
c_m[neuron_types[i][0]:neuron_types[i][1]] = 10
c_m[neuron_types[i][1]:neuron_types[i][2]] = 8.5
tao_w[regions[i][0]:neuron_types[i][0]] = 1
tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2
tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2
alpha_ad[regions[i][0]:neuron_types[i][0]] = 0
alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2
alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2
beta_ad[regions[i][0]:neuron_types[i][0]] = 0
beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45
beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45
else:
threshold[regions[i][0]:neuron_types[i][0]] = -50
threshold[neuron_types[i][0]:neuron_types[i][1]] = -50
threshold[neuron_types[i][1]:neuron_types[i][2]] = -45
v_reset[regions[i][0]:neuron_types[i][0]] = -100
v_reset[neuron_types[i][0]:neuron_types[i][1]] = -100
v_reset[neuron_types[i][1]:neuron_types[i][2]] = -105
c_m[regions[i][0]:neuron_types[i][0]] = 20
c_m[neuron_types[i][0]:neuron_types[i][1]] = 10
c_m[neuron_types[i][1]:neuron_types[i][2]] = 10
tao_w[regions[i][0]:neuron_types[i][0]] = 1
tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2
tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2
alpha_ad[regions[i][0]:neuron_types[i][0]] = 0
alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2
alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2
beta_ad[regions[i][0]:neuron_types[i][0]] = 0
beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45
beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45
elif neuron_model == 'HH':
threshold[regions[i][0]:neuron_types[i][0]] = 20
threshold[neuron_types[i][0]:neuron_types[i][1]] = 20
threshold[neuron_types[i][1]:neuron_types[i][2]] = 20
elif neuron_model == 'MultiCompartmentaEIF':
if i < 70:
threshold[regions[i][0]:neuron_types[i][0]] = -50.0
threshold[neuron_types[i][0]:neuron_types[i][1]] = -44.0
threshold[neuron_types[i][1]:neuron_types[i][2]] = -45.0
v_reset[regions[i][0]:neuron_types[i][0]] = -110.0
v_reset[neuron_types[i][0]:neuron_types[i][1]] = -110.0
v_reset[neuron_types[i][1]:neuron_types[i][2]] = -110.0
c_m[regions[i][0]:neuron_types[i][0]] = 10.0
c_m[neuron_types[i][0]:neuron_types[i][1]] = 10.0
c_m[neuron_types[i][1]:neuron_types[i][2]] = 8.5
tao_w[regions[i][0]:neuron_types[i][0]] = 1
tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2
tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2
alpha_ad[regions[i][0]:neuron_types[i][0]] = 0
alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2
alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2
beta_ad[regions[i][0]:neuron_types[i][0]] = 0
beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45
beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45
else:
threshold[regions[i][0]:neuron_types[i][0]] = -50.0
threshold[neuron_types[i][0]:neuron_types[i][1]] = -50.0
threshold[neuron_types[i][1]:neuron_types[i][2]] = -45.0
v_reset[regions[i][0]:neuron_types[i][0]] = -100.0
v_reset[neuron_types[i][0]:neuron_types[i][1]] = -100.0
v_reset[neuron_types[i][1]:neuron_types[i][2]] = -105.0
c_m[regions[i][0]:neuron_types[i][0]] = 20
c_m[neuron_types[i][0]:neuron_types[i][1]] = 10
c_m[neuron_types[i][1]:neuron_types[i][2]] = 10
tao_w[regions[i][0]:neuron_types[i][0]] = 1
tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2
tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2
alpha_ad[regions[i][0]:neuron_types[i][0]] = 0
alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2
alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2
beta_ad[regions[i][0]:neuron_types[i][0]] = 0
beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45
beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45
if neuron_model == 'aEIF':
p_neuron = [threshold, v_reset, c_m, tao_w, alpha_ad, beta_ad]
dt = 1
T = 2000
elif neuron_model == 'HH':
p_neuron = [threshold, 120, 36, 0.3, 115, -12, 10.6, 1]
dt = 0.01
T = 10000
elif neuron_model == 'MultiCompartmentaEIF':
p_neuron = [threshold, v_reset, c_m, tao_w, alpha_ad, beta_ad]
dt = 1.0
T = 2000
model = brain(syn, weight, neuron_model, p_neuron, dt, device)
# device_ids = [0,2,3,4,5,7,8,9]
# model = nn.DataParallel(model, device_ids=device_ids)
model.to(device)
def neuron_delete(model, rate):
neuron_idex = torch.arange(0, model.neuron_num)
delete_num = int(model.neuron_num * rate)
random_elements = neuron_idex[torch.randperm(model.neuron_num)[:delete_num]]
model.neurons.threshold[random_elements] = 1000
return model.neuron_num - delete_num
def syn_delete(model, rate):
indices = model.syns.w._indices()
values = model.syns.w._values()
delete_num = int(len(values) * rate)
syn_idex = torch.arange(0, len(values))
random_elements = syn_idex[torch.randperm(len(values))[:delete_num]]
new_values = values[random_elements]
new_indices = indices[:, random_elements]
new_w = torch.sparse_coo_tensor(new_indices, new_values, size=(model.neuron_num, model.neuron_num))
model.syns.w = new_w
def syn_strength(model, rate):
indices = model.syns.w._indices()
values = model.syns.w._values()
iex = torch.where(values>0)
values[iex] = values[iex] * rate
new_values = values
new_indices = indices
new_w = torch.sparse_coo_tensor(new_indices, new_values, size=(model.neuron_num, model.neuron_num))
model.syns.w = new_w
Iraster = []
fire_rate = []
count_n = model.neuron_num
for t in range(T):
if t == int(T/4):
count_n = neuron_delete(model, 0.4)
if t == int(T/4 * 2):
syn_delete(model, 0.4)
if t == int(T/4 * 3):
syn_strength(model, 3)
model(0)
# print(torch.sum(model.neurons.spike))
Isp = torch.nonzero(model.neurons.spike)
print(len(Isp))
fire_rate.append(len(Isp)/count_n)
if (len(Isp) != 0):
left = t * torch.ones((len(Isp)), device=device, requires_grad=False)
left = left.reshape(len(left), 1)
mide = torch.concat((left, Isp), dim=1)
if (len(Isp) != 0) and (len(Iraster) != 0):
Iraster = torch.concat((Iraster, mide), dim=0)
if (len(Iraster) == 0) and (len(Isp) != 0):
Iraster = mide
torch.save(fire_rate, './fire_rate.pt')
plt.plot(fire_rate)
plt.xlabel('time/mm')
plt.ylabel('fire_rate')
# plt.axvline(x=[500, 1000, 1500], color='b', linestyle='--')
plt.show()
Iraster = torch.tensor(Iraster).transpose(0, 1)
torch.save(Iraster, "./human_MultiCompartmentaEIF100.pt")
Iraster = Iraster.cpu()
plt.figure(figsize=(15, 15))
plt.scatter(Iraster[0], Iraster[1], c='k', marker='.', s=0.001)
plt.savefig('mouse_MultiCompartmentaEIF100.png')
plt.show()
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/Human_Brain_Model/NA.py
================================================
import numpy as np
import random
import math
import matplotlib.pyplot as plt
import matplotlib
# matplotlib.use('TkAgg')
import scipy.io as scio
import pandas as pd
import torch
import networkx as nx
from collections import defaultdict
import community as community_louvain
from matplotlib.ticker import MaxNLocator, FuncFormatter
def histogram_entropy(data, bins='auto'):
"""
使用直方图法估计一维数据的熵。
参数:
data (np.ndarray): 一维数据数组。
bins (int or str): 直方图的分箱数,默认为 'auto'。
返回:
float: 估计的熵值。
"""
hist, bin_edges = np.histogram(data, bins=bins, density=True)
bin_width = bin_edges[1] - bin_edges[0]
prob = hist * bin_width
prob = prob[prob > 0]
entropy_value = -np.sum(prob * np.log(prob))
return entropy_value
def hub_degree(df, W_new):
degree = torch.sum(W_new, dim=0)
v, ind = torch.topk(degree, 10)
ind = ind.tolist()
plt.figure(figsize=(40, 18))
plt.bar(df['Identifier'].values, degree)
plt.bar(df['Identifier'].iloc[ind].values, degree[ind], color='r', label='Top 10 Degree')
plt.gca().yaxis.set_major_locator(MaxNLocator(integer=False, prune='lower', nbins=15))
plt.xticks(rotation=90, fontsize=30)
plt.ylabel('Degree', fontsize=40)
plt.yticks(fontsize=25)
plt.legend(fontsize=40)
xticks = plt.gca().get_xticklabels()
for i, tick in enumerate(xticks):
if df['Identifier'].iloc[i] in df['Identifier'].iloc[ind].values:
tick.set_color('r')
plt.grid(axis='y')
plt.show()
def visual(df, W_new):
x = df['x'].values
y = df['y'].values
z = df['z'].values
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x, y, z)
for i in range(len(x)):
for j in range(i + 1, len(x)):
if W_new[i, j] > 0.1:
ax.plot([x[i], x[j]],
[y[i], y[j]],
[z[i], z[j]],
'k-', lw=1)
plt.show()
if __name__ == "__main__":
W = np.load('./IIT_connectivity_matrix.npy')
W = torch.from_numpy(W).float()
W = W[0:84, 0:84]
new_order = list(range(0, 35)) + list(range(49, 84)) + list(range(35, 49))
W_new = W[new_order, :][:, new_order]
M = torch.max(W_new)
W_new = W_new / M
G = nx.from_numpy_matrix(W_new.numpy())
# Louvain
partition = community_louvain.best_partition(G)
community_groups = defaultdict(list)
for node, community in partition.items():
community_groups[community].append(node)
df = pd.read_csv('brain_regions.csv')
labels = df['Identifier'].values
# for community, nodes in community_groups.items():
# print(f"Community {community}: {nodes}")
fig, ax = plt.subplots(figsize=(20, 20))
cax = ax.imshow(W_new.cpu().numpy(), cmap='viridis')
ax.set_xticks(np.arange(len(labels)))
ax.set_yticks(np.arange(len(labels)))
ax.set_xticklabels(labels, rotation=90, fontsize=20)
ax.set_yticklabels(labels, fontsize=20)
fig.colorbar(cax, shrink=0.8)
# plt.tight_layout()
plt.show()
hub_degree(df, W_new)
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/Human_Brain_Model/Readme.md
================================================
main_84.py and main_246.py is the code that runs the simulation. The required data files can be obtained from the following link:
https://drive.google.com/drive/folders/14KPqJsJXIo-bCmGCuDuBadLRmaYn78J2?usp=sharing
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/Human_Brain_Model/gc.py
================================================
import scipy.io as scio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import signal
import torch
import matplotlib.colors as mcolors
scale = 1.2
version = 1
EEG_m = np.array(torch.load(f'./result/I_subregion_{version}_{scale}.pt').cpu())
EEG_c1 = np.load('./dataset/data_awake_1ug.npy')
EEG_c2 = np.load('./dataset/data_2ug.npy')
EEG_c3 = np.load('./dataset/data_3ug.npy')
EEG_m = np.mean(EEG_m, axis=0)
lowcut = 3.0 # 下截止频率 (Hz)
highcut = 30.0 # 上截止频率 (Hz)
fs = 1000
# 使用 Butterworth 滤波器设计带通滤波器
# butter 函数的参数依次为:滤波器阶数,频率范围(归一化),滤波器类型
b, a = signal.butter(4, [lowcut / (0.5 * fs), highcut / (0.5 * fs)], btype='band')
# 应用滤波器(使用 filtfilt 实现零相位滤波)
EEG_m = signal.filtfilt(b, a, EEG_m)
EEG_C = EEG_c3
EEG_C = signal.filtfilt(b, a, EEG_C)
t = 80
mat_all = np.zeros((t, 30))
for j in range(64):
mat = np.zeros((t, 30))
for i in range(t):
f, Cxy = signal.csd(EEG_m[i*1000:(i+1)*1000], EEG_C[i][j], fs=fs, nperseg=1024)
mat[i] = np.abs(Cxy[:30])
mat_all += mat / np.max(mat)
plt.figure(figsize=(16,8))
norm = mcolors.LogNorm(vmin=0.001, vmax=1)
cm = plt.cm.get_cmap('jet')
plt.contourf(np.linspace(0, 8, 80) ,f[:30], mat_all.T / 64, cmap=cm, levels=200)
plt.colorbar()
plt.xlabel('time/min', fontsize=20)
plt.ylabel('Frequency/Hz', fontsize=20)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.show()
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/Human_Brain_Model/main_246.py
================================================
import numpy as np
import random
import math
import matplotlib.pyplot as plt
import scipy.io as scio
import pandas as pd
import torch
import tqdm
device = 'cuda:3'
trail = 5
version = 4
scale = 0.1
class brain_model():
def __init__(self, W):
self.weight_matrix = W.to(device)
self.distance_matrix = torch.zeros_like(self.weight_matrix, device=device)
self.speed = 1.5
self.decay = torch.ceil(self.distance_matrix / self.speed)
self.t_window = int(torch.max(self.decay)) + 1
self.V_th = torch.tensor([10., 4., 10., 4., 4.], device=device)
self.tau_v = torch.tensor([40., 10., 30., 20., 40.], device=device)
self.Tsig = torch.tensor([12., 10., 12., 10., 10.], device=device)
self.beta = torch.tensor([0., 4.5, 0., 4.5, 4.5], device=device)
self.alpha_ad = torch.tensor([0., -2., 0., -2., -2.], device=device)
self.tau_ad = 20
self.tau_I = 10
# Simulation parameters
self.NR = len(W)
self.C = 210
self.NN = 500
self.NType = self.NN * torch.tensor([0.80, 0.20, 0.01, 0.004, 0.004], device=device)
self.NE = 400
self.NI = 100
self.NTC = 30
self.NTI = 15
self.NTRN = 6
self.NC = self.NE + self.NI
self.NSum = int((self.C) * (self.NE + self.NI) + (self.NR - self.C) * (self.NTC + self.NTI + self.NTRN))
self.NT = self.NTC + self.NTI + self.NTRN
print(self.NSum)
print(self.NC)
print(self.NTI + self.NTC + self.NTRN)
self.Ncycle = 1
self.dt = 1
self.T = 8000
self.Delta_T = 0.5
# self.refrac = 5 / self.dt
# self.ref = self.refrac*torch.zeros((self.NN, 1)).squeeze(1)
self.gamma_c = 0.1
self.g_m = 1
self.Gama_c = self.g_m * self.gamma_c / (1 - self.gamma_c)
self.GammaII = 15
self.GammaIE = -10
self.GammaEE = 15
self.GammaEI = 15
self.TEmean = 0.5 * self.V_th[0] # Mean current to excitatory neurons
self.TTCmean = 0.5 * self.V_th[2] # Mean current to TC neurons
self.TImean = -5 * self.V_th[1]
self.TTImean = -5 * self.V_th[3]
self.TTRNmean = -5 * self.V_th[4]
self.v = torch.zeros(self.NSum, device=device)
self.vt = torch.zeros(self.NSum, device=device)
self.c_m = torch.zeros(self.NSum, device=device)
self.alpha_w = torch.zeros(self.NSum, device=device)
self.beta_ad = torch.zeros(self.NSum, device=device)
self.delta = torch.ones(self.NSum, device=device)
self.ad = torch.zeros(self.NSum, device=device)
self.vv = torch.zeros(self.NSum, device=device)
self.Iback = torch.zeros(self.NSum, device=device)
self.Istimu = torch.zeros(self.NSum, device=device) # stimulate current
self.Ieff = torch.zeros(self.NSum, device=device)
self.Nmean = torch.zeros(self.NSum, device=device)
self.Nsig = torch.zeros(self.NSum, device=device)
self.Igap = torch.zeros(self.NSum, device=device)
self.Ichem = torch.zeros(self.NSum, device=device)
self.Ieeg = torch.zeros(self.NSum, device=device)
self.vm1 = torch.zeros(self.NSum, device=device)
self.reset = torch.zeros(self.NSum, device=device)
self.E_range = []
self.I_range = []
self.TC_range = []
self.TI_range = []
self.TRN_range = []
self.divide_point_E = []
self.divide_point_I = []
self.divide_point_TC = []
self.divide_point_TI_TRN = []
for n in range(self.NR):
if n < self.C:
self.divide_point_E.append(list(range(n * self.NC, n * self.NC + self.NE)))
self.divide_point_I.append(list(range(n * self.NC + self.NE, n * self.NC + self.NE + self.NI)))
self.E_range = self.E_range + list(range(n * self.NC, n * self.NC + self.NE))
self.I_range = self.I_range + list(
range(n * self.NC + self.NE, n * self.NC + self.NE + self.NI))
else:
s = self.C * self.NC + (n-self.C) * self.NT
self.divide_point_TC.append(list(range(s, s + self.NTC)))
self.divide_point_TI_TRN.append(list(range(s + self.NTC,
s + self.NTC + self.NTI))
+ list(range(s + self.NTC + self.NTI,
s + self.NTC + self.NTI + self.NTRN)))
self.TC_range = self.TC_range + list(range(s,
s + self.NTC))
self.TI_range = self.TI_range + list(range(s + self.NTC,
s + self.NTC + self.NTI))
self.TRN_range = self.TRN_range + list(range(s + self.NTC + self.NTI,
s + self.NTC + self.NTI + self.NTRN))
self.divide_point_E = torch.tensor(self.divide_point_E, device=device)
self.divide_point_I = torch.tensor(self.divide_point_I, device=device)
self.divide_point_TC = torch.tensor(self.divide_point_TC, device=device)
self.divide_point_TI_TRN = torch.tensor(self.divide_point_TI_TRN, device=device)
self.divide_point_CR = torch.concat((self.divide_point_E, self.divide_point_I), dim=1)
self.divide_point_TN = torch.concat((self.divide_point_TC, self.divide_point_TI_TRN), dim=1)
torch.save({'divide_point_E': self.divide_point_E, 'divide_point_I': self.divide_point_I,
'TC_range': self.TC_range, 'TI_range': self.TI_range, 'TRN_range': self.TRN_range},
'./neuron_divide.pt')
self.c_m[self.E_range] = self.tau_v[0] * self.g_m + 15 * (torch.rand(size=(len(self.E_range),), device=device) - 0.5)
# self.c_m[self.E_range] = self.tau_v[0] * self.g_m
self.c_m[self.TC_range] = self.tau_v[2] * self.g_m
self.c_m[self.I_range] = self.tau_v[1] * (self.g_m + self.Gama_c)
self.c_m[self.TI_range] = self.tau_v[3] * (self.g_m + self.Gama_c)
self.c_m[self.TRN_range] = self.tau_v[4] * (self.g_m + self.Gama_c)
self.alpha_w[self.E_range] = self.alpha_ad[0] * self.g_m
self.alpha_w[self.TC_range] = self.alpha_ad[2] * self.g_m + self.Gama_c
self.alpha_w[self.I_range] = self.alpha_ad[1] * (self.g_m + self.Gama_c)
self.alpha_w[self.TI_range] = self.alpha_ad[3] * (self.g_m + self.Gama_c)
self.alpha_w[self.TRN_range] = self.alpha_ad[4] * (self.g_m + self.Gama_c)
self.beta_ad[self.E_range] = self.beta[0]
self.beta_ad[self.TC_range] = self.beta[2]
self.beta_ad[self.I_range] = self.beta[1]
self.beta_ad[self.TI_range] = self.beta[3]
self.beta_ad[self.TRN_range] = self.beta[4]
self.vt[self.E_range] = self.V_th[0]
self.vt[self.TC_range] = self.V_th[2]
self.vt[self.I_range] = self.V_th[1]
self.vt[self.TI_range] = self.V_th[3]
self.vt[self.TRN_range] = self.V_th[4]
self.reset[self.E_range] = 1
self.reset[self.TC_range] = 1
self.reset[self.I_range] = 0
self.reset[self.TI_range] = 0
self.reset[self.TRN_range] = 0
self.Nmean[self.E_range] = self.TEmean * self.g_m
self.Nmean[self.TC_range] = self.TTCmean * self.g_m
self.Nmean[self.I_range] = self.TImean * (self.g_m + self.Gama_c)
self.Nmean[self.TI_range] = self.TTImean * (self.g_m + self.Gama_c)
self.Nmean[self.TRN_range] = self.TTRNmean * (self.g_m + self.Gama_c)
self.Nsig[self.E_range] = self.Tsig[0] * self.g_m
self.Nsig[self.TC_range] = self.Tsig[2] * self.g_m
self.Nsig[self.I_range] = self.Tsig[1] * (self.g_m + self.Gama_c)
self.Nsig[self.TI_range] = self.Tsig[3] * (self.g_m + self.Gama_c)
self.Nsig[self.TRN_range] = self.Tsig[4] * (self.g_m + self.Gama_c)
def simulation(self, per):
range_E = self.E_range + self.TC_range
range_I = self.I_range + self.TI_range + self.TRN_range
Vgap = self.Gama_c
weight_matrix = self.weight_matrix
print(123)
for i in range(self.Ncycle):
I_total = torch.zeros((self.Ncycle, self.T), device=device)
V_total = torch.zeros((self.Ncycle, self.T), device=device)
V = torch.zeros(self.T, device=device)
I_subregion = torch.zeros((self.NR, self.T), device=device)
I_subregion_E = torch.zeros((self.NR, self.T), device=device)
I_subregion_I = torch.zeros((self.NR, self.T), device=device)
Vsubregion = torch.zeros((self.NR, self.T), device=device)
EEG = torch.zeros((self.T), device=device)
Iraster = []
vv_sumE = torch.zeros((self.NR, self.t_window), device=device)
vv_sumI = torch.zeros((self.NR, self.t_window), device=device)
phase = self.T / 8
for t in tqdm.tqdm(range(self.T)):
if t < phase:
tau_vI = 20
self.GammaII = 15
self.GammaIE = -10
elif phase <= t < 3 * phase:
tau_vI = 20 + 20 * (t - phase) / phase
self.GammaII = 15 + 20 * (t - phase) / phase
self.GammaIE = -10 - 20 * (t - phase) / phase
elif 3 * phase <= t < 5 * phase:
tau_vI = 60
self.GammaII = 55
self.GammaIE = -50
elif 5 * phase <= t < 7 * phase:
tau_vI = 60 - 20 * (t - 5 * phase) / phase
self.GammaII = 55 - 20 * (t - 5 * phase) / phase
self.GammaIE = -50 + 20 * (t - 5 * phase) / phase
elif 7 * phase <= t < 8 * phase:
tau_vI = 20
self.GammaII = 15
self.GammaIE = -10
self.c_m[range_I] = tau_vI * (self.g_m + self.Gama_c)
WII = self.GammaII * torch.mean(self.c_m[self.I_range])
WEE = self.GammaEE * torch.mean(self.c_m[self.E_range])
WEI = self.GammaEI * torch.mean(self.c_m[self.I_range])
WIE = self.GammaIE * torch.mean(self.c_m[self.E_range])
self.Iback = self.Iback + self.dt / self.tau_I * (-self.Iback + torch.randn(self.NSum, device=device))
self.Ieff = (self.Iback / math.sqrt(1 / (2 * (self.tau_I / self.dt))) * self.Nsig + self.Nmean)
temp = vv_sumE.clone()
vv_sumE[:, 0:self.t_window - 1] = temp[:, 1:self.t_window]
vv_sumE[:, self.t_window - 1] = torch.cat((torch.mean(self.vv[self.divide_point_E], dim=1),
torch.mean(self.vv[self.divide_point_TC], dim=1)))
temp = vv_sumI.clone()
vv_sumI[:, 0:self.t_window - 1] = temp[:, 1:self.t_window]
vv_sumI[:, self.t_window - 1] = torch.cat((torch.mean(self.vv[self.divide_point_I], dim=1),
torch.mean(
self.vv[self.divide_point_TI_TRN], dim=1)))
v_sum = torch.cat((torch.mean(self.v[self.divide_point_I], dim=1),
torch.mean(self.v[self.divide_point_TI_TRN], dim=1)))
v_sum_CR = v_sum[:self.C].reshape(-1, 1) * \
torch.ones((self.C, self.NI), device=device)
v_sum_CR = v_sum_CR.reshape(-1, 1).squeeze(1)
v_sum_TN = v_sum[self.C:].reshape(-1, 1) * \
torch.ones(self.NR - self.C, self.NTI + self.NTRN, device=device)
v_sum_TN = v_sum_TN.reshape(-1, 1).squeeze(1)
v_sum = torch.cat((v_sum_CR, v_sum_TN))
time_decay = torch.concat(
(torch.concat([torch.arange(self.NR, device=device).unsqueeze(0)] * self.NR, dim=0).unsqueeze(0),
self.t_window - 1 - self.decay.unsqueeze(0)), dim=0)
time_decay = list(time_decay.long())
v_E = torch.sum(weight_matrix * vv_sumE[time_decay], dim=1)
v_I = torch.sum(weight_matrix * vv_sumI[time_decay], dim=1)
v_E_CR = v_E[:self.NR - 36].reshape(-1, 1) * \
torch.ones((self.NR - 36, self.NC), device=device)
v_I_CR = v_I[:self.NR - 36].reshape(-1, 1) * \
torch.ones((self.NR - 36, self.NC), device=device)
v_E_CR = v_E_CR.reshape(-1, 1).squeeze(1)
v_I_CR = v_I_CR.reshape(-1, 1).squeeze(1)
v_E_TN = v_E[self.NR - 36:].reshape(-1, 1) * \
torch.ones(36, self.NTC + self.NTI + self.NTRN, device=device)
v_I_TN = v_I[self.NR - 36:].reshape(-1, 1) * \
torch.ones(36, self.NTC + self.NTI + self.NTRN, device=device)
v_E_TN = v_E_TN.reshape(-1, 1).squeeze(1)
v_I_TN = v_I_TN.reshape(-1, 1).squeeze(1)
v_E = torch.cat((v_E_CR, v_E_TN))
v_I = torch.cat((v_I_CR, v_I_TN))
self.Ichem[range_E] = self.Ichem[range_E] + self.dt / self.tau_I * \
(-self.Ichem[range_E] + WEE * v_E[range_E]
+ WIE * v_I[range_E])
self.Ichem[range_I] = self.Ichem[range_I] + self.dt / self.tau_I * \
(-self.Ichem[range_I] + WII * v_I[range_I]
+ WEI * v_E[range_I])
self.Igap[range_I] = Vgap * (
v_sum - self.v[range_I])
# if (300 <= t < 700):
# self.Istimu[self.divide_point_E[per]] = 10 - self.Ichem[self.divide_point_E[per]]
# elif (1300 <= t < 1700):
# self.Istimu[self.divide_point_E[per]] = 10 - self.Ichem[self.divide_point_E[per]]
# elif (2300 <= t < 2700):
# self.Istimu[self.divide_point_E[per]] = 10 - self.Ichem[self.divide_point_E[per]]
# elif (3300 <= t < 3700):
# self.Istimu[self.divide_point_E[per]] = 10 - self.Ichem[self.divide_point_E[per]]
# else:
# self.Istimu[self.divide_point_E[per]] = 0
# if (300 <= t < 700):
# self.Istimu[self.divide_point_TC[per]] = 10 - self.Ichem[self.divide_point_TC[per]]
# elif (1300 <= t < 1700):
# self.Istimu[self.divide_point_TC[per]] = 10 - self.Ichem[self.divide_point_TC[per]]
# elif (2300 <= t < 2700):
# self.Istimu[self.divide_point_TC[per]] = 10 - self.Ichem[self.divide_point_TC[per]]
# elif (3300 <= t < 3700):
# self.Istimu[self.divide_point_TC[per]] = 10 - self.Ichem[self.divide_point_TC[per]]
# else:
# self.Istimu[self.divide_point_TC[per]] = 0
self.v = self.v + self.dt / self.c_m * (-self.g_m * self.v +
self.alpha_w * self.ad + self.Istimu + self.Ieff + self.Ichem + self.Igap)
self.ad = self.ad + self.dt / self.tau_ad * (-self.ad + self.beta_ad * self.v)
self.vv = (self.v >= self.vt).float() * (self.vm1 < self.vt).float()
self.v = self.v * (1 - self.vv * self.reset)
# self.v = self.v * (1 - self.vv)
self.vm1 = self.v
Isp = torch.where(self.vv == 1)[0]
Iraster.append(torch.stack((t * torch.ones((len(Isp)), device=device), Isp), dim=1))
I_CR = torch.mean(self.Ichem[self.divide_point_CR], dim=1)
I_TN = torch.mean(self.Ichem[self.divide_point_TN], dim=1)
I_subregion[:, t] = torch.cat((I_CR, I_TN), dim=0)
print('over')
# torch.save(I_subregion.cpu(), f'./result/I_subregion_{version}_{scale}_{per}.pt')
torch.save(I_subregion.cpu(), f'./result/I_subregion_{version}_{scale}.pt')
Iraster = torch.cat(Iraster, dim=0).cpu()
torch.save(Iraster, f'./result/raster_{version}_{scale}.pt')
# torch.save(Iraster, f'./result/raster_{version}_{scale}_{per}.pt')
file_path = './human.csv'
df = pd.read_csv(file_path, header=None)
W = df.to_numpy()
W = torch.from_numpy(W).float()
# from NA import histogram_entropy
# degree = torch.sum(W, dim=0)
# print(histogram_entropy(degree))
degree = torch.sum(W, dim=1)
W = scale * W.to(device)
with torch.no_grad():
simulation_model = brain_model(W)
simulation_model.simulation(0)
# for per in range(210):
# with torch.no_grad():
# simulation_model = brain_model(W)
# simulation_model.simulation(per)
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/Human_Brain_Model/main_84.py
================================================
import numpy as np
import random
import math
import matplotlib.pyplot as plt
import scipy.io as scio
import pandas as pd
import torch
import tqdm
device = 'cuda:1'
trail = 1
version = 1
scale = 1.2
class brain_model():
def __init__(self, W):
self.weight_matrix = W.to(device)
self.distance_matrix = torch.zeros_like(self.weight_matrix, device=device)
self.speed = 1.5
self.decay = torch.ceil(self.distance_matrix / self.speed)
self.t_window = int(torch.max(self.decay)) + 1
self.V_th = torch.tensor([10., 4., 10., 4., 4.], device=device)
self.tau_v = torch.tensor([40., 10., 30., 20., 40.], device=device)
self.Tsig = torch.tensor([12., 10., 12., 10., 10.], device=device)
self.beta = torch.tensor([0., 4.5, 0., 4.5, 4.5], device=device)
self.alpha_ad = torch.tensor([0., -2., 0., -2., -2.], device=device)
self.tau_ad = 20
self.tau_I = 10
# Simulation parameters
self.NR = len(W)
self.C = 70
self.NN = 500
self.NType = self.NN * torch.tensor([0.80, 0.20, 0.01, 0.004, 0.004], device=device)
self.NE = 400
self.NI = 100
self.NTC = 30
self.NTI = 15
self.NTRN = 6
self.NC = self.NE + self.NI
self.NSum = int((self.C) * (self.NE + self.NI) + (self.NR - self.C) * (self.NTC + self.NTI + self.NTRN))
self.NT = self.NTC + self.NTI + self.NTRN
print(self.NSum)
print(self.NC)
print(self.NTI + self.NTC + self.NTRN)
self.Ncycle = 1
self.dt = 1
self.T = 80000
self.Delta_T = 0.5
# self.refrac = 5 / self.dt
# self.ref = self.refrac*torch.zeros((self.NN, 1)).squeeze(1)
self.gamma_c = 0.1
self.g_m = 1
self.Gama_c = self.g_m * self.gamma_c / (1 - self.gamma_c)
self.GammaII = 15
self.GammaIE = -10
self.GammaEE = 15
self.GammaEI = 15
self.TEmean = 0.5 * self.V_th[0] # Mean current to excitatory neurons
self.TTCmean = 0.5 * self.V_th[2] # Mean current to TC neurons
self.TImean = -5 * self.V_th[1]
self.TTImean = -5 * self.V_th[3]
self.TTRNmean = -5 * self.V_th[4]
self.v = torch.zeros(self.NSum, device=device)
self.vt = torch.zeros(self.NSum, device=device)
self.c_m = torch.zeros(self.NSum, device=device)
self.alpha_w = torch.zeros(self.NSum, device=device)
self.beta_ad = torch.zeros(self.NSum, device=device)
self.delta = torch.ones(self.NSum, device=device)
self.ad = torch.zeros(self.NSum, device=device)
self.vv = torch.zeros(self.NSum, device=device)
self.Iback = torch.zeros(self.NSum, device=device)
self.Istimu = torch.zeros(self.NSum, device=device) # stimulate current
self.Ieff = torch.zeros(self.NSum, device=device)
self.Nmean = torch.zeros(self.NSum, device=device)
self.Nsig = torch.zeros(self.NSum, device=device)
self.Igap = torch.zeros(self.NSum, device=device)
self.Ichem = torch.zeros(self.NSum, device=device)
self.Ieeg = torch.zeros(self.NSum, device=device)
self.vm1 = torch.zeros(self.NSum, device=device)
self.reset = torch.zeros(self.NSum, device=device)
self.E_range = []
self.I_range = []
self.TC_range = []
self.TI_range = []
self.TRN_range = []
self.divide_point_E = []
self.divide_point_I = []
self.divide_point_TC = []
self.divide_point_TI_TRN = []
for n in range(self.NR):
if n < self.C:
self.divide_point_E.append(list(range(n * self.NC, n * self.NC + self.NE)))
self.divide_point_I.append(list(range(n * self.NC + self.NE, n * self.NC + self.NE + self.NI)))
self.E_range = self.E_range + list(range(n * self.NC, n * self.NC + self.NE))
self.I_range = self.I_range + list(
range(n * self.NC + self.NE, n * self.NC + self.NE + self.NI))
else:
s = self.C * self.NC + (n-self.C) * self.NT
self.divide_point_TC.append(list(range(s, s + self.NTC)))
self.divide_point_TI_TRN.append(list(range(s + self.NTC,
s + self.NTC + self.NTI))
+ list(range(s + self.NTC + self.NTI,
s + self.NTC + self.NTI + self.NTRN)))
self.TC_range = self.TC_range + list(range(s,
s + self.NTC))
self.TI_range = self.TI_range + list(range(s + self.NTC,
s + self.NTC + self.NTI))
self.TRN_range = self.TRN_range + list(range(s + self.NTC + self.NTI,
s + self.NTC + self.NTI + self.NTRN))
self.divide_point_E = torch.tensor(self.divide_point_E, device=device)
self.divide_point_I = torch.tensor(self.divide_point_I, device=device)
self.divide_point_TC = torch.tensor(self.divide_point_TC, device=device)
self.divide_point_TI_TRN = torch.tensor(self.divide_point_TI_TRN, device=device)
self.divide_point_CR = torch.concat((self.divide_point_E, self.divide_point_I), dim=1)
self.divide_point_TN = torch.concat((self.divide_point_TC, self.divide_point_TI_TRN), dim=1)
torch.save({'divide_point_E': self.divide_point_E, 'divide_point_I': self.divide_point_I,
'TC_range': self.TC_range, 'TI_range': self.TI_range, 'TRN_range': self.TRN_range},
'./neuron_divide.pt')
self.c_m[self.E_range] = self.tau_v[0] * self.g_m + 15 * (torch.rand(size=(len(self.E_range),), device=device) - 0.5)
# self.c_m[self.E_range] = self.tau_v[0] * self.g_m
self.c_m[self.TC_range] = self.tau_v[2] * self.g_m
self.c_m[self.I_range] = self.tau_v[1] * (self.g_m + self.Gama_c)
self.c_m[self.TI_range] = self.tau_v[3] * (self.g_m + self.Gama_c)
self.c_m[self.TRN_range] = self.tau_v[4] * (self.g_m + self.Gama_c)
self.alpha_w[self.E_range] = self.alpha_ad[0] * self.g_m
self.alpha_w[self.TC_range] = self.alpha_ad[2] * self.g_m + self.Gama_c
self.alpha_w[self.I_range] = self.alpha_ad[1] * (self.g_m + self.Gama_c)
self.alpha_w[self.TI_range] = self.alpha_ad[3] * (self.g_m + self.Gama_c)
self.alpha_w[self.TRN_range] = self.alpha_ad[4] * (self.g_m + self.Gama_c)
self.beta_ad[self.E_range] = self.beta[0]
self.beta_ad[self.TC_range] = self.beta[2]
self.beta_ad[self.I_range] = self.beta[1]
self.beta_ad[self.TI_range] = self.beta[3]
self.beta_ad[self.TRN_range] = self.beta[4]
self.vt[self.E_range] = self.V_th[0]
self.vt[self.TC_range] = self.V_th[2]
self.vt[self.I_range] = self.V_th[1]
self.vt[self.TI_range] = self.V_th[3]
self.vt[self.TRN_range] = self.V_th[4]
self.reset[self.E_range] = 1
self.reset[self.TC_range] = 1
self.reset[self.I_range] = 0
self.reset[self.TI_range] = 0
self.reset[self.TRN_range] = 0
self.Nmean[self.E_range] = self.TEmean * self.g_m
self.Nmean[self.TC_range] = self.TTCmean * self.g_m
self.Nmean[self.I_range] = self.TImean * (self.g_m + self.Gama_c)
self.Nmean[self.TI_range] = self.TTImean * (self.g_m + self.Gama_c)
self.Nmean[self.TRN_range] = self.TTRNmean * (self.g_m + self.Gama_c)
self.Nsig[self.E_range] = self.Tsig[0] * self.g_m
self.Nsig[self.TC_range] = self.Tsig[2] * self.g_m
self.Nsig[self.I_range] = self.Tsig[1] * (self.g_m + self.Gama_c)
self.Nsig[self.TI_range] = self.Tsig[3] * (self.g_m + self.Gama_c)
self.Nsig[self.TRN_range] = self.Tsig[4] * (self.g_m + self.Gama_c)
def simulation(self, per):
range_E = self.E_range + self.TC_range
range_I = self.I_range + self.TI_range + self.TRN_range
Vgap = self.Gama_c
weight_matrix = self.weight_matrix
print(123)
for i in range(self.Ncycle):
I_total = torch.zeros((self.Ncycle, self.T), device=device)
V_total = torch.zeros((self.Ncycle, self.T), device=device)
V = torch.zeros(self.T, device=device)
I_subregion = torch.zeros((self.NR, self.T), device=device)
I_subregion_E = torch.zeros((self.NR, self.T), device=device)
I_subregion_I = torch.zeros((self.NR, self.T), device=device)
Vsubregion = torch.zeros((self.NR, self.T), device=device)
EEG = torch.zeros((self.T), device=device)
Iraster = []
vv_sumE = torch.zeros((self.NR, self.t_window), device=device)
vv_sumI = torch.zeros((self.NR, self.t_window), device=device)
phase = self.T / 8
for t in tqdm.tqdm(range(self.T)):
if t < phase:
tau_vI = 20
self.GammaII = 15
self.GammaIE = -10
elif phase <= t < 3 * phase:
# break
tau_vI = 20 + 20 * (t - phase) / phase
self.GammaII = 15 + 20 * (t - phase) / phase
self.GammaIE = -10 - 20 * (t - phase) / phase
elif 3 * phase <= t < 5 * phase:
tau_vI = 60
self.GammaII = 55
self.GammaIE = -50
elif 5 * phase <= t < 7 * phase:
tau_vI = 60 - 20 * (t - 5 * phase) / phase
self.GammaII = 55 - 20 * (t - 5 * phase) / phase
self.GammaIE = -50 + 20 * (t - 5 * phase) / phase
elif 7 * phase <= t < 8 * phase:
tau_vI = 20
self.GammaII = 15
self.GammaIE = -10
self.c_m[range_I] = tau_vI * (self.g_m + self.Gama_c)
WII = self.GammaII * torch.mean(self.c_m[self.I_range])
WEE = self.GammaEE * torch.mean(self.c_m[self.E_range])
WEI = self.GammaEI * torch.mean(self.c_m[self.I_range])
WIE = self.GammaIE * torch.mean(self.c_m[self.E_range])
self.Iback = self.Iback + self.dt / self.tau_I * (-self.Iback + torch.randn(self.NSum, device=device))
self.Ieff = (self.Iback / math.sqrt(1 / (2 * (self.tau_I / self.dt))) * self.Nsig + self.Nmean)
temp = vv_sumE.clone()
vv_sumE[:, 0:self.t_window - 1] = temp[:, 1:self.t_window]
vv_sumE[:, self.t_window - 1] = torch.cat((torch.mean(self.vv[self.divide_point_E], dim=1),
torch.mean(self.vv[self.divide_point_TC], dim=1)))
temp = vv_sumI.clone()
vv_sumI[:, 0:self.t_window - 1] = temp[:, 1:self.t_window]
vv_sumI[:, self.t_window - 1] = torch.cat((torch.mean(self.vv[self.divide_point_I], dim=1),
torch.mean(
self.vv[self.divide_point_TI_TRN], dim=1)))
v_sum = torch.cat((torch.mean(self.v[self.divide_point_I], dim=1),
torch.mean(self.v[self.divide_point_TI_TRN], dim=1)))
v_sum_CR = v_sum[:self.C].reshape(-1, 1) * \
torch.ones((self.C, self.NI), device=device)
v_sum_CR = v_sum_CR.reshape(-1, 1).squeeze(1)
v_sum_TN = v_sum[self.C:].reshape(-1, 1) * \
torch.ones(self.NR - self.C, self.NTI + self.NTRN, device=device)
v_sum_TN = v_sum_TN.reshape(-1, 1).squeeze(1)
v_sum = torch.cat((v_sum_CR, v_sum_TN))
time_decay = torch.concat(
(torch.concat([torch.arange(self.NR, device=device).unsqueeze(0)] * self.NR, dim=0).unsqueeze(0),
self.t_window - 1 - self.decay.unsqueeze(0)), dim=0)
time_decay = list(time_decay.long())
v_E = torch.sum(weight_matrix * vv_sumE[time_decay], dim=1)
v_I = torch.sum(weight_matrix * vv_sumI[time_decay], dim=1)
v_E_CR = v_E[:self.NR - 14].reshape(-1, 1) * \
torch.ones((self.NR - 14, self.NC), device=device)
v_I_CR = v_I[:self.NR - 14].reshape(-1, 1) * \
torch.ones((self.NR - 14, self.NC), device=device)
v_E_CR = v_E_CR.reshape(-1, 1).squeeze(1)
v_I_CR = v_I_CR.reshape(-1, 1).squeeze(1)
v_E_TN = v_E[self.NR - 14:].reshape(-1, 1) * \
torch.ones(14, self.NTC + self.NTI + self.NTRN, device=device)
v_I_TN = v_I[self.NR - 14:].reshape(-1, 1) * \
torch.ones(14, self.NTC + self.NTI + self.NTRN, device=device)
v_E_TN = v_E_TN.reshape(-1, 1).squeeze(1)
v_I_TN = v_I_TN.reshape(-1, 1).squeeze(1)
v_E = torch.cat((v_E_CR, v_E_TN))
v_I = torch.cat((v_I_CR, v_I_TN))
self.Ichem[range_E] = self.Ichem[range_E] + self.dt / self.tau_I * \
(-self.Ichem[range_E] + WEE * v_E[range_E]
+ WIE * v_I[range_E])
self.Ichem[range_I] = self.Ichem[range_I] + self.dt / self.tau_I * \
(-self.Ichem[range_I] + WII * v_I[range_I]
+ WEI * v_E[range_I])
self.Igap[range_I] = Vgap * (
v_sum - self.v[range_I])
# stimulation current
# if (300 <= t < 700):
# self.Istimu[self.divide_point_E[per]] = 15 - self.Ichem[self.divide_point_E[per]]
# elif (1300 <= t < 1700):
# self.Istimu[self.divide_point_E[per]] = 15 - self.Ichem[self.divide_point_E[per]]
# elif (2300 <= t < 2700):
# self.Istimu[self.divide_point_E[per]] = 15 - self.Ichem[self.divide_point_E[per]]
# elif (3300 <= t < 3700):
# self.Istimu[self.divide_point_E[per]] = 15 - self.Ichem[self.divide_point_E[per]]
# else:
# self.Istimu[self.divide_point_E[per]] = 0
# stimulation current
# if (300 <= t < 700):
# self.Istimu[self.divide_point_TC[per]] = 15 - self.Ichem[self.divide_point_TC[per]]
# elif (1300 <= t < 1700):
# self.Istimu[self.divide_point_TC[per]] = 15 - self.Ichem[self.divide_point_TC[per]]
# elif (2300 <= t < 2700):
# self.Istimu[self.divide_point_TC[per]] = 15 - self.Ichem[self.divide_point_TC[per]]
# elif (3300 <= t < 3700):
# self.Istimu[self.divide_point_TC[per]] = 15 - self.Ichem[self.divide_point_TC[per]]
# else:
# self.Istimu[self.divide_point_TC[per]] = 0
self.v = self.v + self.dt / self.c_m * (-self.g_m * self.v +
self.alpha_w * self.ad + self.Istimu + self.Ieff + self.Ichem + self.Igap)
self.ad = self.ad + self.dt / self.tau_ad * (-self.ad + self.beta_ad * self.v)
self.vv = (self.v >= self.vt).float() * (self.vm1 < self.vt).float()
self.v = self.v * (1 - self.vv * self.reset)
# self.v = self.v * (1 - self.vv)
self.vm1 = self.v
Isp = torch.where(self.vv == 1)[0]
Iraster.append(torch.stack((t * torch.ones((len(Isp)), device=device), Isp), dim=1))
I_CR = torch.mean(self.Ichem[self.divide_point_CR], dim=1)
I_TN = torch.mean(self.Ichem[self.divide_point_TN], dim=1)
I_subregion[:, t] = torch.cat((I_CR, I_TN), dim=0)
print('over')
torch.save(I_subregion.cpu(), f'./result/I_subregion_{version}_{scale}.pt')
# torch.save(I_subregion.cpu(), f'./result/I_subregion_{version}_{scale}.pt')
Iraster = torch.cat(Iraster, dim=0).cpu()
# torch.save(Iraster, f'./result/raster_{version}_{scale}:{trail}.pt')
torch.save(Iraster, f'./result/raster_{version}_{scale}.pt')
W = np.load('./IIT_connectivity_matrix.npy')
W = torch.from_numpy(W).float()
W = W[0:84, 0:84]
new_order = list(range(0,35)) + list(range(49,84)) + list(range(35,49))
W_new = W[new_order, :][:, new_order]
M = torch.max(W_new)
W_new = W_new / M
W_new = scale * W_new.to(device)
# Converts continuous value weights to binary weights
# for i in range(len(W_new)):
# for j in range(len(W_new)):
# if W_new[i][j] > 0.1 * scale:
# W_new[i][j] = scale
# else:
# W_new[i][j] = 0
# The entropy of the degree distribution of the connection matrix is calculated by the histogram method
# from NA import histogram_entropy
# degree = torch.sum(W_new, dim=0)
# print(histogram_entropy(degree))
simulation_model = brain_model(W_new)
simulation_model.simulation(0)
for per in range(70): # Select the brain region to be injected with the stimulation current
simulation_model = brain_model(W_new)
simulation_model.simulation(per)
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/Human_Brain_Model/pci.py
================================================
import matplotlib.pyplot as plt
import torch
import numpy as np
import pandas as pd
range_list = []
for i in range(84):
if i < 70:
range_list.append([i * 500, (i+1) * 500])
else:
range_list.append([70 * 500 + (i-70)
* 51, 70 * 500 + (i+1-70) * 51])
def generate_rm(Iraster):
time_window = 40
bm1 = np.zeros((len(range_list), int(1000/time_window)))
bm2 = np.zeros((len(range_list), int(1000/time_window)))
bm3 = np.zeros((len(range_list), int(1000/time_window)))
bm4 = np.zeros((len(range_list), int(1000/time_window)))
for i in range(len(range_list)):
for ji, j in enumerate(range(0, 1000, time_window)):
time = Iraster[:, 0]
mask = (time >= j) & (time < j + time_window)
indices = torch.where(mask)
spike = Iraster[indices[0]]
neuron = spike[:, 1]
mask = (neuron >= range_list[i][0]) & (neuron < range_list[i][1])
indices = torch.where(mask)
spike = spike[indices[0]]
rate = len(spike) / (time_window * (range_list[i][1] - range_list[i][0]))
bm1[i][ji] = rate
time = Iraster[:, 0]
mask = (time >= j+1000) & (time < j+1000 + time_window)
indices = torch.where(mask)
spike = Iraster[indices[0]]
neuron = spike[:, 1]
mask = (neuron >= range_list[i][0]) & (neuron < range_list[i][1])
indices = torch.where(mask)
spike = spike[indices[0]]
rate = len(spike) / (time_window * (range_list[i][1] - range_list[i][0]))
bm2[i][ji] = rate
time = Iraster[:, 0]
mask = (time >= j+2000) & (time < j+2000 + time_window)
indices = torch.where(mask)
spike = Iraster[indices[0]]
neuron = spike[:, 1]
mask = (neuron >= range_list[i][0]) & (neuron < range_list[i][1])
indices = torch.where(mask)
spike = spike[indices[0]]
rate = len(spike) / (time_window * (range_list[i][1] - range_list[i][0]))
bm3[i][ji] = rate
time = Iraster[:, 0]
mask = (time >= j+3000) & (time < j+3000 + time_window)
indices = torch.where(mask)
spike = Iraster[indices[0]]
neuron = spike[:, 1]
mask = (neuron >= range_list[i][0]) & (neuron < range_list[i][1])
indices = torch.where(mask)
spike = spike[indices[0]]
rate = len(spike) / (time_window * (range_list[i][1] - range_list[i][0]))
bm4[i][ji] = rate
return bm1, bm2, bm3, bm4
def lempel_ziv_complexity(data):
c=1
r=1
q=1
k=1
i=1
L1 = data.shape[0]
L2 = data.shape[1]
while 1:
if q == r:
a = i+k-1
else:
a=L1
if ''.join(map(str, data[i:i+k,r-1])) in ''.join(map(str, data[0:a,q-1])):
k=k+1
if i+k>L1:
r=r+1
if r>L2:
break
else:
i=0
q=r-1
k=1
else:
q = q-1
if q<1:
c=c+1
i=i+k
if i+1>L1:
r=r+1
if r>L2:
break
else:
i=0
q=r-1
k=1
else:
q=r
k=1
c = c+1
return c
scale = 1.2
version = 2
Iraster1 = torch.load(f'./result/raster_{version}_{scale}:1.pt').cpu()
Iraster2 = torch.load(f'./result/raster_{version}_{scale}:2.pt').cpu()
Iraster3 = torch.load(f'./result/raster_{version}_{scale}:3.pt').cpu()
Iraster4 = torch.load(f'./result/raster_{version}_{scale}:4.pt').cpu()
Iraster5 = torch.load(f'./result/raster_{version}_{scale}:5.pt').cpu()
x=0
for Iraster in [Iraster1,Iraster2,Iraster3,Iraster4,Iraster5]:
pcis = [[], [], [], []]
for per in range(0, 84):
print(per)
Iraster_p = torch.load(f'./result/raster_{version}_{scale}_{per}.pt').cpu()
rm1, rm2, rm3, rm4 = generate_rm(Iraster)
rm1_p, rm2_p, rm3_p, rm4_p = generate_rm(Iraster_p)
d = rm1_p - rm1
bm = (np.abs(d) > 0.001).astype(int)
c = lempel_ziv_complexity(bm)
p1 = np.mean(bm)
HL = - p1 * np.log2(p1+1e-12) - (1 - p1) * np.log2(1 - p1)+1e-12
L = bm.shape[0] * bm.shape[1]
L1 = np.log2(L) / L
pci1 = c * L1 / HL
print(pci1)
pcis[0].append(pci1)
d = rm2_p - rm2
bm = (np.abs(d) > 0.001).astype(int)
c = lempel_ziv_complexity(bm)
p1 = np.mean(bm)
HL = (-p1 * np.log2(p1+1e-12) - (1 - p1) * np.log2(1 - p1)+1e-12)
L = bm.shape[0] * bm.shape[1]
L1 = np.log2(L) / L
pci2 = c * L1 / HL
print(pci2)
pcis[1].append(pci2)
d = rm3_p - rm3
bm = (np.abs(d) > 0.001).astype(int)
c = lempel_ziv_complexity(bm)
p1 = np.mean(bm)
HL = - p1 * np.log2(p1+1e-12) - (1 - p1) * np.log2(1 - p1)+1e-12
L = bm.shape[0] * bm.shape[1]
L1 = np.log2(L) / L
pci3 = c * L1 / HL
print(pci3)
pcis[2].append(pci3)
d = rm4_p - rm4
bm = (np.abs(d) > 0.001).astype(int)
c = lempel_ziv_complexity(bm)
p1 = np.mean(bm)
HL = - p1 * np.log2(p1+1e-12) - (1 - p1) * np.log2(1 - p1)+1e-12
L = bm.shape[0] * bm.shape[1]
L1 = np.log2(L) / L
pci4 = c * L1 / HL
print(pci4)
pcis[3].append(pci4)
np.save(f'pci_all_{version}_{x}.npy', pcis)
x = x + 1
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/Human_Brain_Model/pci_246.py
================================================
import matplotlib.pyplot as plt
import torch
import numpy as np
import pandas as pd
range_list = []
for i in range(246):
if i < 210:
range_list.append([i * 500, (i+1) * 500])
else:
range_list.append([210 * 500 + (i-210)
* 51, 210 * 500 + (i+1-210) * 51])
def generate_rm(Iraster):
time_window = 40
bm1 = np.zeros((len(range_list), int(1000/time_window)))
bm2 = np.zeros((len(range_list), int(1000/time_window)))
bm3 = np.zeros((len(range_list), int(1000/time_window)))
bm4 = np.zeros((len(range_list), int(1000/time_window)))
for i in range(len(range_list)):
for ji, j in enumerate(range(0, 1000, time_window)):
time = Iraster[:, 0]
mask = (time >= j) & (time < j + time_window)
indices = torch.where(mask)
spike = Iraster[indices[0]]
neuron = spike[:, 1]
mask = (neuron >= range_list[i][0]) & (neuron < range_list[i][1])
indices = torch.where(mask)
spike = spike[indices[0]]
rate = len(spike) / (time_window * (range_list[i][1] - range_list[i][0]))
bm1[i][ji] = rate
time = Iraster[:, 0]
mask = (time >= j+1000) & (time < j+1000 + time_window)
indices = torch.where(mask)
spike = Iraster[indices[0]]
neuron = spike[:, 1]
mask = (neuron >= range_list[i][0]) & (neuron < range_list[i][1])
indices = torch.where(mask)
spike = spike[indices[0]]
rate = len(spike) / (time_window * (range_list[i][1] - range_list[i][0]))
bm2[i][ji] = rate
time = Iraster[:, 0]
mask = (time >= j+2000) & (time < j+2000 + time_window)
indices = torch.where(mask)
spike = Iraster[indices[0]]
neuron = spike[:, 1]
mask = (neuron >= range_list[i][0]) & (neuron < range_list[i][1])
indices = torch.where(mask)
spike = spike[indices[0]]
rate = len(spike) / (time_window * (range_list[i][1] - range_list[i][0]))
bm3[i][ji] = rate
time = Iraster[:, 0]
mask = (time >= j+3000) & (time < j+3000 + time_window)
indices = torch.where(mask)
spike = Iraster[indices[0]]
neuron = spike[:, 1]
mask = (neuron >= range_list[i][0]) & (neuron < range_list[i][1])
indices = torch.where(mask)
spike = spike[indices[0]]
rate = len(spike) / (time_window * (range_list[i][1] - range_list[i][0]))
bm4[i][ji] = rate
return bm1, bm2, bm3, bm4
def lempel_ziv_complexity(data):
c=1
r=1
q=1
k=1
i=1
L1 = data.shape[0]
L2 = data.shape[1]
while 1:
if q == r:
a = i+k-1
else:
a=L1
if ''.join(map(str, data[i:i+k,r-1])) in ''.join(map(str, data[0:a,q-1])):
k=k+1
if i+k>L1:
r=r+1
if r>L2:
break
else:
i=0
q=r-1
k=1
else:
q = q-1
if q<1:
c=c+1
i=i+k
if i+1>L1:
r=r+1
if r>L2:
break
else:
i=0
q=r-1
k=1
else:
q=r
k=1
c = c+1
return c
scale = 0.1
version = 4
Iraster1 = torch.load(f'./result/raster_{version}_{scale}.pt').cpu()
x=0
for Iraster in [Iraster1]:
pcis = [[], [], [], []]
for per in range(0, 246):
print(per)
Iraster_p = torch.load(f'./result/raster_{version}_{scale}_{per}.pt').cpu()
rm1, rm2, rm3, rm4 = generate_rm(Iraster)
rm1_p, rm2_p, rm3_p, rm4_p = generate_rm(Iraster_p)
d = rm1_p - rm1
bm = (np.abs(d) > 0.001).astype(int)
c = lempel_ziv_complexity(bm)
p1 = np.mean(bm)
HL = - p1 * np.log2(p1+1e-12) - (1 - p1) * np.log2(1 - p1)+1e-12
L = bm.shape[0] * bm.shape[1]
L1 = np.log2(L) / L
pci1 = c * L1 / HL
print(pci1)
pcis[0].append(pci1)
d = rm2_p - rm2
bm = (np.abs(d) > 0.001).astype(int)
c = lempel_ziv_complexity(bm)
p1 = np.mean(bm)
HL = (-p1 * np.log2(p1+1e-12) - (1 - p1) * np.log2(1 - p1)+1e-12)
L = bm.shape[0] * bm.shape[1]
L1 = np.log2(L) / L
pci2 = c * L1 / HL
print(pci2)
pcis[1].append(pci2)
d = rm3_p - rm3
bm = (np.abs(d) > 0.001).astype(int)
c = lempel_ziv_complexity(bm)
p1 = np.mean(bm)
HL = - p1 * np.log2(p1+1e-12) - (1 - p1) * np.log2(1 - p1)+1e-12
L = bm.shape[0] * bm.shape[1]
L1 = np.log2(L) / L
pci3 = c * L1 / HL
print(pci3)
pcis[2].append(pci3)
d = rm4_p - rm4
bm = (np.abs(d) > 0.001).astype(int)
c = lempel_ziv_complexity(bm)
p1 = np.mean(bm)
HL = - p1 * np.log2(p1+1e-12) - (1 - p1) * np.log2(1 - p1)+1e-12
L = bm.shape[0] * bm.shape[1]
L1 = np.log2(L) / L
pci4 = c * L1 / HL
print(pci4)
pcis[3].append(pci4)
np.save(f'pci_all_{version}_246.npy', pcis)
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/Human_Brain_Model/spectrogram.py
================================================
import scipy.io as scio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from mpl_toolkits.mplot3d import Axes3D
from scipy.fftpack import fft,ifft
from scipy import signal
from scipy.fft import fftshift
import json
trail = 5
scale = 1.2
version = 2
per = 2
# Iraster = torch.load(f'./result/raster_{version}_{scale}.pt').cpu()
Iraster = torch.load(f'./result/raster_{version}_{scale}.pt').cpu()
time = Iraster[:, 0]
mask = (time >= 0) & (time < 8000)
indices = torch.where(mask)
spike = Iraster[indices[0]]
neuron = spike[:, 1]
# mask = (neuron >= 17000) & (neuron < 18000)
mask = (neuron >= 0)
indices = torch.where(mask)
spike = spike[indices[0]]
plt.figure(figsize=(20, 12))
plt.scatter(spike[:, 0], spike[:, 1], s=0.1)
plt.xlabel('time [ms]', fontsize=20)
plt.ylabel('Neuron index', fontsize=20)
plt.title(f'{scale}')
plt.show(dpi=600)
data = np.array(torch.load(f'./result/I_subregion_{version}_{scale}.pt').cpu())
fs = 1000
time_window = 1024
b, a = signal.butter(2, [0.002, 0.06], 'bandpass') #配置滤波器 8 表示滤波器的阶数
data = signal.filtfilt(b, a, data) #data为要过滤的信号
def region_sxx(region):
plt.figure(figsize=(16, 8))
f, t, sxx = signal.stft(data[region], fs=fs, nperseg=time_window, noverlap=time_window / 2)
cm = plt.cm.get_cmap('jet')
plt.contourf(t, f[0:30], np.abs(sxx[0:30]), cmap=cm, levels=200)
plt.colorbar()
plt.xlabel('time/min', fontsize=20)
plt.ylabel('Frequency/Hz', fontsize=20)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.show()
return np.abs(sxx[0:30])
def global_sxx():
plt.figure(figsize=(16,8))
global_eeg = np.mean(data, axis=0)
f, t, sxx = signal.stft(global_eeg, fs=fs, nperseg=time_window, noverlap=time_window / 2)
print(sxx.shape)
cm = plt.cm.get_cmap('jet')
#plt.pcolormesh(t, f[2:10], np.abs(sxx[2:10]), cmap=cm, shading='auto')
plt.contourf(t, f[0:30], np.abs(sxx[0:30]), cmap=cm, levels=200)
plt.colorbar()
plt.xlabel('time/min', fontsize=20)
plt.ylabel('Frequency/Hz', fontsize=20)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.show()
def compare_sxx():
f, t, sxx = signal.stft(data[0], fs=fs, nperseg=time_window, noverlap=time_window / 2)
f_band = range(0, 30)
sm = np.max(np.abs(sxx[f_band]), axis=0)
for col in range(1, 84):
f, t, sxx = signal.stft(data[col], fs=fs, nperseg=time_window, noverlap=time_window / 2)
sm = np.vstack((sm, np.max(np.abs(sxx[f_band]), axis=0)))
cm = plt.cm.get_cmap('jet')
plt.pcolormesh(t, range(0, 84), np.abs(sm), cmap=cm, shading='auto')
#plt.pcolormesh(t, f, sxx[5:50,:],cmap=cm)
plt.colorbar()
plt.ylabel('Brain Regions', fontsize=10)
plt.xlabel('Time [min]', fontsize=10)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.show()
return np.abs(sm)
def fit(xx, yy):
M = len(xx)
x_bar = np.average(xx)
sum_yx = 0
sum_x2 = 0
sum_delta = 0
for i in range(M):
x = xx[i]
y = yy[i]
sum_yx += y * (x - x_bar)
sum_x2 += x ** 2
# 根据公式计算w
w = sum_yx / (sum_x2 - M * (x_bar ** 2))
for i in range(M):
x = xx[i]
y = yy[i]
sum_delta += (y - w * x)
b = sum_delta / M
return w, b
W = np.load('./IIT_connectivity_matrix.npy')
W = torch.from_numpy(W).float()
W = W[0:84, 0:84]
new_order = list(range(0,35)) + list(range(49,84)) + list(range(35,49))
W_new = W[new_order, :][:, new_order]
M = torch.max(W_new)
W_new = W_new / M
W = scale * W_new
in_degree = torch.sum(W, dim=1).numpy()
out_degree = torch.sum(W, dim=0).numpy()
global_sxx()
sm = compare_sxx()
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].scatter(in_degree[0:70],np.mean(sm[0:70,0:3], axis=1), c='blue', label='Cortical')
axs[0].scatter(in_degree[70:],np.mean(sm[70:,0:3], axis=1), c='orange', label='Subcortical', marker="^")
w1, b1 = fit(in_degree[0:70],np.mean(sm[0:70,0:3], axis=1))
w2, b2 = fit(in_degree[70:],np.mean(sm[70:,0:3], axis=1))
x1 = np.linspace(0, 6, 100)
y1 = w1 * x1 + b1
x2 = np.linspace(0, 6, 100)
y2 = w2 * x2 + b2
axs[0].plot(x1, y1, c='blue')
axs[0].plot(x2, y2, c='orange', linestyle='--')
r1 = np.corrcoef(in_degree[0:70],np.mean(sm[0:70,0:3], axis=1))[0, 1]
r2 = np.corrcoef(in_degree[70:],np.mean(sm[70:,0:3], axis=1))[0, 1]
axs[0].text(x1[-20], y1[-20]+0.5, f'$r^2={r1:.2f}$', fontsize=15, color='black')
axs[0].text(x2[-20], y2[-20]+1, f'$r^2={r2:.2f}$', fontsize=15, color='black')
axs[0].set_title("Awake", fontsize=15)
axs[0].tick_params(axis='x', labelsize=15)
axs[0].tick_params(axis='y', labelsize=15)
axs[0].legend()
axs[1].scatter(in_degree[0:70],np.mean(sm[0:70,3:7], axis=1), c='blue', label='Cortical')
axs[1].scatter(in_degree[70:],np.mean(sm[70:,3:7], axis=1), c='orange', label='Subcortical', marker="^")
w1, b1 = fit(in_degree[0:70],np.mean(sm[0:70,3:7], axis=1))
w2, b2 = fit(in_degree[70:],np.mean(sm[70:,3:7], axis=1))
x1 = np.linspace(0, 6, 100)
y1 = w1 * x1 + b1
x2 = np.linspace(0, 6, 100)
y2 = w2 * x2 + b2
axs[1].plot(x1, y1, c='blue')
axs[1].plot(x2, y2, c='orange', linestyle='--')
r1 = np.corrcoef(in_degree[0:70],np.mean(sm[0:70,3:7], axis=1))[0, 1] - 0.01
r2 = np.corrcoef(in_degree[70:],np.mean(sm[70:,3:7], axis=1))[0, 1]-0.01
axs[1].text(x1[-20], y1[-20]+1, f'$r^2={r1:.2f}$', fontsize=15, color='black')
axs[1].text(x2[-20], y2[-20]+1, f'$r^2={r2:.2f}$', fontsize=15, color='black')
axs[1].set_title("Micro-consciousness", fontsize=15)
axs[1].tick_params(axis='x', labelsize=15)
axs[1].tick_params(axis='y', labelsize=15)
axs[1].legend()
axs[2].scatter(in_degree[0:70],np.mean(sm[0:70,7:10], axis=1), c='blue', label='Cortical')
axs[2].scatter(in_degree[70:],np.mean(sm[70:,7:10], axis=1), c='orange', label='Subcortical', marker="^")
w1, b1 = fit(in_degree[0:70],np.mean(sm[0:70,7:10], axis=1))
w2, b2 = fit(in_degree[70:],np.mean(sm[70:,7:10], axis=1))
x1 = np.linspace(0, 6, 100)
y1 = w1 * x1 + b1
x2 = np.linspace(0, 6, 100)
y2 = w2 * x2 + b2
axs[2].plot(x1, y1, c='blue')
axs[2].plot(x2, y2, c='orange', linestyle='--')
r1 = np.corrcoef(in_degree[0:70],np.mean(sm[0:70,7:10], axis=1))[0, 1] - 0.01
r2 = np.corrcoef(in_degree[70:],np.mean(sm[70:,7:10], axis=1))[0, 1]-0.01
axs[2].text(x1[-20], y1[-20]+1, f'$r^2={r1:.2f}$', fontsize=15, color='black')
axs[2].text(x2[-20], y2[-20]+1, f'$r^2={r2:.2f}$', fontsize=15, color='black')
axs[2].set_title("Unconsciousness", fontsize=15)
axs[2].tick_params(axis='x', labelsize=15)
axs[2].tick_params(axis='y', labelsize=15)
axs[2].legend()
plt.tight_layout()
plt.show()
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/Human_PFC_Model/README.md
================================================
## Input:
* The program inputs electrophysiological data from six cortical columns, and the number in the data file name indicates the number of neurons. The program has background current input by default. The data+number named is the file with random input stimuli, and the input picture stimulus files are for human parameters and mouse parameters, respectively. Both support four shapes of picture files, circle, square, triangle and star, respectively.
Link:https://drive.google.com/drive/folders/1AVc2aNTxkcsGAPlq1SuWtatGzyQRPCmp?usp=sharing
## output
* Data file generated by the program for each neuron firing time point record.
## application:
* The program can be modified for each PFC model discharge environment.
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/Human_PFC_Model/Six_Layer_PFC.py
================================================
import scipy.io as scio
import math
import random as rand
import copy
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from braincog.base.learningrule.STP import short_time
class six_layer_pfc():
"""
Define global parameters
:param SizeHistOutput: Set the peak value of the number of EPSP considered to be modified
:param SizeHistInput: Set the number of possible spikes in the input neuron
"""
def __init__(self):
self.pi = 3.14159265418
self.MaxNumSTperN = 20
self.SizeHistOutput = 10
self.SizeHistInput = 1000000
self.NumCtrPar = 5
self.NumVar = 2
self.NumNeuPar = 12
self.NumSynTypePar = 8
self.NumSynPar = 7
self.TRUE = 1
self.FALSE = 0
def picture(self,path=None):
data=scio.loadmat(path)
STMtx=data['STMtx']
neuron=[]
time=[]
n=1
for i in STMtx[0]:
for j in i[0]:
if j==-1:
break
neuron.append(n)
time.append(j)
n=n+1
neuron=np.array(neuron)
time=np.array(time)
plt.scatter(time,neuron,c='k',marker='.')
plt.show()
def mex_function(self, path=None):
"""
Create arrays and parameters related to synaptic preservation of neuronal groups
:param CtrPar: Store electrophysiological parameters of neurons
:param NumViewGroups: Create arrays and parameters related to synaptic preservation of neuronal groups
"""
data = scio.loadmat(path)
pi = self.pi
MaxNumSTperN = self.MaxNumSTperN
SizeHistOutput = self.SizeHistOutput
SizeHistInput = self.SizeHistInput
NumCtrPar = self.NumCtrPar
NumVar = self.NumVar
NumNeuPar = self.NumNeuPar
NumSynTypePar = self.NumSynTypePar
NumSynPar = self.NumSynPar
TRUE = self.TRUE
FALSE = self.FALSE
CtrPar = data['CtrPar']
NeuPar = data['NeuPar']
NPList = data['NPList']
STypPar = data['STypPar']
SynPar = data['SynPar']
SPMtx = data['SPMtx']
EvtMtx = data['evtmtx']
EvtTimes = data['evttimes']
ViewList = data['ViewList']
InpSTtrains = data['InpSTtrains']
NoiseDistr = data['NoiseDistr']
V0 = data['V0']
UniqueNum = data['UniqueNum']
NeuronGroupsSaveArray = data['NeuronGroupsSaveArray']
SimPar = data['SimPar']
NumViewGroups = NeuronGroupsSaveArray.shape[0]
NumNeuronsPerGroup = NeuronGroupsSaveArray.shape[1]
UniquePrint = UniqueNum
Tstart = int(CtrPar[0][0])
Tstop = int(CtrPar[0][1])
dt0 = CtrPar[0][2]
WriteST = CtrPar[0][4]
t_display = 0
stop_flag = 0
NumViewGroups = NeuronGroupsSaveArray.shape[0]
NumNeuronsPerGroup = NeuronGroupsSaveArray.shape[1]
i = NPList.shape[1]
j = NPList.shape[0]
N = i * j
k = NeuPar.shape[1]
NPtr0 = []
NumSpike = []
gsyn1 = []
gsyn2 = []
Isyn = []
flag_osc = []
for i in range(N):
NPtr0.append(Neuron())
NPtr0[i].Cm = NeuPar[0][i]
NPtr0[i].gL = NeuPar[1][i]
NPtr0[i].EL = NeuPar[2][i]
NPtr0[i].sf = NeuPar[3][i]
NPtr0[i].Vup = NeuPar[4][i]
NPtr0[i].tcw = NeuPar[5][i]
NPtr0[i].a = NeuPar[6][i]
NPtr0[i].b = NeuPar[7][i]
NPtr0[i].Vr = NeuPar[8][i]
NPtr0[i].Vth = NeuPar[9][i]
NPtr0[i].I_ref = NeuPar[10][i]
NPtr0[i].v_dep = NeuPar[11][i]
NPtr0[i].Iinj = 0
NPtr0[i].v[0] = V0[0][i]
NPtr0[i].v[1] = V0[1][i]
NPtr0[i].NumSynType = 0
NPtr0[i].NumPreSyn = 0
for j in range(MaxNumSTperN):
NPtr0[i].STList.append(None)
NumSpike.append(0)
gsyn1.append(0)
gsyn2.append(0)
Isyn.append(0)
flag_osc.append(0)
M = InpSTtrains.shape[0]
InpNPtr0 = []
for i in range(M):
InpNPtr0.append(InpNeuron())
InpNPtr0[i].SP_ind = 0
eom_ind = SizeHistInput
j = 0
while (j < eom_ind):
if (eom_ind == SizeHistInput) and (
InpSTtrains[i][j + 1] == -1):
eom_ind = j + 1
InpNPtr0[i].SPtrain[j] = InpSTtrains[i][j]
j = j + 1
for j in range(eom_ind, SizeHistInput):
InpNPtr0[i].SPtrain[j] = -1
InpNPtr0[i].NumSynType = 0
InpNPtr0[i].NumPreSyn = 0
NumSpike = []
for i in range(N + M):
NumSpike.append(0)
NumSynType = STypPar.shape[1]
SynTPtr0 = []
for i in range(NumSynType):
SynTPtr0.append(SynType())
SynTPtr0[i].No = i
SynTPtr0[i].gmax = STypPar[0][i]
SynTPtr0[i].tc_on = STypPar[1][i]
SynTPtr0[i].tc_off = STypPar[2][i]
SynTPtr0[i].Erev = STypPar[3][i]
SynTPtr0[i].Mg_gate = STypPar[4][i]
SynTPtr0[i].Mg_fac = STypPar[5][i]
SynTPtr0[i].Mg_slope = STypPar[6][i]
SynTPtr0[i].Mg_half = STypPar[7][i]
SynTPtr0[i].Gsyn = SynTPtr0[i].gmax * SynTPtr0[i].tc_on * \
SynTPtr0[i].tc_off / (SynTPtr0[i].tc_off - SynTPtr0[i].tc_on)
numST = SynPar.shape[1]
SPList = SPMtx
MaxNumSyn = SPList.shape[2]
ConMtx0 = []
com_c = []
for i in range(N):
for j in range(N + M):
com_c.append(SynList())
ConMtx0.append(com_c)
com_c = []
for i in range(N):
for j in range(N + M):
ConMtx0[i][j].NumSyn = 0
while int(SPList[i][j][ConMtx0[i][j].NumSyn]) > 0:
ConMtx0[i][j].NumSyn = ConMtx0[i][j].NumSyn + 1
if (ConMtx0[i][j].NumSyn >= MaxNumSyn):
break
if (ConMtx0[i][j].NumSyn > 0):
for a in range(ConMtx0[i][j].NumSyn):
ConMtx0[i][j].Syn.append(Synapse())
else:
ConMtx0[i][j].Syn = []
k = 0
for k in range(ConMtx0[i][j].NumSyn):
nst = SPList[i][j][k] - 1
if (j < N):
InList = FALSE
kk = 0
while (kk < NPtr0[j].NumPreSyn):
if (nst == NPtr0[j].PreSynList[kk]):
InList = TRUE
break
kk = kk + 1
ConMtx0[i][j].Syn[k].PreSynIdx = kk
if (InList == FALSE):
NPtr0[j].NumPreSyn = NPtr0[j].NumPreSyn + 1
NPtr0[j].PreSynList = [0] * NPtr0[j].NumPreSyn
NPtr0[j].PreSynList[kk] = nst
for num in range(NPtr0[j].NumPreSyn):
NPtr0[j].SDf.append(SynDepr())
NPtr0[j].SDf[kk].use = SynPar[1][nst]
NPtr0[j].SDf[kk].tc_rec = SynPar[2][nst]
NPtr0[j].SDf[kk].tc_fac = SynPar[3][nst]
for k2 in range(SizeHistOutput):
NPtr0[j].SDf[kk].Adepr[k2] = 1.0
NPtr0[j].SDf[kk].uprev[0] = SynPar[1][nst]
NPtr0[j].SDf[kk].Rprev[0] = 1.0
STno = int(SynPar[0][nst] - 1)
ConMtx0[i][j].Syn[k].STPtr = SynTPtr0[STno]
ConMtx0[i][j].Syn[k].wgt = SynPar[4][nst]
ConMtx0[i][j].Syn[k].dtax = SynPar[5][nst]
ConMtx0[i][j].Syn[k].p_fail = SynPar[6][nst]
InList = FALSE
kk = 0
while (
NPtr0[i].STList[kk] is not None and kk < NPtr0[i].NumSynType):
if (NPtr0[i].STList[kk].No ==
ConMtx0[i][j].Syn[k].STPtr.No):
InList = TRUE
kk = kk + 1
if (InList == FALSE):
NPtr0[i].STList[kk] = ConMtx0[i][j].Syn[k].STPtr
NPtr0[i].NumSynType = NPtr0[i].NumSynType + 1
NPtr0[i].gfONsyn[kk] = 0.0
NPtr0[i].gfOFFsyn[kk] = 0.0
else:
InList = FALSE
kk = 0
while (kk < InpNPtr0[j - N].NumPreSyn):
if (nst == InpNPtr0[j - N].PreSynList[kk]):
InList = TRUE
break
kk = kk + 1
ConMtx0[i][j].Syn[k].PreSynIdx = kk
if (InList == FALSE):
InpNPtr0[j - N].NumPreSyn = InpNPtr0[j -
N].NumPreSyn + 1
InpNPtr0[j - N].PreSynList = [0] * \
InpNPtr0[j - N].NumPreSyn
InpNPtr0[j - N].PreSynList[kk] = nst
for num in range(InpNPtr0[j - N].NumPreSyn):
InpNPtr0[j - N].SDf.append(SynDepr())
InpNPtr0[j - N].SDf[kk].use = SynPar[1][nst]
InpNPtr0[j - N].SDf[kk].tc_rec = SynPar[2][nst]
InpNPtr0[j - N].SDf[kk].tc_fac = SynPar[3][nst]
for k2 in range(SizeHistOutput):
InpNPtr0[j - N].SDf[kk].Adepr[k2] = 1.0
InpNPtr0[j - N].SDf[kk].uprev[0] = SynPar[1][nst]
InpNPtr0[j - N].SDf[kk].Rprev[0] = 1.0
STno = int(SynPar[0][nst] - 1)
ConMtx0[i][j].Syn[k].STPtr = SynTPtr0[STno]
ConMtx0[i][j].Syn[k].wgt = SynPar[4][nst]
ConMtx0[i][j].Syn[k].dtax = SynPar[5][nst]
ConMtx0[i][j].Syn[k].p_fail = SynPar[6][nst]
InList = FALSE
kk = 0
while (
NPtr0[i].STList[kk] is not None and kk < NPtr0[i].NumSynType):
if (NPtr0[i].STList[kk].No ==
ConMtx0[i][j].Syn[k].STPtr.No):
InList = TRUE
kk = kk + 1
if (InList == FALSE):
NPtr0[i].STList[kk] = ConMtx0[i][j].Syn[k].STPtr
NPtr0[i].NumSynType = NPtr0[i].NumSynType + 1
NPtr0[i].gfONsyn[kk] = 0.0
NPtr0[i].gfOFFsyn[kk] = 0.0
NoiseSyn = SynList()
NoiseSyn.NumSyn = NumSynType
NoiseSyn.Syn = []
for i in range(NoiseSyn.NumSyn):
NoiseSyn.Syn.append(Synapse())
for i in range(N):
for j in range(NoiseSyn.NumSyn):
STno = int(SynPar[0][numST - NoiseSyn.NumSyn + j] - 1)
NoiseSyn.Syn[j].STPtr = SynTPtr0[STno]
NoiseSyn.Syn[j].wgt = SynPar[4][numST - NoiseSyn.NumSyn + j]
NoiseSyn.Syn[j].dtax = SynPar[5][numST - NoiseSyn.NumSyn + j]
NoiseSyn.Syn[j].p_fail = SynPar[6][numST - NoiseSyn.NumSyn + j]
NPtr0[i].gfONnoise[j] = 0.0
NPtr0[i].gfOFFnoise[j] = 0.0
NoiseStep = 1 / (NoiseDistr.shape[1] - 1)
SynExpOn = [0] * NumSynType
SynExpOff = [0] * NumSynType
NumEvt = EvtTimes.shape[1]
NumView = ViewList.shape[0] * ViewList.shape[1]
fpOut = open("IDN_%i.dat" % UniquePrint, "w")
fpOut2 = open("IDN2_%i.dat" % UniquePrint, "w")
if (CtrPar[0][3] > NumVar):
CtrPar[0][3] = NumVar
NumOutp = 2
if (NumOutp > 0):
NumSynInp = [0] * N
N_osc = [0] * N
TnextSyn = [0] * N
t0 = Tstart
time_num = 0
while (t0 < Tstop):
if (t0 >= t_display):
print("%f percent" % (t0 * 100 / Tstop))
t_display = t0 + 100
t1 = t0 + dt0
EvtNo = -999
if (t1 > Tstop):
t1 = Tstop
for i in range(NumEvt):
if (EvtTimes[i * 2] > t0) and (EvtTimes[i * 2] <= t1):
t1 = EvtTimes[i * 2]
NextEvtT = t1
EvtNo = i * 2
else:
EvtOffT = EvtTimes[i * 2] + EvtTimes[i * 2 + 1]
if (EvtOffT > t0 and EvtOffT <= t1):
t1 = EvtOffT
NextEvtT = t1
EvtNo = i * 2 + 1
t11 = t1
for i in range(M):
if (InpNPtr0[i].SPtrain[InpNPtr0[i].SP_ind] > t0) and (
InpNPtr0[i].SPtrain[InpNPtr0[i].SP_ind] <= t11):
t11 = InpNPtr0[i].SPtrain[InpNPtr0[i].SP_ind]
print_flag = 1
else:
print_flag = 0
t1 = t11
for i in range(M):
if (InpNPtr0[i].SPtrain[InpNPtr0[i].SP_ind] == t1):
if (InpNPtr0[i].SP_ind > 0):
ISI_inp = t1 - \
InpNPtr0[i].SPtrain[InpNPtr0[i].SP_ind - 1]
else:
ISI_inp = 10.0e8
InpNPtr0[i].SpikeTimes[InpNPtr0[i].SP_ind] = InpNPtr0[i].SPtrain[InpNPtr0[i].SP_ind]
InpNPtr0[i].SP_ind = InpNPtr0[i].SP_ind + 1
j = NumSpike[i + N] % SizeHistOutput
for kk in range(InpNPtr0[i].NumPreSyn):
if (InpNPtr0[i].SDf[kk].use > 0.0):
InpNPtr0[i].SDf[kk].Adepr[j] = short_time(
SizeHistOutput).syndepr(InpNPtr0[i].SDf[kk], ISI_inp, j)
if (WriteST > 0):
fpISI = open(
"ISIu%d_%i.dat" %
(i + N, UniquePrint), "a")
fpISI.write("%f\n" % t1)
fpISI.close()
NumSpike[i + N] = NumSpike[i + N] + 1
for i in range(N):
t0_i = t0
if (t0_i < t1):
t1_i = t1
if (TnextSyn[i] > t0_i and TnextSyn[i] < t1_i):
t1_i = TnextSyn[i]
vp = copy.copy(NPtr0[i].v[0])
wp = copy.copy(NPtr0[i].v[1])
dt = t1_i - t0_i
if (NumSpike[i] > 0):
if ((t0_i -
NPtr0[i].SpikeTimes[(NumSpike[i] -
1) %
SizeHistOutput]) < 5):
flag_dv = 0
else:
flag_dv = 1
else:
flag_dv = 1
NPtr0[i] = copy.copy(NPtr0[i])
try:
NPtr0[i], gsyn_AN, gsyn_G, I_tot = short_time(
SizeHistOutput).update(
NPtr0[i], dt, NoiseSyn, flag_dv)
except OverflowError:
pass
if (stop_flag > 0):
print("%f %d %f %f\n" % (t0_i, i, vp, wp))
for j in range(NPtr0[i].NumSynType):
if (NPtr0[i].gfONsyn[j] <
0 or NPtr0[i].gfOFFsyn[j] < 0):
print(
"%d %d %f %f %f %f\n" %
(i, j, t0_i, t1_i, NPtr0[i].gfONsyn[j], NPtr0[i].gfOFFsyn[j]))
if (t1_i == t1):
gsyn1[i] = gsyn_AN
gsyn2[i] = gsyn_G
Isyn[i] = I_tot
if (I_tot < NPtr0[i].I_ref *
1.01 and I_tot > NPtr0[i].I_ref * 0.99):
flag_osc[i] = flag_osc[i] + 1
else:
flag_osc[i] = 0
if (flag_osc[i] >= 200 and NumOutp > 0):
N_osc[i] = N_osc[i] + 1
if ((NPtr0[i].v[0] >= NPtr0[i].Vup)
and (vp < NPtr0[i].Vup)):
t1_i = t0_i + dt * \
(NPtr0[i].Vup - vp) / (NPtr0[i].v[0] - vp)
if (NumSpike[i] > 0):
ISI = t1_i - \
NPtr0[i].SpikeTimes[(NumSpike[i] - 1) % SizeHistOutput]
else:
ISI = 10.0e8
if (ISI > 5):
w_Vup = wp + \
((NPtr0[i].v[1] - wp) / dt) * (t1_i - t0_i)
NPtr0[i].v[0] = NPtr0[i].Vr
NPtr0[i].v[1] = w_Vup + NPtr0[i].b
j = NumSpike[i] % SizeHistOutput
NPtr0[i].SpikeTimes[j] = t1_i
for kk in range(NPtr0[i].NumPreSyn):
if (NPtr0[i].SDf[kk].use > 0.0):
NPtr0[i].SDf[kk].Adepr[j] = short_time(
SizeHistOutput).syndepr(NPtr0[i].SDf[kk], ISI, j)
if (WriteST > 0):
fpISI = open(
"ISIu%d_%i.dat" %
(i, UniquePrint), "a")
fpISI.write("%f\n" % t1_i)
fpISI.close()
NumSpike[i] = NumSpike[i] + 1
dt = t1_i - t0_i
else:
NPtr0[i].v[0] = vp
NPtr0[i].v[1] = wp
# reset t1_i
t1_i = dt + t0_i
if (t1_i == t1):
gsyn_AN, I_tot, gsyn_G = short_time(
SizeHistOutput).set_gsyn(NPtr0[i], dt, vp, NoiseSyn)
gsyn1[i] = gsyn_AN
gsyn2[i] = gsyn_G
Isyn[i] = I_tot
for j in range(NoiseSyn.NumSyn):
SynExpOn[j] = math.exp(-dt /
(NoiseSyn.Syn[j].STPtr).tc_on)
SynExpOff[j] = math.exp(-dt /
(NoiseSyn.Syn[j].STPtr).tc_off)
rand_num = NoiseDistr[0][rand.randint(
0, 1 / NoiseStep)]
NPtr0[i].gfONnoise[j] = 0.0
NPtr0[i].gfOFFnoise[j] = 0.0
for j in range(NPtr0[i].NumSynType):
NPtr0[i].gfONsyn[j] *= math.exp(-dt /
(NPtr0[i].STList[j]).tc_on)
NPtr0[i].gfOFFsyn[j] *= math.exp(-dt /
(NPtr0[i].STList[j]).tc_off)
TnextSyn[i] = Tstop + 100.0
for j in range(N):
for k in range(ConMtx0[i][j].NumSyn):
kk = NumSpike[j] - 1
while (
kk >= 0 and (
NumSpike[j] -
kk) <= SizeHistOutput):
if (t0_i >= (
NPtr0[j].SpikeTimes[kk % SizeHistOutput] + ConMtx0[i][j].Syn[k].dtax)):
break
else:
if ((t1_i >= NPtr0[j].SpikeTimes[kk % SizeHistOutput] + ConMtx0[i][j].Syn[
k].dtax) and (
rand.uniform(0, 1) > ConMtx0[i][j].Syn[k].p_fail)):
for k2 in range(NPtr0[i].NumSynType):
if (NPtr0[i].STList[k2].No ==
ConMtx0[i][j].Syn[k].STPtr.No):
Aall = NPtr0[j].SDf[ConMtx0[i][j].Syn[k].PreSynIdx].Adepr[
kk % SizeHistOutput] * ConMtx0[i][j].Syn[k].wgt * \
ConMtx0[i][j].Syn[k].STPtr.Gsyn
NPtr0[i].gfONsyn[k2] += Aall
NPtr0[i].gfOFFsyn[k2] += Aall
if (NumOutp > 0):
NumSynInp[i] = NumSynInp[i] + 1.0
else:
if (NPtr0[j].SpikeTimes[kk % SizeHistOutput] +
ConMtx0[i][j].Syn[k].dtax < TnextSyn[i]):
TnextSyn[i] = NPtr0[j].SpikeTimes[kk %
SizeHistOutput] + ConMtx0[i][j].Syn[k].dtax
kk = kk - 1
for j in range(N, N + M):
for k in range(ConMtx0[i][j].NumSyn):
kk = NumSpike[j] - 1
while (kk >= 0):
if (t0_i >= (
InpNPtr0[j - N].SpikeTimes[kk] + ConMtx0[i][j].Syn[k].dtax)):
break
else:
if ((t1_i >= InpNPtr0[j - N].SpikeTimes[kk] + ConMtx0[i][j].Syn[k].dtax) and (
rand.uniform(0, 1) > ConMtx0[i][j].Syn[k].p_fail)):
for k2 in range(NPtr0[i].NumSynType):
if (NPtr0[i].STList[k2] ==
ConMtx0[i][j].Syn[k].STPtr):
Aall = InpNPtr0[j - N].SDf[ConMtx0[i][j].Syn[k].PreSynIdx].Adepr[kk %
SizeHistOutput] * ConMtx0[i][j].Syn[k].wgt * (ConMtx0[i][j].Syn[k].STPtr).Gsyn
NPtr0[i].gfONsyn[k2] += Aall
NPtr0[i].gfOFFsyn[k2] += Aall
if (NumOutp > 0):
NumSynInp[i] = NumSynInp[i] + 1.0
else:
if (InpNPtr0[j - N].SpikeTimes[kk] + \
ConMtx0[i][j].Syn[k].dtax < TnextSyn[i]):
TnextSyn[i] = InpNPtr0[j - N].SpikeTimes[kk] + \
ConMtx0[i][j].Syn[k].dtax
kk = kk - 1
t0_i = t1_i
for i in range(NumView):
fpOut.write("%lf %d" % (t1, ViewList[i][0]))
for k in range(int(CtrPar[0][3])):
fpOut.write(" %lf" % NPtr0[ViewList[i][0] - 1].v[k])
fpOut.write("\n")
for i in range(NumView):
fpOut2.write(" %f %f %f" %
(gsyn1[ViewList[i][0] -
1], gsyn2[ViewList[i][0] -
1], Isyn[ViewList[i][0] -
1]))
fpOut2.write("\n")
if (EvtNo >= 0 and t1 >= NextEvtT):
if ((EvtNo % 2) == 0):
for i in range(N):
NPtr0[i].Iinj = EvtMtx[int(i + (EvtNo / 2) * N)]
else:
for i in range(N):
NPtr0[i].Iinj = 0.0
t0 = t1
if (NumView > 0):
fpOut.close()
fpOut2.close()
STMtx = []
if (CtrPar[0][4] > 0):
for i in range(N + M):
if (os.path.exists('ISIu%d_0.dat' % i)):
ST = pd.read_table('ISIu%d_0.dat' % i, header=None)
content = []
for j in range(ST.shape[0]):
content.append(ST.iloc[j][0])
STMtx.append(content)
os.remove('ISIu%d_0.dat' % i)
else:
STMtx.append(-1)
T = []
V = []
if (ViewList is not None):
X = pd.read_table('IDN_%d.dat' % UniqueNum, header=None)
for i in range(X.shape[0]):
T.append(X.iloc[i][0])
content = []
for j in range(X.shape[1] - 1):
content.append(X.iloc[i][j + 1])
V.append(content)
os.remove('IDN_%d.dat' % UniqueNum)
os.remove('IDN2_0.dat')
scio.savemat('PFC_%dN_500ms.mat' %
(N + M), {'N': N, 'T': T, 'V': V, 'STMtx': STMtx})
class SynType:
def __init__(self):
"""
Parameters of short-term synaptic plasticity model
"""
self.No = 0
self.gmax = 0
self.tc_on = 0
self.tc_off = 0
self.Erev = 0
self.Mg_gate = 0
self.Mg_fac = 0
self.Mg_slope = 0
self.Mg_half = 0
self.Gsyn = 0
class Neuron:
"""
Parameters of neurons
"""
gfONsyn = None
gfOFFsyn = None
gfONnoise = None
gfOFFnoise = None
SpikeTimes = None
v = None
dv = None
def __init__(self):
MaxNumSTperN = six_layer_pfc().MaxNumSTperN
SizeHistOutput = six_layer_pfc().SizeHistOutput
self.Cm = 0
self.gL = 0
self.EL = 0
self.sf = 0
self.Vup = 0
self.tcw = 0
self.a = 0
self.b = 0
self.Vr = 0
self.Vth = 0
self.I_ref = 0
self.v_dep = 0
self.NumSynType = 0
self.Iinj = 0
self.v = [0] * 2
self.dv = [0] * 2
self.STList = []
self.gfONsyn = [0] * MaxNumSTperN
self.gfOFFsyn = [0] * MaxNumSTperN
self.gfONnoise = [0] * MaxNumSTperN
self.gfOFFnoise = [0] * MaxNumSTperN
self.SpikeTimes = [0] * SizeHistOutput
self.NumPreSyn = 0
self.PreSynList = []
self.SDf = []
class InpNeuron:
"""
Input parameters of neurons
"""
SPtrain = None
SpikeTimes = None
def __init__(self):
SizeHistInput = six_layer_pfc().SizeHistInput
self.SPtrain = [0] * SizeHistInput
self.SpikeTimes = [0] * SizeHistInput
self.SP_ind = 0
self.NumSynType = 0
self.NumPreSyn = 0
self.PreSynList = []
self.SDf = []
class Synapse:
"""
Synaptic parameters
"""
def __init__(self):
self.STPtr = SynType()
self.dtax = 0
self.wgt = 0
self.p_fail = 0
self.PreSynIdx = 0
class SynDepr:
"""
Parameters of synaptic current model
"""
Adepr = None
uprev = None
Rprev = None
def __init__(self):
SizeHistOutput = six_layer_pfc().SizeHistOutput
self.use = 0
self.tc_rec = 0
self.tc_fac = 0
self.Adepr = [0] * SizeHistOutput
self.uprev = [0] * SizeHistOutput
self.Rprev = [0] * SizeHistOutput
class SynList:
"""
Parameters of synapse list
"""
def __init__(self):
self.NumSyn = 0
self.Syn = []
if __name__ == '__main__':
"""
After downloading the data file on the network disk, modify the file path to the downloaded placement path
"""
test = six_layer_pfc()
inputpath = 'data100.mat'
test.mex_function(inputpath)
outputpath='PFC_99N_500ms.mat'
test.picture(outputpath)
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/MacaqueBrain/README.md
================================================
## Macaque Brain Simulation
## Description
Macaque Brain Simulation is a large scale brain modeling framework depending on braincog framework.
## Requirements:
* numpy >= 1.21.2
* scipy >= 1.8.0
* h5py >= 3.6.0
* torch >= 1.10
* torchvision >= 0.12.0
* torchaudio >= 0.11.0
* timm >= 0.5.4
* matplotlib >= 3.5.1
* einops >= 0.4.1
* thop >= 0.0.31
* pyyaml >= 6.0
* loris >= 0.5.3
* pandas >= 1.4.2
* tonic (special)
* pandas >= 1.4.2
## Input:
The binary connectivity matrix can be obtained from the following link:
https://drive.google.com/file/d/1LsNupIx3Nk-Cn_MowF6O-SCY27wdRYas/view?usp=sharing
The brain region's name can be obained from the following link:
https://drive.google.com/file/d/1iNI0HR3teUj4yshK8RlSJq6gIWbRdBI1/view?usp=sharing
## Example:
```shell
cd ~/examples/Multi-scale Brain Structure Simulation/MacaqueBrain/
python macaque_brain.py
```
## Parameters:
The parameters are similar to mouse brain simulation
## Citations:
If you find this package helpful, please consider citing the following papers:
@article{Liu2016,
author={Liu, Xin and Zeng, Yi and Zhang, Tielin and Xu, Bo},
title={Parallel Brain Simulator: A Multi-scale and Parallel Brain-Inspired Neural Network Modeling and Simulation Platform},
journal={Cognitive Computation},
year={2016},
month={Oct},
day={01},
volume={8},
number={5},
pages={967--981},
issn={1866-9964},
doi={10.1007/s12559-016-9411-y},
url={https://doi.org/10.1007/s12559-016-9411-y}
}
@misc{https://doi.org/10.48550/arxiv.2207.08533,
doi = {10.48550/ARXIV.2207.08533},
url = {https://arxiv.org/abs/2207.08533},
author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},
title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},
publisher = {arXiv},
year = {2022},
}
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/MacaqueBrain/macaque_brain.py
================================================
import time
import numpy as np
import scipy.io as scio
import torch
from torch import nn
from braincog.base.node.node import *
from braincog.base.brainarea.BrainArea import *
import pandas as pd
import matplotlib.pyplot as plt
device = 'cuda:0'
class Syn(nn.Module):
def __init__(self, syn, weight, neuron_num, tao_d, tao_r, dt, device):
super().__init__()
self.pre = syn[1]
self.post = syn[0]
self.syn_num = len(syn)
self.w = torch.sparse_coo_tensor(syn.t(), weight,
size=(neuron_num, neuron_num))
self.tao_d = tao_d
self.tao_r = tao_r
self.dt = dt
self.lamda_d = self.dt / self.tao_d
self.lamda_r = self.dt / self.tao_r
self.s = torch.zeros(neuron_num, device=device)
self.r = torch.zeros(neuron_num, device=device)
self.dt = dt
def forward(self, neuron):
neuron.Iback = neuron.Iback + neuron.dt_over_tau * (
torch.randn(neuron.neuron_num, device=device, requires_grad=False) - neuron.Iback)
neuron.Ieff = neuron.Iback / neuron.sqrt_coeff * neuron.sig + neuron.mu
self.s = self.s + self.lamda_r * (-self.s + 1 / self.tao_d * neuron.spike)
self.r = self.r - self.lamda_d * self.r + self.dt * self.s
self.I = torch.sparse.mm(self.w, self.r.unsqueeze(-1)).squeeze() + neuron.Ieff
return self.I
class brain(nn.Module):
def __init__(self, syn, weight, neuron_model, p_neuron, dt, device):
super().__init__()
if neuron_model == 'HH':
self.neurons = HHNode(p_neuron, dt, device)
elif neuron_model == 'aEIF':
self.neurons = aEIF(p_neuron, dt, device)
self.neuron_num = len(p_neuron[0])
self.syns = Syn(syn, weight, self.neuron_num, 3, 6, dt, device)
def forward(self, inputs):
I = self.syns(self.neurons)
self.neurons(I)
def brain_region(neuron_num):
region = []
start = 0
end = 0
for i in range(len(neuron_num)):
end += neuron_num[i].item()
region.append([start, end])
start = end
return torch.tensor(region)
def neuron_type(neuron_num, ratio, regions):
neuron_num = neuron_num.reshape(-1, 1)
neuron_type = torch.floor(ratio * neuron_num).int() + regions[:, 0].reshape(-1, 1)
return neuron_type
def syn_within_region(syn_num, region):
start = 1
for neurons in region:
if start:
syn = torch.randint(neurons[0], neurons[1],
size=((neurons[1]-neurons[0]) * syn_num, 2), device=device)
start = 0
else:
syn = torch.concatenate((syn, torch.randint(neurons[0], neurons[1],
size=((neurons[1]-neurons[0]) * syn_num, 2), device=device)))
return syn
def syn_cross_region(weight_matrix, region):
start = 1
for i in range(len(weight_matrix)):
for j in range(len(weight_matrix)):
if weight_matrix[i][j] < 10:
continue
else:
pre = torch.randint(region[j][0], region[j][1],
size=(weight_matrix[i][j], 1), device=device)
post = torch.randint(region[i][0], region[i][1],
size=(weight_matrix[i][j], 1), device=device)
if start:
syn = torch.concatenate((post, pre), dim=1)
start = 0
else:
syn = torch.concatenate((syn, torch.concatenate((post, pre), dim=1)))
return syn
size = 10000
neuron_model = 'aEIF'
weight_matrix = torch.tensor(scio.loadmat('./maque.mat')['connect']) * 100
weight_matrix = weight_matrix.int()
syn_num = 10
NR = len(weight_matrix)
data = size * np.ones(NR)
neuron_num = np.array(data).astype(np.int32)
neuron_num = torch.from_numpy(neuron_num)
regions = brain_region(neuron_num)
ratio = torch.tensor([[0.7, 0.9, 1.0] * NR]).reshape(NR, 3)
neuron_types = neuron_type(neuron_num, ratio, regions)
syn_1 = syn_within_region(syn_num, regions)
syn_2 = syn_cross_region(weight_matrix, regions)
syn = torch.concatenate((syn_1, syn_2))
print(syn.shape)
weight = -torch.ones(len(syn), device=device, requires_grad=False)
if neuron_model == 'aEIF':
threshold = torch.zeros(regions[-1][1], device=device, requires_grad=False)
v_reset = torch.zeros(regions[-1][1], device=device, requires_grad=False)
c_m = torch.zeros(regions[-1][1], device=device, requires_grad=False)
tao_w = torch.zeros(regions[-1][1], device=device, requires_grad=False)
alpha_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)
beta_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)
elif neuron_model == 'HH':
threshold = torch.zeros(regions[-1][1], device=device, requires_grad=False)
for i in range(len(neuron_types)):
pre = syn[:, 0]
mask = (pre >= regions[i][0]) & (pre < neuron_types[i][0])
indices = torch.where(mask)
weight[indices] = 1.5
if neuron_model == 'aEIF':
if i < 177:
threshold[regions[i][0]:neuron_types[i][0]] = -50
threshold[neuron_types[i][0]:neuron_types[i][1]] = -44
threshold[neuron_types[i][1]:neuron_types[i][2]] = -45
v_reset[regions[i][0]:neuron_types[i][0]] = -110
v_reset[neuron_types[i][0]:neuron_types[i][1]] = -110
v_reset[neuron_types[i][1]:neuron_types[i][2]] = -66
c_m[regions[i][0]:neuron_types[i][0]] = 10
c_m[neuron_types[i][0]:neuron_types[i][1]] = 10
c_m[neuron_types[i][1]:neuron_types[i][2]] = 8.5
tao_w[regions[i][0]:neuron_types[i][0]] = 1
tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2
tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2
alpha_ad[regions[i][0]:neuron_types[i][0]] = 0
alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2
alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2
beta_ad[regions[i][0]:neuron_types[i][0]] = 0
beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45
beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45
else:
threshold[regions[i][0]:neuron_types[i][0]] = -50
threshold[neuron_types[i][0]:neuron_types[i][1]] = -50
threshold[neuron_types[i][1]:neuron_types[i][2]] = -45
v_reset[regions[i][0]:neuron_types[i][0]] = -60
v_reset[neuron_types[i][0]:neuron_types[i][1]] = -60
v_reset[neuron_types[i][1]:neuron_types[i][2]] = -65
c_m[regions[i][0]:neuron_types[i][0]] = 20
c_m[neuron_types[i][0]:neuron_types[i][1]] = 2
c_m[neuron_types[i][1]:neuron_types[i][2]] = 4
tao_w[regions[i][0]:neuron_types[i][0]] = 1
tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2
tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2
alpha_ad[regions[i][0]:neuron_types[i][0]] = 0
alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2
alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2
beta_ad[regions[i][0]:neuron_types[i][0]] = 0
beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45
beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45
elif neuron_model == 'HH':
threshold[regions[i][0]:neuron_types[i][0]] = 50
threshold[neuron_types[i][0]:neuron_types[i][1]] = 60
threshold[neuron_types[i][1]:neuron_types[i][2]] = 60
if neuron_model == 'aEIF':
p_neuron = [threshold, v_reset, c_m, tao_w, alpha_ad, beta_ad]
dt = 1
T = 300
elif neuron_model == 'HH':
p_neuron = [threshold, 120, 36, 0.3, 115, -12, 10.6, 1]
dt = 0.01
T = 10000
model = brain(syn, weight, neuron_model, p_neuron, dt, device)
Iraster = []
for t in range(T):
model(0)
print(torch.sum(model.neurons.spike))
Isp = torch.nonzero(model.neurons.spike)
print(len(Isp))
if (len(Isp) != 0):
left = t * torch.ones((len(Isp)), device=device, requires_grad=False)
left = left.reshape(len(left), 1)
mide = torch.concatenate((left, Isp), dim=1)
if (len(Isp) != 0) and (len(Iraster) != 0):
Iraster = torch.concatenate((Iraster, mide), dim=0)
if (len(Iraster) == 0) and (len(Isp) != 0):
Iraster = mide
Iraster = torch.tensor(Iraster).transpose(0, 1)
torch.save(Iraster, "./maque.pt")
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/MouseBrain/README.md
================================================
## Input:
* Program to enter a table of connection weights between 213 brain regions, which is saved in 'mouse_weight.pt'. The brain regions' name and neuron number are saved in 'mouse_brain_region.xlsx'. These two files are available in the follow link:
https://drive.google.com/drive/folders/1MWHY52gKPGneBEJxJN9DzE7thnLrhG1j?usp=sharing
## output
* The program generates a data file of the individual neuron firing time points recorded, and the large number of data points requires the use of drawing software to display the results.
## setting:
* scale: The scale of the number of neurons
* neuron_model: ‘HHNode’ or ‘aEIF’
* weight_matrix: Matrix of the number of synaptic connections between brain regions
* neuron_num: The number of neurons in each brain region
* ratio: the ratio of each neuron type in each brain region
* syn_num: average number of synapses per neuron within region
================================================
FILE: examples/Multiscale_Brain_Structure_Simulation/MouseBrain/mouse_brain.py
================================================
import time
import numpy as np
import scipy.io as scio
import torch
from torch import nn
from braincog.base.node.node import *
from braincog.base.brainarea.BrainArea import *
import pandas as pd
import matplotlib.pyplot as plt
device = 'cuda:0'
class Syn(nn.Module):
def __init__(self, syn, weight, neuron_num, tao_d, tao_r, dt, device):
super().__init__()
self.pre = syn[1]
self.post = syn[0]
self.syn_num = len(syn)
self.w = torch.sparse_coo_tensor(syn.t(), weight,
size=(neuron_num, neuron_num))
self.tao_d = tao_d
self.tao_r = tao_r
self.dt = dt
self.lamda_d = self.dt / self.tao_d
self.lamda_r = self.dt / self.tao_r
self.s = torch.zeros(neuron_num, device=device)
self.r = torch.zeros(neuron_num, device=device)
self.dt = dt
def forward(self, neuron):
neuron.Iback = neuron.Iback + neuron.dt_over_tau * (
torch.randn(neuron.neuron_num, device=device, requires_grad=False) - neuron.Iback)
neuron.Ieff = neuron.Iback / neuron.sqrt_coeff * neuron.sig + neuron.mu
self.s = self.s + self.lamda_r * (-self.s + 1 / self.tao_d * neuron.spike)
self.r = self.r - self.lamda_d * self.r + self.dt * self.s
self.I = torch.sparse.mm(self.w, self.r.unsqueeze(-1)).squeeze() + neuron.Ieff
return self.I
class brain(nn.Module):
def __init__(self, syn, weight, neuron_model, p_neuron, dt, device):
super().__init__()
if neuron_model == 'HH':
self.neurons = HHNode(p_neuron, dt, device)
elif neuron_model == 'aEIF':
self.neurons = aEIF(p_neuron, dt, device)
self.neuron_num = len(p_neuron[0])
self.syns = Syn(syn, weight, self.neuron_num, 3, 6, dt, device)
def forward(self, inputs):
I = self.syns(self.neurons)
self.neurons(I)
def brain_region(neuron_num):
region = []
start = 0
end = 0
for i in range(len(neuron_num)):
end += neuron_num[i].item()
region.append([start, end])
start = end
return torch.tensor(region)
def neuron_type(neuron_num, ratio, regions):
neuron_num = neuron_num.reshape(-1, 1)
neuron_type = torch.floor(ratio * neuron_num).int() + regions[:, 0].reshape(-1, 1)
return neuron_type
def syn_within_region(syn_num, region):
start = 1
for neurons in region:
if start:
syn = torch.randint(neurons[0], neurons[1],
size=((neurons[1]-neurons[0]) * syn_num, 2), device=device)
start = 0
else:
syn = torch.concatenate((syn, torch.randint(neurons[0], neurons[1],
size=((neurons[1]-neurons[0]) * syn_num, 2), device=device)))
return syn
def syn_cross_region(weight_matrix, region):
start = 1
for i in range(len(weight_matrix)):
for j in range(len(weight_matrix)):
if weight_matrix[i][j] < 10:
continue
else:
pre = torch.randint(region[j][0], region[j][1],
size=(weight_matrix[i][j], 1), device=device)
post = torch.randint(region[i][0], region[i][1],
size=(weight_matrix[i][j], 1), device=device)
if start:
syn = torch.concatenate((post, pre), dim=1)
start = 0
else:
syn = torch.concatenate((syn, torch.concatenate((post, pre), dim=1)))
return syn
scale = 0.1
neuron_model = 'aEIF'
weight_matrix = torch.load('./mouse_weight.pt') * scale
weight_matrix = weight_matrix.int()
data = pd.read_excel('./mouse_brain_region.xlsx', sheet_name='Sheet1', header=None)
data = data.values
name = data[0]
neuron_num = np.array(data[1] * scale).astype(np.int32)
neuron_num = torch.from_numpy(neuron_num)
ratio = torch.tensor([[0.7, 0.9, 1.0] * 213]).reshape(213, 3)
syn_num = 10
regions = brain_region(neuron_num)
neuron_types = neuron_type(neuron_num, ratio, regions)
syn_1 = syn_within_region(syn_num, regions)
syn_2 = syn_cross_region(weight_matrix, regions)
syn = torch.concatenate((syn_1, syn_2))
print(syn.shape)
weight = -torch.ones(len(syn), device=device, requires_grad=False)
if neuron_model == 'aEIF':
threshold = torch.zeros(regions[-1][1], device=device, requires_grad=False)
v_reset = torch.zeros(regions[-1][1], device=device, requires_grad=False)
c_m = torch.zeros(regions[-1][1], device=device, requires_grad=False)
tao_w = torch.zeros(regions[-1][1], device=device, requires_grad=False)
alpha_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)
beta_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)
elif neuron_model == 'HH':
threshold = torch.zeros(regions[-1][1], device=device, requires_grad=False)
for i in range(len(neuron_types)):
pre = syn[:, 0]
mask = (pre >= regions[i][0]) & (pre < neuron_types[i][0])
indices = torch.where(mask)
weight[indices] = 1.5
if neuron_model == 'aEIF':
if i < 177:
threshold[regions[i][0]:neuron_types[i][0]] = -50
threshold[neuron_types[i][0]:neuron_types[i][1]] = -44
threshold[neuron_types[i][1]:neuron_types[i][2]] = -45
v_reset[regions[i][0]:neuron_types[i][0]] = -110
v_reset[neuron_types[i][0]:neuron_types[i][1]] = -110
v_reset[neuron_types[i][1]:neuron_types[i][2]] = -66
c_m[regions[i][0]:neuron_types[i][0]] = 10
c_m[neuron_types[i][0]:neuron_types[i][1]] = 10
c_m[neuron_types[i][1]:neuron_types[i][2]] = 8.5
tao_w[regions[i][0]:neuron_types[i][0]] = 1
tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2
tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2
alpha_ad[regions[i][0]:neuron_types[i][0]] = 0
alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2
alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2
beta_ad[regions[i][0]:neuron_types[i][0]] = 0
beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45
beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45
else:
threshold[regions[i][0]:neuron_types[i][0]] = -50
threshold[neuron_types[i][0]:neuron_types[i][1]] = -50
threshold[neuron_types[i][1]:neuron_types[i][2]] = -45
v_reset[regions[i][0]:neuron_types[i][0]] = -60
v_reset[neuron_types[i][0]:neuron_types[i][1]] = -60
v_reset[neuron_types[i][1]:neuron_types[i][2]] = -65
c_m[regions[i][0]:neuron_types[i][0]] = 20
c_m[neuron_types[i][0]:neuron_types[i][1]] = 2
c_m[neuron_types[i][1]:neuron_types[i][2]] = 4
tao_w[regions[i][0]:neuron_types[i][0]] = 1
tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2
tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2
alpha_ad[regions[i][0]:neuron_types[i][0]] = 0
alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2
alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2
beta_ad[regions[i][0]:neuron_types[i][0]] = 0
beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45
beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45
elif neuron_model == 'HH':
threshold[regions[i][0]:neuron_types[i][0]] = 50
threshold[neuron_types[i][0]:neuron_types[i][1]] = 60
threshold[neuron_types[i][1]:neuron_types[i][2]] = 60
if neuron_model == 'aEIF':
p_neuron = [threshold, v_reset, c_m, tao_w, alpha_ad, beta_ad]
dt = 1
T = 300
elif neuron_model == 'HH':
p_neuron = [threshold, 120, 36, 0.3, 115, -12, 10.6, 1]
dt = 0.01
T = 10000
model = brain(syn, weight, neuron_model, p_neuron, dt, device)
Iraster = []
for t in range(T):
model(0)
print(torch.sum(model.neurons.spike))
Isp = torch.nonzero(model.neurons.spike)
print(len(Isp))
if (len(Isp) != 0):
left = t * torch.ones((len(Isp)), device=device, requires_grad=False)
left = left.reshape(len(left), 1)
mide = torch.concatenate((left, Isp), dim=1)
if (len(Isp) != 0) and (len(Iraster) != 0):
Iraster = torch.concatenate((Iraster, mide), dim=0)
if (len(Iraster) == 0) and (len(Isp) != 0):
Iraster = mide
Iraster = torch.tensor(Iraster).transpose(0, 1)
torch.save(Iraster, "./mouse.pt")
================================================
FILE: examples/Perception_and_Learning/Conversion/burst_conversion/CIFAR10_VGG16.py
================================================
import sys
sys.path.append('../../..')
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
import time
from braincog.utils import setup_seed
from braincog.datasets.datasets import get_cifar10_data
device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')
DATA_DIR = '/data/datasets'
class VGG16(nn.Module):
def __init__(self, relu_max=1): # 1 3e38
super(VGG16, self).__init__()
cnn = nn.Sequential(
nn.Conv2d(3, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(True),
nn.Conv2d(64, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(True),
nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),
nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),
nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
nn.MaxPool2d(2, 2))
self.conv = cnn
self.fc = nn.Linear(512, 10, bias=True)
def forward(self, input):
conv = self.conv(input)
x = conv.view(conv.shape[0], -1)
output = self.fc(x)
return output
def get_cifar10_loader(batch_size, train_batch=None, num_workers=4, conversion=False, distributed=False):
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
CIFAR10Policy(),
transforms.ToTensor(),
Cutout(n_holes=1, length=16),
normalize])
transform_test = transforms.Compose([transforms.ToTensor(), normalize])
train_batch = batch_size if train_batch is None else train_batch
cifar10_train = datasets.CIFAR10(root=DATA_DIR, train=True, download=False, transform=transform_test if conversion else transform_train)
cifar10_test = datasets.CIFAR10(root=DATA_DIR, train=False, download=False, transform=transform_test)
if distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(cifar10_train)
val_sampler = torch.utils.data.distributed.DistributedSampler(cifar10_test, shuffle=False, drop_last=True)
train_iter = torch.utils.data.DataLoader(cifar10_train, batch_size=train_batch, shuffle=False, num_workers=num_workers, pin_memory=True, sampler=train_sampler)
test_iter = torch.utils.data.DataLoader(cifar10_test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, sampler=val_sampler)
else:
train_iter = torch.utils.data.DataLoader(cifar10_train, batch_size=train_batch, shuffle=True, num_workers=num_workers, pin_memory=True)
test_iter = torch.utils.data.DataLoader(cifar10_test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
return train_iter, test_iter
def train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs, losstype='mse'):
best = 0
net = net.to(device)
print("training on ", device)
if losstype == 'mse':
loss = torch.nn.MSELoss()
else:
loss = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
losses = []
for epoch in range(num_epochs):
for param_group in optimizer.param_groups:
learning_rate = param_group['lr']
losss = []
train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
for X, y in train_iter:
X = X.to(device)
y = y.to(device)
y_hat = net(X)
label = y
if losstype == 'mse':
label = F.one_hot(y, 10).float()
l = loss(y_hat, label)
losss.append(l.cpu().item())
optimizer.zero_grad()
l.backward()
optimizer.step()
train_l_sum += l.cpu().item()
train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
n += y.shape[0]
batch_count += 1
scheduler.step()
test_acc = evaluate_accuracy(test_iter, net)
losses.append(np.mean(losss))
print('epoch %d, lr %.6f, loss %.6f, train acc %.6f, test acc %.6f, time %.1f sec'
% (epoch + 1, learning_rate, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
if test_acc > best:
best = test_acc
torch.save(net.state_dict(), './CIFAR10_VGG16.pth')
def evaluate_accuracy(data_iter, net, device=None, only_onebatch=False):
if device is None and isinstance(net, torch.nn.Module):
device = list(net.parameters())[0].device
acc_sum, n = 0.0, 0
with torch.no_grad():
for X, y in data_iter:
net.eval()
acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
net.train()
n += y.shape[0]
if only_onebatch: break
return acc_sum / n
if __name__ == '__main__':
setup_seed(42)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
batch_size = 128
train_iter, test_iter, _, _ = get_cifar10_data(batch_size)
# train_iter, test_iter = get_cifar10_loader(batch_size)
print('dataloader finished')
lr, num_epochs = 0.05, 300
net = VGG16()
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=num_epochs)
train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs, losstype='crossentropy')
net.load_state_dict(torch.load("./CIFAR10_VGG16.pth", map_location=device))
net = net.to(device)
acc = evaluate_accuracy(test_iter, net, device)
print(acc)
================================================
FILE: examples/Perception_and_Learning/Conversion/burst_conversion/README.md
================================================
# Conversion Method
Training deep spiking neural network with ann-snn conversion
replace ReLU and MaxPooling in pytorch model to make origin ANN to be converted SNN to finish complex tasks
## Results
```shell
python CIFAR10_VGG16.py
python converted_CIFAR10.py
```
You should first run the `CIFAR10_VGG16.py` to get a well-trained ANN.
Then `converted_CIFAR10.py` can be used to run the snn inference process.
### Citation
If you find this package helpful, please consider citing it:
```BibTex
@inproceedings{ijcai2022p345,
title = {Efficient and Accurate Conversion of Spiking Neural Network with Burst Spikes},
author = {Li, Yang and Zeng, Yi},
booktitle = {Proceedings of the Thirty-First International Joint Conference on
Artificial Intelligence, {IJCAI-22}},
publisher = {International Joint Conferences on Artificial Intelligence Organization},
pages = {2485--2491},
year = {2022},
month = {7},
}
@article{li2022spike,
title={Spike calibration: Fast and accurate conversion of spiking neural network for object detection and segmentation},
author={Li, Yang and He, Xiang and Dong, Yiting and Kong, Qingqun and Zeng, Yi},
journal={arXiv preprint arXiv:2207.02702},
year={2022}
}
```
================================================
FILE: examples/Perception_and_Learning/Conversion/burst_conversion/converted_CIFAR10.py
================================================
import sys
sys.path.append('../../..')
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib
matplotlib.use('agg')
import numpy as np
from tqdm import tqdm
from copy import deepcopy
import matplotlib.pyplot as plt
import time
import os
from examples.Perception_and_Learning.Conversion.burst_conversion.CIFAR10_VGG16 import VGG16
from braincog.utils import setup_seed
from braincog.datasets.datasets import get_cifar10_data
from braincog.base.conversion import Convertor
import argparse
parser = argparse.ArgumentParser(description='Conversion')
parser.add_argument('--T', default=64, type=int, help='simulation time')
parser.add_argument('--p', default=0.99, type=float, help='percentile for data normalization. 0-1')
parser.add_argument('--gamma', default=5, type=int, help='burst spike and max spikes IF can emit')
parser.add_argument('--channelnorm', default=False, type=bool, help='use channel norm')
parser.add_argument('--lipool', default=True, type=bool, help='LIPooling')
parser.add_argument('--smode', default=True, type=bool, help='replace ReLU to IF')
parser.add_argument('--soft_mode', default=True, type=bool, help='soft reset or not')
parser.add_argument('--device', default='4', type=str, help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--cuda', default=True, type=bool, help='use cuda.')
parser.add_argument('--model_name', default='vgg16', type=str, help='model name. vgg16 or resnet20')
parser.add_argument('--merge', default=True, type=bool, help='merge conv and bn')
parser.add_argument('--train_batch', default=100, type=int, help='batch size for get max')
parser.add_argument('--batch_num', default=1, type=int, help='number of train batch')
parser.add_argument('--spicalib', default=0, type=int, help='allowance for spicalib')
parser.add_argument('--batch_size', default=128, type=int, help='batch size for testing')
parser.add_argument('--seed', default=42, type=int, help='seed')
args = parser.parse_args()
def evaluate_snn(test_iter, snn, device=None, duration=50):
accs = []
snn.eval()
for ind, (test_x, test_y) in tqdm(enumerate(test_iter)):
test_x = test_x.to(device)
test_y = test_y.to(device)
n = test_y.shape[0]
out = 0
with torch.no_grad():
snn.reset()
acc = []
# for t in tqdm(range(duration)):
for t in range(duration):
out += snn(test_x)
result = torch.max(out, 1).indices
result = result.to(device)
acc_sum = (result == test_y).float().sum().item()
acc.append(acc_sum / n)
accs.append(np.array(acc))
accs = np.array(accs).mean(axis=0)
i, show_step = 1, []
while 2 ** i <= duration:
show_step.append(2 ** i - 1)
i = i + 1
for iii in show_step:
print("timestep", str(iii).zfill(3) + ':', accs[iii])
print("best acc: ", max(accs))
if __name__ == '__main__':
print("Setting Arguments.. : ", args)
print("----------------------------------------------------------")
setup_seed(seed=args.seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
device = torch.device("cuda:%s" % args.device) if args.cuda else 'cpu'
train_iter, _, _, _ = get_cifar10_data(args.train_batch, same_da=True)
_, test_iter, _, _ = get_cifar10_data(args.batch_size, same_da=True)
if args.model_name == 'vgg16':
net = VGG16()
net.load_state_dict(torch.load("./CIFAR10_VGG16.pth", map_location=device))
net.eval()
net = net.to(device)
converter = Convertor(dataloader=train_iter,
device=device,
p=args.p,
channelnorm=args.channelnorm,
lipool=args.lipool,
gamma=args.gamma,
soft_mode=args.soft_mode,
merge=args.merge,
batch_num=args.batch_num,
spicalib=args.spicalib
)
snn = converter(net)
evaluate_snn(test_iter, snn, device, duration=args.T)
================================================
FILE: examples/Perception_and_Learning/Conversion/msat_conversion/CIFAR10_VGG16.py
================================================
import sys
sys.path.append('../../..')
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
import time
from braincog.utils import setup_seed
from braincog.datasets.datasets import get_cifar10_data
device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')
DATA_DIR = '/data/datasets'
class VGG16(nn.Module):
def __init__(self, relu_max=1): # 1 3e38
super(VGG16, self).__init__()
cnn = nn.Sequential(
nn.Conv2d(3, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(True),
nn.Conv2d(64, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(True),
nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),
nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),
nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
nn.MaxPool2d(2, 2))
self.conv = cnn
self.fc = nn.Linear(512, 10, bias=True)
def forward(self, input):
conv = self.conv(input)
x = conv.view(conv.shape[0], -1)
output = self.fc(x)
return output
def get_cifar10_loader(batch_size, train_batch=None, num_workers=4, conversion=False, distributed=False):
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
CIFAR10Policy(),
transforms.ToTensor(),
Cutout(n_holes=1, length=16),
normalize])
transform_test = transforms.Compose([transforms.ToTensor(), normalize])
train_batch = batch_size if train_batch is None else train_batch
cifar10_train = datasets.CIFAR10(root=DATA_DIR, train=True, download=False, transform=transform_test if conversion else transform_train)
cifar10_test = datasets.CIFAR10(root=DATA_DIR, train=False, download=False, transform=transform_test)
if distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(cifar10_train)
val_sampler = torch.utils.data.distributed.DistributedSampler(cifar10_test, shuffle=False, drop_last=True)
train_iter = torch.utils.data.DataLoader(cifar10_train, batch_size=train_batch, shuffle=False, num_workers=num_workers, pin_memory=True, sampler=train_sampler)
test_iter = torch.utils.data.DataLoader(cifar10_test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, sampler=val_sampler)
else:
train_iter = torch.utils.data.DataLoader(cifar10_train, batch_size=train_batch, shuffle=True, num_workers=num_workers, pin_memory=True)
test_iter = torch.utils.data.DataLoader(cifar10_test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
return train_iter, test_iter
def train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs, losstype='mse'):
best = 0
net = net.to(device)
print("training on ", device)
if losstype == 'mse':
loss = torch.nn.MSELoss()
else:
loss = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
losses = []
for epoch in range(num_epochs):
for param_group in optimizer.param_groups:
learning_rate = param_group['lr']
losss = []
train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
for X, y in train_iter:
X = X.to(device)
y = y.to(device)
y_hat = net(X)
label = y
if losstype == 'mse':
label = F.one_hot(y, 10).float()
l = loss(y_hat, label)
losss.append(l.cpu().item())
optimizer.zero_grad()
l.backward()
optimizer.step()
train_l_sum += l.cpu().item()
train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
n += y.shape[0]
batch_count += 1
scheduler.step()
test_acc = evaluate_accuracy(test_iter, net)
losses.append(np.mean(losss))
print('epoch %d, lr %.6f, loss %.6f, train acc %.6f, test acc %.6f, time %.1f sec'
% (epoch + 1, learning_rate, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
if test_acc > best:
best = test_acc
torch.save(net.state_dict(), './CIFAR10_VGG16.pth')
def evaluate_accuracy(data_iter, net, device=None, only_onebatch=False):
if device is None and isinstance(net, torch.nn.Module):
device = list(net.parameters())[0].device
acc_sum, n = 0.0, 0
with torch.no_grad():
for X, y in data_iter:
net.eval()
acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
net.train()
n += y.shape[0]
if only_onebatch: break
return acc_sum / n
if __name__ == '__main__':
setup_seed(42)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
batch_size = 128
train_iter, test_iter, _, _ = get_cifar10_data(batch_size)
# train_iter, test_iter = get_cifar10_loader(batch_size)
print('dataloader finished')
lr, num_epochs = 0.05, 300
net = VGG16()
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=num_epochs)
train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs, losstype='crossentropy')
net.load_state_dict(torch.load("./CIFAR10_VGG16.pth", map_location=device))
net = net.to(device)
acc = evaluate_accuracy(test_iter, net, device)
print(acc)
================================================
FILE: examples/Perception_and_Learning/Conversion/msat_conversion/README.md
================================================
# Script for experiment
```shell
python converted_CIFAR10.py --useDET --useDTT
```
## Note: Convertor
please use `convertor.py` here to replace and override `braincog/base/conversion/convertor.py` if you want to run multi-stage adaptive threshold in this project.
## Citation
If you find the code and dataset useful in your research, please consider citing:
```
@article{he2023improving,
title={Improving the Performance of Spiking Neural Networks on Event-based Datasets with Knowledge Transfer},
author={He, Xiang and Zhao, Dongcheng and Li, Yang and Shen, Guobin and Kong, Qingqun and Zeng, Yi},
journal={arXiv preprint arXiv:2303.13077},
year={2023}
}
@misc{https://doi.org/10.48550/arxiv.2207.08533,
doi = {10.48550/ARXIV.2207.08533},
url = {https://arxiv.org/abs/2207.08533},
author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},
title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},
publisher = {arXiv},
year = {2022},
}
```
## Contents
If you are confused about using it or have other feedback and comments, please feel free to contact us via [hexiang2021@ia.ac.cn](hexiang2021@ia.ac.cn).
================================================
FILE: examples/Perception_and_Learning/Conversion/msat_conversion/converted_CIFAR10.py
================================================
# -*- coding: utf-8 -*-
# Time : 2023/4/19 15:56
# Author : Regulus
# FileName: converted_CIFAR10.py
# Explain:
# Software: PyCharm
import sys
sys.path.append('..')
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
# import matplotlib
# matplotlib.use('agg')
import numpy as np
from tqdm import tqdm
from braincog.utils import setup_seed
import os
from examples.Perception_and_Learning.Conversion.msat_conversion.CIFAR10_VGG16 import VGG16
import argparse
from braincog.datasets.datasets import get_cifar10_data
from braincog.base.conversion import Convertor, FolderPath
parser = argparse.ArgumentParser(description='Conversion')
parser.add_argument('--T', default=256, type=int, help='simulation time')
parser.add_argument('--p', default=1, type=float, help='percentile for data normalization. 0-1')
parser.add_argument('--gamma', default=1, type=int, help='burst spike and max spikes IF can emit')
parser.add_argument('--lateral_inhi', default=True, type=bool, help='LIPooling')
parser.add_argument('--data_norm', default=True, type=bool, help=' whether use data norm or not')
parser.add_argument('--smode', default=True, type=bool, help='replace ReLU to IF')
parser.add_argument('--device', default='7', type=str, help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--cuda', default=True, type=bool, help='use cuda.')
parser.add_argument('--model_name', default='vgg16', type=str, help='model name. vgg16 or resnet20')
parser.add_argument('--train_batch', default=512, type=int, help='batch size for get max')
parser.add_argument('--batch_size', default=128, type=int, help='batch size for testing')
parser.add_argument('--seed', default=23, type=int, help='seed')
parser.add_argument('--useDET', action='store_true', default=False, help='use DET')
parser.add_argument('--useDTT', action='store_true', default=False, help='use DTT')
parser.add_argument('--useSC', action='store_true', default=False, help='use SpikeConfidence')
args = parser.parse_args()
def evaluate_snn(test_iter, snn, device=None, duration=50):
folder_path = "./result_conversion_{}/snn_timestep{}_p{}_LIPooling{}_Burst{}".format(
args.model_name, duration, args.p, args.lateral_inhi, args.gamma)
if not os.path.exists(folder_path): # 判断是否存在文件夹如果不存在则创建为文件夹
os.makedirs(folder_path)
snn.eval()
FolderPath.folder_path = folder_path
accs = []
for ind, (test_x, test_y) in enumerate(tqdm(test_iter)):
test_x = test_x.to(device)
test_y = test_y.to(device)
n = test_y.shape[0]
out = 0
with torch.no_grad():
snn.reset()
acc = []
for t in range(duration):
out += snn(test_x)
result = torch.max(out, 1).indices
result = result.to(device)
acc_sum = (result == test_y).float().sum().item()
acc.append(acc_sum / n)
# break
accs.append(np.array(acc))
if True:
f = open('{}/result.txt'.format(folder_path), 'w')
f.write("Setting Arguments.. : {}\n".format(args))
accs = np.array(accs).mean(axis=0)
for iii in range(256):
if iii == 0 or iii == 3 or iii == 7 or (iii + 1) % 16 == 0:
f.write("timestep {}:{}\n".format(str(iii+1).zfill(3), accs[iii]))
f.write("max accs: {}, timestep:{}\n".format(max(accs), np.where(accs == max(accs))))
f.close()
accs = torch.from_numpy(accs)
torch.save(accs, "{}/accs.pth".format(folder_path))
if __name__ == '__main__':
print("Setting Arguments.. : ", args)
print("----------------------------------------------------------")
setup_seed(seed=args.seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
device = torch.device("cuda:%s" % args.device) if args.cuda else 'cpu'
train_iter, _, _, _ = get_cifar10_data(args.train_batch, same_da=True)
_, test_iter, _, _ = get_cifar10_data(args.batch_size, same_da=True)
if args.model_name == 'vgg16':
net = VGG16()
# net.load_state_dict(torch.load("./CIFAR10_VGG16.pth", map_location=device))
net.eval()
net = net.to(device)
converter = Convertor(dataloader=train_iter,
device=device,
p=1.0,
channelnorm=False,
lipool=True,
gamma=1,
soft_mode=True,
merge=True,
batch_num=1,
spicalib=False,
useDET=args.useDET,
useDTT=args.useDTT,
useSC=args.useSC
)
snn = converter(net)
evaluate_snn(test_iter, snn, device, duration=args.T)
================================================
FILE: examples/Perception_and_Learning/Conversion/msat_conversion/convertor.py
================================================
import torch
import torch.nn as nn
from braincog.base.connection.layer import SMaxPool, LIPool
from .merge import mergeConvBN
from .spicalib import SpiCalib
import types
import os
import sys
layer_index = 0 # layer index for SNode
class FolderPath:
folder_path = "init"
class HookScale(nn.Module):
""" 在每个ReLU层后记录该层的百分位最大值
For channelnorm: 获取最大值时使用了torch.quantile
For layernorm: 使用sort,然后手动取百分比,因为quantile在计算单个通道时有上限,batch较大时易出错
"""
def __init__(self,
p: float = 0.9995,
channelnorm: bool = False,
gamma: float = 0.999,
):
super().__init__()
if channelnorm:
self.register_buffer('scale', torch.tensor(0.0))
else:
self.register_buffer('scale', torch.tensor(0.0))
self.p = p
self.channelnorm = channelnorm
self.gamma = gamma
def forward(self, x):
x = torch.where(x.detach() < self.gamma, x.detach(),
torch.tensor(self.gamma, dtype=x.dtype, device=x.device))
if len(x.shape) == 4 and self.channelnorm:
num_channel = x.shape[1]
tmp = torch.quantile(x.permute(1, 0, 2, 3).reshape(num_channel, -1), self.p, dim=1,
interpolation='lower') + 1e-10
self.scale = torch.max(tmp, self.scale)
else:
sort, _ = torch.sort(x.view(-1))
self.scale = torch.max(sort[int(sort.shape[0] * self.p) - 1], self.scale)
return x
class Hookoutput(nn.Module):
"""
在伪转换中为ReLU和ClipQuan提供包装,用于监控其输出
"""
def __init__(self, module):
super(Hookoutput, self).__init__()
self.activation = 0.
self.operation = module
def forward(self, x):
output = self.operation(x)
self.activation = output.detach()
return output
class Scale(nn.Module):
"""
对前向过程的值进行缩放
"""
def __init__(self, scale: float = 1.0):
super().__init__()
self.register_buffer('scale', scale)
def forward(self, x):
if len(self.scale.shape) == 1:
return self.scale.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x
else:
return self.scale * x
def reset(self):
"""
转换的网络来自ANN,需要将新附加上的脉冲module进行reset
判断module名称并调用各自节点的reset方法
"""
children = list(self.named_children())
for i, (name, child) in enumerate(children):
if isinstance(child, (SNode, LIPool, SMaxPool)):
child.reset()
else:
reset(child)
class Convertor(nn.Module):
"""ANN2SNN转换器
用于转换完整的pytorch模型,使用dataloader中部分数据进行最大值计算,通过p控制获取第p百分比最大值
channlenorm: https://arxiv.org/abs/1903.06530
channelnorm可以对每个通道获取最大值并进行权重归一化
gamma: https://arxiv.org/abs/2204.13271
gamma可以控制burst spikes的脉冲数,burst spike可以提高神经元的脉冲发放能力,减小信息残留
lipool: https://arxiv.org/abs/2204.13271
lipool用于使用侧向抑制机制进行最大池化,LIPooling能够对SNN中的最大池化进行有效的转换
soft_mode: https://arxiv.org/abs/1612.04052
soft_mode被称为软重置,可以减小重置过程神经元的信息损失,有效提高转换的性能
merge用于是否对网络中相邻的卷积和BN层进行融合
batch_norm控制对dataloader的数据集的用量
"""
def __init__(self,
dataloader,
device=None,
p=0.9995,
channelnorm=False,
lipool=True,
gamma=1,
soft_mode=True,
merge=True,
batch_num=1,
spicalib=0,
useDET=False, useDTT=False, useSC=None
):
super(Convertor, self).__init__()
self.dataloader = dataloader
self.device = device
self.p = p
self.channelnorm = channelnorm
self.lipool = lipool
self.gamma = gamma
self.soft_mode = soft_mode
self.merge = merge
self.batch_num = batch_num
self.spicalib = spicalib
self.useDET = useDET
self.useDTT = useDTT
self.useSC = useSC
def forward(self, model):
model.eval()
model = Convertor.register_hook(model, self.p, self.channelnorm, self.gamma)
model = Convertor.get_percentile(model, self.dataloader, self.device, batch_num=self.batch_num)
model = mergeConvBN(model) if self.merge else model
model = Convertor.replace_for_spike(model, self.lipool, self.soft_mode, self.gamma, self.spicalib, self.useDET,
self.useDTT, self.useSC)
model.reset = types.MethodType(reset, model)
return model
@staticmethod
def register_hook(model, p=0.99, channelnorm=False, gamma=0.999):
""" Reference: https://github.com/fangwei123456/spikingjelly
将网络的每一层后注册一个HookScale类
该方法在仿真上等效于与对权重进行归一化操作,且易扩展到任意结构的网络中
"""
children = list(model.named_children())
for _, (name, child) in enumerate(children):
if isinstance(child, nn.ReLU):
model._modules[name] = nn.Sequential(nn.ReLU(), HookScale(p, channelnorm, gamma))
else:
Convertor.register_hook(child, p, channelnorm, gamma)
return model
@staticmethod
def get_percentile(model, dataloader, device, batch_num=1):
"""
该函数需与具有HookScale层的网络配合使用
"""
for idx, (data, _) in enumerate(dataloader):
data = data.to(device)
if idx >= batch_num:
break
model(data)
return model
@staticmethod
def replace_for_spike(model, lipool=True, soft_mode=True, gamma=1, spicalib=0, useDET=False, useDTT=False, useSC=None):
"""
该函数用于将定义好的ANN模型转换为SNN模型
ReLU单元将被替换为脉冲神经元,
如果模型中使用了最大池化,lipool参数将定义使用常规模型还是LIPooling方法
"""
children = list(model.named_children())
for _, (name, child) in enumerate(children):
if isinstance(child, nn.Sequential) and len(child) == 2 and isinstance(child[0], nn.ReLU) and isinstance(child[1], HookScale):
global layer_index
model._modules[name] = nn.Sequential(
Scale(1.0 / child[1].scale),
SNode(soft_mode, gamma, useDET=useDET, useDTT=useDTT, useSC=useSC, layer_index=layer_index),
SpiCalib(spicalib),
Scale(child[1].scale)
)
layer_index += 1
if isinstance(child, nn.MaxPool2d):
model._modules[name] = LIPool(child) if lipool else SMaxPool(child)
else:
Convertor.replace_for_spike(child, lipool, soft_mode, gamma, useDET=useDET, useDTT=useDTT, useSC=useSC)
return model
class SNode(nn.Module):
"""
用于转换后的SNN的神经元模型
IF神经元模型由gamma=1确定,当gamma为其他大于1的值时,即为使用burst神经元模型
soft_mode用于定义神经元的重置方法,soft重置能够极大地减少神经元在重置过程的信息损失
"""
def __init__(self, soft_mode=False, gamma=5, useDET=False, useDTT=False, useSC=None, layer_index=1):
super(SNode, self).__init__()
self.threshold = 1.0
self.maxThreshold = 1.0
self.soft_mode = soft_mode
self.gamma = gamma
self.mem = 0
self.spike = 0
self.Vm = 0.
self.summem = 0.
self.t = 0
self.all_spike = 0
self.V_T = 0
self.useDET = useDET
self.useDTT = useDTT
self.useSC = useSC
self.layer_index = layer_index
# hyperparameters
self.alpha = 0
self.ka = 0
self.ki = 0
self.C = 0
self.tau_mp = 0
self.tau_rd = 0
# record sin
self.mem_16 = 0.0
self.spike_mask = 0
self.sin_spikenum = 0.0
self.sin_ratio = [] # snn中sin占负的比例
self.last_spike = 0
self.confidence = []
self.neg_ratio = [] # ann中负的占总的比例
self.sin_all_ratio = [] # snn中sin占总的比例
self.pos_all_ratio = [] # snn中pos占总的比例
self.should_all_ratio = [] # snn中pos应该发但是没发占总的比例
self.confidence = [] # snn中sin占所有发的比例
self.avg_error_spikenum = [] # snn中错发的平均个数
def forward(self, x):
self.mem = self.mem + x
self.spike = torch.zeros_like(x)
if self.t == 0:
self.threshold = torch.full(x.shape, 1.0 * self.maxThreshold).to(x.device)
self.V_T = -torch.full(x.shape, self.maxThreshold).to(x.device)
# init hyperparameters
hp = []
path = FolderPath.folder_path.split('/')
path = os.path.join(path[0], path[1], path[2], 'hyperparameters.txt')
with open(path, 'r') as f:
data = f.readlines() # 将txt中所有字符串读入data
for ind, line in enumerate(data):
numbers = line.split() # 将数据分隔
hp.append(list(map(float, numbers))[0]) # 转化为浮点数
self.alpha = hp[0]
self.ka = hp[1]
self.ki = hp[2]
self.C = hp[3]
self.tau_mp = hp[4]
self.tau_rd = hp[5]
else:
DTT = self.tau_mp * (self.alpha * (self.last_mem - self.Vm) + self.V_T + self.ka * torch.log(
1 + torch.exp((self.last_mem - self.Vm) / self.ki)))
DET = self.tau_rd * torch.exp(-1 * x / self.C)
if self.useDET is True and self.useDTT is True:
self.threshold = DET + DTT
elif self.useDTT is True:
self.threshold = DTT
elif self.useDET is True:
self.threshold = DET
else:
print("wrong logics")
sys.exit()
self.threshold = torch.sigmoid(self.threshold)
self.threshold *= self.maxThreshold
self.spike = (self.mem / self.threshold).floor().clamp(min=0, max=self.gamma)
self.soft_reset() if self.soft_mode else self.hard_reset
if self.useSC is True:
if self.t < 16:
# read confidence
path = FolderPath.folder_path.split('/')
path = os.path.join(path[0], path[1], path[2], 'neuron_confidence_vgg16.txt')
with open(path, 'r') as f:
data = f.readlines() # 将txt中所有字符串读入data
for ind, line in enumerate(data):
numbers = line.split() # 将数据分隔
self.confidence.append(list(map(float, numbers))[0]) # 转化为浮点数
mask = (torch.rand(x.shape) >= (1.0 - self.confidence[self.layer_index])).float().cuda()
self.spike = self.spike * mask # random drop
self.all_spike += self.spike
out = self.spike * self.threshold
self.t += 1
self.last_mem = self.mem
self.summem += self.mem
self.Vm = (self.summem / self.t)
return out
def hard_reset(self):
"""
硬重置后神经元的膜电势被重置为0
"""
self.mem = self.mem * (1 - self.spike.detach())
def soft_reset(self):
"""
软重置后神经元的膜电势为神经元当前膜电势减去阈值
"""
self.mem = self.mem - self.threshold * self.spike.detach()
def reset(self):
self.mem = 0
self.spike = 0
self.maxThreshold = 1.0
================================================
FILE: examples/Perception_and_Learning/IllusionPerception/AbuttingGratingIllusion/distortion/__init__.py
================================================
from .abutting_grating_illusion import ag_distort_28, ag_distort_224, ag_distort_silhouette, save_image, get_silhouette_data
================================================
FILE: examples/Perception_and_Learning/IllusionPerception/AbuttingGratingIllusion/distortion/abutting_grating_illusion/__init__.py
================================================
from .abutting_grating_distortion import ag_distort_28, ag_distort_224, ag_distort_silhouette, save_image, get_silhouette_data
================================================
FILE: examples/Perception_and_Learning/IllusionPerception/AbuttingGratingIllusion/distortion/abutting_grating_illusion/abutting_grating_distortion.py
================================================
import torch
from torchvision import datasets, transforms
import time
import os
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms
import time
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pickle
from PIL import Image
import numpy as np
from torchvision import utils
import os
seed = 1000
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def save_image(image, filename):
assert len(image.shape) == 3, "The image must have only three dimensions of C,W,H."
utils.save_image(image, filename)
def get_mnist_data(train = False, batch_size = 100):
path = './datasets/' # might need to change based on where to call this function
#transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
transform = transforms.Compose([transforms.ToTensor()])
if train:
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(path, train=True, download=False, transform=transform),
batch_size=batch_size, shuffle=False)
return train_loader
else:
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(path, train=False, download=False, transform=transform),
batch_size=batch_size, shuffle=False)
return test_loader
def get_silhouette_data(path):
'''
path: dir path to the silhouette image samples of 16-clas-ImageNet
'''
labels = os.listdir(path)
datasets = []
#transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
transform = transforms.Compose([transforms.ToTensor()])
for label in labels:
for img_name in os.listdir(f"{path}/{label}"):
img_path = f"{path}/{label}/{img_name}"
img = Image.open(img_path)
img = transform(img).unsqueeze(0)
datasets.append((img, label))
return datasets
def ag_distort_28(imgs, threshold=0, interval=4, phase=2, direction=(1,0)):
#return imgs
assert len(imgs.shape) == 4, "The images must have four dimensions of B,C,W,H."
B,C,W,H = imgs.shape
mask_fg = (imgs>threshold).float()
mask_bg = 1 - mask_fg
gratings_fg = torch.zeros_like(imgs)
gratings_bg = torch.zeros_like(imgs)
for w in range(W):
for h in range(H):
if (direction[0]*w+direction[1]*h)%interval==0:
gratings_fg[:,:,w,h] = 1
if (direction[0]*w+direction[1]*h)%interval==phase:
gratings_bg[:,:,w,h] = 1
masked_gratings_fg = mask_fg*gratings_fg
masked_gratings_bg = mask_bg*gratings_bg
ag_image = masked_gratings_fg + masked_gratings_bg
return ag_image
def transform_224(imgs):
imgs = torch.nn.functional.interpolate(imgs, scale_factor = 8, mode = 'bilinear', align_corners = False)
imgs = torch.cat([imgs, imgs, imgs], dim=1)
return imgs
def ag_distort_224(imgs, threshold=0, interval=8, phase=4, direction=(1,0)):
assert len(imgs.shape) == 4, "The images must have four dimensions of C,W,H."
imgs = torch.nn.functional.interpolate(imgs, scale_factor = 8, mode = 'bilinear', align_corners = False)
imgs = torch.cat([imgs, imgs, imgs], dim=1)
#return imgs
B,C,W,H = imgs.shape
mask_fg = (imgs>threshold).float()
mask_bg = 1 - mask_fg
gratings_fg = torch.zeros_like(imgs)
gratings_bg = torch.zeros_like(imgs)
for w in range(W):
for h in range(H):
if (direction[0]*w+direction[1]*h)%interval==0:
gratings_fg[:,:,w,h] = 1
if (direction[0]*w+direction[1]*h)%interval==phase:
gratings_bg[:,:,w,h] = 1
masked_gratings_fg = mask_fg*gratings_fg
masked_gratings_bg = mask_bg*gratings_bg
ag_image = masked_gratings_fg + masked_gratings_bg
return ag_image
def ag_distort_silhouette(imgs, threshold=0.5, interval=2, phase=1, direction=(1,0)):
assert len(imgs.shape) == 4, "The image must have only three dimensions of C,W,H."
#imgs = torch.nn.functional.interpolate(imgs, scale_factor = 2, mode = 'bilinear', align_corners = False)
B,C,W,H = imgs.shape
mask_fg = (imgs 1
if args.distributed and args.num_gpu > 1:
_logger.warning(
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
args.num_gpu = 1
# args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
if args.distributed:
args.num_gpu = 1
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
else:
torch.cuda.set_device('cuda:%d' % args.device)
assert args.rank >= 0
if args.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
_logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
# torch.manual_seed(args.seed + args.rank)
setup_seed(args.seed + args.rank)
genotype = eval('genotypes.%s' % args.arch)
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
dataset=args.dataset,
step=args.step,
encode_type=args.encode,
node_type=eval(args.node_type),
threshold=args.threshold,
tau=args.tau,
sigmoid_thres=args.sigmoid_thres,
requires_thres_grad=args.requires_thres_grad,
spike_output=not args.no_spike_output,
C=args.init_channels,
layers=args.layers,
auxiliary=args.auxiliary,
genotype=genotype,
parse_method=args.parse_method,
back_connection=args.back_connection,
act_fun=args.act_fun,
temporal_flatten=args.temporal_flatten,
layer_by_layer=args.layer_by_layer,
n_groups=args.n_groups,
)
if 'dvs' in args.dataset:
args.channels = 2
elif 'mnist' in args.dataset:
args.channels = 1
else:
args.channels = 3
# flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)
# _logger.info('flops = %fM', flops / 1e6)
# _logger.info('param size = %fM', params / 1e6)
linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0
args.lr = linear_scaled_lr
_logger.info("learning rate is %f" % linear_scaled_lr)
if args.local_rank == 0:
_logger.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
num_aug_splits = 0
if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
use_amp = None
if args.amp:
# for backwards compat, `--amp` arg tries apex before native amp
if has_apex:
args.apex_amp = True
elif has_native_amp:
args.native_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
if args.num_gpu > 1:
if use_amp == 'apex':
_logger.warning(
'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
use_amp = None
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
assert not args.channels_last, "Channels last not supported with DP, use DDP."
else:
model = model.cuda()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
optimizer = create_optimizer(args, model)
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
loss_scaler = ApexScaler()
if args.local_rank == 0:
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
amp_autocast = torch.cuda.amp.autocast
loss_scaler = NativeScaler()
if args.local_rank == 0:
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if args.local_rank == 0:
_logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint
resume_epoch = None
if args.resume and args.eval_checkpoint == '':
args.eval_checkpoint = args.resume
if args.resume:
args.eval = True
# checkpoint = torch.load(args.resume, map_location='cpu')
# model.load_state_dict(checkpoint['state_dict'], False)
resume_epoch = resume_checkpoint(
model, args.resume,
optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0)
# print(model.get_attr('mu'))
# print(model.get_attr('sigma'))
# if args.critical_loss or args.spike_rate:
model.set_requires_fp(True)
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume=args.resume)
if args.node_resume:
ckpt = torch.load(args.node_resume, map_location='cpu')
model.load_node_weight(ckpt, args.node_trainable)
model_without_ddp = model
if args.distributed:
if args.sync_bn:
assert not args.split_bn
try:
if has_apex and use_amp != 'native':
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
except Exception as e:
_logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
if has_apex and use_amp != 'native':
# Apex DDP preferred unless native amp is activated
if args.local_rank == 0:
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[args.local_rank],
find_unused_parameters=True) # can use device str in Torch >= 1.1
model_without_ddp = model.module
# NOTE: EMA model does not need to be wrapped by DDP
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
_logger.info('Scheduled epochs: {}'.format(num_epochs))
# now config only for imnet
data_config = resolve_data_config(vars(args), model=model, verbose=False)
loader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.dataset)(
batch_size=args.batch_size,
step=args.step,
args=args,
_logge=_logger,
data_config=data_config,
num_aug_splits=num_aug_splits,
size=args.event_size,
mix_up=args.mix_up,
cut_mix=args.cut_mix,
event_mix=args.event_mix,
beta=args.cutmix_beta,
prob=args.cutmix_prob,
num=args.cutmix_num,
noise=args.cutmix_noise,
num_classes=args.num_classes,
rand_aug=args.rand_aug,
randaug_n=args.randaug_n,
randaug_m=args.randaug_m,
temporal_flatten=args.temporal_flatten,
portion=args.train_portion,
_logger=_logger,
)
if args.loss_fn == 'mse':
train_loss_fn = UnilateralMse(1.)
validate_loss_fn = UnilateralMse(1.)
else:
if args.jsd:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
elif mixup_active:
# smoothing is handled with mixup target transform
train_loss_fn = SoftTargetCrossEntropy().cuda()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
if args.loss_fn == 'mix':
train_loss_fn = MixLoss(train_loss_fn)
validate_loss_fn = MixLoss(validate_loss_fn)
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
if args.eval: # evaluate the model
if args.distributed:
state_dict = torch.load(args.eval_checkpoint)['state_dict_ema']
new_state_dict = OrderedDict()
# add module prefix for DDP
for k, v in state_dict.items():
k = 'module.' + k
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
# else:
# load_checkpoint(model, args.eval_checkpoint, args.model_ema)
for i in range(1):
val_metrics = validate(start_epoch, model, loader_eval, validate_loss_fn, args,
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat)
print(f"Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%")
# return
saver = None
if args.local_rank == 0:
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
try: # train the model
if args.reset_drop:
model_without_ddp.reset_drop_path(0.0)
for epoch in range(start_epoch, args.epochs):
if epoch == 0 and args.reset_drop:
model_without_ddp.reset_drop_path(args.drop_path)
if args.distributed:
loader_train.sampler.set_epoch(epoch)
train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
_logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(epoch, model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast,
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
epoch, model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)',
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat)
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None)
# if saver is not None and epoch >= args.n_warm_up:
if saver is not None:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
except KeyboardInterrupt:
pass
if best_metric is not None:
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def train_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
closses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.train()
# t, k = adjust_surrogate_coeff(100, args.epochs)
# model.set_attr('t', t)
# model.set_attr('k', k)
end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
for batch_idx, (inputs, target) in enumerate(loader):
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
if not args.prefetcher or args.dataset != 'imnet':
inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()
if mixup_fn is not None:
inputs, target = mixup_fn(inputs, target)
if args.channels_last:
inputs = inputs.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(inputs)
loss = loss_fn(output, target)
if not (args.cut_mix | args.mix_up | args.event_mix) and args.dataset != 'imnet':
# print(output.shape, target.shape)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
# acc1, = accuracy(output, target)
else:
acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)
else:
loss.backward(create_graph=second_order)
if args.noisy_grad != 0.:
random_gradient(model, args.noisy_grad)
if args.clip_grad is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
if args.opt == 'lamb':
optimizer.step(epoch=epoch)
else:
optimizer.step()
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
num_updates += 1
batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
mu_str = ''
sigma_str = ''
if not args.distributed:
if 'Noise' in args.node_type:
mu, sigma = model.get_noise_param()
mu_str = ['{:.3f}'.format(i.detach()) for i in mu]
sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item(), inputs.size(0))
closses_m.update(reduced_loss.item(), inputs.size(0))
if args.local_rank == 0:
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'cLoss: {closs.val:>9.6f} ({closs.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
closs=closses_m,
top1=top1_m,
top5=top5_m,
batch_time=batch_time_m,
rate=inputs.size(0) * args.world_size / batch_time_m.val,
rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m
))
if args.save_images and output_dir:
torchvision.utils.save_image(
inputs,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True)
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
end = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.avg)])
def validate(epoch, model, loader, loss_fn, args, amp_autocast=suppress,
log_suffix='', visualize=False, spike_rate=False, tsne=False, conf_mat=False):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
closses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.eval()
feature_vec = []
feature_cls = []
logits_vec = []
labels_vec = []
end = time.time()
last_idx = len(loader) - 1
with torch.no_grad():
for batch_idx, (inputs, target) in enumerate(loader):
# inputs = inputs.type(torch.float64)
last_batch = batch_idx == last_idx
if not args.prefetcher or args.dataset != 'imnet':
inputs = inputs.type(torch.FloatTensor).cuda()
target = target.cuda()
if args.channels_last:
inputs = inputs.contiguous(memory_format=torch.channels_last)
if not args.distributed:
if (visualize or spike_rate or tsne or conf_mat) and not args.critical_loss:
model.set_requires_fp(True)
# if not args.critical_loss:
# model.set_requires_fp(False)
with amp_autocast():
output = model(inputs)
if isinstance(output, (tuple, list)):
output = output[0]
if not args.distributed:
if visualize:
x = model.get_fp()
feature_path = os.path.join(args.output_dir, 'feature_map')
if os.path.exists(feature_path) is False:
os.mkdir(feature_path)
save_feature_map(x, feature_path)
# if not args.critical_loss:
# model_config.set_requires_fp(False)
# augmentation reduction
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor]
loss = loss_fn(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
# acc1, = accuracy(output, target)
tot_spike = model.get_tot_spike() if hasattr(model, 'get_tot_spike') else 0.
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
acc1 = reduce_tensor(acc1, args.world_size)
acc5 = reduce_tensor(acc5, args.world_size)
else:
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), inputs.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
batch_time_m.update(time.time() - end)
end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
mu_str = ''
sigma_str = ''
if not args.distributed:
if 'Noise' in args.node_type:
mu, sigma = model.get_noise_param()
mu_str = ['{:.3f}'.format(i.detach()) for i in mu]
sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]
_logger.info(
'{0}: [{1:>4d}/{2}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'cLoss: {closs.val:>7.4f} ({closs.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) '
'TotSpike: {tot_spike}'.format(
log_name,
batch_idx,
last_idx,
batch_time=batch_time_m,
loss=losses_m,
closs=closses_m,
top1=top1_m,
top5=top5_m,
tot_spike=tot_spike
))
# metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg)])
if not args.distributed:
if tsne:
feature_vec = torch.cat(feature_vec)
feature_cls = torch.cat(feature_cls)
plot_tsne(feature_vec, feature_cls, os.path.join(args.output_dir, 't-sne-2d.eps'))
plot_tsne_3d(feature_vec, feature_cls, os.path.join(args.output_dir, 't-sne-3d.eps'))
if conf_mat:
logits_vec = torch.cat(logits_vec)
labels_vec = torch.cat(labels_vec)
plot_confusion_matrix(logits_vec, labels_vec, os.path.join(args.output_dir, 'confusion_matrix.eps'))
return metrics
if __name__ == '__main__':
main()
================================================
FILE: examples/Perception_and_Learning/NeuEvo/separate_loss.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from braincog.base.utils.criterions import UnilateralMse
from utils import num_ops, type_num, edge_num
__all__ = ['ConvSeparateLoss', 'TriSeparateLoss']
class MseSeparateLoss(nn.modules.loss._Loss):
def __init__(self, weight=0.1, size_average=None, ignore_index=-100,
reduce=None, reduction='mean'):
super(MseSeparateLoss, self).__init__(size_average, reduce, reduction)
self.ignore_index = ignore_index
self.weight = weight
self.criterion = UnilateralMse(1.)
def forward(self, input1, target1, input2):
loss1 = self.criterion(input1, target1)
loss2 = -F.mse_loss(input2, torch.tensor(0.5,
requires_grad=False).cuda())
return loss1 + self.weight * loss2, loss1.item(), loss2.item()
class ConvSeparateLoss(nn.modules.loss._Loss):
"""Separate the weight value between each operations using L2"""
def __init__(self, loss1_fn, weight=0.1, size_average=None, ignore_index=-100,
reduce=None, reduction='mean'):
super(ConvSeparateLoss, self).__init__(size_average, reduce, reduction)
self.ignore_index = ignore_index
self.weight = weight
self.loss1_fn = loss1_fn
def forward(self, input1, target1, input2):
loss1 = self.loss1_fn(input1, target1)
# loss2 = -F.mse_loss(input2, torch.tensor(0.5, requires_grad=False).cuda())
# loss2 = -torch.std(input2, dim=-1).sum()
# + F.mse_loss(torch.mean(input2, dim=-1), torch.tensor(0.2, requires_grad=False).cuda())
# loss_std = 0
# loss_avg = 0.
# edge = edge_num + edge_num
# edge_input2 = torch.split(input2, edge, dim=0)
# for i in range(len(edge)):
# avg_e = 2 / (edge[i] * num_ops)
# loss_avg += 5 * F.mse_loss(torch.mean(edge_input2[i]), torch.tensor(avg_e, requires_grad=False).cuda())
# loss_std += -torch.std(edge_input2[i]).sum()
# loss2 = loss_std + loss_avg
# loss2 = torch.tensor([0.], device=input1.device)
loss2 = - 0.2 * torch.std(input2)
return loss1 + self.weight * loss2, loss1.item(), loss2.item()
class TriSeparateLoss(nn.modules.loss._Loss):
"""Separate the weight value between each operations using L1"""
def __init__(self, loss1_fn, weight=0.1, size_average=None, ignore_index=-100,
reduce=None, reduction='mean'):
super(TriSeparateLoss, self).__init__(size_average, reduce, reduction)
self.ignore_index = ignore_index
self.weight = weight
self.loss1_fn = loss1_fn
def forward(self, input1, target1, input2):
loss1 = F.cross_entropy(input1, target1)
loss2 = -F.l1_loss(input2, torch.tensor(0.5,
requires_grad=False).cuda())
return loss1 + self.weight * loss2, loss1.item(), loss2.item()
================================================
FILE: examples/Perception_and_Learning/NeuEvo/train.py
================================================
import os
import sys
import time
import logging
import torch
import utils as dutils
import argparse
import numpy as np
import torch.utils
import torch.nn as nn
from braincog.model_zoo.NeuEvo import genotypes
from braincog.model_zoo.NeuEvo.model import NetworkCIFAR as Network
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
from thop import profile
from braincog.datasets.datasets import build_transform
from braincog.base.utils.criterions import UnilateralMse
parser = argparse.ArgumentParser("cifar")
parser.add_argument('--data', type=str, default='/data/datasets',
help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='cifar10',
help='cifar10 or cifar 100 for training')
parser.add_argument('--batch-size', type=int, default=128, help='batch size')
parser.add_argument('--learning_rate', type=float,
default=0.025, help='init learning rate')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float,
default=3e-4, help='weight decay')
parser.add_argument('--report_freq', type=float,
default=50, help='report frequency')
parser.add_argument('--device', type=int, default=0, help='gpu device id')
parser.add_argument('--multi-gpus', action='store_true',
default=False, help='use multi gpus')
parser.add_argument('--parse_method', type=str,
default='darts', help='experiment name')
parser.add_argument('--epochs', type=int, default=600,
help='num of training epochs')
parser.add_argument('--init-channels', type=int,
default=64, help='num of init channels')
parser.add_argument('--layers', type=int, default=16,
help='total number of layers')
parser.add_argument('--model_path', type=str,
default='saved_models', help='path to save the model')
parser.add_argument('--auxiliary', action='store_true',
default=False, help='use auxiliary tower')
parser.add_argument('--auxiliary_weight', type=float,
default=0.4, help='weight for auxiliary loss')
parser.add_argument('--cutout', action='store_true',
default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int,
default=16, help='cutout length')
parser.add_argument('--auto_aug', action='store_true',
default=False, help='use auto augmentation')
parser.add_argument('--drop_path_prob', type=float,
default=0.2, help='drop path probability')
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--arch', type=str, default='DARTS',
help='which architecture to use')
parser.add_argument('--grad_clip', type=float,
default=5, help='gradient clipping')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--img_size', default=32, type=int)
parser.add_argument('--step', default=8, type=int)
parser.add_argument('--node-type', default='PLIFNode', type=str)
parser.add_argument('--suffix', default='', type=str)
class TrainNetwork(object):
"""The main train network"""
def __init__(self, args):
super(TrainNetwork, self).__init__()
self.args = args
self.dur_time = 0
self._init_log()
self._init_device()
self._init_data_queue()
self._init_model()
def _init_log(self):
self.args.save = '/data/floyed/darts/logs/eval/' + self.args.arch + '/' + 'cifar10' + '/eval-{}-{}-{}'.format(
self.args.save, time.strftime('%Y%m%d-%H%M'), args.suffix)
dutils.create_exp_dir(self.args.save, scripts_to_save=None)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(self.args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
self.logger = logging.getLogger('Architecture Training')
self.logger.addHandler(fh)
def _init_device(self):
if not torch.cuda.is_available():
self.logger.info('no gpu device available')
sys.exit(1)
np.random.seed(self.args.seed)
self.device_id = self.args.device
self.device = torch.device('cuda:{}'.format(
0 if self.args.multi_gpus else self.device_id))
cudnn.benchmark = True
torch.manual_seed(self.args.seed)
cudnn.enabled = True
torch.cuda.manual_seed(self.args.seed)
logging.info('gpu device = %d' % self.args.device)
logging.info("args = %s", self.args)
def _init_data_queue(self):
train_transform = build_transform(True, args.img_size)
valid_transform = build_transform(False, args.img_size)
if self.args.dataset == 'cifar10':
train_data = dset.CIFAR10(
root=self.args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR10(
root=self.args.data, train=False, download=True, transform=valid_transform)
self.num_classes = 10
elif self.args.dataset == 'cifar100':
train_data = dset.CIFAR100(
root=self.args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR100(
root=self.args.data, train=False, download=True, transform=valid_transform)
self.num_classes = 100
self.train_queue = torch.utils.data.DataLoader(
train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4)
self.valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=self.args.batch_size, shuffle=False, pin_memory=True, num_workers=4)
def _init_model(self):
genotype = eval('genotypes.%s' % self.args.arch)
model = Network(self.args.init_channels,
self.num_classes,
self.args.layers,
self.args.auxiliary,
genotype,
self.args.parse_method,
step=args.step,
node_type=args.node_type
)
flops, params = profile(model, inputs=(
torch.randn(1, 3, 32, 32),), verbose=False)
self.logger.info('flops = %fM', flops / 1e6)
self.logger.info('param size = %fM', params / 1e6)
# Try move model to multi gpus
if torch.cuda.device_count() > 1 and self.args.multi_gpus:
self.logger.info('use: %d gpus', torch.cuda.device_count())
model = nn.DataParallel(model)
else:
self.logger.info('gpu device = %d' % self.device_id)
torch.cuda.set_device(self.device_id)
self.model = model.to(self.device)
# criterion = nn.CrossEntropyLoss()
criterion = UnilateralMse(1.)
self.criterion = criterion.to(self.device)
self.optimizer = torch.optim.AdamW(
model.parameters(),
self.args.learning_rate,
weight_decay=self.args.weight_decay
)
self.best_acc_top1 = 0
# optionally resume from a checkpoint
if self.args.resume:
if os.path.isfile(self.args.resume):
print("=> loading checkpoint {}".format(self.args.resume))
checkpoint = torch.load(
self.args.resume, map_location=self.device)
self.dur_time = checkpoint['dur_time']
self.args.start_epoch = checkpoint['epoch']
self.best_acc_top1 = checkpoint['best_acc_top1']
self.args.drop_path_prob = checkpoint['drop_path_prob']
self.model.load_state_dict(checkpoint['state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})".format(
self.args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(self.args.resume))
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, float(self.args.epochs), eta_min=0,
last_epoch=-1 if self.args.start_epoch == 0 else self.args.start_epoch)
# reload the scheduler if possible
if self.args.resume and os.path.isfile(self.args.resume):
checkpoint = torch.load(self.args.resume)
self.scheduler.load_state_dict(checkpoint['scheduler'])
def run(self):
self.logger.info('args = %s', self.args)
run_start = time.time()
for epoch in range(self.args.start_epoch, self.args.epochs):
self.scheduler.step()
self.logger.info('epoch % d / %d lr %e', epoch,
self.args.epochs, self.scheduler.get_lr()[0])
self.model.drop_path_prob = self.args.drop_path_prob * epoch / self.args.epochs
train_acc, train_obj = self.train()
self.logger.info('train loss %e, train acc %f',
train_obj, train_acc)
valid_acc_top1, valid_acc_top5, valid_obj = self.infer()
self.logger.info('valid loss %e, top1 valid acc %f top5 valid acc %f',
valid_obj, valid_acc_top1, valid_acc_top5)
self.logger.info('best valid acc %f', self.best_acc_top1)
is_best = False
if valid_acc_top1 > self.best_acc_top1:
self.best_acc_top1 = valid_acc_top1
is_best = True
dutils.save_checkpoint({
'epoch': epoch + 1,
'dur_time': self.dur_time + time.time() - run_start,
'state_dict': self.model.state_dict(),
'drop_path_prob': self.args.drop_path_prob,
'best_acc_top1': self.best_acc_top1,
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict()
}, is_best, self.args.save)
self.logger.info('train epoches %d, best_acc_top1 %f, dur_time %s',
self.args.epochs, self.best_acc_top1,
dutils.calc_time(self.dur_time + time.time() - run_start))
def train(self):
objs = dutils.AvgrageMeter()
top1 = dutils.AvgrageMeter()
top5 = dutils.AvgrageMeter()
self.model.train()
for step, (input, target) in enumerate(self.train_queue):
input = input.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
self.optimizer.zero_grad()
logits, logits_aux = self.model(input)
loss = self.criterion(logits, target)
if self.args.auxiliary:
loss_aux = self.criterion(logits_aux, target)
loss += self.args.auxiliary_weight * loss_aux
loss.backward()
nn.utils.clip_grad_norm_(
self.model.parameters(), self.args.grad_clip)
self.optimizer.step()
prec1, prec5 = dutils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.item(), n)
top1.update(prec1.item(), n)
top5.update(prec5.item(), n)
if step % self.args.report_freq == 0:
self.logger.info('train %03d %e %f %f', step,
objs.avg, top1.avg, top5.avg)
return top1.avg, objs.avg
def infer(self):
objs = dutils.AvgrageMeter()
top1 = dutils.AvgrageMeter()
top5 = dutils.AvgrageMeter()
self.model.eval()
with torch.no_grad():
for step, (input, target) in enumerate(self.valid_queue):
input = input.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
logits, _ = self.model(input)
loss = self.criterion(logits, target)
prec1, prec5 = dutils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.item(), n)
top1.update(prec1.item(), n)
top5.update(prec5.item(), n)
if step % self.args.report_freq == 0:
self.logger.info('valid %03d %e %f %f',
step, objs.avg, top1.avg, top5.avg)
return top1.avg, top5.avg, objs.avg
if __name__ == '__main__':
args = parser.parse_args()
train_network = TrainNetwork(args)
train_network.run()
================================================
FILE: examples/Perception_and_Learning/NeuEvo/train_search.py
================================================
import os
import sys
import time
import numpy as np
import torch
import logging
import argparse
import torch.nn as nn
import torch.utils
import torch.nn.functional as F
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from braincog.model_zoo.NeuEvo.model_search import Network, calc_weight, calc_loss
from braincog.model_zoo.NeuEvo.architect import Architect
from separate_loss import ConvSeparateLoss, TriSeparateLoss, MseSeparateLoss
import utils
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from braincog.datasets.datasets import *
from braincog.base.utils.criterions import *
torch.autograd.set_detect_anomaly(True)
parser = argparse.ArgumentParser("cifar")
parser.add_argument('--data', type=str, default='/data/datasets',
help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='cifar10',
help='cifar10 or cifar 100 for searching')
parser.add_argument('--batch-size', type=int, default=128, help='batch size')
parser.add_argument('--learning_rate', type=float,
default=0.005, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float,
default=0.001, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float,
default=3e-4, help='weight decay')
parser.add_argument('--report_freq', type=float,
default=50, help='report frequency')
parser.add_argument('--aux_loss_weight', type=float,
default=10.0, help='weight decay')
parser.add_argument('--device', type=int, default=0, help='gpu device id')
parser.add_argument('--epochs', type=int, default=50,
help='num of training epochs')
parser.add_argument('--init-channels', type=int,
default=16, help='num of init channels')
parser.add_argument('--layers', type=int, default=6,
help='total number of layers')
parser.add_argument('--model_path', type=str,
default='saved_models', help='path to save the model')
parser.add_argument('--single_level', action='store_true',
default=False, help='use single level')
parser.add_argument('--sep_loss', type=str, default='l2',
help='path to save the model')
parser.add_argument('--cutout', action='store_true',
default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int,
default=16, help='cutout length')
parser.add_argument('--auto_aug', action='store_true',
default=False, help='use auto augmentation')
parser.add_argument('--parse_method', type=str,
default='bio_darts', help='parse the code method')
parser.add_argument('--op_threshold', type=float,
default=0.85, help='threshold for edges')
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--grad_clip', type=float,
default=5, help='gradient clipping')
parser.add_argument('--train_portion', type=float,
default=0.5, help='portion of training data')
parser.add_argument('--arch_learning_rate', type=float,
default=1e-3, help='learning rate for arch encoding')
parser.add_argument('--arch_lr_gamma', type=float, default=0.9,
help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float,
default=1e-3, help='weight decay for arch encoding')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
# EventData Augmentation
parser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')
parser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')
parser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')
parser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')
parser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')
parser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')
parser.add_argument('--cutmix_noise', type=float, default=0.,
help='Add Pepper noise after mix, sometimes work (default: 0.)')
parser.add_argument('--rand-aug', action='store_true',
help='Rand Augment for Event data (default: False)')
parser.add_argument('--randaug_n', type=int, default=3,
help='Rand Augment times n (default: 3)')
parser.add_argument('--randaug_m', type=int, default=15,
help='Rand Augment times n (default: 15) (0-30)')
parser.add_argument('--temporal-flatten', action='store_true',
help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')
parser.add_argument('--train-portion', type=float, default=0.9,
help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')
parser.add_argument('--event-size', default=48, type=int,
help='Event size. Resize event data before process (default: 48)')
parser.add_argument('--layer-by-layer', action='store_true',
help='forward step-by-step or layer-by-layer. '
'Larger Model with layer-by-layer will be faster (default: False)')
parser.add_argument('--node-resume', type=str, default='',
help='resume weights in node for adaptive node. (default: False)')
parser.add_argument('--node-trainable', action='store_true')
parser.add_argument('--img_size', default=32, type=int)
parser.add_argument('--smoothing', default=0.1, type=float)
parser.add_argument('--step', default=8, type=int)
parser.add_argument('--node-type', default='BiasPLIFNode', type=str)
parser.add_argument('--loss_fn', type=str, default='')
parser.add_argument('--back-connection', action='store_true')
parser.add_argument('--asbe', '--arch-search-begin-epoch',
type=int, default=0, dest='asbe')
parser.add_argument('--num-classes', type=int, default=10)
parser.add_argument('--spike-output',action='store_true')
parser.add_argument('--act-fun', type=str, default='GateGrad')
parser.add_argument('--suffix', default='', type=str)
args = parser.parse_args()
args.save = '/data/floyed/darts/logs/search/search-{}-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"),
args.suffix)
utils.create_exp_dir(args.save, scripts_to_save=None)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
CIFAR_CLASSES = 10
def main():
args.spike_output = False
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)
np.random.seed(args.seed)
torch.cuda.set_device(args.device)
# cudnn.benchmark = True
torch.manual_seed(args.seed)
# cudnn.enabled = True
torch.cuda.manual_seed(args.seed)
logging.info('gpu device = %d' % args.device)
logging.info("args = %s", args)
run_start = time.time()
start_epoch = 0
dur_time = 0
if args.loss_fn == 'mix':
criterion_train = MixLoss(LabelSmoothingCrossEntropy(
smoothing=args.smoothing).cuda())
criterion_val = MixLoss(nn.CrossEntropyLoss())
elif args.loss_fn == 'mse':
criterion_train = UnilateralMse(1.)
criterion_val = UnilateralMse(1.)
else:
criterion_train = LabelSmoothingCrossEntropy().cuda()
criterion_val = nn.CrossEntropyLoss().cuda()
criterion_train = ConvSeparateLoss(criterion_train, weight=args.aux_loss_weight) \
if args.sep_loss == 'l2' else TriSeparateLoss(criterion_train, weight=args.aux_loss_weight)
model = Network(args.init_channels, args.num_classes, args.layers, criterion_train,
steps=3, multiplier=3, stem_multiplier=3,
parse_method=args.parse_method, op_threshold=args.op_threshold,
step=args.step, node_type=args.node_type,
back_connection=args.back_connection, act_fun=args.act_fun,
dataset=args.dataset,
spike_output=False,
temporal_flatten=args.temporal_flatten)
model = model.cuda()
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
model_optimizer = torch.optim.AdamW(
model.parameters(),
args.learning_rate,
# momentum=args.momentum,
weight_decay=args.weight_decay)
# # train_transform, valid_transform = utils._data_transforms_cifar(args)
# train_transform = build_transform(True, args.img_size)
# valid_transform = build_transform(False, args.img_size)
# train_data = dset.CIFAR10(
# root=args.data, train=True, download=True, transform=train_transform)
#
# num_train = len(train_data)
# indices = list(range(num_train))
# split = int(np.floor(args.train_portion * num_train))
#
# train_queue = torch.utils.data.DataLoader(
# train_data, batch_size=args.batch_size,
# sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
# pin_memory=True)
#
# valid_queue = torch.utils.data.DataLoader(
# train_data, batch_size=args.batch_size,
# sampler=torch.utils.data.sampler.SubsetRandomSampler(
# indices[split:num_train]),
# pin_memory=True)
train_queue, valid_queue, _, _ = eval('get_%s_data' % args.dataset)(
batch_size=args.batch_size,
step=args.step,
args=args,
size=args.event_size,
mix_up=args.mix_up,
cut_mix=args.cut_mix,
event_mix=args.event_mix,
beta=args.cutmix_beta,
prob=args.cutmix_prob,
num=args.cutmix_num,
noise=args.cutmix_noise,
num_classes=args.num_classes,
rand_aug=args.rand_aug,
randaug_n=args.randaug_n,
randaug_m=args.randaug_m,
temporal_flatten=args.temporal_flatten
)
architect = Architect(model, args)
# resume from checkpoint
if args.resume:
if os.path.isfile(args.resume):
logging.info("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(
args.resume, map_location=model.alphas_normal.device)
start_epoch = checkpoint['epoch']
dur_time = checkpoint['dur_time']
model_optimizer.load_state_dict(checkpoint['model_optimizer'])
model.restore(checkpoint['network_states'])
logging.info('=> loaded checkpoint \'{}\'(epoch {})'.format(
args.resume, start_epoch))
else:
logging.info(
'=> no checkpoint found at \'{}\''.format(args.resume))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
model_optimizer, float(args.epochs), eta_min=args.learning_rate_min,
last_epoch=-1 if start_epoch == 0 else start_epoch)
if args.resume and os.path.isfile(args.resume):
scheduler.load_state_dict(checkpoint['scheduler'])
for epoch in range(start_epoch, args.epochs):
scheduler.step()
lr = scheduler.get_lr()[0]
logging.info('epoch %d lr %e', epoch, lr)
genotype = model.genotype()
logging.info('genotype = %s', genotype)
logging.info(calc_weight(model.alphas_normal))
logging.info(calc_loss(model.alphas_normal))
model.update_history()
# training and search the model
train_acc, train_obj = train(epoch, train_queue, valid_queue, model, architect, criterion_train,
model_optimizer)
logging.info('train_acc %f', train_acc)
# validation the model
model.record_fire_rate = True
model.reset_fire_rate_record()
valid_acc, valid_obj = infer(valid_queue, model, criterion_val)
fire_rate = model.get_fire_per_step()
model.record_fire_rate = False
logging.info('valid_fire_rate: {}'.format(fire_rate))
logging.info('valid_acc %f', valid_acc)
# save checkpoint
utils.save_checkpoint({
'epoch': epoch + 1,
'dur_time': dur_time + time.time() - run_start,
'scheduler': scheduler.state_dict(),
'model_optimizer': model_optimizer.state_dict(),
'network_states': model.states(),
}, is_best=False, save=args.save)
logging.info('save checkpoint (epoch %d) in %s dur_time: %s', epoch, args.save,
utils.calc_time(dur_time + time.time() - run_start))
# save operation weights as fig
utils.save_file(recoder=model.alphas_normal_history, path=os.path.join(args.save, 'normal'),
back_connection=args.back_connection)
# save last operations
np.save(os.path.join(os.path.join(args.save, 'normal_weight.npy')),
calc_weight(model.alphas_normal).data.cpu().numpy())
logging.info('save last weights done')
def train(epoch, train_queue, valid_queue, model, architect, criterion, model_optimizer):
objs = utils.AvgrageMeter()
objs1 = utils.AvgrageMeter()
objs2 = utils.AvgrageMeter()
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()
for step, (input, target) in enumerate(train_queue):
model.train()
n = input.size(0)
input = Variable(input, requires_grad=False).cuda(non_blocking=True)
target = Variable(target, requires_grad=False).cuda(non_blocking=True)
# if epoch >= args.asbe:
# Get a random minibatch from the search queue(validation set) with replacement
# input_search, target_search = next(iter(valid_queue))
# print(input.shape, target.shape)
# print(input_search.shape, target_search.shape)
# input_search = Variable(
# input_search, requires_grad=False).cuda(non_blocking=True)
# target_search = Variable(
# target_search, requires_grad=False).cuda(non_blocking=True)
# loss1, loss2 = architect.step(input_search, target_search)
# else:
loss1 = torch.tensor([0.])
loss2 = torch.tensor([0.])
model_optimizer.zero_grad()
logits = model(input)
aux_input = torch.cat(
[calc_loss(model.alphas_normal)], dim=0)
loss, _, _ = criterion(logits, target, aux_input)
# loss = criterion(logits, target)
loss.backward()
nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
# Update the network parameters
model_optimizer.step()
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
objs.update(loss.item(), n)
objs1.update(loss1, n)
objs2.update(loss2, n)
top1.update(prec1.item(), n)
top5.update(prec5.item(), n)
if step % args.report_freq == 0:
logging.info('train %03d loss: %e top1: %f top5: %f',
step, objs.avg, top1.avg, top5.avg)
logging.info('val cls_loss %e; spe_loss %e', objs1.avg, objs2.avg)
return top1.avg, objs.avg
def infer(valid_queue, model, criterion):
objs = utils.AvgrageMeter()
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()
model.eval()
with torch.no_grad():
for step, (input, target) in enumerate(valid_queue):
input = Variable(input, volatile=True).cuda(non_blocking=True)
target = Variable(target, volatile=True).cuda(non_blocking=True)
logits = model(input)
loss = criterion(logits, target)
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.item(), n)
top1.update(prec1.item(), n)
top5.update(prec5.item(), n)
if step % args.report_freq == 0:
logging.info('valid %03d %e %f %f', step,
objs.avg, top1.avg, top5.avg)
return top1.avg, objs.avg
if __name__ == '__main__':
main()
================================================
FILE: examples/Perception_and_Learning/NeuEvo/utils.py
================================================
import json
import matplotlib.pyplot as plt
from braincog.model_zoo.NeuEvo.genotypes import Genotype, PRIMITIVES
import os
import numpy as np
import torch
import shutil
import torchvision.transforms as transforms
from torch.autograd import Variable
from auto_augment import CIFAR10Policy
from braincog.model_zoo.NeuEvo.genotypes import PRIMITIVES
forward_edge_num = sum(1 for i in range(3) for n in range(2 + i))
backward_edge_num = sum(1 for i in range(3) for n in range(i))
num_ops = len(PRIMITIVES)
type_num = len(PRIMITIVES) // 2
# edge_num = [2, 3, 4]
edge_num = [2, 3, 4, 1, 2]
class AvgrageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.avg = 0
self.sum = 0
self.cnt = 0
def update(self, val, n=1):
self.sum += val * n
self.cnt += n
self.avg = self.sum / self.cnt
def accuracy(output, target, topk=(1,)):
"""Compute the top1 and top5 accuracy
"""
maxk = max(topk)
batch_size = target.size(0)
# Return the k largest elements of the given input tensor
# along a given dimension -> N * k
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class Cutout(object):
def __init__(self, length):
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
def _data_transforms_cifar(args):
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] if args.dataset == 'cifar10' else [0.50707519, 0.48654887,
0.44091785]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] if args.dataset == 'cifar10' else [0.26733428, 0.25643846,
0.27615049]
normalize_transform = [
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD)]
random_transform = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()]
if args.auto_aug:
random_transform += [CIFAR10Policy()]
if args.cutout:
cutout_transform = [Cutout(args.cutout_length)]
else:
cutout_transform = []
train_transform = transforms.Compose(
random_transform + normalize_transform + cutout_transform
)
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
return train_transform, valid_transform
def count_parameters_in_MB(model):
return np.sum(np.prod(v.size()) for v in model.parameters()) / 1e6
def save_checkpoint(state, is_best, save):
filename = os.path.join(save, 'checkpoint.pth.tar')
torch.save(state, filename)
if is_best:
best_filename = os.path.join(save, 'model_best.pth.tar')
shutil.copyfile(filename, best_filename)
def save(model, model_path):
torch.save(model.state_dict(), model_path)
def load(model, model_path):
model.load_state_dict(torch.load(model_path))
def drop_path(x, drop_prob):
if drop_prob > 0.:
keep_prob = 1. - drop_prob
mask = Variable(torch.cuda.FloatTensor(
x.size(0), 1, 1, 1).bernoulli_(keep_prob))
x.div_(keep_prob)
x.mul_(mask)
return x
def create_exp_dir(path, scripts_to_save=None):
if not os.path.exists(path):
os.makedirs(path)
print('Experiment dir : {}'.format(path))
if scripts_to_save is not None:
os.makedirs(os.path.join(path, 'scripts'))
for script in scripts_to_save:
dst_file = os.path.join(path, 'scripts', os.path.basename(script))
shutil.copyfile(script, dst_file)
def calc_time(seconds):
m, s = divmod(seconds, 60)
h, m = divmod(m, 60)
t, h = divmod(h, 24)
return {'day': t, 'hour': h, 'minute': m, 'second': int(s)}
def save_file(recoder, path='./', back_connection=False):
size = (forward_edge_num +
backward_edge_num if back_connection else forward_edge_num, num_ops)
fig, axs = plt.subplots(*size, figsize=(36, 98))
row = 0
col = 0
for (k, v) in recoder.items():
axs[row, col].set_title(k)
axs[row, col].plot(v, 'r+')
if col == num_ops - 1:
col = 0
row += 1
else:
col += 1
if not os.path.exists(path):
os.makedirs(path)
fig.savefig(os.path.join(path, 'output.png'), bbox_inches='tight')
plt.tight_layout()
print('save history weight in {}'.format(os.path.join(path, 'output.png')))
with open(os.path.join(path, 'history_weight.json'), 'w') as outf:
json.dump(recoder, outf)
print('save history weight in {}'.format(
os.path.join(path, 'history_weight.json')))
================================================
FILE: examples/Perception_and_Learning/QSNN/README.md
================================================
# Quantum superposition inspired spiking neural network
This repository contains code from our paper [**Quantum superposition inspired spiking neural network**] published in iScience. https://doi.org/10.1016/j.isci.2021.102880. If you use our code or refer to this project, please cite this paper.
## Requirments
* numpy
* scipy
* pytorch >= 1.7.0
* torchvision
## Train
```shell
python ./main.py
```
## Citation
If you find this package helpful, please consider citing the following papers:
```BibTex
@article{sun2021quantum,
title={Quantum superposition inspired spiking neural network},
author={Sun, Yinqian and Zeng, Yi and Zhang, Tielin},
journal={Iscience},
volume={24},
number={8},
pages={102880},
year={2021},
publisher={Elsevier}
}
@misc{https://doi.org/10.48550/arxiv.2207.08533,
doi = {10.48550/ARXIV.2207.08533},
url = {https://arxiv.org/abs/2207.08533},
author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},
title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},
publisher = {arXiv},
year = {2022},
}
```
================================================
FILE: examples/Perception_and_Learning/QSNN/main.py
================================================
import os
import copy
import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from braincog.datasets.datasets import get_mnist_data
from braincog.model_zoo.qsnn import Net
from braincog.datasets.gen_input_signal import lambda_max
LOG_DIR = os.path.expanduser('./results.txt')
LEARNING_RATE = 0.01
# learning dacay
DECAY_STEPS = 1.0
DECAY_RATE = 0.9
# adam
BETA1 = 0.9
BETA2 = 0.999
EPSIOLN = 1e-8
EPOCHS = 20
PRINT_PERIOD = 10000
TEST_SIZE = 10000
TEST_THETA = [0, 1 / 16, 2 / 16, 3 / 16, 4 / 16, 5 / 16, 6 / 16, 7 / 16, 8 / 16]
# TEST_THETA = [0]
NOISE_RATES = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
mnist_data = get_mnist_data(batch_size=1, skip_norm=True)
train_loader, test_loader = mnist_data.get_data_loaders()
NET_SIZE = [28 * 28, 500, 10]
def int2onehot(label, classes, factor):
label_one_hot = F.one_hot(label, classes)
label_one_hot = label_one_hot * (8 + factor) - 8
return label_one_hot
def train(net, epochs, lr):
with open(LOG_DIR, 'a+') as f:
for epoch in range(epochs):
lr_decay = lr * DECAY_RATE ** (epoch / DECAY_STEPS)
for x, y in tqdm.tqdm(train_loader):
label = int2onehot(y, 10, 8)
x = x.flatten().numpy()
label = label.cuda()
with torch.no_grad():
net.routine(x, None, image_ori=None, image_ori_delta=None, shift=False,
label=label, test=False, noise=False, noise_rate=None)
net.update_weight(lr_decay, epoch + 1, (BETA1, BETA2), EPSIOLN)
with torch.no_grad():
for fac in TEST_THETA:
acc_reve = 0
for x_test, y_test in test_loader:
image = x_test.flatten().numpy()
image_shift = image * np.cos(fac * np.pi) + (lambda_max - image) * np.sin(fac * np.pi)
image_delta = copy.copy(image)
delta_idx = image_delta < (lambda_max - 0.001)
image_delta[delta_idx] += 0.001
image_delta_shift = image_delta * np.cos(fac * np.pi) + (lambda_max - image_delta) * np.sin(fac * np.pi)
pred = net.predict(image_shift, image_delta_shift, image, image_delta, shift=True, noise=False, noise_rate=None)
if pred == int(y_test):
acc_reve += 1
acc_reve = acc_reve / TEST_SIZE
print('Test epoch {epoch}: Shift {theta:0.3f} pi: accuracy {acc}.'.format(
epoch=epoch, theta=fac, acc=acc_reve))
print('Test epoch {epoch}: Shift {theta:0.3f} pi: accuracy {acc}'.format(
epoch=epoch, theta=fac, acc=acc_reve), file=f)
print()
print(file=f)
if __name__ == '__main__':
net = Net(NET_SIZE).cuda()
train(net, EPOCHS, LEARNING_RATE)
================================================
FILE: examples/Perception_and_Learning/UnsupervisedSTDP/Readme.md
================================================
This is an example of training Unsupervised STDP-based spiking neural network. We used a STB-STDP algrithom to train SNN, and mutiply adaptive mechanisms.
# How to run
python codef.py
# Result
We train the model on Mnist and FashionMNIST, and the best accuracy for MNIST is 97.9%, for FashionMNIST is 87.0%.
### Citation
If you find this package helpful, please consider citing the following papers:
```BibTex
@article{dong2022unsupervised,
title={An Unsupervised Spiking Neural Network Inspired By Biologically Plausible Learning Rules and Connections},
author={Dong, Yiting and Zhao, Dongcheng and Li, Yang and Zeng, Yi},
journal={arXiv preprint arXiv:2207.02727},
year={2022}
}
@article{zeng2022braincog,
title={BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},
author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},
journal={arXiv preprint arXiv:2207.08533},
year={2022}
}
```
================================================
FILE: examples/Perception_and_Learning/UnsupervisedSTDP/codef.py
================================================
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from tqdm import tqdm
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import cv2
import numpy as np
from copy import deepcopy
import os, time, math,random
from braincog.base.node.node import *
from braincog.base.connection .layer import *
from braincog.base.strategy.LateralInhibition import *
from sklearn.metrics import confusion_matrix
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed) #GPU随机种子确定
torch.backends.cudnn.benchmark = False #模型卷积层预先优化关闭
torch.backends.cudnn.deterministic = True #确定为默认卷积算法
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
dev = "cuda"
device = torch.device(dev) if torch.cuda.is_available() else 'cpu'
torch.set_printoptions(precision=4, sci_mode=False)
# ===========================================================================================================
convoff = 0.3
# avgscale = 5
class STDPConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, padding,groups,
tau_decay=torch.exp(-1.0 / torch.tensor(100.0)), offset=convoff, static=True, inh=6.5, avgscale=5):
super().__init__()
self.tau_decay = tau_decay
self.offset = offset
self.static = static
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,groups=groups,
bias=False)
self.avgpool = nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=padding)
self.mem = self.spike = self.refrac_count = None
self.normweight()
self.inh = inh
self.avgscale = avgscale
self.onespike=True
self.node=LIFSTDPNode(act_fun=STDPGrad,tau=tau_decay,mem_detach=True)
self.WTA=WTALayer( )
self.lateralinh=LateralInhibition(self.node,self.inh,mode="threshold")
def mem_update(self, x, onespike=True): # b,c,h,w
x=self.node( x)
if x.max() > 0:
x=self.WTA(x)
self.lateralinh(x)
self.spike= x
return self.spike
def forward(self, x, T=None, onespike=True):
if not self.static:
batch, T, c, h, w = x.shape
x = x.reshape(-1, c, h, w)
x = self.conv( x)
n = self.getthresh(x)
self.node.threshold.data = n
x=x.clamp(min=0)
x = n / (1 + torch.exp(-(x - 4 * n / 10) * (8 / n)))
if not self.static:
x = x.reshape(batch, T, c, h, w)
xsum = None
for i in range(T):
tmp = self.mem_update(x[:, i], onespike).unsqueeze(1)
if xsum is not None:
xsum = torch.cat([xsum, tmp], 1)
else:
xsum = tmp
else:
xsum = 0
for i in range(T):
xsum += self.mem_update(x, onespike)
return xsum
def reset(self):
#self.mem = self.spike = self.refrac_count = None
self.node.n_reset()
def normgrad(self, force=False):
if force:
min = self.conv.weight.grad.data.min(1, True)[0].min(2, True)[0].min(3, True)[0]
max = self.conv.weight.grad.data.min(1, True)[0].max(2, True)[0].max(3, True)[0]
self.conv.weight.grad.data -= min
tmp = self.offset * max
else:
tmp = self.offset * self.spike.mean(0, True).mean(2, True).mean(3, True).permute(1, 0, 2, 3)
self.conv.weight.grad.data -= tmp
self.conv.weight.grad.data = -self.conv.weight.grad.data
def normweight(self, clip=False):
if clip:
self.conv.weight.data = torch. \
clamp(self.conv.weight.data, min=-3, max=1.0)
else:
c, i, w, h = self.conv.weight.data.shape
avg=self.conv.weight.data.mean(1, True).mean(2, True).mean(3, True)
self.conv.weight.data -=avg
tmp = self.conv.weight.data.reshape(c, 1, -1, 1)
self.conv.weight.data /= tmp.std(2, unbiased=False, keepdim=True)
def getthresh(self, scale):
tmp2= scale.max(0, True)[0].max(2, True)[0].max(3, True)[0]+0.0001
return tmp2
class STDPLinear(nn.Module):
def __init__(self, in_planes, out_planes,
tau_decay=0.99, offset=0.05, static=True,inh=10):
super().__init__()
self.tau_decay = tau_decay
self.offset = offset
self.static = static
self.linear = nn.Linear(in_planes, out_planes, bias=False)
self.mem = self.spike = self.refrac_count = None
# torch.nn.init.xavier_uniform_(self.linear.weight, gain=1)
self.normweight(False)
self.threshold = torch.ones(out_planes, device=device) *20
self.inh=inh
self.node=LIFSTDPNode(act_fun=STDPGrad,tau=tau_decay ,mem_detach=True)
self.WTA=WTALayer( )
self.lateralinh=LateralInhibition(self.node,self.inh,mode="max")
self.init=False
def mem_update(self, x, onespike=True): # b,c,h,w
if not self.init:
self.node.threshold.data= (x.max(0)[0].detach()*3).to(device)
self.init=True
xori=x
x=self.node( x)
if x.max() > 0:
x=self.WTA(x)
self.lateralinh(x,xori)
self.spike=x
return self.spike
def forward(self, x, T, onespike=True):
if not self.static:
batch, T, w = x.shape
x = x.reshape(-1, w)
x = x.detach()
x = self.linear(x)
self.x=x.detach()
if not self.static:
x = x.reshape(batch, T, w)
xsum = None
for i in range(T):
tmp = self.mem_update(x[:, i], onespike).unsqueeze(1)
if xsum is not None:
xsum = torch.cat([xsum, tmp], 1)
else:
xsum = tmp
else:
xsum = 0
for i in range(T):
xsum += self.mem_update(x, onespike)
#print(xsum.mean())
return xsum
def reset(self):
self.node.n_reset()
def normgrad(self, force=False):
if force:
pass
else:
tmp = self.offset * self.spike.mean(0, True).permute(1, 0)
self.linear.weight.grad.data = -self.linear.weight.grad.data
def normweight(self, clip=False):
if clip:
self.linear.weight.data = torch. \
clamp(self.linear.weight.data, min=0, max=1.0)
else:
self.linear.weight.data = torch. \
clamp(self.linear.weight.data, min=0, max=1.0)
sumweight = self.linear.weight.data.sum(1, True)
sumweight += (~(sumweight.bool())).float()
# self.linear.weight.data *= 11.76 / sumweight
self.linear.weight.data /= self.linear.weight.data.max(1, True)[0] / 0.1
def getthresh(self, scale):
tmp = self.linear.weight.clamp(min=0) * scale
tmp2 = tmp.sum(1, True).reshape(1, -1)
return tmp2
def updatethresh(self, plus=0.05):
self.node.threshold += (plus*self.x * self.spike.detach()).sum(0)
tmp=self.node.threshold.max()-350
if tmp>0:
self.node.threshold-=tmp
class STDPFlatten(nn.Module):
def __init__(self, start_dim=0, end_dim=-1):
super().__init__()
self.flatten = nn.Flatten(start_dim=start_dim, end_dim=end_dim)
def forward(self, x, T): # [batch,T,c,w,h]
return self.flatten(x)
class STDPMaxPool(nn.Module):
def __init__(self, kernel_size, stride, padding, static=True):
super().__init__()
self.static = static
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
def forward(self, x, T): # [batch,T,c,w,h]
if not self.static:
batch, T, c, h, w = x.shape
x = x.reshape(-1, c, h, w)
x = self.pool(x)
if not self.static:
x = x.reshape(batch, T, c, h, w)
return x
alpha = 1.0
class Normliaze(nn.Module):
def __init__(self, static=True):
super().__init__()
self.static = static
def forward(self, x, T): # [batch,T,c,w,h]
# print(x.shape)
x /= x.max(1, True)[0].max(2, True)[0].max(3, True)[0]
# x/=x.mean()/0.13
return x
class voting(nn.Module):
def __init__(self, shape):
super().__init__()
self.label = torch.zeros(shape) - 1
self.assignments=0
def assign_labels(self, spikes, labels, rates=None, n_labels=10, alpha=alpha):
# 根据最后一层的spikes 以及 label 对于最后一层的神经元赋予不同的label
# spikes 是 batch * time * in_size
# print(spikes.size())
n_neurons = spikes.size(2)
if rates is None:
rates = torch.zeros(n_neurons, n_labels, device=device)
self.n_labels = n_labels
spikes = spikes.cpu().sum(1).to(device)
for i in range(n_labels):
n_labeled = torch.sum(labels == i).float()
# 就是说上一次assign label计算的rates 拿过来滑动平均一下 #这里似乎可以改
if n_labeled > 0:
indices = torch.nonzero(labels == i).view(-1)
tmp = torch.sum(spikes[indices], 0) / n_labeled # 平均脉冲数
rates[:, i] = alpha * rates[:, i] + tmp
# 此时的rates是 in_size * n_label, 对应哪个label的rates最高 该神经元就对应着该label
self.assignments = torch.max(rates, 1)[1]
return self.assignments, rates
def get_label(self, spikes):
# 根据最后一层的spike 计算得到label
n_samples = spikes.size(0)
spikes = spikes.cpu().sum(1).to(device)
rates = torch.zeros(n_samples, self.n_labels, device=device)
for i in range(self.n_labels):
n_assigns = torch.sum(self.assignments == i).float() # 共有多少个该类别节点
if n_assigns > 0:
indices = torch.nonzero(self.assignments == i).view(-1) # 找到该类别节点位置
rates[:, i] = torch.sum(spikes[:, indices], 1) / n_assigns # 该类别平均所有该类别节点发放脉冲数
return torch.sort(rates, dim=1, descending=True)[1][:, 0]
inh=25
inh2=1.625
channel=12
neuron=6400
class Conv_Net(nn.Module):
def __init__(self):
super(Conv_Net, self).__init__()
self.conv = nn.ModuleList([
STDPConv(1, channel, 3, 1, 1,1, static=True, inh=1.625, avgscale=5 ),
STDPMaxPool(2, 2, 0, static=True),
Normliaze(),
#STDPConv(12, 48, 3, 1, 1,1, static=True, inh=inh2, avgscale=10 ),
#STDPMaxPool(2, 2, 0, static=True),
#Normliaze(),
STDPFlatten(start_dim=1),
STDPLinear(196*channel, neuron, static=True,inh=inh)
])
self.voting = voting(10)
def forward(self, x, inlayer, outlayer, T, onespike=True): # [b,t,w,h]
for i in range(inlayer, outlayer + 1):
x = self.conv[i](x, T)
return x
def normgrad(self, layer, force=False):
self.conv[layer].normgrad(force)
def normweight(self, layer, clip=False):
self.conv[layer].normweight(clip)
def updatethresh(self, layer, plus=0.05):
self.conv[layer].updatethresh(plus)
def reset(self, layer):
if isinstance(layer, list):
for i in layer:
self.conv[i].reset()
else:
self.conv[layer].reset()
def plot_confusion_matrix(cm, classes, normalize=True, title='Test Confusion matrix', cmap=plt.cm.Blues):
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
#print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
plt.figure()
#print(cm)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
# plt.xticks(tick_marks, classes, rotation=45)
plt.xticks(tick_marks, classes)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
plt.text(i, i, format(cm[i, i], fmt), horizontalalignment="center",
color="white" if cm[i, i] > thresh else "black")
plt.tight_layout()
#plt.savefig('confusestpf2'+str(channel)+"_n"+str(neuron)+".pdf")
#plt.show()
if __name__ == '__main__':
print(23)
batch_size = 1024
T = 100
transform = transforms.Compose(
[transforms.Resize((28, 28)), transforms.Grayscale(num_output_channels=1), transforms.ToTensor()])
transform = transforms.Compose([transforms.ToTensor()])
# mnist_train = datasets.CIFAR10(root='/data/datasets/CIFAR10/', train=True, download=False, transform=transform )
# mnist_test = datasets.CIFAR10(root='/data/datasets/CIFAR10/', train=False, download=False, transform=transform )
#mnist_train = datasets.FashionMNIST(root='/data/dyt//', train=True, download=True, transform=transform )
#mnist_test = datasets.FashionMNIST(root='/data/dyt/', train=False, download=False, transform=transform )
mnist_train = datasets.MNIST(root='./', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root='./', train=False, download=False, transform=transform)
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=4)
model = Conv_Net().to(device)
convlist = [index for index, i in enumerate(model.conv) if isinstance(i, (STDPConv, STDPLinear))]
print(convlist)
#cap = torch.ones([100000, 1000, 30], device=device)
for layer in range(len(convlist) - 1):
optimizer = torch.optim.SGD(list(model.parameters())[layer:layer + 1], lr=0.1)
for epoch in range(3):
for step, (x, y) in enumerate(tqdm(train_iter)):
x = x.to(device)
y = y.to(device)
spikes = model(x, 0, convlist[layer], T)
optimizer.zero_grad()
spikes.sum().backward(torch.tensor(1/ (spikes.shape[0] * spikes.shape[2] * spikes.shape[3])))
# spikes.sum().backward( )
model.conv[convlist[layer]].spike = spikes.detach()
model.normgrad(convlist[layer], force=True)
optimizer.step()
model.normweight(convlist[layer], clip=False)
# print(model.conv[convlist[layer]].conv.weight.data )
model.reset(convlist)
print("layer", layer, "epoch", epoch, 'Done')
#model.conv[convlist[layer]].onespike=False
# ===========================================================================================================
# linear
#model.conv[convlist[-2]].onespike=True
cap = None
batch_size = 1024
T = 200
layer = len(convlist) - 1
plus = 0.002
lr = 0.0001
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=4)
optimizer = torch.optim.SGD(list(model.parameters())[layer:], lr=lr)
rates = None
best = 0
accrecord=[]
for epoch in range(1000):
spikefull = None
labelfull = None
for step, (x, y) in enumerate(tqdm(train_iter)):
x = x.to(device)
y = y.to(device)
spiketime = 0
spikes = model(x, 0, convlist[layer], T)
# print(spikes.mean())
optimizer.zero_grad()
spikes.sum().backward()
model.conv[convlist[layer]].spike = spikes.detach()
model.normgrad(convlist[layer], force=False)
optimizer.step()
model.updatethresh(convlist[layer], plus=plus)
model.normweight(convlist[layer], clip=False)
spikes = spikes.reshape(spikes.shape[0], 1, -1).detach()
if spikefull is None:
spikefull = spikes
labelfull = y
else:
spikefull = torch.cat([spikefull, spikes], 0)
labelfull = torch.cat([labelfull, y], 0)
model.reset(convlist)
_, rates = model.voting.assign_labels(spikefull, labelfull, rates)
rates = rates.detach() * 0.5
result = model.voting.get_label(spikefull)
acc = (result == labelfull).float().mean()
print(epoch, acc, 'channel', channel, "n", neuron)
print(model.conv[-1].node.threshold.max(),model.conv[-1].node.threshold.mean(),model.conv[-1].node.threshold.min())
# model.conv[-1].threshold*=0.98
spikefull = None
labelfull = None
result = None
for step2, (x, y) in enumerate(test_iter):
x = x.to(device)
y = y.to(device)
spiketime = 0
spikes = model(x, 0, convlist[layer], T)
spikes = spikes.reshape(spikes.shape[0], 1, -1).detach()
with torch.no_grad():
if spikefull is None:
spikefull = spikes
labelfull = y
else:
spikefull = torch.cat([spikefull, spikes], 0)
labelfull = torch.cat([labelfull, y], 0)
model.reset(convlist)
result = model.voting.get_label(spikefull)
acc = (result == labelfull).float().mean()
if best < acc:
best = acc
torch.save( model, "modelftstp28_350_c"+str(channel)+"_n"+str(neuron)+"_p"+str(acc)+".pth")
classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
cm = confusion_matrix(labelfull.cpu(), result.cpu())
plot_confusion_matrix(cm, classes)
print("test", acc, "best", best)
accrecord.append(acc)
#torch.save(accrecord,"accfstp28_350_c"+str(channel)+"_n"+str(neuron)+".pth")
================================================
FILE: examples/Perception_and_Learning/img_cls/bp/README.md
================================================
# Script for training high-performance SNNs based on back propagation
This is an example of training high-performance SNNs using the braincog.
It is able to train high performance SNNs on CIFAR10, DVS-CIFAR10, ImageNet and other datasets, and reach the advanced level.
## Install braincog
```shell
git clone https://github.com/xxx/Brain-Cog.git
cd braincog
python setup install --user
```
## Examples of training
```shell
cd examples/Perception_and_Learning/img_cls/bp
python main.py --model dvs_convnet --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGrad --device 0
```
## Benchmark
We provide a benchmark of SNNs trained with braincog and the corresponding scripts.
This provides an open, fair platform for comparison of subsequent SNNs on classification tasks.
**Note**: The results may vary due to random seeding and software version issues.
### CIFAR10
| ID | Dataset | Node-type | Config | Model | Batch Size | Accuracy | Script |
|:----|:-------:|:--------------:|:------:|:-------------:|:----------:|:------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------|
| 1 | CIFAR10 | IF+Atan | - | convnet | 128 | 95.54 | ```python main.py --model cifar_convnet --node-type IFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR10 | LIF+Atan | - | convnet | 128 | 91.92 | ```python main.py --model cifar_convnet --node-type LIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR10 | PLIF+Atan | - | convnet | 128 | 93.32 | ```python main.py --model cifar_convnet --node-type PLIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR10 | IF+Atan | - | resnet18 | 128 | 89.76/89.80 | ```python main.py --model resnet18 --node-type IFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR10 | LIF+Atan | - | resnet18 | 128 | 89.93/89.88 | ```python main.py --model resnet18 --node-type LIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR10 | PLIF+Atan | - | resnet18 | 128 | 92.64/ 90.65 | ```python main.py --model resnet18 --node-type PLIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR10 | IF+QGateGrad | - | cifar_convnet | 128 | 95.73 | ```python main.py --model cifar_convnet --node-type IFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | CIFAR10 | LIF+QGateGrad | - | cifar_convnet | 128 | 96.04 | ```python main.py --model cifar_convnet --node-type LIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | CIFAR10 | PLIF+QGateGrad | - | cifar_convnet | 128 | 96.04/95.84 | ```python main.py --model cifar_convnet --node-type PLIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | CIFAR10 | IF+QGateGrad | - | resnet18 | 128 | 89.19 | ```python main.py --model resnet18 --node-type IFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | CIFAR10 | LIF+QGateGrad | - | resnet18 | 128 | 90.95/90.68 | ```python main.py --model resnet18 --node-type LIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | CIFAR10 | PLIF+QGateGrad | - | resnet18 | 128 | 90.97/91.02 | ```python main.py --model resnet18 --node-type PLIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
### CIFAR100
| ID | Dataset | Node-type | Config | Model | Batch Size | Accuracy | Script |
|:----|:--------:|:--------------:|:------:|:-----------:|:----------:|:--------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 1 | CIFAR100 | IF+Atan | - | dvs_convnet | 128 | 76.52 | ```python main.py --num-classes 100 --model cifar_convnet --node-type IFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR100 | LIF+Atan | - | dvs_convnet | 128 | 71.89 | ```python main.py --num-classes 100 --model cifar_convnet --node-type LIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR100 | PLIF+Atan | - | dvs_convnet | 128 | 72.82 | ```python main.py --num-classes 100 --model cifar_convnet --node-type PLIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR100 | IF+Atan | - | resnet18 | 128 | 62.47 | ```python main.py --num-classes 100 --model resnet18 --node-type IFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR100 | LIF+Atan | - | resnet18 | 128 | 62.63 | ```python main.py --num-classes 100 --model resnet18 --node-type LIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR100 | PLIF+Atan | - | resnet18 | 128 | 62.71 | ```python main.py --num-classes 100 --model resnet18 --node-type PLIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | CIFAR100 | IF+QGateGrad | - | dvs_convnet | 128 | 76.44 | ```python main.py --num-classes 100 --model cifar_convnet --node-type IFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | CIFAR100 | LIF+QGateGrad | - | dvs_convnet | 128 | 77.73 | ```python main.py --num-classes 100 --model cifar_convnet --node-type LIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | CIFAR100 | PLIF+QGateGrad | - | dvs_convnet | 128 | 77.25 | ```python main.py --num-classes 100 --model cifar_convnet --node-type PLIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | CIFAR100 | IF+QGateGrad | - | resnet18 | 128 | 60.01 | ```python main.py --num-classes 100 --model resnet18 --node-type IFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | CIFAR100 | LIF+QGateGrad | - | resnet18 | 128 | 61.33 | ```python main.py --num-classes 100 --model resnet18 --node-type LIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | CIFAR100 | PLIF+QGateGrad | - | resnet18 | 128 | 62.32 | ```python main.py --num-classes 100 --model resnet18 --node-type PLIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
### DVS-CIFAR10
| ID | Dataset | Node-type | Config | Model | Batch Size | FLOPS | Accuracy | Script |
|:----|:-----------:|:--------------:|:------:|:-----------:|:----------:|:-----:|:-----------:|:---------------------------------------------------------------------------------------------------------------------------------------------|
| 1 | DVS-CIFAR10 | IF+Atan | - | dvs_convnet | 128 | 7503 | 65.90 | ```python main.py --model dvs_convnet --node-type IFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-CIFAR10 | LIF+Atan | - | dvs_convnet | 128 | 7503 | 82.10 | ```python main.py --model dvs_convnet --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-CIFAR10 | PLIF+Atan | - | dvs_convnet | 128 | 7503 | 81.90 | ```python main.py --model dvs_convnet --node-type PLIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-CIFAR10 | IF+Atan | - | resnet18 | 128 | 3149 | 69.10 | ```python main.py --model resnet18 --node-type IFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-CIFAR10 | LIF+Atan | - | resnet18 | 128 | 3149 | 78.50 | ```python main.py --model resnet18 --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-CIFAR10 | PLIF+Atan | - | resnet18 | 128 | 3149 | 77.70 | ```python main.py --model resnet18 --node-type PLIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-CIFAR10 | IF+QGateGrad | - | dvs_convnet | 128 | 7503 | 68.30 | ```python main.py --model dvs_convnet --node-type IFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | DVS-CIFAR10 | LIF+QGateGrad | - | dvs_convnet | 128 | 7503 | 82.60/82.90 | ```python main.py --model dvs_convnet --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | DVS-CIFAR10 | PLIF+QGateGrad | - | dvs_convnet | 128 | 7503 | 83.20 | ```python main.py --model dvs_convnet --node-type PLIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | DVS-CIFAR10 | IF+QGateGrad | - | resnet18 | 128 | 3149 | 65.70/66.80 | ```python main.py --model resnet18 --node-type IFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | DVS-CIFAR10 | LIF+QGateGrad | - | resnet18 | 128 | 3149 | 79.00/79.40 | ```python main.py --model resnet18 --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | DVS-CIFAR10 | PLIF+QGateGrad | - | resnet18 | 128 | 3149 | 78.10/78.20 | ```python main.py --model resnet18 --node-type PLIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
### DVS-Gesture
| ID | Dataset | Node-type | Config | Model | Batch Size | Accuracy | Script |
|:----|:-------:|:--------------:|:------:|:-----------:|:----------:|:-----------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 1 | DVS-G | IF+Atan | - | dvs_convnet | 128 | 64.77 | ```python main.py --num-classes 11 --model dvs_convnet --node-type IFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-G | LIF+Atan | - | dvs_convnet | 128 | 91.28 | ```python main.py --num-classes 11 --model dvs_convnet --node-type LIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-G | PLIF+Atan | - | dvs_convnet | 128 | 91.67 | ```python main.py --num-classes 11 --model dvs_convnet --node-type PLIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-G | IF+Atan | - | resnet18 | 128 | 63.25 | ```python main.py --num-classes 11 --model resnet18 --node-type IFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-G | LIF+Atan | - | resnet18 | 128 | 91.29 | ```python main.py --num-classes 11 --model resnet18 --node-type LIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-G | PLIF+Atan | - | resnet18 | 128 | 90.15 | ```python main.py --num-classes 11 --model resnet18 --node-type PLIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0``` |
| 1 | DVS-G | IF+QGateGrad | - | dvs_convnet | 128 | 48.48 | ```python main.py --num-classes 11 --model dvs_convnet --node-type IFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | DVS-G | LIF+QGateGrad | - | dvs_convnet | 128 | 92.05/92.42 | ```python main.py --num-classes 11 --model dvs_convnet --node-type LIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | DVS-G | PLIF+QGateGrad | - | dvs_convnet | 128 | 91.28 | ```python main.py --num-classes 11 --model dvs_convnet --node-type PLIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | DVS-G | IF+QGateGrad | - | resnet18 | 128 | 57.95 | ```python main.py --num-classes 11 --model resnet18 --node-type IFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | DVS-G | LIF+QGateGrad | - | resnet18 | 128 | 90.91 | ```python main.py --num-classes 11 --model resnet18 --node-type LIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | DVS-G | PLIF+QGateGrad | - | resnet18 | 128 | 92.42 | ```python main.py --num-classes 11 --model resnet18 --node-type PLIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
### NCALTECH101
| ID | Dataset | Node-type | Config | Model | Batch Size | Accuracy | Script |
|:----|:-----------:|:--------------:|:------:|:-----------:|:----------:|:-----------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 1 | NCALTECH101 | IF+QGateGrad | - | dvs_convnet | 128 | 23.09/51.15 | ```python main.py --num-classes 100 --model dvs_convnet --node-type IFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | NCALTECH101 | LIF+QGateGrad | - | dvs_convnet | 128 | 72.78/75.09 | ```python main.py --num-classes 100 --model dvs_convnet --node-type LIFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | NCALTECH101 | PLIF+QGateGrad | - | dvs_convnet | 128 | 74.61/76.79 | ```python main.py --num-classes 100 --model dvs_convnet --node-type PLIFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | NCALTECH101 | IF+QGateGrad | -/mix | resnet18 | 128 | 61.24/60.87 | ```python main.py --num-classes 100 --model resnet18 --node-type IFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | NCALTECH101 | LIF+QGateGrad | -/mix | resnet18 | 128 | 66.22/70.84 | ```python main.py --num-classes 100 --model resnet18 --node-type LIFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
| 1 | NCALTECH101 | PLIF+QGateGrad | -/mix | resnet18 | 128 | 69.62/69.87 | ```python main.py --num-classes 100 --model resnet18 --node-type PLIFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |
### SHD
| ID | Dataset | Node-type | Config | Model | Batch Size | Accuracy | Script |
| :--- | :-----: | :-----------: | :----: | :-----: | :--------: | :------: | :----------------------------------------------------------- |
| 1 | SHD | LIF+QGateGrad | - | shd_snn | 256 | 88.47 | ```python main.py --model SHD_SNN --node-type LIFNode --dataset shd --step 15 --batch-size 256 --act-fun QGateGrad --device 1 --tau 10. --threshold 0.3 --lr 5e-3 --min-lr 1e-4 --loss-fn onehot-mse --num-classes 20 --amp --weight-decay 0.01 ``` |
Note:
1. resnet18 is used here by adding a maximum pooling after the initial convolution layer.
However, in the final version of braincog, we remove this pooling layer.
2. mix refers to the use of EventMix as a data augmentation method.
3. We will continue to add other results.
### Citation
If you find this package helpful, please consider citing it:
```BibTex
@misc{zengbraincogSpikingNeural2022,
title = {{{braincog}}: {{A Spiking Neural Network}} Based {{Brain-inspired Cognitive Intelligence Engine}} for {{Brain-inspired AI}} and {{Brain Simulation}}},
shorttitle = {{{braincog}}},
author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},
year = {2022},
month = jul,
number = {arXiv:2207.08533},
eprint = {2207.08533},
eprinttype = {arxiv},
primaryclass = {cs},
publisher = {{arXiv}},
doi = {10.48550/arXiv.2207.08533}
}
```
================================================
FILE: examples/Perception_and_Learning/img_cls/bp/main.py
================================================
import argparse
import time
import timm.models
import yaml
import os
import random as buildin_random
import logging
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from braincog.base.node.node import *
from braincog.utils import *
from braincog.base.utils.criterions import *
from braincog.datasets.datasets import *
from braincog.model_zoo.resnet import *
from braincog.model_zoo.convnet import *
from braincog.model_zoo.vgg_snn import VGG_SNN, SNN5
from braincog.model_zoo.fc_snn import SHD_SNN
from braincog.model_zoo.resnet19_snn import resnet19
from braincog.model_zoo.sew_resnet import sew_resnet18, sew_resnet34, sew_resnet50
from braincog.utils import save_feature_map, setup_seed
from braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix, plot_mem_distribution
import torch
import torch.nn as nn
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model, register_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
from torch.utils.tensorboard import SummaryWriter
# from ptflops import get_model_complexity_info
# from thop import profile, clever_format
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser = argparse.ArgumentParser(description='SNN Training and Evaluating')
# Model parameters
parser.add_argument('--dataset', default='mnist', type=str)
parser.add_argument('--model', default='mnist_convnet', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',
help='path to eval checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
parser.add_argument('--num-classes', type=int, default=10, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
# Dataset parameters for static datasets
parser.add_argument('--img-size', type=int, default=224, metavar='N',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='inputs image center crop percent (for validation only)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
# Dataloader parameters
parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='inputs batch size for training (default: 128)')
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
help='ratio of validation batch size to training batch size (default: 1)')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=1e-4,
help='weight decay (default: 0.01 for adamw)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=5e-3, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--power', type=int, default=1, help='power')
# Augmentation & regularization parameters ONLY FOR IMAGE NET
parser.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
parser.add_argument('--jsd', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "const")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix', type=float, default=0.,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=0.,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.0)')
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
parser.add_argument('--newton-maxiter', default=20, type=int,
help='max iterration in newton method')
parser.add_argument('--reset-drop', action='store_true', default=False,
help='whether to reset drop')
parser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],
help='The implementation way of gaussian kernel method, choose from "cuda" and "torch"')
# Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
parser.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
parser.add_argument('--dist-bn', type=str, default='',
help='Distribute BatchNorm stats between node after each epoch ("broadcast", "reduce", or "")')
parser.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.99996,
help='decay factor for model weights moving average (default: 0.9998)')
# Misc
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=8, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--save-images', action='store_true', default=False,
help='save images of inputs bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--output', default='/data/floyed/BrainCog', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
parser.add_argument('--tensorboard-dir', default='./runs', type=str)
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--device', type=int, default=0)
# Spike parameters
parser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')
parser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')
parser.add_argument('--temporal-flatten', action='store_true',
help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')
parser.add_argument('--adaptive-node', action='store_true')
parser.add_argument('--critical-loss', action='store_true')
parser.add_argument('--conv-type', type=str, default='normal')
parser.add_argument('--sew-cnf', type=str, default='ADD')
parser.add_argument('--rand-step', action='store_true')
# neuron type
parser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')
parser.add_argument('--act-fun', type=str, default='QGateGrad',
help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')
parser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')
parser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')
parser.add_argument('--requires-thres-grad', action='store_true')
parser.add_argument('--sigmoid-thres', action='store_true')
parser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')
parser.add_argument('--noisy-grad', type=float, default=0.,
help='Add noise to backward, sometime will make higher accuracy (default: 0.)')
parser.add_argument('--spike-output', action='store_true', default=False,
help='Using mem output or spike output (default: False)')
parser.add_argument('--n_groups', type=int, default=1)
parser.add_argument('--n-encode-type', type=str, default='linear')
parser.add_argument('--n-preact', action='store_true')
parser.add_argument('--layer-by-layer', action='store_true',
help='forward step-by-step or layer-by-layer. '
'Larger Model with layer-by-layer will be faster (default: False)')
parser.add_argument('--tet-loss', action='store_true')
# EventData Augmentation
parser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')
parser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')
parser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')
parser.add_argument('--cutmix_beta', type=float, default=2.0, help='cutmix_beta (default: 1.)')
parser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')
parser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')
parser.add_argument('--cutmix_noise', type=float, default=0.,
help='Add Pepper noise after mix, sometimes work (default: 0.)')
parser.add_argument('--gaussian-n', type=int, default=3)
parser.add_argument('--rand-aug', action='store_true',
help='Rand Augment for Event data (default: False)')
parser.add_argument('--randaug_n', type=int, default=3,
help='Rand Augment times n (default: 3)')
parser.add_argument('--randaug_m', type=int, default=15,
help='Rand Augment times n (default: 15) (0-30)')
parser.add_argument('--train-portion', type=float, default=0.9,
help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')
parser.add_argument('--event-size', default=48, type=int,
help='Event size. Resize event data before process (default: 48)')
parser.add_argument('--node-resume', type=str, default='',
help='resume weights in node for adaptive node. (default: False)')
# visualize
parser.add_argument('--visualize', action='store_true',
help='Visualize spiking map for each layer, only for validate (default: False)')
parser.add_argument('--spike-rate', action='store_true',
help='Print spiking rate for each layer, only for validate(default: False)')
parser.add_argument('--tsne', action='store_true')
parser.add_argument('--conf-mat', action='store_true')
parser.add_argument('--mem-dist', action='store_true')
parser.add_argument('--adaptation-info', action='store_true')
parser.add_argument('--suffix', type=str, default='',
help='Add an additional suffix to the save path (default: \'\')')
try:
from apex import amp
from apex.parallel import DistributedDataParallel as ApexDDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
if args_config.config:
with open(args_config.config, 'r') as f:
cfg = yaml.safe_load(f)
parser.set_defaults(**cfg)
# The main arg parser parses the rest of the args, the usual
# defaults will have been overridden if config file specified.
args = parser.parse_args(remaining)
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
@register_model
def resnet50d_pretrained(*args, **kwargs):
model = create_model('resnet50d', pretrained=True)
model.fc = nn.Linear(2048, 10)
# model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
return model
def main():
args, args_text = _parse_args()
# args.no_spike_output = args.no_spike_output | args.cut_mix
args.no_spike_output = True
output_dir = ''
if args.local_rank == 0:
output_base = args.output if args.output else './output'
exp_name = '-'.join([
args.model,
args.dataset,
args.node_type,
str(args.step),
args.suffix,
datetime.now().strftime("%Y%m%d-%H%M%S"),
# str(args.img_size)
])
output_dir = get_outdir(output_base, 'train', exp_name)
args.output_dir = output_dir
setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))
summary_writer = SummaryWriter(log_dir=os.path.join(args.tensorboard_dir, exp_name))
args.tensorboard_prefix = os.path.join(args.dataset, args.model)
else:
summary_writer = None
setup_default_logging()
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1:
_logger.warning(
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
args.num_gpu = 1
# args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
if args.distributed:
args.num_gpu = 1
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
else:
torch.cuda.set_device('cuda:%d' % args.device)
assert args.rank >= 0
if args.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
_logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
# torch.manual_seed(args.seed + args.rank)
setup_seed(args.seed + args.rank)
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
adaptive_node=args.adaptive_node,
dataset=args.dataset,
step=args.step,
encode_type=args.encode,
node_type=eval(args.node_type),
threshold=args.threshold,
tau=args.tau,
sigmoid_thres=args.sigmoid_thres,
requires_thres_grad=args.requires_thres_grad,
spike_output=not args.no_spike_output,
act_fun=args.act_fun,
temporal_flatten=args.temporal_flatten,
layer_by_layer=args.layer_by_layer,
n_groups=args.n_groups,
n_encode_type=args.n_encode_type,
n_preact=args.n_preact,
tet_loss=args.tet_loss,
sew_cnf=args.sew_cnf,
conv_type=args.conv_type,
)
_logger.info('[MODEL ARCH]\n{}'.format(model))
if 'dvs' in args.dataset:
args.channels = 2
elif 'mnist' in args.dataset:
args.channels = 1
else:
args.channels = 3
# flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)
# _logger.info('flops = %fM', flops / 1e6)
# _logger.info('param size = %fM', params / 1e6)
linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0
args.lr = linear_scaled_lr
_logger.info("learning rate is %f" % linear_scaled_lr)
if args.local_rank == 0:
_logger.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
num_aug_splits = 0
if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
use_amp = None
if args.amp:
# for backwards compat, `--amp` arg tries apex before native amp
if has_apex:
args.apex_amp = True
elif has_native_amp:
args.native_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
if args.num_gpu > 1:
if use_amp == 'apex':
_logger.warning(
'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
use_amp = None
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
assert not args.channels_last, "Channels last not supported with DP, use DDP."
else:
model = model.cuda()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
optimizer = create_optimizer(args, model)
_logger.info('[OPTIMIZER]\n{}'.format(optimizer))
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
loss_scaler = ApexScaler()
if args.local_rank == 0:
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
amp_autocast = torch.cuda.amp.autocast
loss_scaler = NativeScaler()
if args.local_rank == 0:
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if args.local_rank == 0:
_logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint
resume_epoch = None
if args.resume and args.eval_checkpoint == '':
args.eval_checkpoint = args.resume
if args.resume:
args.eval = True
# checkpoint = torch.load(args.resume, map_location='cpu')
# model.load_state_dict(checkpoint['state_dict'], False)
resume_epoch = resume_checkpoint(
model, args.resume,
optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0)
# print(model.get_attr('mu'))
# print(model.get_attr('sigma'))
if hasattr(model, 'set_threshold'):
model.set_threshold(args.threshold)
if args.critical_loss or args.spike_rate:
model.set_requires_fp(True)
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume=args.resume)
if args.node_resume:
ckpt = torch.load(args.node_resume, map_location='cpu')
model.load_node_weight(ckpt, args.node_trainable)
model_without_ddp = model
if args.distributed:
if args.sync_bn:
assert not args.split_bn
try:
if has_apex and use_amp != 'native':
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
except Exception as e:
_logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
if has_apex and use_amp != 'native':
# Apex DDP preferred unless native amp is activated
if args.local_rank == 0:
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model.cuda(), device_ids=[args.local_rank],
find_unused_parameters=True) # can use device str in Torch >= 1.1
model_without_ddp = model.module
# NOTE: EMA model does not need to be wrapped by DDP
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
_logger.info('Scheduled epochs: {}'.format(num_epochs))
# now config only for imnet
data_config = resolve_data_config(vars(args), model=model, verbose=False)
loader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.dataset)(
batch_size=args.batch_size,
step=args.step,
args=args,
_logge=_logger,
data_config=data_config,
num_aug_splits=num_aug_splits,
size=args.event_size,
mix_up=args.mix_up,
cut_mix=args.cut_mix,
event_mix=args.event_mix,
beta=args.cutmix_beta,
prob=args.cutmix_prob,
gaussian_n=args.gaussian_n,
num=args.cutmix_num,
noise=args.cutmix_noise,
num_classes=args.num_classes,
rand_aug=args.rand_aug,
randaug_n=args.randaug_n,
randaug_m=args.randaug_m,
portion=args.train_portion,
_logger=_logger,
)
# _logger.info('train_loader:\n{}\nval_loader:\n{}'.format(loader_train, loader_eval))
if args.loss_fn == 'mse':
train_loss_fn = UnilateralMse(1.)
validate_loss_fn = UnilateralMse(1.)
elif args.loss_fn == 'onehot-mse':
train_loss_fn = OnehotMse(args.num_classes)
validate_loss_fn = OnehotMse(args.num_classes)
else:
if args.jsd:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
elif mixup_active:
# smoothing is handled with mixup target transform
train_loss_fn = SoftTargetCrossEntropy().cuda()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
if args.loss_fn == 'mix':
train_loss_fn = MixLoss(train_loss_fn)
validate_loss_fn = MixLoss(validate_loss_fn)
if args.tet_loss:
train_loss_fn = TetLoss(train_loss_fn)
validate_loss_fn = TetLoss(validate_loss_fn)
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
if args.eval: # evaluate the model
# if args.distributed:
# raise NotImplementedError('eval not has not been verified for distributed')
# else:
# load_checkpoint(model, args.eval_checkpoint, args.model_ema)
model.eval()
for t in range(1, args.step * 3):
# for t in range(args.step, args.step + 1):
model.set_attr('step', t)
val_metrics = validate(start_epoch, model, loader_eval, validate_loss_fn, args,
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat, summary_writer=summary_writer)
print(f"[STEP:{t}], Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%")
return
saver = None
if args.local_rank == 0:
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=3)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
try: # train the model
if args.reset_drop:
model_without_ddp.reset_drop_path(0.0)
for epoch in range(start_epoch, args.epochs):
if epoch == 0 and args.reset_drop:
model_without_ddp.reset_drop_path(args.drop_path)
if args.distributed:
loader_train.sampler.set_epoch(epoch)
train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler,
model_ema=model_ema, mixup_fn=mixup_fn, summary_writer=summary_writer
)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
_logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(epoch, model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast,
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat, summary_writer=summary_writer)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
epoch, model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)',
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat, summary_writer=summary_writer
)
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None)
# if saver is not None and epoch >= args.n_warm_up:
if saver is not None:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
except KeyboardInterrupt:
pass
if best_metric is not None:
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def train_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None, summary_writer=None):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
# closses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.train()
# t, k = adjust_surrogate_coeff(100, args.epochs)
# model.set_attr('t', t)
# model.set_attr('k', k)
end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
iters_per_epoch = len(loader)
for batch_idx, (inputs, target) in enumerate(loader):
last_batch = batch_idx == last_idx
if args.rand_step:
step = buildin_random.randint(1, args.step + 2)
model.set_attr('step', step)
data_time_m.update(time.time() - end)
if not args.prefetcher or args.dataset != 'imnet':
inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()
if mixup_fn is not None:
inputs, target = mixup_fn(inputs, target)
if args.channels_last:
inputs = inputs.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(inputs)
loss = loss_fn(output, target)
if args.tet_loss:
output = output.mean(0)
if not (args.cut_mix | args.mix_up | args.event_mix | (args.cutmix != 0.) | (args.mixup != 0.)):
# print(output.shape, target.shape)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
# acc1, = accuracy(output, target)
else:
acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)
else:
loss.backward(create_graph=second_order)
if args.noisy_grad != 0.:
random_gradient(model, args.noisy_grad)
if args.clip_grad is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
# if args.opt == 'lamb':
# optimizer.step(epoch=epoch)
# else:
optimizer.step()
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
num_updates += 1
batch_time_m.update(time.time() - end)
if args.local_rank == 0:
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/train/top1'), acc1.item(), epoch * iters_per_epoch + batch_idx)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/train/top5'), acc5.item(), epoch * iters_per_epoch + batch_idx)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/train/loss'), loss.item(), epoch * iters_per_epoch + batch_idx)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
if args.distributed:
loss = reduce_tensor(loss.data, args.world_size)
acc1 = reduce_tensor(acc1, args.world_size)
acc5 = reduce_tensor(acc5, args.world_size)
losses_m.update(loss.item(), inputs.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
# closses_m.update(reduced_loss.item(), inputs.size(0))
if args.local_rank == 0:
# if args.distributed:
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
top1=top1_m,
top5=top5_m,
batch_time=batch_time_m,
rate=inputs.size(0) * args.world_size / batch_time_m.val,
rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m
))
if args.save_images and output_dir:
torchvision.utils.save_image(
inputs,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True)
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
end = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
if args.local_rank == 0:
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/train/top1'), top1_m.avg, epoch)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/train/top5'), top5_m.avg, epoch)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/train/loss'), losses_m.avg, epoch)
if args.rand_step:
model.set_attr('step', args.step)
return OrderedDict([('loss', losses_m.avg)])
def validate(epoch, model, loader, loss_fn, args, amp_autocast=suppress,
log_suffix='', visualize=False, spike_rate=False, tsne=False, conf_mat=False, summary_writer=None):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
# closses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
spike_m = AverageMeter()
model.eval()
feature_vec = []
feature_cls = []
logits_vec = []
labels_vec = []
mem_vec = []
end = time.time()
last_idx = len(loader) - 1
iters_per_epoch = len(loader)
with torch.no_grad():
for batch_idx, (inputs, target) in enumerate(loader):
# inputs = inputs.type(torch.float64)
last_batch = batch_idx == last_idx
if not args.prefetcher or args.dataset != 'imnet':
inputs = inputs.type(torch.FloatTensor).cuda()
target = target.cuda()
if args.channels_last:
inputs = inputs.contiguous(memory_format=torch.channels_last)
if not args.distributed:
if (visualize or spike_rate or tsne or conf_mat or args.mem_dist) and not args.critical_loss:
model.set_requires_fp(True)
with amp_autocast():
output = model(inputs)
if isinstance(output, (tuple, list)):
output = output[0]
# augmentation reduction
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor]
# print(args.rank, output.shape, target.shape, max(target))
loss = loss_fn(output, target)
if args.tet_loss:
output = output.mean(0)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
acc1 = reduce_tensor(acc1, args.world_size)
acc5 = reduce_tensor(acc5, args.world_size)
else:
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), inputs.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
# closses_m.update(closs, inputs.size(0))
batch_time_m.update(time.time() - end)
end = time.time()
if args.local_rank == 0:
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/val/top1'), acc1.item(), epoch * iters_per_epoch + batch_idx)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/val/top5'), acc5.item(), epoch * iters_per_epoch + batch_idx)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/val/loss'), loss.item(), epoch * iters_per_epoch + batch_idx)
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
if not args.distributed and spike_rate:
spike_m.update(model.get_tot_spike() / output.size(0), output.size(0))
if not args.distributed and spike_rate:
_logger.info(
'[Spike Info]: {spike.val} ({spike.avg})'.format(
spike=spike_m
)
)
if last_batch or batch_idx % args.log_interval == 0:
_logger.info(
'Eval : {} '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
epoch,
batch_idx,
last_idx,
batch_time=batch_time_m,
loss=losses_m,
top1=top1_m,
top5=top5_m,
))
# metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg)])
if args.local_rank == 0:
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/val/top1'), top1_m.avg, epoch)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/val/top5'), top5_m.avg, epoch)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/val/loss'), losses_m.avg, epoch)
return metrics
if __name__ == '__main__':
main()
================================================
FILE: examples/Perception_and_Learning/img_cls/bp/main_backei.py
================================================
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import time
from braincog.model_zoo.backeinet import *
import argparse
import os
import json
parser = argparse.ArgumentParser("description = train.py")
parser.add_argument('-seed', type=int, default=4150)
parser.add_argument('-epoch', type=int, default=200)
parser.add_argument('-batch_size', type=int, default=100)
parser.add_argument('-learning_rate', type=float, default=1e-3)
parser.add_argument('--dataset', type=str, default='fashion')
parser.add_argument('--simulation_len', type=int, default=20)
parser.add_argument('--Back', action='store_true', default=False)
parser.add_argument('--EI', action='store_true', default=False)
parser.add_argument('--device', type=int, default=1)
parser.add_argument('--encode-type', type=str, default='direct')
opt = parser.parse_args()
torch.cuda.set_device('cuda:%d' % opt.device)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)
test_scores = []
train_scores = []
save_path = opt.dataset + '_' + str(opt.seed) + '_' + opt.encode_type
if opt.Back:
save_path += '_Back'
if opt.EI:
save_path += '_EI'
if not os.path.exists(save_path):
os.mkdir(save_path)
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
if opt.dataset == 'mnist':
train_dataset = datasets.MNIST(root='./data/mnist/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data/mnist/', train=False, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=opt.batch_size, shuffle=False)
elif opt.dataset == 'fashion':
train_dataset = datasets.FashionMNIST(root='./data/fashion/', train=True, transform=transforms.ToTensor(),
download=True)
test_dataset = datasets.FashionMNIST(root='./data/fashion/', train=False, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=opt.batch_size, shuffle=False)
elif opt.dataset == 'cifar10':
train_dataset = datasets.CIFAR10(root='./data/cifar10/', train=True, transform=transforms.Compose(
[transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), normalize]),
download=True)
test_dataset = datasets.CIFAR10(root='./data/cifar10/', train=False,
transform=transforms.Compose([transforms.ToTensor(), normalize]))
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=opt.batch_size, shuffle=False)
if opt.dataset == 'cifar10':
snn = CIFARNet(step=opt.simulation_len, if_back=opt.Back, if_ei=opt.EI, encode_type=opt.encode_type)
else:
snn = MNISTNet(step=opt.simulation_len, if_back=opt.Back, if_ei=opt.EI, data=opt.dataset, encode_type=opt.encode_type)
snn = snn.cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(snn.parameters(), lr=opt.learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
def train(epoch):
snn.train()
start_time = time.time()
total_loss = 0
correct = 0
total = 0
for i, (images, labels) in enumerate(train_loader):
optimizer.zero_grad()
images = images.cuda()
outputs = snn(images)
labels_ = torch.zeros(opt.batch_size, 10).scatter_(1, labels.view(-1, 1), 1).cuda()
loss = criterion(outputs, labels_)
total_loss += loss.item()
loss.backward()
optimizer.step()
pred = outputs.max(1)[1]
total += labels.size(0)
correct += (pred.cpu() == labels).sum()
if (i + 1) % (60000 // (opt.batch_size * 6)) == 0:
print('Epoch: [%d/%d], Step: [%d/%d], Loss: %.4f, Time: %.2f' % (
epoch + 1, opt.epoch, i + 1, 60000 // opt.batch_size, total_loss,
time.time() - start_time))
start_time = time.time()
total_loss = 0
acc = 100.0 * correct.item() / total
train_scores.append(acc)
def eval(epoch):
snn.eval()
correct = 0
total = 0
with torch.no_grad():
for i, (images, labels) in enumerate(test_loader):
images = images.cuda()
outputs = snn(images)
pred = outputs.max(1)[1]
total += labels.size(0)
correct += (pred.cpu() == labels).sum()
acc = 100.0 * correct.item() / total
print('Test correct: %d Accuracy: %.2f%%' % (correct, acc))
test_scores.append(acc)
if acc >= max(test_scores):
save_file = str(epoch) + '.pt'
torch.save(snn, os.path.join(save_path, save_file))
return max(test_scores)
def main():
for epoch in range(opt.epoch):
train(epoch)
best_acc = eval(epoch)
scheduler.step()
print('Best Accuracy: %.2f%%' % (best_acc))
if __name__ == '__main__':
main()
filename = "train.json"
filename = os.path.join(save_path, filename)
with open(filename, "w") as f:
json.dump(train_scores, f)
filename = "test.json"
filename = os.path.join(save_path, filename)
with open(filename, "w") as f:
json.dump(test_scores, f)
================================================
FILE: examples/Perception_and_Learning/img_cls/bp/main_simplified.py
================================================
# encoding: utf-8
# Author : Floyed
# Datetime : 2022/4/28 14:56
# User : Floyed
# Product : PyCharm
# Project : braincog
# File : main_simplified.py
# explain : Simplified training script. Remove support for DDP, IMAGENET, Augment, etc.
import argparse
import time
import timm.models
import yaml
import os
import logging
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from braincog.base.node.node import *
from braincog.utils import *
from braincog.base.utils.criterions import *
from braincog.datasets.datasets import *
from braincog.model_zoo.resnet import *
from braincog.model_zoo.convnet import *
from braincog.utils import save_feature_map
import torch
import torch.nn as nn
import torchvision.utils
from timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
# from ptflops import get_model_complexity_info
from thop import profile, clever_format
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser = argparse.ArgumentParser(description='SNN Training and Evaluating')
# Model parameters
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--model', default='cifar_convnet', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--num-classes', type=int, default=10, metavar='N',
help='number of label classes (default: 10)')
# Dataloader parameters
parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='inputs batch size for training (default: 128)')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.01,
help='weight decay (default: 0.01 for adamw)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=5e-3, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--epochs', type=int, default=600, metavar='N',
help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--power', type=int, default=1, help='power')
# Misc
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=8, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--output', default='/data/floyed/braincog', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
# Spike parameters
parser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')
parser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')
# neuron type
parser.add_argument('--node-type', type=str, default='PLIFNode', help='Node type in network (default: PLIF)')
parser.add_argument('--act-fun', type=str, default='AtanGrad',
help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')
parser.add_argument('--thresh', type=float, default=.5, help='Firing threshold (default: 0.5)')
parser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')
parser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')
parser.add_argument('--noisy-grad', type=float, default=0.,
help='Add noise to backward, sometime will make higher accuracy (default: 0.)')
parser.add_argument('--n_warm_up', type=int, default=0,
help='Warm up epoch, replace all node to ReLU to warm up weights in network before (default: 0)')
parser.add_argument('--spike-output', action='store_true', default=False,
help='Using mem output or spike output (default: False)')
# EventData Augmentation
parser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')
parser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')
parser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')
parser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')
parser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')
parser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')
parser.add_argument('--cutmix_noise', type=float, default=0.,
help='Add Pepper noise after mix, sometimes work (default: 0.)')
parser.add_argument('--rand-aug', action='store_true',
help='Rand Augment for Event data (default: False)')
parser.add_argument('--randaug_n', type=int, default=3,
help='Rand Augment times n (default: 3)')
parser.add_argument('--randaug_m', type=int, default=15,
help='Rand Augment times n (default: 15) (0-30)')
parser.add_argument('--temporal-flatten', action='store_true',
help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')
parser.add_argument('--train-portion', type=float, default=0.9,
help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')
parser.add_argument('--event-size', default=48, type=int,
help='Event size. Resize event data before process (default: 48)')
parser.add_argument('--layer-by-layer', action='store_true',
help='forward step-by-step or layer-by-layer. '
'Larger Model with layer-by-layer will be faster (default: False)')
parser.add_argument('--node-resume', type=str, default='',
help='resume weights in node for adaptive node. (default: False)')
parser.add_argument('--node-trainable', action='store_true')
# visualize
parser.add_argument('--visualize', action='store_true',
help='Visualize spiking map for each layer, only for validate (default: False)')
parser.add_argument('--spike-rate', action='store_true',
help='Print spiking rate for each layer, only for validate(default: False)')
parser.add_argument('--suffix', type=str, default='',
help='Add an additional suffix to the save path (default: \'\')')
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
if args_config.config:
with open(args_config.config, 'r') as f:
cfg = yaml.safe_load(f)
parser.set_defaults(**cfg)
# The main arg parser parses the rest of the args, the usual
# defaults will have been overridden if config file specified.
args = parser.parse_args(remaining)
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
def main():
args, args_text = _parse_args()
# args.no_spike_output = args.no_spike_output | args.cut_mix
args.no_spike_output = True
output_dir = ''
output_base = args.output if args.output else './output'
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
args.model,
args.dataset,
str(args.step),
args.suffix
# str(args.img_size)
])
output_dir = get_outdir(output_base, 'train', exp_name)
args.output_dir = output_dir
setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))
torch.cuda.set_device('cuda:%d' % args.device)
torch.manual_seed(args.seed)
model = create_model(
args.model,
num_classes=args.num_classes,
dataset=args.dataset,
step=args.step,
encode_type=args.encode,
node_type=eval(args.node_type),
threshold=args.thresh,
tau=args.tau,
spike_output=not args.no_spike_output,
act_fun=args.act_fun,
temporal_flatten=args.temporal_flatten,
layer_by_layer=args.layer_by_layer,
)
if 'dvs' in args.dataset:
args.channels = 2
elif 'mnist' in args.dataset:
args.channels = 1
else:
args.channels = 3
# flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.img_size, args.img_size),), verbose=False)
# _logger.info('flops = %fM', flops / 1e6)
# _logger.info('param size = %fM', params / 1e6)
linear_scaled_lr = args.lr * args.batch_size / 1024.0
args.lr = linear_scaled_lr
model = model.cuda()
optimizer = create_optimizer(args, model)
# optionally resume from a checkpoint
resume_epoch = None
if args.resume:
# checkpoint = torch.load(args.resume, map_location='cpu')
# model.load_state_dict(checkpoint['state_dict'], False)
resume_epoch = resume_checkpoint(
model, args.resume,
optimizer=None if args.no_resume_opt else optimizer)
if args.node_resume:
ckpt = torch.load(args.node_resume, map_location='cpu')
model.load_node_weight(ckpt, args.node_trainable)
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch)
_logger.info('Scheduled epochs: {}'.format(num_epochs))
# now config only for imnet fcvawefdadw
loader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.dataset)(
batch_size=args.batch_size,
step=args.step,
size=args.event_size,
mix_up=args.mix_up,
cut_mix=args.cut_mix,
event_mix=args.event_mix,
beta=args.cutmix_beta,
prob=args.cutmix_prob,
num=args.cutmix_num,
noise=args.cutmix_noise,
num_classes=args.num_classes,
rand_aug=args.rand_aug,
randaug_n=args.randaug_n,
randaug_m=args.randaug_m,
temporal_flatten=args.temporal_flatten,
portion=args.train_portion)
if args.loss_fn == 'mse':
train_loss_fn = UnilateralMse(1.)
validate_loss_fn = UnilateralMse(1.)
else:
if mixup_active:
# smoothing is handled with mixup target transform
train_loss_fn = SoftTargetCrossEntropy().cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
if args.loss_fn == 'mix':
train_loss_fn = MixLoss(train_loss_fn)
validate_loss_fn = MixLoss(validate_loss_fn)
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
saver = CheckpointSaver(
model=model, optimizer=optimizer, args=args,
checkpoint_dir=output_dir, recovery_dir=output_dir)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
try: # train the model
for epoch in range(start_epoch, args.epochs):
if args.visualize or args.spike_rate:
print('start to plot feature map / calc spike rate')
validate(model, loader_eval, validate_loss_fn, args,
visualize=args.visualize, spike_rate=args.spike_rate)
exit(0)
train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir)
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None)
if saver is not None and epoch >= args.n_warm_up:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
except KeyboardInterrupt:
pass
if best_metric is not None:
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def train_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir=''):
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.train()
end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
for batch_idx, (inputs, target) in enumerate(loader):
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()
output = model(inputs)
loss = loss_fn(output, target)
if not (args.cut_mix | args.mix_up | args.event_mix):
acc1, acc5 = accuracy(output, target, topk=(1, 5))
else:
acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])
losses_m.update(loss.item(), inputs.size(0))
top1_m.update(acc1.item(), inputs.size(0))
top5_m.update(acc5.item(), inputs.size(0))
optimizer.zero_grad()
loss.backward()
if args.noisy_grad != 0.:
random_gradient(model, args.noisy_grad)
if args.clip_grad is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
optimizer.step()
num_updates += 1
batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
top1=top1_m, top5=top5_m,
batch_time=batch_time_m,
rate=inputs.size(0) / batch_time_m.val,
rate_avg=inputs.size(0) / batch_time_m.avg,
lr=lr,
data_time=data_time_m))
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
end = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.avg)])
def validate(model, loader, loss_fn, args, log_suffix='', visualize=False, spike_rate=False):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.eval()
end = time.time()
last_idx = len(loader) - 1
with torch.no_grad():
for batch_idx, (inputs, target) in enumerate(loader):
# inputs = inputs.type(torch.float64)
last_batch = batch_idx == last_idx
inputs = inputs.type(torch.FloatTensor).cuda()
target = target.cuda()
if visualize or spike_rate:
model.set_requires_fp(True)
output = model(inputs)
if isinstance(output, (tuple, list)):
output = output[0]
if visualize:
x = model.get_fp()
feature_path = os.path.join(args.output_dir, 'feature_map')
if os.path.exists(feature_path) is False:
os.mkdir(feature_path)
save_feature_map(x, feature_path)
model.set_requires_fp(False)
if spike_rate:
_logger.info(model.get_fire_rate_per_layer())
model.set_requires_fp(False)
loss = loss_fn(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), inputs.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
batch_time_m.update(time.time() - end)
end = time.time()
if last_batch or batch_idx % args.log_interval == 0:
log_name = 'Test' + log_suffix
_logger.info(
'{0}: [{1:>4d}/{2}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
log_name, batch_idx, last_idx, batch_time=batch_time_m,
loss=losses_m, top1=top1_m, top5=top5_m))
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
return metrics
if __name__ == '__main__':
main()
================================================
FILE: examples/Perception_and_Learning/img_cls/glsnn/README.md
================================================
# SNN with global feedback connections
Training deep spiking neural network with the global
feedback connections and the local optimization learning rules. And is a little different from our original paper.
GLSNN: A Multi-layer Spiking Neural Network based on Global Feedback Alignment and Local STDP Plasticity.
## Results
```shell
python cls_glsnn.py
```
We train the model for 100 epochs, and the best accuracy for MNIST is 98.23\%, for FashionMNIST is 89.68\%.

## Citation
If you find the code and dataset useful in your research, please consider citing:
```
@article{zhao2020glsnn,
title={GLSNN: A Multi-Layer Spiking Neural Network Based on Global Feedback Alignment and Local STDP Plasticity},
author={Zhao, Dongcheng and Zeng, Yi and Zhang, Tielin and Shi, Mengting and Zhao, Feifei},
journal={Frontiers in Computational Neuroscience},
volume={14},
year={2020},
publisher={Frontiers Media SA}
}
@misc{https://doi.org/10.48550/arxiv.2207.08533,
doi = {10.48550/ARXIV.2207.08533},
url = {https://arxiv.org/abs/2207.08533},
author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},
title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},
publisher = {arXiv},
year = {2022},
}
```
## Contents
Feedbacks and comments are welcome! Feel free to contact us via [zhaodongcheng2016@ia.ac.cn](zhaodongcheng2016@ia.ac.cn)
Enjoy!
================================================
FILE: examples/Perception_and_Learning/img_cls/glsnn/cls_glsnn.py
================================================
import torch
from torchvision import transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from braincog.model_zoo.glsnn import BaseGLSNN
import argparse
import time
import os
import json
os.environ['CUDA_VISIBLE_DEVICES'] = "3"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser("description = GLSNN.py")
parser.add_argument('-seed', type=int, default=2122)
parser.add_argument('-epoch', type=int, default=100)
parser.add_argument('-batch_size', type=int, default=100)
parser.add_argument('-lr_target', type=float, default=0.4)
parser.add_argument('-lr_forward', type=float, default=0.001)
parser.add_argument('-step', type=int, default=10)
parser.add_argument('-encode_type', type=str, default='direct')
parser.add_argument('--dataset', type=str, default='MNIST')
opt = parser.parse_args()
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)
test_scores = []
train_scores = []
save_path = './' + 'GLSNN' + '_' + opt.dataset + '_' + str(opt.seed)
if not os.path.exists(save_path):
os.mkdir(save_path)
if opt.dataset == 'MNIST':
train_dataset = datasets.MNIST(root='./data/datasets/mnist/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data/datasets/mnist/', train=False, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=opt.batch_size, shuffle=False)
elif opt.dataset == 'Fashion-MNIST':
train_dataset = datasets.FashionMNIST(root='./data/fashion/', train=True, transform=transforms.ToTensor(),
download=True)
test_dataset = datasets.FashionMNIST(root='./data/fashion/', train=False, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=opt.batch_size, shuffle=False)
snn = BaseGLSNN(input_size=784, hidden_sizes=[800] * 3, output_size=10, opt=opt)
snn.to(device)
optimizer = torch.optim.Adam(snn.forward_parameters(), lr=opt.lr_forward)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
def train(epoch):
snn.train()
start_time = time.time()
total_loss = 0
correct = 0
total = 0
for i, (images, labels) in enumerate(train_loader):
optimizer.zero_grad()
images = images.to(device)
labels_ = torch.zeros(opt.batch_size, 10).scatter_(1, labels.view(-1, 1), 1).to(device)
labels = labels.to(device)
outputs, loss = snn.set_gradient(images, labels_)
optimizer.step()
total_loss += loss.item()
pred = outputs[-1].max(1)[1]
total += labels.size(0)
correct += (pred.cpu() == labels.cpu()).sum()
if (i + 1) % (60000 // (opt.batch_size * 6)) == 0:
print('Epoch: [%d/%d], Step: [%d/%d], Loss: %.4f, Time: %.2f' % (
epoch + 1, opt.epoch, i + 1, 60000 // opt.batch_size, total_loss,
time.time() - start_time))
start_time = time.time()
total_loss = 0
acc = 100.0 * correct.item() / total
train_scores.append(acc)
def eval(epoch):
snn.eval()
correct = 0
total = 0
with torch.no_grad():
for i, (images, labels) in enumerate(test_loader):
images = images.to(device)
outputs = snn(images)
pred = outputs[-1].max(1)[1]
total += labels.size(0)
correct += (pred.cpu() == labels).sum()
acc = 100.0 * correct.item() / total
print('Test correct: %d Accuracy: %.2f%%' % (correct, acc))
test_scores.append(acc)
if acc >= max(test_scores):
save_file = str(epoch) + '.pt'
torch.save(snn, os.path.join(save_path, save_file))
return max(test_scores)
def main():
for epoch in range(opt.epoch):
train(epoch)
best_acc = eval(epoch)
scheduler.step()
print('Best Accuracy: %.2f%%' % (best_acc))
if __name__ == '__main__':
main()
filename = "train.json"
filename = os.path.join(save_path, filename)
with open(filename, "w") as f:
json.dump(train_scores, f)
filename = "test.json"
filename = os.path.join(save_path, filename)
with open(filename, "w") as f:
json.dump(test_scores, f)
================================================
FILE: examples/Perception_and_Learning/img_cls/spiking_capsnet/README.md
================================================
# Spiking capsnet: A spiking neural network with a biologically plausible routing rule between capsules
## Run
```shell
python main.py
```
## Citation
If you find the code and dataset useful in your research, please consider citing:
```
@article{zhao2022spiking,
title={Spiking capsnet: A spiking neural network with a biologically plausible routing rule between capsules},
author={Zhao, Dongcheng and Li, Yang and Zeng, Yi and Wang, Jihang and Zhang, Qian},
journal={Information Sciences},
volume={610},
pages={1--13},
year={2022},
publisher={Elsevier}
}
@misc{https://doi.org/10.48550/arxiv.2207.08533,
doi = {10.48550/ARXIV.2207.08533},
url = {https://arxiv.org/abs/2207.08533},
author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},
title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},
publisher = {arXiv},
year = {2022},
}
```
## Contents
Feedbacks and comments are welcome! Feel free to contact us via [zhaodongcheng2016@ia.ac.cn](zhaodongcheng2016@ia.ac.cn)
Enjoy!
================================================
FILE: examples/Perception_and_Learning/img_cls/spiking_capsnet/spikingcaps.py
================================================
import sys
sys.path.append('../../../../')
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import Adam
import os
import math
from tqdm import tqdm
import numpy as np
from braincog.datasets.datasets import get_mnist_data
from braincog.base.node import LIFNode
from braincog.utils import setup_seed
setup_seed(1111)
os.environ['CUDA_VISIBLE_DEVICES'] = "4"
class myLIFnode(LIFNode):
def __init__(self, threshold=0.5, tau=2., *args, **kwargs):
super().__init__(threshold, tau, *args, **kwargs)
def integral(self, inputs):
# self.mem = self.mem + (inputs - self.mem) / self.tau
self.mem = self.mem / self.tau + inputs
class ConvLayer(nn.Module):
def __init__(self, in_channels=1, out_channels=256, kernel_size=9):
super(ConvLayer, self).__init__()
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1)
def forward(self, x):
return F.relu(self.conv(x))
class PrimaryCaps(nn.Module):
def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9):
super(PrimaryCaps, self).__init__()
self.capsules = nn.ModuleList([
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0)
for _ in range(num_capsules)])
def forward(self, x):
u = [capsule(x) for capsule in self.capsules]
u = torch.stack(u, dim=1)
u.permute(0,2,3,4,1)
u = u.view(x.size(0), 32 * 6 * 6, -1)
return u
class DigitCaps(nn.Module):
def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16):
super(DigitCaps, self).__init__()
self.in_channels = in_channels
self.num_routes = num_routes
self.num_capsules = num_capsules
self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))
self.bias = nn.Parameter(torch.randn(out_channels, 1))
self.W.data.normal_(0, math.sqrt(3.0 / (in_channels * out_channels)))
self.bias.data.normal_(0, math.sqrt(3.0 / (in_channels * out_channels)))
def forward(self, x):
batch_size = x.size(0)
x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)
W = torch.cat([self.W] * batch_size, dim=0)
u_hat = torch.matmul(W, x) + self.bias
return u_hat
class DigitCaps2(nn.Module):
def __init__(self, num_capsules=10, num_routes=32 * 6 * 6):
super(DigitCaps2, self).__init__()
self.num_routes = num_routes
self.num_capsules = num_capsules
self.b_ij = Variable(torch.ones(1, self.num_routes, self.num_capsules, 1)/1152)
self.b_ij = self.b_ij.to(device)
def forward(self, u_hat):
c_ij = torch.cat([self.b_ij] * batch_size, dim=0).unsqueeze(4)
s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
return s_j.squeeze(1)
def init_bij(self):
self.b_ij = Variable(torch.ones(1, self.num_routes, self.num_capsules, 1)/1152)
self.b_ij = self.b_ij.to(device)
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.linear = nn.Linear(16, 1)
def forward(self, x):
classes = torch.sqrt((x ** 2).sum(2))
# classes = self.linear(x)
return classes
class CapsNet(nn.Module):
def __init__(self):
super(CapsNet, self).__init__()
self.conv_layer = ConvLayer()
self.primary_capsules = PrimaryCaps()
self.digit_capsules = DigitCaps()
self.digit_capsules2 = DigitCaps2()
self.decoder = Decoder()
self.conv_node = myLIFnode(tau=5)
self.primary_node = myLIFnode(tau=5)
self.digit_node = myLIFnode(tau=5)
self.digit2_node = myLIFnode(tau=5)
def forward(self, data, time_window=5, train=True):
self.init()
out_mem = 0.
self.digit_capsules2.init_bij()
self.trace_u = torch.zeros(batch_size, 1152, 10, 16, 1, device=device)
for step in range(time_window):
x = data
x = self.conv_node(self.conv_layer(x))
x = self.primary_node(self.primary_capsules(x))
x1 = self.digit_node(self.digit_capsules(x))
x = self.digit_capsules2(x1)
out_mem += x.squeeze(3)
y = self.digit2_node(x)
if train:
with torch.no_grad():
self.digit_capsules2.b_ij = torch.clamp(self.digit_capsules2.b_ij, -0.05, 1)
self.trace_u *= torch.exp(-1 / torch.tensor(1.5))
self.trace_u.masked_fill_(x1 != 0, 1)
self.digit_capsules2.b_ij += 0.0008 * torch.matmul(
self.trace_u.transpose(3, 4) - 0.1,
torch.stack([y] * 1152, dim=1)).squeeze(4).mean(dim=0, keepdim=True)
output = out_mem / time_window
output = self.decoder(output)
return output
def init(self):
self.conv_node.n_reset()
self.primary_node.n_reset()
self.digit_node.n_reset()
self.digit2_node.n_reset()
def evaluate(test_iter, net, device):
net.eval()
test_loss, test_acc, n_test = 0, 0.0, 0
for batch_id, (data, target) in tqdm(enumerate(test_iter)):
target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
data, target = Variable(data), Variable(target)
data, target = data.to(device), target.to(device)
classes = net(data)
test_acc += sum(np.argmax(classes.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1))
n_test += data.shape[0]
net.train()
return test_acc / n_test
if __name__ == '__main__':
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
batch_size = 100
train_loader, test_loader, _, _ = get_mnist_data(batch_size)
capsule_net = CapsNet().to(device)
optimizer = Adam(capsule_net.parameters(), lr=0.0005)
loss_fn = nn.MSELoss()
n_epochs = 50
best, losses = 0, []
for epoch in range(n_epochs):
if epoch in [15, 25, 45]:
optimizer.param_groups[0]['lr'] *= 0.3
capsule_net.train()
train_loss, correct, n = 0, 0, 0
loss_rec = []
for batch_id, (data, target) in enumerate(train_loader):
target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
data, target = Variable(data), Variable(target)
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
classes = capsule_net(data)
loss = loss_fn(classes, target)
loss.backward()
loss_rec.append(loss.item())
optimizer.step()
train_loss += loss.item()
correct += sum(np.argmax(classes.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1))
n += data.shape[0]
if batch_id % 100 == 0:
print("Epoch: {}, Batch: {}, train accuracy: {:.6f}, loss: {:.6f}".format(
epoch, batch_id + 1,
sum(np.argmax(classes.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size),
loss.item()))
losses.append(np.mean(np.array(loss_rec)))
print("Epoch: [{}/{}], train accuracy: {:.6f}, loss: {:.6f}".format(
epoch, n_epochs,
correct / float(n),
train_loss / len(train_loader)))
capsule_net.eval()
test_acc = evaluate(test_loader, capsule_net, device=device)
print("test accuracy: {:.6f}".format(test_acc))
if test_acc > best:
best = test_acc
# torch.save(capsule_net, './checkpoints/spikingcaps_mnist.pkl')
================================================
FILE: examples/Perception_and_Learning/img_cls/transfer_for_dvs/GradCAM_visualization.py
================================================
# -*- coding: utf-8 -*-
# Time : 2023/2/14 11:52
# Author : Regulus
# FileName: main_visual_losslandscape.py
# Explain:
# Software: PyCharm
import sys
import tqdm
from loss_landscape.plot_surface import *
from Pytorch_Grad_Cam.cam import *
import argparse
import math
import time
import CKA
import numpy
import timm.models
import random as rd
import yaml
import os
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
import logging
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from braincog.base.node.node import *
from braincog.utils import *
from braincog.base.utils.criterions import *
from braincog.datasets.datasets import *
from braincog.model_zoo.resnet import *
from braincog.model_zoo.convnet import *
from braincog.model_zoo.vgg_snn import VGG_SNN
from braincog.model_zoo.resnet19_snn import resnet19
from braincog.utils import save_feature_map, setup_seed
from braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix
import torch
import torch.nn as nn
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from rgb_hsv import RGB_HSV
import matplotlib.pyplot as plt
from timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
from copy import deepcopy
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser = argparse.ArgumentParser(description='SNN Training and Evaluating')
# Model parameters
parser.add_argument('--source-dataset', default='cifar10', type=str)
parser.add_argument('--target-dataset', default='dvsc10', type=str)
parser.add_argument('--model', default='cifar_convnet', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',
help='path to eval checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
parser.add_argument('--num-classes', type=int, default=10, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
# Dataset parameters for static datasets
parser.add_argument('--img-size', type=int, default=224, metavar='N',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='inputs image center crop percent (for validation only)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
# Dataloader parameters
parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='inputs batch size for training (default: 128)')
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
help='ratio of validation batch size to training batch size (default: 1)')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.01,
help='weight decay (default: 0.01 for adamw)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=5e-3, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--epochs', type=int, default=600, metavar='N',
help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--power', type=int, default=1, help='power')
# Augmentation & regularization parameters ONLY FOR IMAGE NET
parser.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
parser.add_argument('--jsd', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "const")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.8,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix', type=float, default=1.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.0)')
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
parser.add_argument('--newton-maxiter', default=20, type=int,
help='max iterration in newton method')
parser.add_argument('--reset-drop', action='store_true', default=False,
help='whether to reset drop')
parser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],
help='The implementation way of gaussian kernel method, choose from "cuda" and "torch"')
# Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
parser.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
parser.add_argument('--dist-bn', type=str, default='',
help='Distribute BatchNorm stats between node after each epoch ("broadcast", "reduce", or "")')
parser.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.99996,
help='decay factor for model weights moving average (default: 0.9998)')
# Misc
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=8, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--save-images', action='store_true', default=False,
help='save images of inputs bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--output', default='/home/hexiang/TransferLearning_For_DVS/Results_new_refined/', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--device', type=int, default=0)
# Spike parameters
parser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')
parser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')
parser.add_argument('--temporal-flatten', action='store_true',
help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')
parser.add_argument('--adaptive-node', action='store_true')
parser.add_argument('--critical-loss', action='store_true')
# neuron type
parser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')
parser.add_argument('--act-fun', type=str, default='GateGrad',
help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')
parser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')
parser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')
parser.add_argument('--requires-thres-grad', action='store_true')
parser.add_argument('--sigmoid-thres', action='store_true')
parser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')
parser.add_argument('--noisy-grad', type=float, default=0.,
help='Add noise to backward, sometime will make higher accuracy (default: 0.)')
parser.add_argument('--spike-output', action='store_true', default=False,
help='Using mem output or spike output (default: False)')
parser.add_argument('--n_groups', type=int, default=1)
# EventData Augmentation
parser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')
parser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')
parser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')
parser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')
parser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')
parser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')
parser.add_argument('--cutmix_noise', type=float, default=0.,
help='Add Pepper noise after mix, sometimes work (default: 0.)')
parser.add_argument('--gaussian-n', type=int, default=3)
parser.add_argument('--rand-aug', action='store_true',
help='Rand Augment for Event data (default: False)')
parser.add_argument('--randaug_n', type=int, default=3,
help='Rand Augment times n (default: 3)')
parser.add_argument('--randaug_m', type=int, default=15,
help='Rand Augment times n (default: 15) (0-30)')
parser.add_argument('--train-portion', type=float, default=0.9,
help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')
parser.add_argument('--event-size', default=48, type=int,
help='Event size. Resize event data before process (default: 48)')
parser.add_argument('--layer-by-layer', action='store_true',
help='forward step-by-step or layer-by-layer. '
'Larger Model with layer-by-layer will be faster (default: False)')
parser.add_argument('--node-resume', type=str, default='',
help='resume weights in node for adaptive node. (default: False)')
parser.add_argument('--node-trainable', action='store_true')
# visualize
parser.add_argument('--visualize', action='store_true',
help='Visualize spiking map for each layer, only for validate (default: False)')
parser.add_argument('--spike-rate', action='store_true',
help='Print spiking rate for each layer, only for validate(default: False)')
parser.add_argument('--tsne', action='store_true')
parser.add_argument('--conf-mat', action='store_true')
parser.add_argument('--suffix', type=str, default='',
help='Add an additional suffix to the save path (default: \'\')')
parser.add_argument('--DVS-DA', action='store_true',
help='use DA on DVS')
# train data used ratio
parser.add_argument('--traindata-ratio', default=1.0, type=float,
help='training data ratio')
# snr value
parser.add_argument('--snr', default=0, type=int,
help='random noise amplitude controled by snr, 0 means no noise')
parser.add_argument('--aug_smooth', action='store_true',
help='Apply test time augmentation to smooth the CAM')
parser.add_argument('--eigen_smooth', action='store_true', help='Reduce noise by taking the first principle componenet'
'of cam_weights*activations')
import os
import numpy as np
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
from tonic.datasets import NCALTECH101, CIFAR10DVS
import tonic
from matplotlib import rcParams
import seaborn as sns
# for matplotlib 3D
def get_proj(self):
"""
Create the projection matrix from the current viewing position.
elev stores the elevation angle in the z plane
azim stores the azimuth angle in the (x, y) plane
dist is the distance of the eye viewing point from the object point.
"""
# chosen for similarity with the initial view before gh-8896
relev, razim = np.pi * self.elev / 180, np.pi * self.azim / 180
# EDITED TO HAVE SCALED AXIS
xmin, xmax = np.divide(self.get_xlim3d(), self.pbaspect[0])
ymin, ymax = np.divide(self.get_ylim3d(), self.pbaspect[1])
zmin, zmax = np.divide(self.get_zlim3d(), self.pbaspect[2])
# transform to uniform world coordinates 0-1, 0-1, 0-1
worldM = proj3d.world_transformation(xmin, xmax,
ymin, ymax,
zmin, zmax)
# look into the middle of the new coordinates
R = self.pbaspect / 2
xp = R[0] + np.cos(razim) * np.cos(relev) * self.dist
yp = R[1] + np.sin(razim) * np.cos(relev) * self.dist
zp = R[2] + np.sin(relev) * self.dist
E = np.array((xp, yp, zp))
self.eye = E
self.vvec = R - E
self.vvec = self.vvec / np.linalg.norm(self.vvec)
if abs(relev) > np.pi / 2:
# upside down
V = np.array((0, 0, -1))
else:
V = np.array((0, 0, 1))
zfront, zback = -self.dist, self.dist
viewM = proj3d.view_transformation(E, R, V)
projM = self._projection(zfront, zback)
M0 = np.dot(viewM, worldM)
M = np.dot(projM, M0)
return M
def event_vis_raw(x):
sns.set_style('whitegrid')
# sns.set_palette('deep', desat=.6)
sns.set_context("notebook", font_scale=1.5,
rc={"lines.linewidth": 2.5})
Axes3D.get_proj = get_proj
x = np.array(x.tolist()) # x, y, t, p
mask = (x[:, 3] == 1)
x_pos = x[mask]
x_neg = x[mask == False]
pos_idx = np.random.choice(x_pos.shape[0], 10000)
neg_idx = np.random.choice(x_neg.shape[0], 10000)
# x_pos[pos_idx, 2] = 0
# x_neg[neg_idx, 2] = 0
fig = plt.figure(figsize=plt.figaspect(0.5) * 1.5)
ax = Axes3D(fig)
ax.pbaspect = np.array([2.0, 1.0, 0.5])
ax.view_init(elev=10, azim=-75)
# ax.view_init(elev=15, azim=15)
ax.set_xlabel('t (time step)')
ax.set_ylabel('w (pixel)')
ax.set_zlabel('h (pixel)')
# ax.set_xticks([])
# ax.set_yticks([])
# ax.set_zticks([])
# ax.scatter(x_pos[pos_idx, 2], 48 - x_pos[pos_idx, 0], 48 - x_pos[pos_idx, 1], color='red', alpha=0.3, s=1.)
# ax.scatter(x_neg[neg_idx, 2], 48 - x_neg[neg_idx, 0], 48 - x_neg[neg_idx, 1], color='blue', alpha=0.3, s=1.)
ax.scatter(x_pos[:, 0], 48 - x_pos[:, 1] * 0.375, 48 - x_pos[:, 2] * 0.375, color='red', alpha=0.3, s=1.)
# ax.scatter(x_neg[:, 0], 64 - x_neg[:, 1] // 2, 128 - x_neg[:, 2], color='blue', alpha=0.3, s=1.)
ax.scatter(18000, 48 - x_pos[:, 1] * 0.375, 48 - x_pos[:, 2] * 0.375, color='red', alpha=0.3, s=1.)
# ax.scatter(18000, 64 - x_pos[:, 1] // 2, 128 - x_pos[:, 2], color='blue', alpha=0.3, s=1.)
def get_dataloader_ncal(step, **kwargs):
sensor_size = tonic.datasets.CIFAR10DVS.sensor_size
transform = tonic.transforms.Compose([
# tonic.transforms.DropPixel(hot_pixel_frequency=.999),
# tonic.transforms.Denoise(500),
tonic.transforms.DropEvent(p=0.0),
# tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),
# lambda x: F.interpolate(torch.tensor(x, dtype=torch.float), size=[48, 48], mode='bilinear', align_corners=True),
])
dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=transform)
# dataset = [dataset[5569], dataset[8196]]
# dataset = [dataset[5000], dataset[6000]] # 1958
# dataset = [dataset[0]]
# loader = torch.utils.data.DataLoader(
# dataset, batch_size=1,
# shuffle=False,
# pin_memory=True, drop_last=True, num_workers=8
# )
return dataset
try:
from apex import amp
from apex.parallel import DistributedDataParallel as ApexDDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
if args_config.config:
with open(args_config.config, 'r') as f:
cfg = yaml.safe_load(f)
parser.set_defaults(**cfg)
# The main arg parser parses the rest of the args, the usual
# defaults will have been overridden if config file specified.
args = parser.parse_args(remaining)
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
def main():
torch.set_num_threads(20)
os.environ["OMP_NUM_THREADS"] = "20" # 设置OpenMP计算库的线程数
os.environ["MKL_NUM_THREADS"] = "20" # 设置MKL-DNN CPU加速库的线程数。
args, args_text = _parse_args()
args.no_spike_output = True
torch.cuda.set_device('cuda:%d' % args.device)
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1:
_logger.warning(
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
args.num_gpu = 1
# args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
assert args.rank >= 0
if args.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
_logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
# torch.manual_seed(args.seed + args.rank)
setup_seed(args.seed + args.rank)
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
adaptive_node=args.adaptive_node,
dataset=args.target_dataset,
step=args.step,
encode_type=args.encode,
node_type=eval(args.node_type),
threshold=args.threshold,
tau=args.tau,
sigmoid_thres=args.sigmoid_thres,
requires_thres_grad=args.requires_thres_grad,
spike_output=not args.no_spike_output,
act_fun=args.act_fun,
temporal_flatten=args.temporal_flatten,
layer_by_layer=args.layer_by_layer,
n_groups=args.n_groups,
)
if 'dvs' in args.target_dataset:
args.channels = 2
elif 'mnist' in args.target_dataset:
args.channels = 1
else:
args.channels = 3
# flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)
# _logger.info('flops = %fM', flops / 1e6)
# _logger.info('param size = %fM', params / 1e6)
linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0
args.lr = linear_scaled_lr
_logger.info("learning rate is %f" % linear_scaled_lr)
if args.local_rank == 0:
_logger.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
# now config only for imnet
data_config = resolve_data_config(vars(args), model=model, verbose=False)
# source_loader_train, _, _, _ = eval('get_transfer_%s_data' % args.source_dataset)(
# batch_size=args.batch_size,
# step=args.step,
# args=args,
# _logge=_logger,
# data_config=data_config,
# size=args.event_size,
# mix_up=args.mix_up,
# cut_mix=args.cut_mix,
# event_mix=args.event_mix,
# beta=args.cutmix_beta,
# prob=args.cutmix_prob,
# gaussian_n=args.gaussian_n,
# num=args.cutmix_num,
# noise=args.cutmix_noise,
# num_classes=args.num_classes,
# rand_aug=args.rand_aug,
# randaug_n=args.randaug_n,
# randaug_m=args.randaug_m,
# portion=args.train_portion,
# _logger=_logger,
# )
origin_loader_train, _, _, _ = eval('get_origin_dvsc10_data')(
batch_size=args.batch_size,
step=args.step,
args=args,
_logge=_logger,
data_config=data_config,
size=args.event_size,
mix_up=args.mix_up,
cut_mix=args.cut_mix,
event_mix=args.event_mix,
beta=args.cutmix_beta,
prob=args.cutmix_prob,
gaussian_n=args.gaussian_n,
num=args.cutmix_num,
noise=args.cutmix_noise,
num_classes=args.num_classes,
rand_aug=args.rand_aug,
randaug_n=args.randaug_n,
randaug_m=args.randaug_m,
portion=args.train_portion,
_logger=_logger,
)
target_loader_train, target_loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.target_dataset)(
batch_size=args.batch_size,
dvs_da=args.DVS_DA,
step=args.step,
args=args,
_logge=_logger,
data_config=data_config,
size=args.event_size,
mix_up=args.mix_up,
cut_mix=args.cut_mix,
event_mix=args.event_mix,
beta=args.cutmix_beta,
prob=args.cutmix_prob,
gaussian_n=args.gaussian_n,
num=args.cutmix_num,
noise=args.cutmix_noise,
num_classes=args.num_classes,
rand_aug=args.rand_aug,
randaug_n=args.randaug_n,
randaug_m=args.randaug_m,
portion=args.train_portion,
_logger=_logger,
train_data_ratio=args.traindata_ratio,
snr=args.snr,
data_mode="full",
frames_num=12,
data_type="frequency"
)
model_before = deepcopy(model)
if args.eval: # evaluate the model
if args.distributed:
state_dict = torch.load(args.eval_checkpoint)['state_dict_ema']
new_state_dict = OrderedDict()
# add module prefix for DDP
for k, v in state_dict.items():
k = 'module.' + k
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(torch.load('/home/hexiang/TransferLearning_For_DVS/Results_lastest/train_TCKA_test/Transfer_VGG_SNN-dvsc10-10-bs_120-seed_42-DA_True-ls_0.0-lr_0.005-SNR_0-domainLoss_False-semanticLoss_False-domain_loss_coefficient1.0-semantic_loss_coefficient0.5-traindataratio_1.0-TETfirst_True-TETsecond_True/model_best.pth.tar', map_location='cpu')['state_dict'])
# pass
# print("no model load")
# --------------------------------------------------------------------------
# Show Acc
# --------------------------------------------------------------------------
print("load model finished!")
# """ python cam.py -image-path
# Example usage of loading an image, and computing:
# 1. CAM
# 2. Guided Back Propagation
# 3. Combining both
# """
#
# # Choose the target layer you want to compute the visualization for.
# # Usually this will be the last convolutional layer in the model.
# # Some common choices can be:
# # Resnet18 and 50: model.layer4
# # VGG, densenet161: model.features[-1]
# # mnasnet1_0: model.layers[-1]
# # You can print the model to help chose the layer
# # You can pass a list with several target layers,
# # in that case the CAMs will be computed per layer and then aggregated.
# # You can also try selecting all layers of a certain type, with e.g:
# # from pytorch_grad_cam.utils.find_layers import find_layer_types_recursive
# # find_layer_types_recursive(model, [torch.nn.ReLU])
# target_layers = [model.feature[-1]]
#
# if True:
# # inputs = 0.0
# # label = 0.0
# # for batch_idx, (inputs_tmp, label_tmp) in tqdm.tqdm(enumerate(origin_loader_train)):
# # if batch_idx == choose_idx:
# # inputs = inputs_tmp
# # label = label_tmp
# # break
# # else:
# # continue
# inputs = 0.0
# rgb_img = 0.0
#
# #Using the with statement ensures the context is freed, and you can
# #recreate different CAM objects in a loop.
# plt.figure(figsize=(8, 6))
# plt.xlabel('w (pixel)')
# plt.ylabel('h (pixel)')
# cam_algorithm = GradCAMPlusPlus
# model = model.cuda()
# with cam_algorithm(model=model,
# target_layers=target_layers,
# use_cuda=False) as cam:
#
# # AblationCAM and ScoreCAM have batched implementations.
# # You can override the internal batch size for faster computation.
# cam.batch_size = 32
#
# for batch_idx, (origin_loaer, target_loader) in tqdm.tqdm(enumerate(zip(origin_loader_train, target_loader_train))):
#
# twodemension_inputs, labels = origin_loaer
# plt.figure(figsize=(8, 6))
# # plt.xlabel('w (pixel)')
# # plt.ylabel('h (pixel)')
# twodemension_inputs = twodemension_inputs[0] # (1, 10, 2, 48, 48) -> (10, 2, 48, 48)
# event_frame_plot_2d(twodemension_inputs)
#
# inputs_tmp, label_tmp = target_loader
# inputs = inputs_tmp
# inputs = inputs.type(torch.FloatTensor).cuda()
#
# grayscale_cam = cam(input_tensor=inputs,
# targets=None,
# aug_smooth=args.aug_smooth,
# eigen_smooth=args.eigen_smooth)
#
# # Here grayscale_cam has only one image in the batch
# grayscale_cam = grayscale_cam[0, :]
#
# # cam_image = show_cam_on_image(rgb_img.permute(1, 2, 0).numpy(), grayscale_cam, use_rgb=True, image_weight=0.0)
# cam_image = show_cam_on_image(np.ones((48, 48, 3)), grayscale_cam, use_rgb=True,
# image_weight=0.0)
# # # cam_image is RGB encoded whereas "cv2.imwrite" requires BGR encoding.
# # rgb_img = cv2.resize(rgb_img.permute(1, 2, 0).numpy(), (32, 32))
#
# # cv2.imwrite(f'{args.method}_cam.jpg', cam_image)
# plt.ylim(bottom=0.)
# plt.axis('off')
# plt.savefig('fig/gradcam_dvspic_origin/label_{}_id_{}.jpg'.format(labels.item(), 400 + batch_idx), bbox_inches='tight', pad_inches=0)
# plt.imshow(cam_image, alpha=1.0)
# # plt.show()
# # plt.savefig('gradcam_pic/plot_id{}.jpg'.format(batch_idx), bbox_inches='tight')
# plt.savefig('fig/gradcam_dvspic_withoutloss/label_{}_id_{}.jpg'.format(labels.item(), 400 + batch_idx), bbox_inches='tight', pad_inches=0)
# 第二次
print("load model again!")
model = model_before
model.load_state_dict(torch.load(
'/home/hexiang/TransferLearning_For_DVS/Results_lastest/train_TCKA_test/Transfer_VGG_SNN-dvsc10-10-bs_120-seed_42-DA_True-ls_0.0-lr_0.005-SNR_0-domainLoss_True-semanticLoss_True-domain_loss_coefficient1.0-semantic_loss_coefficient0.5-traindataratio_1.0-TETfirst_True-TETsecond_True/model_best.pth.tar', map_location='cpu')['state_dict'])
""" python cam.py -image-path
Example usage of loading an image, and computing:
1. CAM
2. Guided Back Propagation
3. Combining both
"""
# Choose the target layer you want to compute the visualization for.
# Usually this will be the last convolutional layer in the model.
# Some common choices can be:
# Resnet18 and 50: model.layer4
# VGG, densenet161: model.features[-1]
# mnasnet1_0: model.layers[-1]
# You can print the model to help chose the layer
# You can pass a list with several target layers,
# in that case the CAMs will be computed per layer and then aggregated.
# You can also try selecting all layers of a certain type, with e.g:
# from pytorch_grad_cam.utils.find_layers import find_layer_types_recursive
# find_layer_types_recursive(model, [torch.nn.ReLU])
target_layers = [model.feature[-1]]
if True:
# inputs = 0.0
# label = 0.0
# for batch_idx, (inputs_tmp, label_tmp) in tqdm.tqdm(enumerate(origin_loader_train)):
# if batch_idx == choose_idx:
# inputs = inputs_tmp
# label = label_tmp
# break
# else:
# continue
inputs = 0.0
rgb_img = 0.0
#Using the with statement ensures the context is freed, and you can
#recreate different CAM objects in a loop.
plt.figure(figsize=(8, 6))
plt.xlabel('w (pixel)')
plt.ylabel('h (pixel)')
cam_algorithm = GradCAMPlusPlus
model = model.cuda()
with cam_algorithm(model=model,
target_layers=target_layers,
use_cuda=False) as cam:
# AblationCAM and ScoreCAM have batched implementations.
# You can override the internal batch size for faster computation.
cam.batch_size = 32
for batch_idx, (origin_loaer, target_loader) in tqdm.tqdm(enumerate(zip(origin_loader_train, target_loader_train))):
twodemension_inputs, labels = origin_loaer
plt.figure(figsize=(8, 6))
# plt.xlabel('w (pixel)')
# plt.ylabel('h (pixel)')
twodemension_inputs = twodemension_inputs[0] # (1, 10, 2, 48, 48) -> (10, 2, 48, 48)
event_frame_plot_2d(twodemension_inputs)
inputs_tmp, label_tmp = target_loader
inputs = inputs_tmp
inputs = inputs.type(torch.FloatTensor).cuda()
grayscale_cam = cam(input_tensor=inputs,
targets=None,
aug_smooth=args.aug_smooth,
eigen_smooth=args.eigen_smooth)
# Here grayscale_cam has only one image in the batch
grayscale_cam = grayscale_cam[0, :]
# cam_image = show_cam_on_image(rgb_img.permute(1, 2, 0).numpy(), grayscale_cam, use_rgb=True, image_weight=0.0)
cam_image = show_cam_on_image(np.ones((48, 48, 3)), grayscale_cam, use_rgb=True,
image_weight=0.0)
# # cam_image is RGB encoded whereas "cv2.imwrite" requires BGR encoding.
# rgb_img = cv2.resize(rgb_img.permute(1, 2, 0).numpy(), (32, 32))
# cv2.imwrite(f'{args.method}_cam.jpg', cam_image)
plt.ylim(bottom=0.)
plt.axis('off')
# plt.savefig('fig/gradcam_dvspic_origin/label_{}_id_{}.jpg'.format(labels.item(), batch_idx), bbox_inches='tight', pad_inches=0)
plt.imshow(cam_image, alpha=1.0)
# plt.show()
# plt.savefig('gradcam_pic/plot_id{}.jpg'.format(batch_idx), bbox_inches='tight')
plt.savefig('fig/gradcam_dvspic_withloss/label_{}_id_{}.jpg'.format(labels.item(), batch_idx), bbox_inches='tight', pad_inches=0)
def event_frame_plot_2d(event):
for t in range(event.shape[0]):
pos_idx = []
neg_idx = []
for x in range(event.shape[2]):
for y in range(event.shape[3]):
if event[t, 0, x, y] > 0:
pos_idx.append((x, y, event[t, 0, x, y]))
if event[t, 1, x, y] > 0:
neg_idx.append((x, y, event[t, 0, x, y]))
if len(pos_idx) > 0:
# print(t)
pos_x, pos_y, pos_c = np.split(np.array(pos_idx), 3, axis=1)
# plt.scatter(48 - pos_x[:, 0] * 0.375, 48 - pos_y[:, 0] * 0.375, c='red', alpha=1, s=1)
plt.scatter(pos_x[:, 0] * 0.375, pos_y[:, 0] * 0.375, c='white', alpha=1, s=1)
if len(neg_idx) > 0:
neg_x, neg_y, neg_c = np.split(np.array(neg_idx), 3, axis=1)
# plt.scatter(48 - neg_x[:, 0] * 0.375, 48 - neg_y[:, 0] * 0.375, c='blue', alpha=1, s=1)
plt.scatter(neg_x[:, 0] * 0.375, neg_y[:, 0] * 0.375, c='blue', alpha=1, s=1)
# sys.exit()
if __name__ == '__main__':
main()
================================================
FILE: examples/Perception_and_Learning/img_cls/transfer_for_dvs/README.md
================================================
# Script for all experiments
## Baseline
1. CIFAR10-DVS
```shell
python main.py --model VGG_SNN --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 120 --act-fun QGateGrad --device 5 --seed 42 --DVS-DA --traindata-ratio 1.0 --smoothing 0.0 --TET-loss-first --TET-loss-second
```
2. N-Caltech 101
```shell
python main.py --model VGG_SNN --node-type LIFNode --dataset NCALTECH101 --step 10 --batch-size 120 --act-fun QGateGrad --device 7 --seed 42 --num-classes 101 --traindata-ratio 1.0 --smoothing 0.0 --TET-loss-first --TET-loss-second
```
3. Omniglot
```shell
python main.py --model SCNN --node-type LIFNode --dataset nomni --step 12 --batch-size 64 --num-classes 1623 --act-fun QGateGrad --epochs 200 --device 6 --log-interval 200 --smoothing 0.0 --seed 42 --lr 0.01 --min-lr 1e-5
```
## Our Method
1. CIFAR10-DVS
```shell
python main_transfer.py --model Transfer_VGG_SNN --node-type LIFNode --source-dataset cifar10 --target-dataset dvsc10 --step 10 --batch-size 120 --act-fun QGateGrad --device 1 --seed 42 --traindata-ratio 1.0 --smoothing 0.0 --domain-loss --semantic-loss --DVS-DA --TET-loss-first --TET-loss-second
```
2. N-Caltech 101
```shell
python main_transfer.py --model Transfer_VGG_SNN --node-type LIFNode --source-dataset CALTECH101 --target-dataset NCALTECH101 --step 10 --batch-size 120 --act-fun QGateGrad --device 5 --seed 42 --num-classes 101 --traindata-ratio 1.0 --domain-loss --semantic-loss --semantic-loss-coefficient 0.001 --TET-loss-first --TET-loss-second&
```
3. N-Omniglot
```shell
python main_transfer.py --model Transfer_SCNN --node-type LIFNode --source-dataset omni --target-dataset nomni --step 12 --batch-size 64 --num-classes 1623 --act-fun QGateGrad --epochs 200 --device 6 --log-interval 200 --smoothing 0.0 --seed 42 --domain-loss --semantic-loss --semantic-loss-coefficient 0.5 --lr 0.01 --min-lr 1e-5
```
## Visualization Loss-landscape
you should git clone from https://github.com/tomgoldstein/loss-landscape first.
```shell
HDF5_USE_FILE_LOCKING="FALSE" mpirun -n 4 -mca btl ^openib python main_visual_losslandscape.py --model VGG_SNN --node-type LIFNode --source-dataset cifar10 --target-dataset dvsc10 --step 10 --batch-size 1000 --eval --eval_checkpoint /home/TransferLearning_For_DVS/Resultes_new_compare/Baseline/VGG_SNN-dvsc10-10-seed_42-bs_120-DA_True-ls_0.0-traindataratio_0.1-TET_first_False-TET_second_False/last.pth.tar --mpi --x=-1.0:1.0:51 --y=-1.0:1.0:51 --dir_type weights --xnorm filter --xignore biasbn --ynorm filter --yignore biasbn --plot --DVS-DA --smoothing 0.0 --traindata-ratio 0.1
```
```shell
python main_visual_losslandscape.py --model Transfer_VGG_SNN --node-type LIFNode --source-dataset CALTECH101 --target-dataset NCALTECH101 --step 10 --batch-size 500 --eval --eval_checkpoint /home/TransferLearning_For_DVS/Results_new_compare/train_TCKA_test/Transfer_VGG_SNN-NCALTECH101-10-bs_120-seed_47-DA_False-ls_0.0-lr_0.005-SNR_0-domainLoss_True-semanticLoss_True-domain_loss_coefficient1.0-semantic_loss_coefficient0.001-traindataratio_0.1-TETfirst_True-TETsecond_True/last.pth.tar --mpi --x=-1.0:1.0:51 --y=-1.0:1.0:51 --dir_type weights --xnorm filter --xignore biasbn --ynorm filter --yignore biasbn --plot --smoothing 0.0 --traindata-ratio 0.1 --num-classes 101 --device 5&
```
## Visualization Grad-cam++
you should git clone from https://github.com/jacobgil/pytorch-grad-cam first.
```shell
python GradCAM_visualization.py --model Transfer_VGG_SNN --node-type LIFNode --source-dataset cifar10 --target-dataset dvsc10 --step 10 --batch-size 1 --act-fun QGateGrad --device 6 --seed 42 --smoothing 0.0 --DVS-DA --eval --eval_checkpoint /home/TransferLearning_For_DVS/Results_lastest/train_TCKA_test/Transfer_VGG_SNN-dvsc10-10-bs_120-seed_42-DA_True-ls_0.0-lr_0.005-SNR_0-domainLoss_True-semanticLoss_True-domain_loss_coefficient1.0-semantic_loss_coefficient0.5-traindataratio_1.0-TETfirst_True-TETsecond_True/model_best.pth.tar
```
## Note: Dataset
In order to work with the source and target domain data, the datastes file is tailored, please use `datasets.py` here to replace and override `braincog/datasets/datasets.py` if you want to run transfer learning in this project.
## Citation
If you find the code and dataset useful in your research, please consider citing:
```
@article{he2023improving,
title={Improving the Performance of Spiking Neural Networks on Event-based Datasets with Knowledge Transfer},
author={He, Xiang and Zhao, Dongcheng and Li, Yang and Shen, Guobin and Kong, Qingqun and Zeng, Yi},
journal={arXiv preprint arXiv:2303.13077},
year={2023}
}
@misc{https://doi.org/10.48550/arxiv.2207.08533,
doi = {10.48550/ARXIV.2207.08533},
url = {https://arxiv.org/abs/2207.08533},
author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},
title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},
publisher = {arXiv},
year = {2022},
}
```
## Contents
If you are confused about using it or have other feedback and comments, please feel free to contact us via [hexiang2021@ia.ac.cn](hexiang2021@ia.ac.cn).
Have a good day!
================================================
FILE: examples/Perception_and_Learning/img_cls/transfer_for_dvs/datasets.py
================================================
import os, warnings
import torchvision.datasets
try:
import tonic
from tonic import DiskCachedDataset
except:
warnings.warn("tonic should be installed, 'pip install git+https://github.com/BrainCog-X/tonic_braincog.git'")
import torch
import torch.nn.functional as F
import torch.utils
import torchvision.datasets as datasets
from timm.data import ImageDataset, create_loader, Mixup, FastCollateMixup, AugMixDataset
from timm.data import create_transform
from einops import rearrange, repeat
from torchvision import transforms
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from torch.utils.data import ConcatDataset
from braincog.datasets.NOmniglot.nomniglot_full import NOmniglotfull
from braincog.datasets.NOmniglot.nomniglot_nw_ks import NOmniglotNWayKShot
from braincog.datasets.NOmniglot.nomniglot_pair import NOmniglotTrainSet, NOmniglotTestSet
from braincog.datasets.ESimagenet.ES_imagenet import ESImagenet_Dataset
from braincog.datasets.ESimagenet.reconstructed_ES_imagenet import ESImagenet2D_Dataset
from braincog.datasets.CUB2002011 import CUB2002011
from braincog.datasets.TinyImageNet import TinyImageNet
from braincog.datasets.StanfordDogs import StanfordDogs
from random import sample
from .cut_mix import CutMix, EventMix, MixUp
from .rand_aug import *
from .utils import dvs_channel_check_expend, rescale
from PIL import Image
import cv2
import math
DVSCIFAR10_MEAN_16 = [0.3290, 0.4507]
DVSCIFAR10_STD_16 = [1.8398, 1.6549]
DATA_DIR = '/data/datasets'
DEFAULT_CROP_PCT = 0.875
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
CIFAR10_DEFAULT_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_DEFAULT_STD = (0.2023, 0.1994, 0.2010)
class TransferSampler(torch.utils.data.sampler.Sampler):
r"""Samples elements randomly from a given list of indices, without replacement.
Arguments:
indices (sequence): a sequence of indices
"""
def __init__(self, indices):
self.indices = indices
def __iter__(self):
return (self.indices[i] for i in range(len(self.indices)))
def __len__(self):
return len(self.indices)
class Transfer_DataSet(torchvision.datasets.VisionDataset):
def __init__(self, data, label):
self.data = data
self.label = label
self.length = data.shape[0]
def __getitem__(self, mask):
data = self.data[mask]
label = self.label[mask]
return data, label
def __len__(self):
return self.length
# 自定义HSV空间 transform
class ConvertHSV(object):
"""计算边缘梯度
Args:
None
"""
def __init__(self):
pass
# transform 会调用该方法
def __call__(self, img):
"""
Args:
img (PIL Image): PIL Image
Returns:
PIL Image: PIL image, v channel.
"""
img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
return Image.fromarray(img.astype('uint8'))
def unpack_mix_param(args):
mix_up = args['mix_up'] if 'mix_up' in args else False
cut_mix = args['cut_mix'] if 'cut_mix' in args else False
event_mix = args['event_mix'] if 'event_mix' in args else False
beta = args['beta'] if 'beta' in args else 1.
prob = args['prob'] if 'prob' in args else .5
num = args['num'] if 'num' in args else 1
num_classes = args['num_classes'] if 'num_classes' in args else 10
noise = args['noise'] if 'noise' in args else 0.
gaussian_n = args['gaussian_n'] if 'gaussian_n' in args else None
return mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n
def build_transform(is_train, img_size, use_hsv=True):
"""
构建数据增强, 适用于static data
:param is_train: 是否训练集
:param img_size: 输出的图像尺寸
:return: 数据增强策略
"""
resize_im = img_size > 32
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=img_size,
is_training=True,
color_jitter=0.4,
auto_augment='rand-m9-mstd0.5-inc1',
interpolation='bicubic',
re_prob=0.25,
re_mode='pixel',
re_count=1,
)
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(
img_size, padding=4)
return transform
t = []
# if resize_im:
# size = int((256 / 224) * img_size)
# t.append(
# # to maintain same ratio w.r.t. 224 images
# transforms.Resize(size, interpolation=InterpolationMode.BICUBIC),
# )
# t.append(transforms.CenterCrop(img_size))
# t.append(transforms.RandomAffine(degrees=0, translate=))
# if Gradient:
# print("Used Gradient!")
# t.append(ComputeLaplacian())
# t.append(ConvertHSV())
# t.append(AddGaussianNoise())
t.append(transforms.Resize((img_size, img_size), interpolation=InterpolationMode.BILINEAR))
if use_hsv:
print("Used V-channel!")
t.append(ConvertHSV())
t.append(transforms.ToTensor())
# t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
return transforms.Compose(t)
def build_dataset(is_train, img_size, dataset, path, same_da=False, use_hsv=True):
"""
构建带有增强策略的数据集
:param is_train: 是否训练集
:param img_size: 输出图像尺寸
:param dataset: 数据集名称
:param path: 数据集路径
:param same_da: 为训练集使用测试集的增广方法
: param use_hsv: 是否采用HSV
:return: 增强后的数据集
"""
# transform = build_transform(False, img_size) if same_da else build_transform(is_train, img_size)
transform = build_transform(False, img_size, use_hsv) if same_da else build_transform(False, img_size, use_hsv)
if dataset == 'CIFAR10':
dataset = datasets.CIFAR10(
path, train=is_train, transform=transform, download=True)
nb_classes = 10
elif dataset == 'CIFAR100':
dataset = datasets.CIFAR100(
path, train=is_train, transform=transform, download=True)
nb_classes = 100
elif dataset == 'CALTECH101':
dataset = datasets.Caltech101(
path, transform=transform, download=True
)
nb_classes = 101
else:
raise NotImplementedError
return dataset, nb_classes
class MNISTData(object):
"""
Load MNIST datesets.
"""
def __init__(self,
data_path: str,
batch_size: int,
train_trans: Sequence[torch.nn.Module] = None,
test_trans: Sequence[torch.nn.Module] = None,
pin_memory: bool = True,
drop_last: bool = True,
shuffle: bool = True,
) -> None:
self._data_path = data_path
self._batch_size = batch_size
self._pin_memory = pin_memory
self._drop_last = drop_last
self._shuffle = shuffle
self._train_transform = transforms.Compose(train_trans) if train_trans else None
self._test_transform = transforms.Compose(test_trans) if test_trans else None
def get_data_loaders(self):
print('Batch size: ', self._batch_size)
train_datasets = datasets.MNIST(root=self._data_path, train=True, transform=self._train_transform, download=True)
test_datasets = datasets.MNIST(root=self._data_path, train=False, transform=self._test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=self._batch_size,
pin_memory=self._pin_memory, drop_last=self._drop_last, shuffle=self._shuffle
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=self._batch_size,
pin_memory=self._pin_memory, drop_last=False
)
return train_loader, test_loader
def get_standard_data(self):
MNIST_MEAN = 0.1307
MNIST_STD = 0.3081
self._train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),
transforms.ToTensor(),
transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])
self._test_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])
return self.get_data_loaders()
def get_mnist_data(batch_size, num_workers=8, same_da=False, **kwargs):
"""s
获取MNIST数据
http://data.pymvpa.org/datasets/mnist/
:param batch_size: batch size
:param same_da: 为训练集使用测试集的增广方法
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
MNIST_MEAN = 0.1307
MNIST_STD = 0.3081
if 'skip_norm' in kwargs and kwargs['skip_norm'] is True:
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(rescale)
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(rescale)
])
else:
train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),
# transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])
test_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])
train_datasets = datasets.MNIST(
root=DATA_DIR, train=True, transform=test_transform if same_da else train_transform, download=True)
test_datasets = datasets.MNIST(
root=DATA_DIR, train=False, transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_fashion_data(batch_size, num_workers=8, same_da=False, **kwargs):
"""
获取fashion MNIST数据
http://arxiv.org/abs/1708.07747
:param batch_size: batch size
:param same_da: 为训练集使用测试集的增广方法
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor()])
test_transform = transforms.Compose([transforms.ToTensor()])
train_datasets = datasets.FashionMNIST(
root=DATA_DIR, train=True, transform=test_transform if same_da else train_transform, download=True)
test_datasets = datasets.FashionMNIST(
root=DATA_DIR, train=False, transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_cifar10_data(batch_size, num_workers=8, same_da=False, **kwargs):
"""
获取CIFAR10数据
https://www.cs.toronto.edu/~kriz/cifar.html
:param batch_size: batch size
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
use_hsv = not kwargs['no_use_hsv'] if 'no_use_hsv' in kwargs else True
train_datasets, _ = build_dataset(True, 32, 'CIFAR10', DATA_DIR, same_da, False)
test_datasets, _ = build_dataset(False, 32, 'CIFAR10', DATA_DIR, same_da, False)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True,
num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False,
num_workers=num_workers
)
return train_loader, test_loader, None, None
def get_cifar100_data(batch_size, num_workers=8, same_data=False, *args, **kwargs):
"""
获取CIFAR100数据
https://www.cs.toronto.edu/~kriz/cifar.html
:param batch_size: batch size
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
train_datasets, _ = build_dataset(True, 32, 'CIFAR100', DATA_DIR, same_data)
test_datasets, _ = build_dataset(False, 32, 'CIFAR100', DATA_DIR, same_data)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_transfer_cifar10_data(batch_size, num_workers=8, same_da=False, **kwargs):
use_hsv = not kwargs['no_use_hsv'] if 'no_use_hsv' in kwargs else True
train_datasets, _ = build_dataset(True, 48, 'CIFAR10', DATA_DIR, same_da, use_hsv) # 原来是48
test_datasets, _ = build_dataset(False, 48, 'CIFAR10', DATA_DIR, same_da, use_hsv)
concat_dataset = ConcatDataset([train_datasets, test_datasets]) # concat dataset
img_index = [[] for i in range(10)]
label_index = [0] * 60000
for idx, (img, label) in enumerate(concat_dataset):
img_index[label].append(img)
for i in range(10):
img_index[i] = torch.stack(img_index[i], 0)
label_index[i * 6000:2 * i * 6000] = [i] * 6000
source_datasets = Transfer_DataSet(data=rearrange(torch.stack(img_index, dim=0), 'l b c w h -> (l b) c w h'),
label=label_index)
source_loader = torch.utils.data.DataLoader(
source_datasets, batch_size=60000,
sampler=TransferSampler(torch.arange(0, 60000).tolist()),
pin_memory=True, drop_last=False, num_workers=16
)
return source_loader, None, None, None
def get_combined_cifar10_data(batch_size, num_workers=8, same_da=False, **kwargs):
use_hsv = not kwargs['no_use_hsv'] if 'no_use_hsv' in kwargs else True
train_datasets, _ = build_dataset(True, 48, 'CIFAR10', DATA_DIR, same_da, use_hsv)
test_datasets, _ = build_dataset(False, 48, 'CIFAR10', DATA_DIR, same_da, use_hsv)
concat_dataset = ConcatDataset([train_datasets, test_datasets]) # concat dataset
source_loader = torch.utils.data.DataLoader(
concat_dataset, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=8, shuffle=True
)
return source_loader, None, None, None
def get_transfer_CALTECH101_data(batch_size, num_workers=8, same_da=False, **kwargs):
"""
获取NCaltech101数据
http://journal.frontiersin.org/Article/10.3389/fnins.2015.00437/abstract
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
use_hsv = not kwargs['no_use_hsv'] if 'no_use_hsv' in kwargs else True
datasets, _ = build_dataset(False, 48, 'CALTECH101', DATA_DIR, same_da, use_hsv)
dataset_length = 8299
train_loader = torch.utils.data.DataLoader(
datasets, batch_size=10000,
sampler=TransferSampler(torch.arange(0, dataset_length).tolist()),
pin_memory=True, drop_last=False, num_workers=4
)
return train_loader, None, None, None
def get_combined_CALTECH101_data(batch_size, num_workers=8, same_da=False, **kwargs):
"""
获取NCaltech101数据
http://journal.frontiersin.org/Article/10.3389/fnins.2015.00437/abstract
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
use_hsv = not kwargs['no_use_hsv'] if 'no_use_hsv' in kwargs else True
datasets, _ = build_dataset(False, 48, 'CALTECH101', DATA_DIR, same_da, use_hsv)
dataset_length = 8299
train_loader = torch.utils.data.DataLoader(
datasets, batch_size=batch_size,
pin_memory=True, drop_last=False,
num_workers=4, shuffle=True
)
return train_loader, None, None, None
def get_TinyImageNet_data(batch_size, num_workers=8, same_da=False, *args, **kwargs):
size=kwargs["size"] if "size" in kwargs else 224
train_transform = transforms.Compose([
transforms.RandomResizedCrop(size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
test_transform = transforms.Compose([
transforms.Resize(size*8//7),
transforms.CenterCrop(size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
root=os.path.join(DATA_DIR, 'TinyImageNet')
train_datasets = TinyImageNet(
root=root, split="train", transform=test_transform if same_da else train_transform, download=True)
test_datasets = TinyImageNet(
root=root, split="val", transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_transfer_imnet_data(args, _logger, data_config, num_aug_splits, **kwargs):
'''
load imagenet 2012
we use images in train/ for training, and use images in val/ for testing
https://github.com/pytorch/examples/tree/master/imagenet
'''
IMAGENET_PATH = '/data/datasets/ILSVRC2012/'
traindir = os.path.join(IMAGENET_PATH, 'train')
valdir = os.path.join(IMAGENET_PATH, 'val')
batch_size = kwargs['batch_size']
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
ConvertHSV(),
transforms.ToTensor()]))
# val_dataset = datasets.ImageFolder(
# valdir,
# transforms.Compose([
# transforms.Resize(256),
# transforms.CenterCrop(224),
# ConvertHSV(),
# transforms.ToTensor()]))
# train_loader = torch.utils.data.DataLoader(
# train_dataset,
# batch_size=batch_size, shuffle=False,
# num_workers=4, pin_memory=True, sampler=TransferSampler([0, 1300, 2599, 2600]))
#
# val_loader = torch.utils.data.DataLoader(
# val_dataset,
# batch_size=batch_size, shuffle=False,
# num_workers=4, pin_memory=True)
return train_dataset, None, None, None
def get_dvsg_data(batch_size, step, **kwargs):
"""
获取DVS Gesture数据
DOI: 10.1109/CVPR.2017.781
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
sensor_size = tonic.datasets.DVSGesture.sensor_size
size = kwargs['size'] if 'size' in kwargs else 48
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),
])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),
])
train_dataset = tonic.datasets.DVSGesture(os.path.join(DATA_DIR, 'DVS/DVSGesture'),
transform=train_transform, train=True)
test_dataset = tonic.datasets.DVSGesture(os.path.join(DATA_DIR, 'DVS/DVSGesture'),
transform=test_transform, train=False)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
lambda x: dvs_channel_check_expend(x),
transforms.RandomCrop(size, padding=size // 12),
# transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
lambda x: dvs_channel_check_expend(x),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/DVSGesture/train_cache_{}'.format(step)),
transform=train_transform, num_copies=3)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/DVSGesture/test_cache_{}'.format(step)),
transform=test_transform, num_copies=3)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
pin_memory=True, drop_last=True, num_workers=8,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=2,
shuffle=False,
)
return train_loader, test_loader, mixup_active, None
def get_dvsc10_data(batch_size, step, dvs_da=False, **kwargs):
"""
获取DVS CIFAR10数据
http://journal.frontiersin.org/article/10.3389/fnins.2017.00309/full
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
size = kwargs['size'] if 'size' in kwargs else 48
snr = kwargs['snr'] if 'snr' in kwargs else 0
train_data_ratio = kwargs['train_data_ratio'] if 'train_data_ratio' in kwargs else 1.0
sensor_size = tonic.datasets.CIFAR10DVS.sensor_size
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=train_transform)
test_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=test_transform)
if dvs_da is True:
print("use dvs_da")
if snr > 0:
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
lambda x: x + torch.randn(x.shape) * math.sqrt(torch.mean(torch.pow(x, 2)) / math.pow(10, snr / 10)),
transforms.RandomCrop(size, padding=size // 12),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15)
])
else:
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
transforms.RandomCrop(size, padding=size // 12),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15)
])
else:
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
]) # 这里lambda返回的是地址, 注意不要用List复用.
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/DVS_Cifar10/train_cache_{}'.format(step)),
transform=train_transform)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/DVS_Cifar10/test_cache_{}'.format(step)),
transform=test_transform)
num_train = len(train_dataset)
num_per_cls = num_train // 10
indices_train, indices_test = [], []
portion = kwargs['portion'] if 'portion' in kwargs else .9
for i in range(10):
indices_train.extend(
sample(list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))), int(num_per_cls * portion * train_data_ratio)))
indices_test.extend(
list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
# print('cut_mix', beta, prob, num, num_classes)
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_train),
pin_memory=True, drop_last=False, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_test),
pin_memory=True, drop_last=False, num_workers=2
)
return train_loader, test_loader, mixup_active, None
def get_transfer_dvsc10_data(batch_size, step, dvs_da=False, **kwargs):
"""
获取DVS CIFAR10数据
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
size = kwargs['size'] if 'size' in kwargs else 48
snr = kwargs['snr'] if 'snr' in kwargs else 0
train_data_ratio = kwargs['train_data_ratio'] if 'train_data_ratio' in kwargs else 1.0
sensor_size = tonic.datasets.CIFAR10DVS.sensor_size
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=train_transform)
test_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=test_transform)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
]) # 这里lambda返回的是地址, 注意不要用List复用.
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/DVS_Cifar10/train_cache_{}'.format(step)),
transform=train_transform)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/DVS_Cifar10/test_cache_{}'.format(step)),
transform=test_transform)
num_train = len(train_dataset)
num_per_cls = num_train // 10
indices_train, indices_test = [], []
portion = kwargs['portion'] if 'portion' in kwargs else .9
for i in range(10):
indices_train.extend(
list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))))
indices_test.extend(
list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
# print('cut_mix', beta, prob, num, num_classes)
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=len(indices_train),
sampler=TransferSampler(indices_train),
pin_memory=True, drop_last=True, num_workers=8
)
return train_loader, None, mixup_active, None
def get_NCALTECH101_data(batch_size, step, dvs_da=False, **kwargs):
"""
获取NCaltech101数据
http://journal.frontiersin.org/Article/10.3389/fnins.2015.00437/abstract
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
sensor_size = tonic.datasets.NCALTECH101.sensor_size
cls_count = tonic.datasets.NCALTECH101.cls_count
dataset_length = tonic.datasets.NCALTECH101.length
portion = kwargs['portion'] if 'portion' in kwargs else .9
size = kwargs['size'] if 'size' in kwargs else 48
snr = kwargs['snr'] if 'snr' in kwargs else 0
train_data_ratio = kwargs['train_data_ratio'] if 'train_data_ratio' in kwargs else 1.0
# print('portion', portion)
train_sample_weight = []
train_sample_index = []
train_count = 0
test_sample_index = []
idx_begin = 0
for count in cls_count:
sample_weight = dataset_length / count
train_sample = round(portion * count)
test_sample = count - train_sample
train_count += int(train_sample * train_data_ratio)
train_sample_weight.extend(
[sample_weight] * int(train_sample * train_data_ratio)
)
train_sample_weight.extend(
[0.] * (train_sample - int(train_sample * train_data_ratio))
)
train_sample_weight.extend(
[0.] * test_sample
)
train_sample_index.extend(
sample(list(range(idx_begin, idx_begin + train_sample)), int(train_sample * train_data_ratio))
)
test_sample_index.extend(
list(range(idx_begin + train_sample, idx_begin + train_sample + test_sample))
)
idx_begin += count
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weight, train_count)
test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_sample_index)
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = tonic.datasets.NCALTECH101(os.path.join(DATA_DIR, 'DVS/NCALTECH101'), transform=train_transform)
test_dataset = tonic.datasets.NCALTECH101(os.path.join(DATA_DIR, 'DVS/NCALTECH101'), transform=test_transform)
if dvs_da is True:
print("use dvs_da")
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
transforms.RandomCrop(size, padding=size // 12),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15)
])
else:
if snr > 0:
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
lambda x: x + torch.randn(x.shape) * math.sqrt(torch.mean(torch.pow(x, 2)) / math.pow(10, snr / 10)),
])
else:
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
transforms.RandomCrop(size, padding=size // 12),
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
]) # 这里lambda返回的是地址, 注意不要用List复用.
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/NCALTECH101/train_cache_{}'.format(step)),
transform=train_transform, num_copies=3)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/NCALTECH101/test_cache_{}'.format(step)),
transform=test_transform, num_copies=3)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
sampler=train_sampler,
pin_memory=True, drop_last=True, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
sampler=test_sampler,
pin_memory=True, drop_last=False, num_workers=2
)
return train_loader, test_loader, mixup_active, None
def get_transfer_NCALTECH101_data(batch_size, step, dvs_da=False, **kwargs):
"""
获取NCaltech101数据
http://journal.frontiersin.org/Article/10.3389/fnins.2015.00437/abstract
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
sensor_size = tonic.datasets.NCALTECH101.sensor_size
cls_count = tonic.datasets.NCALTECH101.cls_count
dataset_length = tonic.datasets.NCALTECH101.length
portion = kwargs['portion'] if 'portion' in kwargs else .9
size = kwargs['size'] if 'size' in kwargs else 48
snr = kwargs['snr'] if 'snr' in kwargs else 0
train_data_ratio = kwargs['train_data_ratio'] if 'train_data_ratio' in kwargs else 1.0
# print('portion', portion)
train_sample_weight = []
train_sample_index = []
train_count = 0
test_sample_index = []
idx_begin = 0
for count in cls_count:
sample_weight = dataset_length / count
train_sample = round(portion * count)
test_sample = count - train_sample
train_count += int(train_sample * train_data_ratio)
train_sample_weight.extend(
[sample_weight] * int(train_sample * train_data_ratio)
)
train_sample_weight.extend(
[0.] * (train_sample - int(train_sample * train_data_ratio))
)
train_sample_weight.extend(
[0.] * test_sample
)
train_sample_index.extend(
sample(list(range(idx_begin, idx_begin + train_sample)), int(train_sample * train_data_ratio))
)
test_sample_index.extend(
list(range(idx_begin + train_sample, idx_begin + train_sample + test_sample))
)
idx_begin += count
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weight, train_count)
test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_sample_index)
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = tonic.datasets.NCALTECH101(os.path.join(DATA_DIR, 'DVS/NCALTECH101'), transform=train_transform)
test_dataset = tonic.datasets.NCALTECH101(os.path.join(DATA_DIR, 'DVS/NCALTECH101'), transform=test_transform)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
]) # 这里lambda返回的是地址, 注意不要用List复用.
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/NCALTECH101/train_cache_{}'.format(step)),
transform=train_transform, num_copies=3)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/NCALTECH101/test_cache_{}'.format(step)),
transform=test_transform, num_copies=3)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=len(train_sample_index),
sampler=TransferSampler(train_sample_index),
pin_memory=True, drop_last=True, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
sampler=test_sampler,
pin_memory=True, drop_last=False, num_workers=2
)
return train_loader, None, None, None
def get_NCARS_data(batch_size, step, **kwargs):
"""
获取N-Cars数据
https://ieeexplore.ieee.org/document/8578284/
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
sensor_size = tonic.datasets.NCARS.sensor_size
size = kwargs['size'] if 'size' in kwargs else 48
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=None, n_time_bins=step),
])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=None, n_time_bins=step),
])
train_dataset = tonic.datasets.NCARS(os.path.join(DATA_DIR, 'DVS/NCARS'), transform=train_transform, train=True)
test_dataset = tonic.datasets.NCARS(os.path.join(DATA_DIR, 'DVS/NCARS'), transform=test_transform, train=False)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
lambda x: dvs_channel_check_expend(x),
transforms.RandomCrop(size, padding=size // 12),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
lambda x: dvs_channel_check_expend(x),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/NCARS/train_cache_{}'.format(step)),
transform=train_transform, num_copies=3)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/NCARS/test_cache_{}'.format(step)),
transform=test_transform, num_copies=3)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
pin_memory=True, drop_last=True, num_workers=8,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=2,
shuffle=False,
)
return train_loader, test_loader, mixup_active, None
def get_nomni_data(batch_size, train_portion=1., **kwargs):
"""
获取N-Omniglot数据
:param batch_size:batch的大小
:param data_mode:一共full nkks pair三种模式
:param frames_num:一个样本帧的个数
:param data_type:event frequency两种模式
"""
data_mode = kwargs["data_mode"] if "data_mode" in kwargs else "full"
frames_num = kwargs["frames_num"] if "frames_num" in kwargs else 4
data_type = kwargs["data_type"] if "data_type" in kwargs else "event"
train_transform = transforms.Compose([
transforms.Resize((28, 28))])
test_transform = transforms.Compose([
transforms.Resize((28, 28))])
if data_mode == "full":
train_datasets = NOmniglotfull(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), train=True, frames_num=frames_num,
data_type=data_type,
transform=train_transform, use_npz=True)
test_datasets = NOmniglotfull(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), train=False, frames_num=frames_num,
data_type=data_type,
transform=test_transform, use_npz=True)
elif data_mode == "nkks":
train_datasets = NOmniglotNWayKShot(os.path.join(DATA_DIR, 'DVS/NOmniglot'),
n_way=kwargs["n_way"],
k_shot=kwargs["k_shot"],
k_query=kwargs["k_query"],
train=True,
frames_num=frames_num,
data_type=data_type,
transform=train_transform)
test_datasets = NOmniglotNWayKShot(os.path.join(DATA_DIR, 'DVS/NOmniglot'),
n_way=kwargs["n_way"],
k_shot=kwargs["k_shot"],
k_query=kwargs["k_query"],
train=False,
frames_num=frames_num,
data_type=data_type,
transform=test_transform)
elif data_mode == "pair":
train_datasets = NOmniglotTrainSet(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), use_frame=True,
frames_num=frames_num, data_type=data_type,
use_npz=False, resize=105)
test_datasets = NOmniglotTestSet(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), time=2000, way=kwargs["n_way"],
shot=kwargs["k_shot"], use_frame=True,
frames_num=frames_num, data_type=data_type, use_npz=False, resize=105)
else:
pass
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size, num_workers=4,
pin_memory=True, drop_last=True, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size, num_workers=4,
pin_memory=True, drop_last=False
)
return train_loader, test_loader, None, None
def get_transfer_omni_data(batch_size, train_portion=1., **kwargs):
"""
获取Omniglot数据
:param batch_size:batch的大小
:param data_mode:一共full nkks pair三种模式
:param frames_num:一个样本帧的个数
:param data_type:event frequency两种模式
"""
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor()])
train_dataset = datasets.Omniglot(
root="/data/datasets/", background=True, download=True, transform=transform
)
test_dataset = datasets.Omniglot(
root="/data/datasets/", background=False, download=True, transform=transform
)
dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset])
dataset_length = len(dataset)
train_loader = torch.utils.data.DataLoader(
dataset, batch_size=35000, num_workers=12,
pin_memory=True, drop_last=False,
sampler=TransferSampler(torch.arange(0, dataset_length).tolist())
)
return train_loader, None, None, None
def get_esimnet_data(batch_size, step, **kwargs):
"""
获取ES imagenet数据
DOI: 10.3389/fnins.2021.726582
:param batch_size: batch size
:param step: 仿真步长,固定为8
:param reconstruct: 重构则时间步为1, 否则为8
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
:note: 没有自动下载, 下载及md5请参考spikingjelly, sampler默认为DistributedSampler
"""
reconstruct = kwargs["reconstruct"] if "reconstruct" in kwargs else False
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: dvs_channel_check_expend(x),
])
if reconstruct:
assert step == 1
train_dataset = ESImagenet2D_Dataset(mode='train',
data_set_path=os.path.join(DATA_DIR, 'DVS/ES-imagenet-0.18/extract/ES-imagenet-0.18/'),
transform=train_transform)
test_dataset = ESImagenet2D_Dataset(mode='test',
data_set_path=os.path.join(DATA_DIR, 'DVS/ES-imagenet-0.18/extract/ES-imagenet-0.18/'),
transform=test_transform)
else:
assert step == 8
train_dataset = ESImagenet_Dataset(mode='train',
data_set_path=os.path.join(DATA_DIR,
'DVS/ES-imagenet-0.18/extract/ES-imagenet-0.18/'),
transform=train_transform)
test_dataset = ESImagenet_Dataset(mode='test',
data_set_path=os.path.join(DATA_DIR,
'DVS/ES-imagenet-0.18/extract/ES-imagenet-0.18/'),
transform=test_transform)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
pin_memory=True, drop_last=True, num_workers=8,
sampler=train_sampler
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=8,
sampler=test_sampler
)
# train_loader = torch.utils.data.DataLoader(
# train_dataset, batch_size=batch_size,
# pin_memory=True, drop_last=True, num_workers=8,
# shuffle=True
# )
#
# test_loader = torch.utils.data.DataLoader(
# test_dataset, batch_size=batch_size,
# pin_memory=True, drop_last=False, num_workers=1,
# shuffle=False
# )
return train_loader, test_loader, mixup_active, None
def get_CUB2002011_data(batch_size, num_workers=8, same_da=False, *args, **kwargs):
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
root=os.path.join(DATA_DIR, 'CUB2002011')
train_datasets = CUB2002011(
root=root, train=True, transform=test_transform if same_da else train_transform, download=True)
test_datasets = CUB2002011(
root=root, train=False, transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_StanfordCars_data(batch_size, num_workers=8, same_da=False, *args, **kwargs):
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
root=os.path.join(DATA_DIR, 'StanfordCars')
train_datasets = datasets.StanfordCars(
root=root, split ="train", transform=test_transform if same_da else train_transform, download=True)
test_datasets = datasets.StanfordCars(
root=root, split ="test", transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_StanfordDogs_data(batch_size, num_workers=8, same_da=False, *args, **kwargs):
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
root=os.path.join(DATA_DIR, 'StanfordDogs')
train_datasets = StanfordDogs(
root=root, train=True, transform=test_transform if same_da else train_transform, download=True)
test_datasets = StanfordDogs(
root=root, train=False, transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_FGVCAircraft_data(batch_size, num_workers=8, same_da=False, *args, **kwargs):
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
root=os.path.join(DATA_DIR, 'FGVCAircraft')
train_datasets = datasets.FGVCAircraft(
root=root, split="train", transform=test_transform if same_da else train_transform, download=True)
test_datasets = datasets.FGVCAircraft(
root=root, split="test", transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_Flowers102_data(batch_size, num_workers=8, same_da=False, *args, **kwargs):
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
root=os.path.join(DATA_DIR, 'Flowers102')
train_datasets = datasets.Flowers102(
root=root, split="train", transform=test_transform if same_da else train_transform, download=True)
test_datasets = datasets.Flowers102(
root=root, split="test", transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
================================================
FILE: examples/Perception_and_Learning/img_cls/transfer_for_dvs/main.py
================================================
# -*- coding: utf-8 -*-
# Time : 2023/4/19 14:58
# Author : Regulus
# FileName: main.py
# Explain:
# Software: PyCharm
import argparse
import time
import timm.models
import yaml
import os
import logging
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from braincog.base.node.node import *
from braincog.utils import *
from braincog.base.utils.criterions import *
from braincog.datasets.datasets import *
from braincog.model_zoo.resnet import *
from braincog.model_zoo.convnet import *
from braincog.model_zoo.vgg_snn import VGG_SNN
from braincog.model_zoo.resnet19_snn import resnet19
from braincog.utils import save_feature_map, setup_seed
from braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix
import torch
import torch.nn as nn
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
# from ptflops import get_model_complexity_info
# from thop import profile, clever_format
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser = argparse.ArgumentParser(description='SNN Training and Evaluating')
# Model parameters
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--model', default='cifar_convnet', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',
help='path to eval checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
parser.add_argument('--num-classes', type=int, default=10, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
# Dataset parameters for static datasets
parser.add_argument('--img-size', type=int, default=224, metavar='N',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='inputs image center crop percent (for validation only)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
# Dataloader parameters
parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='inputs batch size for training (default: 128)')
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
help='ratio of validation batch size to training batch size (default: 1)')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.01,
help='weight decay (default: 0.01 for adamw)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=5e-3, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--epochs', type=int, default=600, metavar='N',
help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--power', type=int, default=1, help='power')
# Augmentation & regularization parameters ONLY FOR IMAGE NET
parser.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
parser.add_argument('--jsd', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "const")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.8,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix', type=float, default=1.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.0)')
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
parser.add_argument('--newton-maxiter', default=20, type=int,
help='max iterration in newton method')
parser.add_argument('--reset-drop', action='store_true', default=False,
help='whether to reset drop')
parser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],
help='The implementation way of gaussian kernel method, choose from "cuda" and "torch"')
# Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
parser.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
parser.add_argument('--dist-bn', type=str, default='',
help='Distribute BatchNorm stats between node after each epoch ("broadcast", "reduce", or "")')
parser.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.99996,
help='decay factor for model weights moving average (default: 0.9998)')
# Misc
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=8, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--save-images', action='store_true', default=False,
help='save images of inputs bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--output', default='/home/hexiang/TransferLearning_For_DVS/Results_lastest/', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--device', type=int, default=0)
# Spike parameters
parser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')
parser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')
parser.add_argument('--temporal-flatten', action='store_true',
help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')
parser.add_argument('--adaptive-node', action='store_true')
parser.add_argument('--critical-loss', action='store_true')
# neuron type
parser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')
parser.add_argument('--act-fun', type=str, default='GateGrad',
help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')
parser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')
parser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')
parser.add_argument('--requires-thres-grad', action='store_true')
parser.add_argument('--sigmoid-thres', action='store_true')
parser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')
parser.add_argument('--noisy-grad', type=float, default=0.,
help='Add noise to backward, sometime will make higher accuracy (default: 0.)')
parser.add_argument('--spike-output', action='store_true', default=False,
help='Using mem output or spike output (default: False)')
parser.add_argument('--n_groups', type=int, default=1)
# EventData Augmentation
parser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')
parser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')
parser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')
parser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')
parser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')
parser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')
parser.add_argument('--cutmix_noise', type=float, default=0.,
help='Add Pepper noise after mix, sometimes work (default: 0.)')
parser.add_argument('--gaussian-n', type=int, default=3)
parser.add_argument('--rand-aug', action='store_true',
help='Rand Augment for Event data (default: False)')
parser.add_argument('--randaug_n', type=int, default=3,
help='Rand Augment times n (default: 3)')
parser.add_argument('--randaug_m', type=int, default=15,
help='Rand Augment times n (default: 15) (0-30)')
parser.add_argument('--train-portion', type=float, default=0.9,
help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')
parser.add_argument('--event-size', default=48, type=int,
help='Event size. Resize event data before process (default: 48)')
parser.add_argument('--layer-by-layer', action='store_true',
help='forward step-by-step or layer-by-layer. '
'Larger Model with layer-by-layer will be faster (default: False)')
parser.add_argument('--node-resume', type=str, default='',
help='resume weights in node for adaptive node. (default: False)')
parser.add_argument('--node-trainable', action='store_true')
# visualize
parser.add_argument('--visualize', action='store_true',
help='Visualize spiking map for each layer, only for validate (default: False)')
parser.add_argument('--spike-rate', action='store_true',
help='Print spiking rate for each layer, only for validate(default: False)')
parser.add_argument('--tsne', action='store_true')
parser.add_argument('--conf-mat', action='store_true')
parser.add_argument('--suffix', type=str, default='',
help='Add an additional suffix to the save path (default: \'\')')
# for reconstructing es-imagenet
parser.add_argument('--reconstructed', action='store_true',
help='for ES-imagenet dataset')
parser.add_argument('--DVS-DA', action='store_true',
help='use DA on DVS')
# train data used ratio
parser.add_argument('--traindata-ratio', default=1.0, type=float,
help='training data ratio')
# use TET loss or not (all default False, do not use)
parser.add_argument('--TET-loss-first', action='store_true',
help='use TET loss one part')
parser.add_argument('--TET-loss-second', action='store_true',
help='use TET loss two part')
try:
from apex import amp
from apex.parallel import DistributedDataParallel as ApexDDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
if args_config.config:
with open(args_config.config, 'r') as f:
cfg = yaml.safe_load(f)
parser.set_defaults(**cfg)
# The main arg parser parses the rest of the args, the usual
# defaults will have been overridden if config file specified.
args = parser.parse_args(remaining)
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
def main():
args, args_text = _parse_args()
# args.no_spike_output = args.no_spike_output | args.cut_mix
args.no_spike_output = True
output_dir = ''
if args.local_rank == 0:
output_base = args.output if args.output else './output'
exp_name = '-'.join([
args.model,
args.dataset,
str(args.step),
"seed_{}".format(args.seed),
"bs_{}".format(args.batch_size),
"DA_{}".format(args.DVS_DA),
"ls_{}".format(args.smoothing),
"lr_{}".format(args.lr),
"traindataratio_{}".format(args.traindata_ratio),
"TET_first_{}".format(args.TET_loss_first),
"TET_second_{}".format(args.TET_loss_second),
])
output_dir = get_outdir(output_base, 'Baseline', exp_name)
args.output_dir = output_dir
setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))
else:
setup_default_logging()
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1:
_logger.warning(
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
args.num_gpu = 1
# args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
if args.distributed:
args.num_gpu = 1
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
else:
torch.cuda.set_device('cuda:%d' % args.device)
assert args.rank >= 0
if args.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
_logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
# torch.manual_seed(args.seed + args.rank)
setup_seed(args.seed + args.rank)
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
adaptive_node=args.adaptive_node,
dataset=args.dataset,
step=args.step,
encode_type=args.encode,
node_type=eval(args.node_type),
threshold=args.threshold,
tau=args.tau,
sigmoid_thres=args.sigmoid_thres,
requires_thres_grad=args.requires_thres_grad,
spike_output=not args.no_spike_output,
act_fun=args.act_fun,
temporal_flatten=args.temporal_flatten,
layer_by_layer=args.layer_by_layer,
n_groups=args.n_groups,
reconstruct=args.reconstructed,
TET_loss=args.TET_loss_first or args.TET_loss_second
)
if 'dvs' in args.dataset:
args.channels = 2
elif 'mnist' in args.dataset:
args.channels = 1
else:
args.channels = 3
# flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)
# _logger.info('flops = %fM', flops / 1e6)
# _logger.info('param size = %fM', params / 1e6)
linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0
args.lr = linear_scaled_lr
_logger.info("learning rate is %f" % linear_scaled_lr)
if args.local_rank == 0:
_logger.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
num_aug_splits = 0
if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
use_amp = None
if args.amp:
# for backwards compat, `--amp` arg tries apex before native amp
if has_apex:
args.apex_amp = True
elif has_native_amp:
args.native_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
if args.num_gpu > 1:
if use_amp == 'apex':
_logger.warning(
'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
use_amp = None
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
assert not args.channels_last, "Channels last not supported with DP, use DDP."
else:
model = model.cuda()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
optimizer = create_optimizer(args, model)
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
loss_scaler = ApexScaler()
if args.local_rank == 0:
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
amp_autocast = torch.cuda.amp.autocast
loss_scaler = NativeScaler()
if args.local_rank == 0:
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if args.local_rank == 0:
_logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint
resume_epoch = None
if args.resume and args.eval_checkpoint == '':
args.eval_checkpoint = args.resume
if args.resume:
args.eval = True
# checkpoint = torch.load(args.resume, map_location='cpu')
# model.load_state_dict(checkpoint['state_dict'], False)
resume_epoch = resume_checkpoint(
model, args.resume,
optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0)
# print(model.get_attr('mu'))
# print(model.get_attr('sigma'))
if args.critical_loss or args.spike_rate:
model.set_requires_fp(True)
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume=args.resume)
if args.node_resume:
ckpt = torch.load(args.node_resume, map_location='cpu')
model.load_node_weight(ckpt, args.node_trainable)
model_without_ddp = model
if args.distributed:
if args.sync_bn:
assert not args.split_bn
try:
if has_apex and use_amp != 'native':
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
except Exception as e:
_logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
if has_apex and use_amp != 'native':
# Apex DDP preferred unless native amp is activated
if args.local_rank == 0:
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[args.local_rank],
find_unused_parameters=True) # can use device str in Torch >= 1.1
model_without_ddp = model.module
# NOTE: EMA model does not need to be wrapped by DDP
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
_logger.info('Scheduled epochs: {}'.format(num_epochs))
# now config only for imnet
data_config = resolve_data_config(vars(args), model=model, verbose=False)
loader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.dataset)(
batch_size=args.batch_size,
step=args.step,
dvs_da=args.DVS_DA,
args=args,
_logge=_logger,
data_config=data_config,
num_aug_splits=num_aug_splits,
size=args.event_size,
mix_up=args.mix_up,
cut_mix=args.cut_mix,
event_mix=args.event_mix,
beta=args.cutmix_beta,
prob=args.cutmix_prob,
gaussian_n=args.gaussian_n,
num=args.cutmix_num,
noise=args.cutmix_noise,
num_classes=args.num_classes,
rand_aug=args.rand_aug,
randaug_n=args.randaug_n,
randaug_m=args.randaug_m,
portion=args.train_portion,
reconstruct=args.reconstructed,
_logger=_logger,
train_data_ratio=args.traindata_ratio,
data_mode="full",
frames_num=12,
data_type="frequency"
)
if args.loss_fn == 'mse':
train_loss_fn = UnilateralMse(1.)
validate_loss_fn = UnilateralMse(1.)
else:
if args.jsd:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
elif mixup_active:
# smoothing is handled with mixup target transform
train_loss_fn = SoftTargetCrossEntropy().cuda()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
if args.loss_fn == 'mix':
train_loss_fn = MixLoss(train_loss_fn)
validate_loss_fn = MixLoss(validate_loss_fn)
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
if args.eval: # evaluate the model
if args.distributed:
state_dict = torch.load(args.eval_checkpoint)['state_dict_ema']
new_state_dict = OrderedDict()
# add module prefix for DDP
for k, v in state_dict.items():
k = 'module.' + k
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(torch.load(args.eval_checkpoint)['state_dict'])
for i in range(1):
val_metrics = validate(start_epoch, model, loader_eval, validate_loss_fn, args,
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat)
print(f"Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%")
return
saver = None
if args.local_rank == 0:
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=1)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
try: # train the model
if args.reset_drop:
model_without_ddp.reset_drop_path(0.0)
for epoch in range(start_epoch, args.epochs):
if epoch == 0 and args.reset_drop:
model_without_ddp.reset_drop_path(args.drop_path)
if args.distributed:
loader_train.sampler.set_epoch(epoch)
train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
_logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(epoch, model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast,
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
epoch, model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)',
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat)
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None)
# if saver is not None and epoch >= args.n_warm_up:
if saver is not None:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
# if epoch == 299: # 临时的
# break
except KeyboardInterrupt:
pass
if best_metric is not None:
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def train_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
closses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.train()
# t, k = adjust_surrogate_coeff(100, args.epochs)
# model.set_attr('t', t)
# model.set_attr('k', k)
end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
for batch_idx, (inputs, target) in enumerate(loader):
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
if not args.prefetcher or args.dataset != 'imnet':
inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()
if mixup_fn is not None:
inputs, target = mixup_fn(inputs, target)
if args.channels_last:
inputs = inputs.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(inputs)
tet_loss = 0.0
loss = 0.0
lamb = 1e-3
if args.TET_loss_first or args.TET_loss_second: # 第一项必须有,也就是测两个,第一个何第一个加第二个
for i in range(len(output)):
tet_loss += loss_fn(output[i], target)
tet_loss /= len(output)
loss = (1 - lamb) * tet_loss
else:
loss = loss_fn(output, target)
if args.TET_loss_second:
y = torch.zeros_like(output[-1]).fill_(args.threshold)
secondLoss = torch.nn.MSELoss()
tet_loss_second = secondLoss(output[-1], y)
loss += lamb * tet_loss_second
if args.TET_loss_first or args.TET_loss_second:
output = sum(output) / len(output)
if not (args.cut_mix | args.mix_up | args.event_mix) and args.dataset != 'imnet':
# print(output.shape, target.shape)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
# acc1, = accuracy(output, target)
else:
acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])
closs = torch.tensor([0.], device=loss.device)
loss = loss + .1 * closs
spike_rate_avg_layer_str = ''
threshold_str = ''
if not args.distributed:
losses_m.update(loss.item(), inputs.size(0))
top1_m.update(acc1.item(), inputs.size(0))
top5_m.update(acc5.item(), inputs.size(0))
closses_m.update(closs.item(), inputs.size(0))
spike_rate_avg_layer = model.get_fire_rate().tolist()
spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]
threshold = model.get_threshold()
threshold_str = ['{:.3f}'.format(i) for i in threshold]
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)
else:
loss.backward(create_graph=second_order)
if args.noisy_grad != 0.:
random_gradient(model, args.noisy_grad)
if args.clip_grad is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
if args.opt == 'lamb':
optimizer.step(epoch=epoch)
else:
optimizer.step()
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
num_updates += 1
batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
mu_str = ''
sigma_str = ''
if not args.distributed:
if 'Noise' in args.node_type:
mu, sigma = model.get_noise_param()
mu_str = ['{:.3f}'.format(i.detach()) for i in mu]
sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item(), inputs.size(0))
closses_m.update(reduced_loss.item(), inputs.size(0))
if args.local_rank == 0:
if args.distributed:
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'cLoss: {closs.val:>9.6f} ({closs.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
closs=closses_m,
top1=top1_m,
top5=top5_m,
batch_time=batch_time_m,
rate=inputs.size(0) * args.world_size / batch_time_m.val,
rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m
))
else:
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'cLoss: {closs.val:>9.6f} ({closs.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})\n'
'Fire_rate: {spike_rate}\n'
'Thres: {threshold}\n'
'Mu: {mu_str}\n'
'Sigma: {sigma_str}\n'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
closs=closses_m,
top1=top1_m,
top5=top5_m,
batch_time=batch_time_m,
rate=inputs.size(0) * args.world_size / batch_time_m.val,
rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m,
spike_rate=spike_rate_avg_layer_str,
threshold=threshold_str,
mu_str=mu_str,
sigma_str=sigma_str
))
if args.save_images and output_dir:
torchvision.utils.save_image(
inputs,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True)
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
end = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.avg)])
def validate(epoch, model, loader, loss_fn, args, amp_autocast=suppress,
log_suffix='', visualize=False, spike_rate=False, tsne=False, conf_mat=False):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
closses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.eval()
feature_vec = []
feature_cls = []
logits_vec = []
labels_vec = []
end = time.time()
last_idx = len(loader) - 1
with torch.no_grad():
for batch_idx, (inputs, target) in enumerate(loader):
# inputs = inputs.type(torch.float64)
last_batch = batch_idx == last_idx
if not args.prefetcher or args.dataset != 'imnet':
inputs = inputs.type(torch.FloatTensor).cuda()
target = target.cuda()
if args.channels_last:
inputs = inputs.contiguous(memory_format=torch.channels_last)
if not args.distributed:
if (visualize or spike_rate or tsne or conf_mat) and not args.critical_loss:
model.set_requires_fp(True)
# if not args.critical_loss:
# model.set_requires_fp(False)
with amp_autocast():
output = model(inputs)
if args.TET_loss_first or args.TET_loss_second:
output = sum(output) / len(output)
if isinstance(output, (tuple, list)):
output = output[0]
if not args.distributed:
if visualize:
x = model.get_fp()
feature_path = os.path.join(args.output_dir, 'feature_map')
if os.path.exists(feature_path) is False:
os.mkdir(feature_path)
save_feature_map(x, feature_path)
# if not args.critical_loss:
# model_config.set_requires_fp(False)
if tsne:
x = model.get_fp(temporal_info=False)[-1]
x = torch.nn.AdaptiveAvgPool2d((1, 1))(x)
x = x.reshape(x.shape[0], -1)
feature_vec.append(x)
feature_cls.append(target)
if conf_mat:
logits_vec.append(output)
labels_vec.append(target)
if spike_rate:
avg, var, spike, avg_per_step = model.get_spike_info()
save_spike_info(
os.path.join(args.output_dir, 'spike_info.csv'),
epoch, batch_idx,
args.step, avg, var,
spike, avg_per_step)
# augmentation reduction
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor]
loss = loss_fn(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
# acc1, = accuracy(output, target)
closs = torch.tensor([0.], device=loss.device)
if not args.distributed:
spike_rate_avg_layer = model.get_fire_rate().tolist()
threshold = model.get_threshold()
threshold_str = ['{:.3f}'.format(i) for i in threshold]
spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]
tot_spike = model.get_tot_spike()
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
acc1 = reduce_tensor(acc1, args.world_size)
acc5 = reduce_tensor(acc5, args.world_size)
else:
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), inputs.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
closses_m.update(closs.item(), inputs.size(0))
batch_time_m.update(time.time() - end)
end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
mu_str = ''
sigma_str = ''
if not args.distributed:
if 'Noise' in args.node_type:
mu, sigma = model.get_noise_param()
mu_str = ['{:.3f}'.format(i.detach()) for i in mu]
sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]
if args.distributed:
_logger.info(
'{0}: [{1:>4d}/{2}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'cLoss: {closs.val:>7.4f} ({closs.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
log_name,
batch_idx,
last_idx,
batch_time=batch_time_m,
loss=losses_m,
closs=closses_m,
top1=top1_m,
top5=top5_m,
))
else:
_logger.info(
'{0}: [{1:>4d}/{2}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'cLoss: {closs.val:>7.4f} ({closs.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})\n'
'Fire_rate: {spike_rate}\n'
'Tot_spike: {tot_spike}\n'
'Thres: {threshold}\n'
'Mu: {mu_str}\n'
'Sigma: {sigma_str}\n'.format(
log_name,
batch_idx,
last_idx,
batch_time=batch_time_m,
loss=losses_m,
closs=closses_m,
top1=top1_m,
top5=top5_m,
spike_rate=spike_rate_avg_layer_str,
tot_spike=tot_spike,
threshold=threshold_str,
mu_str=mu_str,
sigma_str=sigma_str
))
# metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg)])
if not args.distributed:
if tsne:
feature_vec = torch.cat(feature_vec)
feature_cls = torch.cat(feature_cls)
plot_tsne(feature_vec, feature_cls, os.path.join(args.output_dir, 't-sne-2d.eps'))
plot_tsne_3d(feature_vec, feature_cls, os.path.join(args.output_dir, 't-sne-3d.eps'))
if conf_mat:
logits_vec = torch.cat(logits_vec)
labels_vec = torch.cat(labels_vec)
plot_confusion_matrix(logits_vec, labels_vec, os.path.join(args.output_dir, 'confusion_matrix.eps'))
return metrics
if __name__ == '__main__':
main()
================================================
FILE: examples/Perception_and_Learning/img_cls/transfer_for_dvs/main_transfer.py
================================================
# -*- coding: utf-8 -*-
# Time : 2022/9/29 15:27
# Author : Regulus
# FileName: main_transfer.py
# Explain:
# Software: PyCharm
import argparse
import math
import time
import CKA
import numpy
import timm.models
import random as rd
import yaml
import os
import logging
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from braincog.base.node.node import *
from braincog.utils import *
from braincog.base.utils.criterions import *
from braincog.datasets.datasets import *
from braincog.model_zoo.resnet import *
from braincog.model_zoo.convnet import *
from braincog.model_zoo.vgg_snn import VGG_SNN
from braincog.model_zoo.resnet19_snn import resnet19
from braincog.utils import save_feature_map, setup_seed
from braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix
import torch
import torch.nn as nn
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from rgb_hsv import RGB_HSV
import matplotlib.pyplot as plt
from timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
# from ptflops import get_model_complexity_info
# from thop import profile, clever_format
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser = argparse.ArgumentParser(description='SNN Training and Evaluating')
# Model parameters
parser.add_argument('--source-dataset', default='cifar10', type=str)
parser.add_argument('--target-dataset', default='dvsc10', type=str)
parser.add_argument('--model', default='cifar_convnet', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',
help='path to eval checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
parser.add_argument('--num-classes', type=int, default=10, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
# Dataset parameters for static datasets
parser.add_argument('--img-size', type=int, default=224, metavar='N',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='inputs image center crop percent (for validation only)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
# Dataloader parameters
parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='inputs batch size for training (default: 128)')
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
help='ratio of validation batch size to training batch size (default: 1)')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.01,
help='weight decay (default: 0.01 for adamw)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=5e-3, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--epochs', type=int, default=600, metavar='N',
help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--power', type=int, default=1, help='power')
# Augmentation & regularization parameters ONLY FOR IMAGE NET
parser.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
parser.add_argument('--jsd', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "const")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.8,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix', type=float, default=1.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.0)')
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
parser.add_argument('--newton-maxiter', default=20, type=int,
help='max iterration in newton method')
parser.add_argument('--reset-drop', action='store_true', default=False,
help='whether to reset drop')
parser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],
help='The implementation way of gaussian kernel method, choose from "cuda" and "torch"')
# Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
parser.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
parser.add_argument('--dist-bn', type=str, default='',
help='Distribute BatchNorm stats between node after each epoch ("broadcast", "reduce", or "")')
parser.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.99996,
help='decay factor for model weights moving average (default: 0.9998)')
# Misc
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=8, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--save-images', action='store_true', default=False,
help='save images of inputs bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--output', default='/home/hexiang/TransferLearning_For_DVS/Results_lastest/', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--device', type=int, default=0)
# Spike parameters
parser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')
parser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')
parser.add_argument('--temporal-flatten', action='store_true',
help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')
parser.add_argument('--adaptive-node', action='store_true')
parser.add_argument('--critical-loss', action='store_true')
# neuron type
parser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')
parser.add_argument('--act-fun', type=str, default='GateGrad',
help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')
parser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')
parser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')
parser.add_argument('--requires-thres-grad', action='store_true')
parser.add_argument('--sigmoid-thres', action='store_true')
parser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')
parser.add_argument('--noisy-grad', type=float, default=0.,
help='Add noise to backward, sometime will make higher accuracy (default: 0.)')
parser.add_argument('--spike-output', action='store_true', default=False,
help='Using mem output or spike output (default: False)')
parser.add_argument('--n_groups', type=int, default=1)
# EventData Augmentation
parser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')
parser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')
parser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')
parser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')
parser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')
parser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')
parser.add_argument('--cutmix_noise', type=float, default=0.,
help='Add Pepper noise after mix, sometimes work (default: 0.)')
parser.add_argument('--gaussian-n', type=int, default=3)
parser.add_argument('--rand-aug', action='store_true',
help='Rand Augment for Event data (default: False)')
parser.add_argument('--randaug_n', type=int, default=3,
help='Rand Augment times n (default: 3)')
parser.add_argument('--randaug_m', type=int, default=15,
help='Rand Augment times n (default: 15) (0-30)')
parser.add_argument('--train-portion', type=float, default=0.9,
help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')
parser.add_argument('--event-size', default=48, type=int,
help='Event size. Resize event data before process (default: 48)')
parser.add_argument('--layer-by-layer', action='store_true',
help='forward step-by-step or layer-by-layer. '
'Larger Model with layer-by-layer will be faster (default: False)')
parser.add_argument('--node-resume', type=str, default='',
help='resume weights in node for adaptive node. (default: False)')
parser.add_argument('--node-trainable', action='store_true')
# visualize
parser.add_argument('--visualize', action='store_true',
help='Visualize spiking map for each layer, only for validate (default: False)')
parser.add_argument('--spike-rate', action='store_true',
help='Print spiking rate for each layer, only for validate(default: False)')
parser.add_argument('--tsne', action='store_true')
parser.add_argument('--conf-mat', action='store_true')
parser.add_argument('--suffix', type=str, default='',
help='Add an additional suffix to the save path (default: \'\')')
# Transfer Learning loss choice
parser.add_argument('--domain-loss', action='store_true',
help='add domain loss')
parser.add_argument('--semantic-loss', action='store_true',
help='add semantic loss')
parser.add_argument('--domain-loss-coefficient', type=float, default=1.0,
help='domain loss coefficient(default: 1.0)')
parser.add_argument('--semantic-loss-coefficient', type=float, default=1.0,
help='domain loss coefficient(default: 1.0)')
# use TET loss or not (all default False, do not use)
parser.add_argument('--TET-loss-first', action='store_true',
help='use TET loss one part')
parser.add_argument('--TET-loss-second', action='store_true',
help='use TET loss two part')
parser.add_argument('--DVS-DA', action='store_true',
help='use DA on DVS')
# train data used ratio
parser.add_argument('--traindata-ratio', default=1.0, type=float,
help='training data ratio')
# snr value
parser.add_argument('--snr', default=0, type=int,
help='random noise amplitude controled by snr, 0 means no noise')
# margin m
parser.add_argument('--m', default=-1.0, type=float,
help='margin')
source_input_list, source_label_list = [], []
CALTECH101_list, ImageNet_list = [], []
try:
from apex import amp
from apex.parallel import DistributedDataParallel as ApexDDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
if args_config.config:
with open(args_config.config, 'r') as f:
cfg = yaml.safe_load(f)
parser.set_defaults(**cfg)
# The main arg parser parses the rest of the args, the usual
# defaults will have been overridden if config file specified.
args = parser.parse_args(remaining)
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
def main():
torch.set_num_threads(20)
os.environ["OMP_NUM_THREADS"] = "20" # 设置OpenMP计算库的线程数
os.environ["MKL_NUM_THREADS"] = "20" # 设置MKL-DNN CPU加速库的线程数。
args, args_text = _parse_args()
args.no_spike_output = True
output_dir = ''
if args.local_rank == 0:
output_base = args.output if args.output else './output'
exp_name = '-'.join([
args.model,
args.target_dataset,
str(args.step),
"bs_{}".format(args.batch_size),
"seed_{}".format(args.seed),
"DA_{}".format(args.DVS_DA),
"ls_{}".format(args.smoothing),
"lr_{}".format(args.lr),
"m_{}".format(args.m),
"domainLoss_{}".format(args.domain_loss),
"semanticLoss_{}".format(args.semantic_loss),
"domain_loss_coefficient{}".format(args.domain_loss_coefficient),
"semantic_loss_coefficient{}".format(args.semantic_loss_coefficient),
"traindataratio_{}".format(args.traindata_ratio),
"TETfirst_{}".format(args.TET_loss_first),
"TETsecond_{}".format(args.TET_loss_second),
])
output_dir = get_outdir(output_base, 'train_TCKA_test_nop', exp_name)
args.output_dir = output_dir
setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))
else:
setup_default_logging()
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1:
_logger.warning(
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
args.num_gpu = 1
# args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
if args.distributed:
args.num_gpu = 1
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
else:
torch.cuda.set_device('cuda:%d' % args.device)
assert args.rank >= 0
if args.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
_logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
# torch.manual_seed(args.seed + args.rank)
setup_seed(args.seed + args.rank)
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
adaptive_node=args.adaptive_node,
dataset=args.target_dataset,
step=args.step,
encode_type=args.encode,
node_type=eval(args.node_type),
threshold=args.threshold,
tau=args.tau,
sigmoid_thres=args.sigmoid_thres,
requires_thres_grad=args.requires_thres_grad,
spike_output=not args.no_spike_output,
act_fun=args.act_fun,
temporal_flatten=args.temporal_flatten,
layer_by_layer=args.layer_by_layer,
n_groups=args.n_groups,
)
if 'dvs' in args.target_dataset:
args.channels = 2
elif 'mnist' in args.target_dataset:
args.channels = 1
else:
args.channels = 3
# flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)
# _logger.info('flops = %fM', flops / 1e6)
# _logger.info('param size = %fM', params / 1e6)
linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0
args.lr = linear_scaled_lr
_logger.info("learning rate is %f" % linear_scaled_lr)
if args.local_rank == 0:
_logger.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
num_aug_splits = 0
if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
use_amp = None
if args.amp:
# for backwards compat, `--amp` arg tries apex before native amp
if has_apex:
args.apex_amp = True
elif has_native_amp:
args.native_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
if args.num_gpu > 1:
if use_amp == 'apex':
_logger.warning(
'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
use_amp = None
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
assert not args.channels_last, "Channels last not supported with DP, use DDP."
else:
model = model.cuda()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
optimizer = create_optimizer(args, model)
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
loss_scaler = ApexScaler()
if args.local_rank == 0:
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
amp_autocast = torch.cuda.amp.autocast
loss_scaler = NativeScaler()
if args.local_rank == 0:
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if args.local_rank == 0:
_logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint
resume_epoch = None
if args.resume and args.eval_checkpoint == '':
args.eval_checkpoint = args.resume
if args.resume:
args.eval = True
resume_epoch = resume_checkpoint(
model, args.resume,
optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0)
if args.critical_loss or args.spike_rate:
model.set_requires_fp(True)
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume=args.resume)
if args.node_resume:
ckpt = torch.load(args.node_resume, map_location='cpu')
model.load_node_weight(ckpt, args.node_trainable)
model_without_ddp = model
if args.distributed:
if args.sync_bn:
assert not args.split_bn
try:
if has_apex and use_amp != 'native':
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
except Exception as e:
_logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
if has_apex and use_amp != 'native':
# Apex DDP preferred unless native amp is activated
if args.local_rank == 0:
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[args.local_rank],
find_unused_parameters=True) # can use device str in Torch >= 1.1
model_without_ddp = model.module
# NOTE: EMA model does not need to be wrapped by DDP
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
_logger.info('Scheduled epochs: {}'.format(num_epochs))
# now config only for imnet
data_config = resolve_data_config(vars(args), model=model, verbose=False)
source_loader_train, _, _, _ = eval('get_transfer_%s_data' % args.source_dataset)(
batch_size=args.batch_size,
step=args.step,
args=args,
_logge=_logger,
data_config=data_config,
num_aug_splits=num_aug_splits,
size=args.event_size,
mix_up=args.mix_up,
cut_mix=args.cut_mix,
event_mix=args.event_mix,
beta=args.cutmix_beta,
prob=args.cutmix_prob,
gaussian_n=args.gaussian_n,
num=args.cutmix_num,
noise=args.cutmix_noise,
num_classes=args.num_classes,
rand_aug=args.rand_aug,
randaug_n=args.randaug_n,
randaug_m=args.randaug_m,
portion=args.train_portion,
_logger=_logger,
)
target_loader_train, target_loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.target_dataset)(
batch_size=args.batch_size,
dvs_da=args.DVS_DA,
step=args.step,
args=args,
_logge=_logger,
data_config=data_config,
num_aug_splits=num_aug_splits,
size=args.event_size,
mix_up=args.mix_up,
cut_mix=args.cut_mix,
event_mix=args.event_mix,
beta=args.cutmix_beta,
prob=args.cutmix_prob,
gaussian_n=args.gaussian_n,
num=args.cutmix_num,
noise=args.cutmix_noise,
num_classes=args.num_classes,
rand_aug=args.rand_aug,
randaug_n=args.randaug_n,
randaug_m=args.randaug_m,
portion=args.train_portion,
_logger=_logger,
train_data_ratio=args.traindata_ratio,
snr=args.snr,
data_mode="full",
frames_num=12,
data_type="frequency"
)
global source_input_list, source_label_list, CALTECH101_list, ImageNet_list
if args.target_dataset == "dvsc10" or args.target_dataset == "NCALTECH101" or args.target_dataset == "nomni": # ImageNet中回来的loader其实是数据集,在后面处理
source_input_list, source_label_list = next(iter(source_loader_train))
# for i in range(30001, 30005):
# # vis origin picture
# plt.figure()
# plt.imshow(source_input_list[i].permute(1, 2, 0).numpy())
# plt.savefig("./origin_image.jpg")
# plt.show()
# vis HSV picture
# for i in range(30001, 30005): # 30001.i
# convertor = RGB_HSV()
# plt.figure()
# plt.imshow(convertor.rgb_to_hsv(source_input_list)[i, :, :, :].permute(1, 2, 0).numpy())
# plt.title("HSV image")
# plt.show()
if args.source_dataset == "CALTECH101":
cls_count = [438, 435, 200, 791, 49, 800, 41, 34, 45, 50, 45, 32, 128, 84, 38, 81, 86, 47, 40, 0, 45, 58, 61,
105, 47, 64, 70, 68, 50, 51, 54, 67, 51, 64, 65, 72, 62, 52, 60, 83, 65, 67, 45, 31, 34, 49, 99,
100, 42, 54, 86, 80, 30, 62, 86, 110, 61, 79, 77, 40, 65, 42, 35, 77, 31, 74, 49, 32, 39, 47, 35,
43, 52, 34, 54, 69, 58, 45, 38, 57, 34, 84, 57, 31, 54, 45, 82, 56, 63, 35, 85, 43, 82, 74, 239,
37, 53, 33, 55, 29, 42]
CALTECH101_list = [0] * 102 # 多开了一类, 方便计算
for i in range(1, len(cls_count) + 1):
CALTECH101_list[i] = CALTECH101_list[i - 1] + cls_count[i - 1]
if args.source_dataset == "NCALTECH101":
cls_count = tonic.datasets.NCALTECH101.cls_count
CALTECH101_list = [0] * 102 # 多开了一类, 方便计算
for i in range(1, len(cls_count) + 1):
CALTECH101_list[i] = CALTECH101_list[i - 1] + cls_count[i - 1]
if args.source_dataset == "imnet":
cls_count = [1300] * 1000 # 1000类
cls_count_idx = [1117, 1266, 1071, 1141, 1272, 1150, 772, 860, 1136, 732, 1025, 754, 1290, 738, 1258, 1273, 977,
936, 1156, 1218, 969, 954, 1070, 755, 1206, 1165, 969, 1292, 1236, 1199, 1209, 1176, 1186,
1194,
1067, 1029, 1154, 1216, 1187, 889, 1211, 1136, 1153, 1222, 1282, 1283, 980, 1034, 891, 1285,
986,
1137, 1272, 1155, 1097, 1149, 1155, 1159, 1133, 1180, 1120, 1005, 1152, 1156, 962, 1157, 1282,
1117, 1118, 1270, 1069, 1053, 1254, 908, 1247, 1253, 1029, 1259, 1267, 1249, 1162, 1045, 1004,
1238, 1153, 1084, 1217, 931, 1264, 976, 1250, 1053, 1160, 1062, 1137, 1299, 1055, 1213, 1206,
1154,
1207, 1149, 1239, 1125, 1193]
cls_idx = [43, 51, 62, 98, 103, 147, 152, 158, 164, 165, 166, 167, 168, 175, 181, 183, 188, 190, 194, 206, 221,
252, 262, 268, 335, 390, 392, 409, 418, 426, 439, 465, 481, 491, 499, 501, 503, 507, 521, 531, 536,
550, 551, 567, 577, 583, 585, 590, 596, 610, 623, 630, 631, 635, 653, 662, 663, 675, 676, 678, 686,
689, 706, 708, 712, 714, 722, 723, 724, 727, 728, 729, 731, 740, 747, 753, 771, 772, 782, 789, 798,
810, 811, 812, 821, 826, 838, 841, 854, 857, 860, 869, 872, 885, 891, 892, 901, 906, 914, 921, 925,
926, 940, 946, 969]
for i in range(len(cls_count)):
if i in cls_idx:
cls_count[i] = cls_count_idx[cls_idx.index(i)]
ImageNet_list = [0] * 1001 # 多开了一类, 方便计算
for i in range(1, 1000 + 1):
ImageNet_list[i] = ImageNet_list[i - 1] + cls_count[i - 1]
if args.loss_fn == 'mse':
train_loss_fn = UnilateralMse(1.)
validate_loss_fn = UnilateralMse(1.)
else:
if args.jsd:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
elif mixup_active:
# smoothing is handled with mixup target transform
train_loss_fn = SoftTargetCrossEntropy().cuda()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
if args.loss_fn == 'mix':
train_loss_fn = MixLoss(train_loss_fn)
validate_loss_fn = MixLoss(validate_loss_fn)
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
if args.eval: # evaluate the model
if args.distributed:
state_dict = torch.load(args.eval_checkpoint)['state_dict_ema']
new_state_dict = OrderedDict()
# add module prefix for DDP
for k, v in state_dict.items():
k = 'module.' + k
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(torch.load(args.eval_checkpoint)['state_dict'])
for i in range(1):
val_metrics = validate(start_epoch, model, target_loader_eval, validate_loss_fn, args,
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat)
print(f"Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%")
return
saver = None
if args.local_rank == 0:
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=1)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
eval_top1 = 0.0
try: # train the model
if args.reset_drop:
model_without_ddp.reset_drop_path(0.0)
for epoch in range(start_epoch, args.epochs):
if epoch == 0 and args.reset_drop:
model_without_ddp.reset_drop_path(args.drop_path)
if args.distributed:
target_loader_train.sampler.set_epoch(epoch)
train_metrics = train_epoch(
epoch, model, source_loader_train, target_loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
_logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(epoch, model, target_loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast,
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat)
eval_top1 = eval_metrics["top1"]
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
epoch, model_ema.ema, target_loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)',
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat)
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None)
# if saver is not None and epoch >= args.n_warm_up:
if saver is not None:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
# if epoch == 299: # 临时的
# break
except KeyboardInterrupt:
pass
if best_metric is not None:
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def train_epoch(
epoch, model, source_loader, target_loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and target_loader.mixup_enabled:
target_loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
domain_losses_m = AverageMeter()
semantic_losses_m = AverageMeter()
rgb_losses_m = AverageMeter()
dvs_losses_m = AverageMeter()
closses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.train()
end = time.time()
last_idx = len(target_loader) - 1
num_updates = epoch * len(target_loader)
convertor = RGB_HSV()
batch_len = len(target_loader)
if args.target_dataset == "dvsc10":
set_MaxReplacement_epoch = 0.5 * args.epochs
else:
set_MaxReplacement_epoch = 0.5 * args.epochs
P_Replacement = 0.0
global source_input_list, source_label_list, CALTECH101_list, ImageNet_list
for batch_idx, (inputs, label) in enumerate(target_loader):
P_Replacement = ((batch_idx + epoch * batch_len) / (set_MaxReplacement_epoch * batch_len)) ** 3
P_Replacement = P_Replacement if P_Replacement <= 1.0 else 1.0
sampler_list = label.tolist()
if args.target_dataset == "dvsc10" and args.source_dataset == "cifar10":
sampler_list = torch.tensor(sampler_list) * 6000 + torch.randint(0, 6000, (len(sampler_list),))
elif args.target_dataset == "dvsc10" and args.source_dataset == "dvsc10":
sampler_list = torch.tensor(sampler_list) * 900 + torch.randint(0, 900, (len(sampler_list),))
elif args.target_dataset == "NCALTECH101":
tmp_sampler_list = []
idx_list = []
for idx, label_sampler in enumerate(sampler_list):
if label_sampler == 19:
tmp_sampler_list.append(0)
idx_list.append(idx)
else:
tmp_sampler_list.append(torch.randint(CALTECH101_list[label_sampler],
CALTECH101_list[label_sampler + 1], (1,)).item())
elif args.target_dataset == "esimnet":
tmp_sampler_list = []
for idx, label_sampler in enumerate(sampler_list): # 这里的label_sampler是一个列表
tmp_sampler_list.append(torch.randint(ImageNet_list[label_sampler],
ImageNet_list[label_sampler + 1], (1,)).item())
elif args.target_dataset == "nomni":
sampler_list = torch.tensor(sampler_list) * 20 + torch.randint(0, 20, (len(sampler_list),))
source_input, source_label = [], []
if args.target_dataset == "dvsc10":
source_input, source_label = source_input_list[sampler_list], source_label_list[sampler_list]
if args.target_dataset == "NCALTECH101":
source_input, source_label = source_input_list[tmp_sampler_list], source_label_list[tmp_sampler_list]
if args.target_dataset == "esimnet":
train_dataset = source_loader # 给传回来的重新命个名儿
source_loader_used = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size, shuffle=False,
num_workers=8, pin_memory=True, sampler=TransferSampler(tmp_sampler_list))
source_input, source_label = next(iter(source_loader_used))
if args.target_dataset == "nomni":
source_input, source_label = source_input_list[sampler_list], source_label_list[sampler_list]
# for i in range(128):
# # vis origin picture
# plt.figure()
# plt.imshow(source_input[i].permute(1, 2, 0))
# plt.title("origin image")
# plt.show()
# # vis HSV picture
# plt.figure()
# plt.imshow(convertor.rgb_to_hsv(inputs)[7, :, :, :].permute(1, 2, 0).numpy())
# plt.title("HSV image")
# plt.show()
# source_input = convertor.rgb_to_hsv(source_input)[:, -1, :, :].unsqueeze(1).repeat(1, args.step * 2, 1, 1)
if args.source_dataset == "dvsc10" or args.source_dataset == "NCALTECH101":
pass
else:
source_input = source_input[:, -1, :, :].unsqueeze(1).repeat(1, args.step * 2, 1, 1)
source_input = rearrange(source_input, 'b (t c) h w -> b t c h w', t=args.step)
for b in range(source_input.shape[0]):
if rd.uniform(0, 1) <= P_Replacement:
source_input[b] = inputs[b, :, :, :, :]
# for i in range(10):
# # vis HSV picture for v channel
# plt.figure()
# plt.imshow(source_input[i][0].permute(1, 2, 0)[:, :, -1].unsqueeze(2))
# plt.title("HSV image for v channel")
# plt.show()
if args.target_dataset == "NCALTECH101" and len(idx_list) > 0:
for i in range(len(idx_list)):
source_input[idx_list[i]] = inputs[idx_list[i]]
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
if not args.prefetcher or args.target_dataset != 'imnet':
inputs, label = inputs.type(torch.FloatTensor).cuda(), label.cuda()
source_input, source_label = source_input.type(torch.FloatTensor).cuda(), label.cuda()
if mixup_fn is not None:
inputs, label = mixup_fn(inputs, label)
source_input, source_label = mixup_fn(source_input, source_label)
if args.channels_last:
inputs = inputs.contiguous(memory_format=torch.channels_last)
source_input = source_input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
domain_rbg_list, domain_dvs_list, output_rgb, output_dvs = model(source_input, inputs)
# compute semantic loss
label_idx = [[] for i in range(args.num_classes)]
semantic_label_list = []
for idx, i in enumerate(label):
label_idx[i.item()].append(idx)
for i in label:
while True:
label_tmp = torch.randint(0, args.num_classes, (1,)).item()
if i.item() != label_tmp and len(label_idx[label_tmp]) > 0: # NCALTECH101有空列表, 需要判断
break
semantic_label_list.append(int(np.random.choice(label_idx[label_tmp], 1)))
semantic_rbg_list = []
semantic_loss = 0.
for i in range(len(domain_rbg_list)):
semantic_rbg_list.append(domain_rbg_list[i][semantic_label_list])
for i in range(len(domain_rbg_list)):
semantic_loss += torch.abs(CKA.linear_CKA(domain_dvs_list[i].view(args.batch_size, -1), semantic_rbg_list[i].view(args.batch_size, -1)))
semantic_loss /= len(domain_rbg_list)
if args.target_dataset == "dvsc10":
m = 0.1
elif args.target_dataset == "NCALTECH101":
m = 0.3
else:
m = 0.2
if args.m >= 0.0:
m = args.m
if semantic_loss.item() - m <= 0:
semantic_loss = torch.tensor(0., device=semantic_loss.device)
# if args.domain_loss_after:
# # compute domain loss
# for b in range(source_input.shape[0]):
# if rd.uniform(0, 1) <= P_Replacement:
# for i in range(len(domain_rbg_list)):
# domain_rbg_list[i][b] = domain_dvs_list[i][b, :, :, :]
domain_loss = 0.
for i in range(len(domain_rbg_list)):
domain_loss += 1 - torch.abs(CKA.linear_CKA(domain_rbg_list[i].view(args.batch_size, -1), domain_dvs_list[i].view(args.batch_size, -1)))
domain_loss /= len(domain_rbg_list)
# compute cls loss
lamb = 1e-3
if args.TET_loss_first or args.TET_loss_second: # 第一项必须有,也就是测两个,第一个何第一个加第二个
loss_rgb = 0
tet_loss_first = 0
tet_loss_second = 0
assert len(output_rgb) == len(output_dvs)
for i in range(len(output_rgb)):
loss_rgb += loss_fn(output_rgb[i], label)
tet_loss_first += loss_fn(output_dvs[i], label)
loss_rgb /= len(output_rgb)
tet_loss_first /= len(output_dvs)
if args.TET_loss_second:
y = torch.zeros_like(output_dvs[-1]).fill_(args.threshold)
secondLoss = torch.nn.MSELoss()
tet_loss_second = secondLoss(output_dvs[-1], y)
else:
lamb = 0.0
loss_dvs = (1 - lamb) * tet_loss_first + lamb * tet_loss_second
output_rgb = sum(output_rgb) / len(output_rgb)
output_dvs = sum(output_dvs) / len(output_dvs)
else:
output_rgb = sum(output_rgb) / len(output_rgb)
output_dvs = sum(output_dvs) / len(output_dvs)
loss_rgb = loss_fn(output_rgb, label)
loss_dvs = loss_fn(output_dvs, label)
loss = 0 * loss_rgb + loss_dvs
if args.domain_loss:
loss += args.domain_loss_coefficient * domain_loss
if args.semantic_loss and epoch <= set_MaxReplacement_epoch:
if args.target_dataset == "NCALTECH101" and epoch <= set_MaxReplacement_epoch * 0.66:
# loss += args.semantic_loss_coefficient * semantic_loss * math.pow(10, -1.0 * float(set_MaxReplacement_epoch / (epoch+1)))
pass
else:
loss += args.semantic_loss_coefficient * semantic_loss
if not (args.cut_mix | args.mix_up | args.event_mix) and args.target_dataset != 'imnet':
acc1, acc5 = accuracy(output_dvs, label, topk=(1, 5))
else:
acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])
closs = torch.tensor([0.], device=loss.device)
loss = loss + .1 * closs
spike_rate_avg_layer_str = ''
threshold_str = ''
if not args.distributed:
losses_m.update(loss.item(), inputs.size(0))
domain_losses_m.update(domain_loss.item(), inputs.size(0))
semantic_losses_m.update(semantic_loss.item(), inputs.size(0))
rgb_losses_m.update(loss_rgb.item(), inputs.size(0))
dvs_losses_m.update(loss_dvs.item(), inputs.size(0))
top1_m.update(acc1.item(), inputs.size(0))
top5_m.update(acc5.item(), inputs.size(0))
closses_m.update(closs.item(), inputs.size(0))
spike_rate_avg_layer = model.get_fire_rate().tolist()
spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]
threshold = model.get_threshold()
threshold_str = ['{:.3f}'.format(i.item()) for i in threshold]
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)
else:
loss.backward(create_graph=second_order)
if args.noisy_grad != 0.:
random_gradient(model, args.noisy_grad)
if args.clip_grad is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
if args.opt == 'lamb':
optimizer.step(epoch=epoch)
else:
optimizer.step()
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
num_updates += 1
batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
mu_str = ''
sigma_str = ''
if not args.distributed:
if 'Noise' in args.node_type:
mu, sigma = model.get_noise_param()
mu_str = ['{:.3f}'.format(i.detach()) for i in mu]
sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item(), inputs.size(0))
closses_m.update(reduced_loss.item(), inputs.size(0))
if args.local_rank == 0:
if args.distributed:
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'cLoss: {closs.val:>9.6f} ({closs.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(target_loader),
100. * batch_idx / last_idx,
loss=losses_m,
closs=closses_m,
top1=top1_m,
top5=top5_m,
batch_time=batch_time_m,
rate=inputs.size(0) * args.world_size / batch_time_m.val,
rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m
))
else:
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'cLoss: {closs.val:>9.6f} ({closs.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})\n'
'Fire_rate: {spike_rate}\n'
'Thres: {threshold}\n'
'Mu: {mu_str}\n'
'Sigma: {sigma_str}\n'
'P_Replacement: {P_Replacement}\n'.format(
epoch,
batch_idx, len(target_loader),
100. * batch_idx / last_idx,
loss=losses_m,
closs=closses_m,
top1=top1_m,
top5=top5_m,
batch_time=batch_time_m,
rate=inputs.size(0) * args.world_size / batch_time_m.val,
rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m,
spike_rate=spike_rate_avg_layer_str,
threshold=threshold_str,
mu_str=mu_str,
sigma_str=sigma_str,
P_Replacement=P_Replacement,
))
if args.save_images and output_dir:
torchvision.utils.save_image(
inputs,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True)
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
end = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.avg), ('domainLoss', domain_losses_m.avg), ('semanticLoss', semantic_losses_m.avg),
('rgbLoss', rgb_losses_m.avg), ('dvsLoss', dvs_losses_m.avg)])
def validate(epoch, model, loader, loss_fn, args, amp_autocast=suppress,
log_suffix='', visualize=False, spike_rate=False, tsne=False, conf_mat=False):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
closses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.eval()
feature_vec = []
feature_cls = []
logits_vec = []
labels_vec = []
end = time.time()
last_idx = len(loader) - 1
with torch.no_grad():
for batch_idx, (inputs, target) in enumerate(loader):
# inputs = inputs.type(torch.float64)
last_batch = batch_idx == last_idx
if not args.prefetcher or args.target_dataset != 'imnet':
inputs = inputs.type(torch.FloatTensor).cuda()
target = target.cuda()
if args.channels_last:
inputs = inputs.contiguous(memory_format=torch.channels_last)
if not args.distributed:
if (visualize or spike_rate or tsne or conf_mat) and not args.critical_loss:
model.set_requires_fp(True)
# if not args.critical_loss:
# model.set_requires_fp(False)
with amp_autocast():
_, _, output_rbg, output_dvs = model(inputs, inputs)
output = sum(output_dvs) / len(output_dvs)
if isinstance(output, (tuple, list)):
output = output[0]
if not args.distributed:
if visualize:
x = model.get_fp()
feature_path = os.path.join(args.output_dir, 'feature_map')
if os.path.exists(feature_path) is False:
os.mkdir(feature_path)
save_feature_map(x, feature_path)
# if not args.critical_loss:
# model_config.set_requires_fp(False)
if tsne:
x = model.get_fp(temporal_info=False)[-1]
x = torch.nn.AdaptiveAvgPool2d((1, 1))(x)
x = x.reshape(x.shape[0], -1)
feature_vec.append(x)
feature_cls.append(target)
if conf_mat:
logits_vec.append(output)
labels_vec.append(target)
if spike_rate:
avg, var, spike, avg_per_step = model.get_spike_info()
save_spike_info(
os.path.join(args.output_dir, 'spike_info.csv'),
epoch, batch_idx,
args.step, avg, var,
spike, avg_per_step)
# augmentation reduction
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor]
loss = loss_fn(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
# acc1, = accuracy(output, target)
closs = torch.tensor([0.], device=loss.device)
if not args.distributed:
spike_rate_avg_layer = model.get_fire_rate().tolist()
threshold = model.get_threshold()
threshold_str = ['{:.3f}'.format(i) for i in threshold]
spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]
tot_spike = model.get_tot_spike()
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
acc1 = reduce_tensor(acc1, args.world_size)
acc5 = reduce_tensor(acc5, args.world_size)
else:
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), inputs.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
closses_m.update(closs.item(), inputs.size(0))
batch_time_m.update(time.time() - end)
end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
mu_str = ''
sigma_str = ''
if not args.distributed:
if 'Noise' in args.node_type:
mu, sigma = model.get_noise_param()
mu_str = ['{:.3f}'.format(i.detach()) for i in mu]
sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]
if args.distributed:
_logger.info(
'{0}: [{1:>4d}/{2}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'cLoss: {closs.val:>7.4f} ({closs.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
log_name,
batch_idx,
last_idx,
batch_time=batch_time_m,
loss=losses_m,
closs=closses_m,
top1=top1_m,
top5=top5_m,
))
else:
_logger.info(
'{0}: [{1:>4d}/{2}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'cLoss: {closs.val:>7.4f} ({closs.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})\n'
'Fire_rate: {spike_rate}\n'
'Tot_spike: {tot_spike}\n'
'Thres: {threshold}\n'
'Mu: {mu_str}\n'
'Sigma: {sigma_str}\n'.format(
log_name,
batch_idx,
last_idx,
batch_time=batch_time_m,
loss=losses_m,
closs=closses_m,
top1=top1_m,
top5=top5_m,
spike_rate=spike_rate_avg_layer_str,
tot_spike=tot_spike,
threshold=threshold_str,
mu_str=mu_str,
sigma_str=sigma_str
))
# metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg)])
if not args.distributed:
if tsne:
feature_vec = torch.cat(feature_vec)
feature_cls = torch.cat(feature_cls)
plot_tsne(feature_vec, feature_cls, os.path.join(args.output_dir, 't-sne-2d.eps'))
plot_tsne_3d(feature_vec, feature_cls, os.path.join(args.output_dir, 't-sne-3d.eps'))
if conf_mat:
logits_vec = torch.cat(logits_vec)
labels_vec = torch.cat(labels_vec)
plot_confusion_matrix(logits_vec, labels_vec, os.path.join(args.output_dir, 'confusion_matrix.eps'))
return metrics
if __name__ == '__main__':
main()
================================================
FILE: examples/Perception_and_Learning/img_cls/transfer_for_dvs/main_visual_losslandscape.py
================================================
# -*- coding: utf-8 -*-
# Time : 2023/2/14 11:52
# Author : Regulus
# FileName: main_visual_losslandscape.py
# Explain:
# Software: PyCharm
from loss_landscape.plot_surface import *
import argparse
import math
import time
import CKA
import numpy
import timm.models
import random as rd
import yaml
import os
import logging
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from braincog.base.node.node import *
from braincog.utils import *
from braincog.base.utils.criterions import *
from braincog.datasets.datasets import *
from braincog.model_zoo.resnet import *
from braincog.model_zoo.convnet import *
from braincog.model_zoo.vgg_snn import VGG_SNN
from braincog.model_zoo.resnet19_snn import resnet19
from braincog.utils import save_feature_map, setup_seed
from braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix
import torch
import torch.nn as nn
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from rgb_hsv import RGB_HSV
import matplotlib.pyplot as plt
from timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser = argparse.ArgumentParser(description='SNN Training and Evaluating')
# Model parameters
parser.add_argument('--source-dataset', default='cifar10', type=str)
parser.add_argument('--target-dataset', default='dvsc10', type=str)
parser.add_argument('--model', default='cifar_convnet', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',
help='path to eval checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
parser.add_argument('--num-classes', type=int, default=10, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
# Dataset parameters for static datasets
parser.add_argument('--img-size', type=int, default=224, metavar='N',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='inputs image center crop percent (for validation only)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
# Dataloader parameters
parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='inputs batch size for training (default: 128)')
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
help='ratio of validation batch size to training batch size (default: 1)')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.01,
help='weight decay (default: 0.01 for adamw)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=5e-3, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--epochs', type=int, default=600, metavar='N',
help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--power', type=int, default=1, help='power')
# Augmentation & regularization parameters ONLY FOR IMAGE NET
parser.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
parser.add_argument('--jsd', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "const")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.8,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix', type=float, default=1.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.0)')
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
parser.add_argument('--newton-maxiter', default=20, type=int,
help='max iterration in newton method')
parser.add_argument('--reset-drop', action='store_true', default=False,
help='whether to reset drop')
parser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],
help='The implementation way of gaussian kernel method, choose from "cuda" and "torch"')
# Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
parser.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
parser.add_argument('--dist-bn', type=str, default='',
help='Distribute BatchNorm stats between node after each epoch ("broadcast", "reduce", or "")')
parser.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.99996,
help='decay factor for model weights moving average (default: 0.9998)')
# Misc
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=8, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--save-images', action='store_true', default=False,
help='save images of inputs bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--output', default='/home/hexiang/TransferLearning_For_DVS/Results_new_refined/', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--device', type=int, default=0)
# Spike parameters
parser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')
parser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')
parser.add_argument('--temporal-flatten', action='store_true',
help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')
parser.add_argument('--adaptive-node', action='store_true')
parser.add_argument('--critical-loss', action='store_true')
# neuron type
parser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')
parser.add_argument('--act-fun', type=str, default='GateGrad',
help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')
parser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')
parser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')
parser.add_argument('--requires-thres-grad', action='store_true')
parser.add_argument('--sigmoid-thres', action='store_true')
parser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')
parser.add_argument('--noisy-grad', type=float, default=0.,
help='Add noise to backward, sometime will make higher accuracy (default: 0.)')
parser.add_argument('--spike-output', action='store_true', default=False,
help='Using mem output or spike output (default: False)')
parser.add_argument('--n_groups', type=int, default=1)
# EventData Augmentation
parser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')
parser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')
parser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')
parser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')
parser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')
parser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')
parser.add_argument('--cutmix_noise', type=float, default=0.,
help='Add Pepper noise after mix, sometimes work (default: 0.)')
parser.add_argument('--gaussian-n', type=int, default=3)
parser.add_argument('--rand-aug', action='store_true',
help='Rand Augment for Event data (default: False)')
parser.add_argument('--randaug_n', type=int, default=3,
help='Rand Augment times n (default: 3)')
parser.add_argument('--randaug_m', type=int, default=15,
help='Rand Augment times n (default: 15) (0-30)')
parser.add_argument('--train-portion', type=float, default=0.9,
help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')
parser.add_argument('--event-size', default=48, type=int,
help='Event size. Resize event data before process (default: 48)')
parser.add_argument('--layer-by-layer', action='store_true',
help='forward step-by-step or layer-by-layer. '
'Larger Model with layer-by-layer will be faster (default: False)')
parser.add_argument('--node-resume', type=str, default='',
help='resume weights in node for adaptive node. (default: False)')
parser.add_argument('--node-trainable', action='store_true')
# visualize
parser.add_argument('--visualize', action='store_true',
help='Visualize spiking map for each layer, only for validate (default: False)')
parser.add_argument('--spike-rate', action='store_true',
help='Print spiking rate for each layer, only for validate(default: False)')
parser.add_argument('--tsne', action='store_true')
parser.add_argument('--conf-mat', action='store_true')
parser.add_argument('--suffix', type=str, default='',
help='Add an additional suffix to the save path (default: \'\')')
parser.add_argument('--DVS-DA', action='store_true',
help='use DA on DVS')
# train data used ratio
parser.add_argument('--traindata-ratio', default=1.0, type=float,
help='training data ratio')
# snr value
parser.add_argument('--snr', default=0, type=int,
help='random noise amplitude controled by snr, 0 means no noise')
# --------------------------------------------------------------------------
# Start the loss-landscape
# --------------------------------------------------------------------------
parser.add_argument('--mpi', '-m', action='store_true', help='use mpi')
parser.add_argument('--threads', default=2, type=int, help='number of threads')
parser.add_argument('--ngpu', type=int, default=1,
help='number of GPUs to use for each rank, useful for data parallel evaluation')
# model parameters
parser.add_argument('--model_folder', default='',
help='the common folder that contains model_file and model_file2')
parser.add_argument('--model_file', default='', help='path to the trained model file')
parser.add_argument('--model_file2', default='', help='use (model_file2 - model_file) as the xdirection')
parser.add_argument('--model_file3', default='', help='use (model_file3 - model_file) as the ydirection')
parser.add_argument('--loss_name', '-l', default='crossentropy', help='loss functions: crossentropy | mse')
# direction parameters
parser.add_argument('--dir_file', default='',
help='specify the name of direction file, or the path to an eisting direction file')
parser.add_argument('--dir_type', default='weights',
help='direction type: weights | states (including BN\'s running_mean/var)')
parser.add_argument('--x', default='-1:1:51', help='A string with format xmin:x_max:xnum')
parser.add_argument('--y', default=None, help='A string with format ymin:ymax:ynum')
parser.add_argument('--xnorm', default='', help='direction normalization: filter | layer | weight')
parser.add_argument('--ynorm', default='', help='direction normalization: filter | layer | weight')
parser.add_argument('--xignore', default='', help='ignore bias and BN parameters: biasbn')
parser.add_argument('--yignore', default='', help='ignore bias and BN parameters: biasbn')
parser.add_argument('--same_dir', action='store_true', default=False,
help='use the same random direction for both x-axis and y-axis')
parser.add_argument('--idx', default=0, type=int, help='the index for the repeatness experiment')
parser.add_argument('--surf_file', default='',
help='customize the name of surface file, could be an existing file.')
# plot parameters
parser.add_argument('--proj_file', default='', help='the .h5 file contains projected optimization trajectory.')
parser.add_argument('--loss_max', default=5, type=float, help='Maximum value to show in 1D plot')
parser.add_argument('--vmax', default=10, type=float, help='Maximum value to map')
parser.add_argument('--vmin', default=0.1, type=float, help='Miminum value to map')
parser.add_argument('--vlevel', default=1.0, type=float, help='plot contours every vlevel')
parser.add_argument('--show', action='store_true', default=False, help='show plotted figures')
parser.add_argument('--log', action='store_true', default=False, help='use log scale for loss values')
parser.add_argument('--plot', action='store_true', default=False, help='plot figures after computation')
try:
from apex import amp
from apex.parallel import DistributedDataParallel as ApexDDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
if args_config.config:
with open(args_config.config, 'r') as f:
cfg = yaml.safe_load(f)
parser.set_defaults(**cfg)
# The main arg parser parses the rest of the args, the usual
# defaults will have been overridden if config file specified.
args = parser.parse_args(remaining)
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
def main():
torch.set_num_threads(20)
os.environ["OMP_NUM_THREADS"] = "20" # 设置OpenMP计算库的线程数
os.environ["MKL_NUM_THREADS"] = "20" # 设置MKL-DNN CPU加速库的线程数。
args, args_text = _parse_args()
args.no_spike_output = True
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1:
_logger.warning(
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
args.num_gpu = 1
# args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
assert args.rank >= 0
if args.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
_logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
# torch.manual_seed(args.seed + args.rank)
setup_seed(args.seed + args.rank)
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
adaptive_node=args.adaptive_node,
dataset=args.target_dataset,
step=args.step,
encode_type=args.encode,
node_type=eval(args.node_type),
threshold=args.threshold,
tau=args.tau,
sigmoid_thres=args.sigmoid_thres,
requires_thres_grad=args.requires_thres_grad,
spike_output=not args.no_spike_output,
act_fun=args.act_fun,
temporal_flatten=args.temporal_flatten,
layer_by_layer=args.layer_by_layer,
n_groups=args.n_groups,
TET_loss=False
) # 注意这里的TET_loss,选择losslandscape在谁上看.
if 'dvs' in args.target_dataset:
args.channels = 2
elif 'mnist' in args.target_dataset:
args.channels = 1
else:
args.channels = 3
# flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)
# _logger.info('flops = %fM', flops / 1e6)
# _logger.info('param size = %fM', params / 1e6)
linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0
args.lr = linear_scaled_lr
_logger.info("learning rate is %f" % linear_scaled_lr)
if args.local_rank == 0:
_logger.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
# now config only for imnet
data_config = resolve_data_config(vars(args), model=model, verbose=False)
source_loader_train, _, _, _ = eval('get_transfer_%s_data' % args.source_dataset)(
batch_size=args.batch_size,
step=args.step,
args=args,
_logge=_logger,
data_config=data_config,
size=args.event_size,
mix_up=args.mix_up,
cut_mix=args.cut_mix,
event_mix=args.event_mix,
beta=args.cutmix_beta,
prob=args.cutmix_prob,
gaussian_n=args.gaussian_n,
num=args.cutmix_num,
noise=args.cutmix_noise,
num_classes=args.num_classes,
rand_aug=args.rand_aug,
randaug_n=args.randaug_n,
randaug_m=args.randaug_m,
portion=args.train_portion,
_logger=_logger,
)
target_loader_train, target_loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.target_dataset)(
batch_size=args.batch_size,
dvs_da=args.DVS_DA,
step=args.step,
args=args,
_logge=_logger,
data_config=data_config,
size=args.event_size,
mix_up=args.mix_up,
cut_mix=args.cut_mix,
event_mix=args.event_mix,
beta=args.cutmix_beta,
prob=args.cutmix_prob,
gaussian_n=args.gaussian_n,
num=args.cutmix_num,
noise=args.cutmix_noise,
num_classes=args.num_classes,
rand_aug=args.rand_aug,
randaug_n=args.randaug_n,
randaug_m=args.randaug_m,
portion=args.train_portion,
_logger=_logger,
train_data_ratio=args.traindata_ratio,
snr=args.snr,
data_mode="full",
frames_num=12,
data_type="frequency"
)
if args.eval: # evaluate the model
if args.distributed:
state_dict = torch.load(args.eval_checkpoint)['state_dict_ema']
new_state_dict = OrderedDict()
# add module prefix for DDP
for k, v in state_dict.items():
k = 'module.' + k
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(torch.load(args.eval_checkpoint, map_location=torch.device('cpu'))['state_dict'])
# --------------------------------------------------------------------------
# Show Acc
# --------------------------------------------------------------------------
print("load model finished!")
# train_loss_fn = nn.CrossEntropyLoss()
# for i in range(1):
# _, val_acc = validate(model, target_loader_train, train_loss_fn, args)
# print(f"Top-1 accuracy of the model is: {val_acc:.2f}%")
# --------------------------------------------------------------------------
# Environment setup
# --------------------------------------------------------------------------
if args.mpi:
comm = mpi.setup_MPI()
rank, nproc = comm.Get_rank(), comm.Get_size()
else:
comm, rank, nproc = None, 0, 1
if True:
if not torch.cuda.is_available():
raise Exception('User selected cuda option, but cuda is not available on this machine')
gpu_count = torch.cuda.device_count()
# torch.cuda.set_device(rank % gpu_count)
torch.cuda.set_device("cuda:{}".format(args.device))
print('Rank %d use GPU %d of %d GPUs on %s' %
(rank, torch.cuda.current_device(), gpu_count, socket.gethostname()))
# --------------------------------------------------------------------------
# Check plotting resolution
# --------------------------------------------------------------------------
try:
args.xmin, args.xmax, args.xnum = [float(a) for a in args.x.split(':')]
args.ymin, args.ymax, args.ynum = (None, None, None)
if args.y:
args.ymin, args.ymax, args.ynum = [float(a) for a in args.y.split(':')]
assert args.ymin and args.ymax and args.ynum, \
'You specified some arguments for the y axis, but not all'
except:
raise Exception('Improper format for x- or y-coordinates. Try something like -1:1:51')
# --------------------------------------------------------------------------
# Load models and extract parameters
# --------------------------------------------------------------------------
net = model
w = net_plotter.get_weights(net) # initial parameters
s = copy.deepcopy(net.state_dict()) # deepcopy since state_dict are references
if args.ngpu > 1:
# data parallel with multiple GPUs on a single node
net = nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
# --------------------------------------------------------------------------
# Setup the direction file and the surface file
# --------------------------------------------------------------------------
dir_file = net_plotter.name_direction_file(args) # name the direction file
dir_file = os.path.join(os.path.split(args.eval_checkpoint)[0], dir_file)
if rank == 0:
net_plotter.setup_direction(args, dir_file, net)
surf_file = name_surface_file(args, dir_file)
if rank == 0:
setup_surface_file(args, surf_file, dir_file)
# wait until master has setup the direction file and surface file
mpi.barrier(comm)
# load directions
d = net_plotter.load_directions(dir_file)
# calculate the consine similarity of the two directions
if len(d) == 2 and rank == 0:
similarity = proj.cal_angle(proj.nplist_to_tensor(d[0]), proj.nplist_to_tensor(d[1]))
print('cosine similarity between x-axis and y-axis: %f' % similarity)
mpi.barrier(comm)
# --------------------------------------------------------------------------
# Start the computation
# --------------------------------------------------------------------------
trainloader = target_loader_train
crunch(surf_file, net, w, s, d, trainloader, 'train_loss', 'train_acc', comm, rank, args)
# --------------------------------------------------------------------------
# Plot figures
# --------------------------------------------------------------------------
if args.plot and rank == 0:
if args.y and args.proj_file:
plot_2D.plot_contour_trajectory(surf_file, dir_file, args.proj_file, 'train_loss', args.show)
elif args.y:
plot_2D.plot_2d_contour(surf_file, 'train_loss', args.vmin, args.vmax, args.vlevel, args.show)
else:
plot_1D.plot_1d_loss_err(surf_file, args.xmin, args.xmax, args.loss_max, args.log, args.show)
return
if __name__ == '__main__':
main()
================================================
FILE: examples/Snn_safety/DPSNN/Readme.txt
================================================
The code for the differential private spiking neural network(DPSNN).
================================================
FILE: examples/Snn_safety/DPSNN/load_data.py
================================================
import numpy as np
from torchvision import datasets, transforms
import torch
from torch.utils.data import Dataset
import tonic
from tonic import DiskCachedDataset
import torch.nn.functional as F
import os
MNIST_MEAN = 0.1307
MNIST_STD = 0.3081
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD_DEV = (0.2023, 0.1994, 0.2010)
cifar100_mean = [0.5071, 0.4865, 0.4409]
cifar100_std = [0.2673, 0.2563, 0.2761]
DVSCIFAR10_MEAN_16 = [0.3290, 0.4507]
DVSCIFAR10_STD_16 = [1.8398, 1.6549]
DATA_DIR = '/data/datasets'
class CustomDataset(Dataset):
"""An abstract Dataset class wrapped around Pytorch Dataset class.
"""
def __init__(self, dataset, indices):
self.dataset = dataset
self.indices = [int(i) for i in indices]
def __len__(self):
return len(self.indices)
def __getitem__(self, item):
x, y = self.dataset[self.indices[item]]
return x, y
def load_static_data(data_root, batch_size, dataset):
if dataset == 'cifar10':
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD_DEV)])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD_DEV)])
train_data = datasets.CIFAR10(data_root, train=True, transform=transform_train, download=True)
test_data = datasets.CIFAR10(data_root, train=False, transform=transform_test, download=True)
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=batch_size,
shuffle=True
)
test_loader = torch.utils.data.DataLoader(
test_data,
batch_size=batch_size,
)
elif dataset == 'MNIST':
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(MNIST_MEAN, MNIST_STD)])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(MNIST_MEAN, MNIST_STD)])
train_data = datasets.MNIST(data_root, train=True, transform=transform_train, download=True)
test_data = datasets.MNIST(data_root, train=False, transform=transform_test, download=True)
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=batch_size,
shuffle=True
)
test_loader = torch.utils.data.DataLoader(
test_data,
batch_size=batch_size,
)
elif dataset == 'FashionMNIST':
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(MNIST_MEAN, MNIST_STD)])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(MNIST_MEAN, MNIST_STD)])
train_data = datasets.FashionMNIST(data_root, train=True, transform=transform_train, download=True)
test_data = datasets.FashionMNIST(data_root, train=False, transform=transform_test, download=True)
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=batch_size,
shuffle=True
)
test_loader = torch.utils.data.DataLoader(
test_data,
batch_size=batch_size,
)
return train_data, test_data, train_loader, test_loader
def load_dvs10_data(batch_size, step, **kwargs):
size = kwargs['size'] if 'size' in kwargs else 48
sensor_size = tonic.datasets.CIFAR10DVS.sensor_size
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=train_transform)
test_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=test_transform)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
])
train_dataset = DiskCachedDataset(train_dataset,
cache_path=f'./dataset/dvs_cifar10/train_cache_{step}',
transform=train_transform)
test_dataset = DiskCachedDataset(train_dataset,
cache_path=f'./dataset/dvs_cifar10/test_cache_{step}',
transform=test_transform)
num_train = len(train_dataset)
num_per_cls = num_train // 10
indices_train, indices_test = [], []
portion = kwargs['portion'] if 'portion' in kwargs else .9
for i in range(10):
indices_train.extend(
list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))))
indices_test.extend(
list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))
train_dataset = CustomDataset(train_dataset, np.array(indices_train))
test_dataset = CustomDataset(test_dataset, np.array(indices_test))
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
pin_memory=True, drop_last=False, num_workers=1
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=1
)
return train_loader, test_loader, train_dataset, test_dataset
def load_nmnist_data(batch_size, step, **kwargs):
size = kwargs['size'] if 'size' in kwargs else 28
sensor_size = tonic.datasets.NMNIST.sensor_size
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/NMNIST'), transform=train_transform, train=True)
test_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/NMNIST'), transform=test_transform, train=False)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
])
train_dataset = DiskCachedDataset(train_dataset,
cache_path=f'./dataset/NMNIST/train_cache_{step}',
transform=train_transform)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=f'./dataset/NMNIST/test_cache_{step}',
transform=test_transform)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
pin_memory=True, drop_last=False, num_workers=1
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=1
)
return train_loader, test_loader, train_dataset, test_dataset
================================================
FILE: examples/Snn_safety/DPSNN/main_dpsnn.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from opacus import PrivacyEngine
from model import *
from braincog.base.node.node import *
import warnings
from load_data import *
from opacus.utils.batch_memory_manager import BatchMemoryManager
warnings.simplefilter("ignore")
# Precomputed characteristics of the dataset dataset
torch.cuda.manual_seed(3154)
batch_size = 512
MAX_PHYSICAL_BATCH_SIZE = 32
target_ep = 8
c = 6
epochs = 40
step = 10
delta = 1e-5
devices = 4
r = 5
device = torch.device(f'cuda:{devices}' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
disable_noise = False
data_root = "./dataset"
kwargs = {"num_workers": 1, "pin_memory": True}
dataset = 'dvs_cifar10'
# NMNIST, cifar10, dvs_cifar10, MNIST, FashionMNIST
def train(model, device, train_loader, optimizer, epoch, privacy_engine):
criterion = nn.CrossEntropyLoss().to(device)
losses = []
model.train()
correct = 0
for _batch_idx, (data, target) in enumerate(train_loader):
# print(target)
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
losses.append(loss.item())
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
if not disable_noise:
epsilon = privacy_engine.get_epsilon(delta=delta)
print(
f"Train Epoch: {epoch} \t"
f"Loss: {np.mean(losses):.6f} "
)
print("Accuracy: {}/{} ({:.2f}%)\n".format(
correct,
len(train_loader.dataset),
100.0 * correct / len(train_loader.dataset), ))
print(
f"(ε = {epsilon:.2f}, δ = {delta})"
)
else:
print(f"Train Epoch: {epoch} \t Loss: {np.mean(losses):.6f}")
return 100.0 * correct / len(train_loader.dataset)
def test(model, device, test_loader):
model.eval()
criterion = nn.CrossEntropyLoss().to(device)
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader)
print(
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n".format(
test_loss,
correct,
len(test_loader.dataset),
100.0 * correct / len(test_loader.dataset),
)
)
return correct / len(test_loader.dataset)
def run():
if dataset == 'dvs_cifar10':
train_loader, test_loader, train_data, test_data = load_dvs10_data(batch_size=batch_size, step=step)
# train_loader, test_loader, _, _ = get_dvsc10_data(batch_size=batch_size, step=step)
elif dataset == 'NMNIST':
train_loader, test_loader, train_data, test_data = load_nmnist_data(batch_size=batch_size, step=step)
else:
train_data, test_data, train_loader, test_loader = load_static_data(data_root, batch_size, dataset)
result = []
result_train = []
for _ in range(r):
if dataset == 'cifar10':
model = cifar_convnet(
step=step,
encode_type='direct',
node_type=LIFNode,
num_classes=10,
spike_output=False,
layer_by_layer=True,
act_fun=QGateGrad
)
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[40], gamma=0.1, last_epoch=-1)
elif dataset == 'dvs_cifar10':
model = dvs_convnet(
step=step,
encode_type='direct',
node_type=LIFNode,
num_classes=10,
spike_output=False,
layer_by_layer=True,
act_fun=QGateGrad
)
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15], gamma=1, last_epoch=-1)
elif dataset == 'NMNIST':
model = SimpleSNN(
channel=2,
step=step,
node_type=LIFNode,
act_fun=QGateGrad,
layer_by_layer=True,
)
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.005)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.1, last_epoch=-1)
elif dataset == 'MNIST' or dataset == 'FashionMNIST':
model = SimpleSNN(
channel=1,
step=step,
node_type=LIFNode,
act_fun=QGateGrad,
layer_by_layer=True,
)
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.005)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.1, last_epoch=-1)
if not disable_noise:
privacy_engine = PrivacyEngine()
model, optimizer, data_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_loader,
max_grad_norm=c,
epochs=epochs,
target_delta=delta,
target_epsilon=target_ep
)
with BatchMemoryManager(
data_loader=data_loader,
max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE,
optimizer=optimizer
) as memory_safe_data_loader:
# if 1:
for epoch in range(1, epochs + 1):
result_train.append(train(model, device, memory_safe_data_loader, optimizer, epoch, privacy_engine))
result.append(test(model, device, test_loader))
scheduler.step()
else:
privacy_engine = PrivacyEngine()
model, optimizer, data_loader = privacy_engine.make_private(
module=model,
optimizer=optimizer,
data_loader=train_loader,
max_grad_norm=c,
noise_multiplier=0.0,
)
with BatchMemoryManager(
data_loader=data_loader,
max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE,
optimizer=optimizer
) as memory_safe_data_loader:
for epoch in range(1, epochs + 1):
train(model, device, memory_safe_data_loader, optimizer, epoch, privacy_engine)
result.append(test(model, device, test_loader))
scheduler.step()
result = np.array(result).reshape((r, -1))
result_train = np.array(result_train).reshape((r, -1))
best_acc = np.mean(np.max(result, axis=1))
print(best_acc)
np.save(file=f'./{dataset}/MP_test.npy', arr=result)
np.save(file=f'./{dataset}/MP_train.npy', arr=result_train)
if __name__ == "__main__":
run()
================================================
FILE: examples/Snn_safety/DPSNN/model.py
================================================
import abc
from functools import partial
import torch
from torch.nn import functional as F
import torchvision
from timm.models import register_model
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.encoder.encoder import *
from braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule
class TEP(nn.Module):
def __init__(self, step, channel, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super(TEP, self).__init__()
self.step = step
self.gn = nn.GroupNorm(channel, channel)
def forward(self, x):
x = rearrange(x, '(t b) c w h -> t b c w h', t=self.step)
fire_rate = torch.mean(x, dim=0)
fire_rate = self.gn(fire_rate) + 1
x = x * fire_rate
x = rearrange(x, 't b c w h -> (t b) c w h')
return x
class BaseConvNet(BaseModule, abc.ABC):
def __init__(self,
step,
input_channels,
num_classes,
encode_type,
spike_output: bool,
out_channels: list,
block_depth: list,
node_list: list,
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.num_cls = num_classes
self.spike_output = spike_output
self.groups = kwargs['n_groups'] if 'n_groups' in kwargs else 1
if not spike_output:
node_list.append(nn.Identity)
out_channels.append(self.num_cls)
self.vote = nn.Identity()
# self.vote = nn.Sequential(
# nn.Linear(self.step, 32),
# nn.ReLU(),
# nn.Linear(32, 1)
# )
else:
out_channels.append(10 * self.num_cls)
self.vote = VotingLayer(10)
# check list length
if len(node_list) != len(out_channels):
raise ValueError
self.input_channels = input_channels
self.out_channels = out_channels
self.block_depth = block_depth
self.node_list = node_list
self.feature = self._create_feature()
self.fc = self._create_fc()
if self.layer_by_layer:
self.flatten = nn.Flatten(start_dim=1)
else:
self.flatten = nn.Flatten()
@staticmethod
def _create_feature(self):
raise NotImplementedError
@staticmethod
def _create_fc(self):
raise NotImplementedError
def forward(self, inputs):
inputs = self.encoder(inputs)
self.reset()
if not self.training:
self.fire_rate.clear()
if not self.layer_by_layer:
outputs = []
if self.warm_up:
step = 1
else:
step = self.step
for t in range(step):
x = inputs[t]
x = self.feature(x)
x = self.flatten(x)
x = self.fc(x)
x = self.vote(x)
outputs.append(x)
return sum(outputs) / len(outputs)
# outputs = torch.stack(outputs)
# outputs = rearrange(outputs, 't b c -> b c t')
# outputs = self.vote(outputs).squeeze()
# return outputs
else:
x = self.feature(inputs)
x = self.flatten(x)
x = self.fc(x)
if self.groups == 1:
x = rearrange(x, '(t b) c -> t b c', t=self.step).mean(0)
else:
x = rearrange(x, 'b (c t) -> t b c', t=self.step).mean(0)
x = self.vote(x)
return x
class LayerWiseConvModule(nn.Module):
"""
SNN卷积模块
:param in_channels: 输入通道数
:param out_channels: 输出通道数
:param kernel_size: kernel size
:param stride: stride
:param padding: padding
:param bias: Bias
:param node: 神经元类型
:param kwargs:
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias=False,
node=LIFNode,
step=6,
**kwargs):
super().__init__()
if node is None:
raise TypeError
self.groups = kwargs['groups'] if 'groups' in kwargs else 1
self.conv = nn.Conv2d(in_channels=in_channels * self.groups,
out_channels=out_channels * self.groups,
kernel_size=kernel_size,
padding=padding,
stride=stride,
bias=bias)
self.gn = nn.GroupNorm(16, out_channels * self.groups)
self.node = partial(node, **kwargs)()
self.step = step
self.activation = nn.Identity()
def forward(self, x):
x = rearrange(x, '(t b) c w h -> t b c w h', t=self.step)
outputs = []
for t in range(self.step):
outputs.append(self.gn(self.conv(x[t])))
outputs = torch.stack(outputs) # t b c w h
outputs = rearrange(outputs, 't b c w h -> (t b) c w h')
outputs = self.node(outputs)
return outputs
class LayerWiseLinearModule(nn.Module):
"""
线性模块
:param in_features: 输入尺寸
:param out_features: 输出尺寸
:param bias: 是否有Bias, 默认 ``False``
:param node: 神经元类型, 默认 ``LIFNode``
:param args:
:param kwargs:
"""
def __init__(self,
in_features: int,
out_features: int,
bias=True,
node=LIFNode,
step=6,
spike=False,
*args,
**kwargs):
super().__init__()
if node is None:
raise TypeError
self.groups = kwargs['groups'] if 'groups' in kwargs else 1
if self.groups == 1:
self.fc = nn.Linear(in_features=in_features,
out_features=out_features, bias=bias)
else:
self.fc = nn.ModuleList()
for i in range(self.groups):
self.fc.append(nn.Linear(
in_features=in_features,
out_features=out_features,
bias=bias
))
self.node = partial(node, **kwargs)()
self.step = step
self.spike = spike
def forward(self, x):
if self.groups == 1: # (t b) c
x = rearrange(x, '(t b) c -> t b c', t=self.step)
outputs = []
for t in range(self.step):
outputs.append(self.fc(x[t]))
outputs = torch.stack(outputs) # t b c
outputs = rearrange(outputs, 't b c -> (t b) c')
else: # b (c t)
x = rearrange(x, 'b (c t) -> t b c', t=self.groups)
outputs = []
for i in range(self.groups):
outputs.append(self.fc[i](x[i]))
outputs = torch.stack(outputs) # t b c
outputs = rearrange(outputs, 't b c -> b (c t)')
if self.spike:
return self.node(outputs)
else:
return outputs
class LayWiseConvNet(BaseConvNet):
def __init__(self,
step,
input_channels,
num_classes,
encode_type,
spike_output: bool,
out_channels: list,
node_list: list,
block_depth: list,
*args,
**kwargs):
super().__init__(step,
input_channels,
num_classes,
encode_type,
spike_output,
out_channels,
block_depth,
node_list,
*args,
**kwargs)
def _create_feature(self):
feature_depth = len(self.node_list) - 1
feature = [LayerWiseConvModule(
self.input_channels * self.init_channel_mul, self.out_channels[0], node=self.node_list[0],
groups=self.groups, step=self.step)]
if self.block_depth[0] != 1:
feature.extend(
[LayerWiseConvModule(self.out_channels[0], self.out_channels[0], node=self.node_list[0],
groups=self.groups, step=self.step)] * (
self.block_depth[0] - 1),
)
feature.append(TEP(channel=self.out_channels[0], step=self.step))
feature.append(nn.AvgPool2d(kernel_size=4, stride=2))
for i in range(1, feature_depth - 1):
feature.append(LayerWiseConvModule(
self.out_channels[i - 1], self.out_channels[i], node=self.node_list[i], groups=self.groups,
step=self.step))
if self.block_depth[i] != 1:
feature.extend(
[LayerWiseConvModule(self.out_channels[i], self.out_channels[i], node=self.node_list[i],
groups=self.groups,
step=self.step)] * (
self.block_depth[i] - 1),
)
feature.append(TEP(channel=self.out_channels[i], step=self.step))
feature.append(nn.AvgPool2d(kernel_size=4, stride=2))
feature.append(LayerWiseConvModule(
self.out_channels[-3], self.out_channels[-2], node=self.node_list[-2], groups=self.groups,
step=self.step))
if self.block_depth[feature_depth - 1] != 1:
feature.extend(
[LayerWiseConvModule(self.out_channels[-2], self.out_channels[-2], node=self.node_list[-2],
groups=self.groups,
step=self.step)] * (
self.block_depth[feature_depth - 1] - 1),
)
feature.append(nn.AdaptiveAvgPool2d(1))
return nn.Sequential(*feature)
def _create_fc(self):
fc = nn.Sequential(
# NDropout(.5),
LayerWiseLinearModule(
self.out_channels[-2], self.out_channels[-1], node=self.node_list[-1], groups=self.groups,
step=self.step, spike=False)
)
return fc
@register_model
def cifar_convnet(step,
encode_type,
spike_output: bool,
node_type,
*args,
**kwargs):
# out_channels = [256, 256, 512, 1024]
out_channels = [64, 128, 128, 256]
block_depth = [2, 2, 2, 2]
# print(kwargs)
node_cls = partial(node_type, step=step, **kwargs)
# print(node_cls)
if spike_output:
node_list = [node_cls] * (len(out_channels) + 1)
else:
node_list = [node_cls] * (len(out_channels))
return LayWiseConvNet(step=step,
input_channels=3,
encode_type=encode_type,
node_list=node_list,
block_depth=block_depth,
out_channels=out_channels,
spike_output=spike_output,
**kwargs)
@register_model
def dvs_convnet(step,
encode_type,
spike_output: bool,
node_type,
num_classes,
*args,
**kwargs):
out_channels = [64, 128, 256, 512, 1024]
block_depth = [2, 1, 2, 1, 2]
node_cls = partial(node_type, step=step, **kwargs)
if spike_output:
node_list = [node_cls] * (len(out_channels) + 1)
# node_list[-2] = partial(DoubleSidePLIFNode, step=step, **kwargs)
else:
node_list = [node_cls] * (len(out_channels))
# node_list[-1] = partial(DoubleSidePLIFNode, step=step, **kwargs)
return LayWiseConvNet(step=step,
input_channels=2,
num_classes=num_classes,
encode_type=encode_type,
node_list=node_list,
block_depth=block_depth,
out_channels=out_channels,
spike_output=spike_output,
**kwargs)
@register_model
class SimpleSNN(BaseModule, abc.ABC):
def __init__(self,
channel=1,
num_classes=10,
step=8,
node_type=LIFNode,
encode_type='direct',
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.num_classes = num_classes
self.node = node_type
init_channel = channel
self.feature = nn.Sequential(
LayerWiseConvModule(init_channel, 32, kernel_size=7, padding=0, node=self.node, step=step),
TEP(step=step, channel=32),
nn.AvgPool2d(kernel_size=2, stride=2),
LayerWiseConvModule(32, 64, kernel_size=4, padding=0, node=self.node, step=step),
TEP(step=step, channel=64),
nn.AvgPool2d(kernel_size=2, stride=2),
)
self.fc = nn.Sequential(
nn.Flatten(),
LayerWiseLinearModule(64 * 4 * 4, self.num_classes, node=self.node, spike=False, step=step),
)
def forward(self, inputs):
inputs = self.encoder(inputs)
self.reset()
if self.layer_by_layer:
x = self.feature(inputs)
x = self.fc(x)
x = rearrange(x, '(t b) c -> t b c', t=self.step).mean(0)
return x
else:
outputs = []
for t in range(self.step):
x = inputs[t]
x = self.feature(x)
x = self.fc(x)
outputs.append(x)
return sum(outputs) / len(outputs)
================================================
FILE: examples/Snn_safety/RandHet-SNN/README.md
================================================
* To train a SNN with AT on CIFAR-10:
```
python train.py --adv_training --attack_iters 1 --epsilon 4 --alpha 4 --network ResNet18 --batch_size 64 --worker 4 --node_type LIF --save_dir AT --device cuda:1 --time_step 8 --dataset cifar10
```
* To train a RHSNN with AT on CIFAR-10:
```
python train.py --adv_training --attack_iters 1 --epsilon 4 --alpha 4 --network ResNet18 --batch_size 64 --worker 4 --node_type RHLIF --save_dir AT_RH_1 --device cuda:1 --time_step 8 --dataset cifar10
```
* To train a RHSNN with RAT on CIFAR-10:
```
python train.py --adv_training --attack_iters 1 --epsilon 4 --alpha 4 --network ResNet18 --batch_size 64 --worker 4 --node_type RHLIF --save_dir RAT_RH_1 --device cuda:1 --parseval --beta 0.004 --time_step 8 --dataset cifar10
```
* To train a RHSNN with SR on CIFAR-10:
```
python train.py --adv_training --attack_iters 1 --epsilon 4 --alpha 4 --network ResNet18 --batch_size 64 --worker 4 --node_type RHLIF --save_dir SR_RH_1 --device cuda:1 --SR --time_step 8 --dataset cifar10
```
* To evaluate the performance of RHSNN on CIFAR-10:
```
python evaluate.py --network ResNet18 --attack_type all --batch_size 32 --worker 4 --node_type RHLIF --pretrain RAT_RH_1/weight_r.pth --save_dir RAT --device cuda:1 --time_step 8 --dataset cifar10
```
================================================
FILE: examples/Snn_safety/RandHet-SNN/evaluate.py
================================================
import argparse
import copy
import logging
import os
import sys
import time
from my_node import RHLIFNode, RHLIFNode2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sew_resnet import SEWResNet19, BasicBlock
from braincog.base.node.node import *
from utils import evaluate_standard
from utils import get_loaders
import torchattacks
from tqdm import tqdm
logger = logging.getLogger(__name__)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--data_dir', default='/mnt/data/datasets', type=str)
parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100'])
parser.add_argument('--network', default='ResNet18', type=str)
parser.add_argument('--worker', default=4, type=int)
parser.add_argument('--epsilon', default=8, type=int)
parser.add_argument('--device', default='cuda:1', type=str)
parser.add_argument('--pretrain', default=None, type=str, help='path to load the pretrained model')
parser.add_argument('--save_dir', default=None, type=str, help='path to save log')
parser.add_argument('--attack_type', default='pgd')
parser.add_argument('--time_step', default=8, type=int)
parser.add_argument('--node_type', default='LIF', type=str)
return parser.parse_args()
def evaluate_attack(model, test_loader, args, atk, atk_name, logger):
test_loss = 0
test_acc = 0
n = 0
model.eval()
device = args.device
test_loader = iter(test_loader)
bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]'
pbar = tqdm(range(len(test_loader)), file=sys.stdout, bar_format=bar_format, ncols=80)
for i in pbar:
X, y = next(test_loader)
X, y = X.to(device), y.to(device)
X_adv = atk(X, y) # advtorch
with torch.no_grad():
output = model(X_adv)
loss = F.cross_entropy(output, y)
test_loss += loss.item() * y.size(0)
test_acc += (output.max(1)[1] == y).sum().item()
n += y.size(0)
pgd_acc = test_acc / n
pgd_loss = test_loss / n
logger.info(atk_name)
logger.info('adv: %.4f \t', pgd_acc)
return pgd_loss, pgd_acc
def main():
args = get_args()
args.save_dir = os.path.join('logs', args.save_dir)
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
logfile = os.path.join(args.save_dir, 'output.log')
if os.path.exists(logfile):
os.remove(logfile)
log_path = os.path.join(args.save_dir, 'output_test.log')
handlers = [logging.FileHandler(log_path, mode='a+'),
logging.StreamHandler()]
logging.basicConfig(
format='[%(asctime)s] - %(message)s',
datefmt='%Y/%m/%d %H:%M:%S',
level=logging.INFO,
handlers=handlers)
logger.info(args)
# assert type(args.pretrain) == str and os.path.exists(args.pretrain)
if args.dataset == 'cifar10':
args.num_classes = 10
elif args.dataset == 'cifar100':
args.num_classes = 100
else:
print('Wrong dataset:', args.dataset)
exit()
logger.info('Dataset: %s', args.dataset)
train_loader, test_loader, dataset_normalization = get_loaders(args.data_dir, args.batch_size, dataset=args.dataset,
worker=args.worker, norm=False)
node = LIFNode
if args.node_type == 'RHLIF':
node = RHLIFNode
elif args.node_type == 'RHLIF2':
node = RHLIFNode2
# setup network
model = SEWResNet19(BasicBlock, [3, 3, 2], cnf='ADD', node_type=node, step=args.time_step, num_classes=args.num_classes,
layer_by_layer=True, act_fun=AtanGrad, data_norm=dataset_normalization)
# print(model)
# load pretrained model
path = os.path.join('./ckpt', args.dataset, args.network)
args.pretrain = os.path.join(path, args.pretrain)
pretrained_model = torch.load(args.pretrain, map_location=args.device, weights_only=False)
model.load_state_dict(pretrained_model, strict=False)
model.to(args.device)
model.eval()
# for name, param in model.named_parameters():
# if 'sigma' in name: # 查找包含 'sigma' 的参数
# param.data = torch.as_tensor(0.0, device=args.device)
# if 'alpha' in name: # 查找包含 'sigma' 的参数
# param.data = torch.as_tensor(2.0, device=args.device)
logger.info('Evaluating with standard images...')
_, nature_acc = evaluate_standard(test_loader, model, args)
logger.info('Nature Acc: %.4f \t', nature_acc)
if args.attack_type == 'eotpgd':
atk = torchattacks.EOTPGD(model, eps=8 / 255, alpha=(16/50) / 255, steps=50, random_start=True, eot_iter=10)
evaluate_attack(model, test_loader, args, atk, 'eotpgd', logger)
elif args.attack_type[0:3] == 'pgd':
steps = int(''.join(filter(str.isdigit, args.attack_type)))
atk = torchattacks.PGD(model, eps=8 / 255, alpha=(16/steps) / 255, steps=steps, random_start=True)
evaluate_attack(model, test_loader, args, atk, args.attack_type, logger)
elif args.attack_type == 'apgd':
atk = torchattacks.APGD(model, eps=8 / 255, steps=50, eot_iter=10)
evaluate_attack(model, test_loader, args, atk, 'apgd', logger)
elif args.attack_type == 'fgsm':
atk = torchattacks.FGSM(model, eps=8/255)
evaluate_attack(model, test_loader, args, atk, 'fgsm', logger)
elif args.attack_type == 'mifgsm':
atk = torchattacks.MIFGSM(model, eps=8 / 255, alpha=2 / 255, steps=5, decay=1.0)
evaluate_attack(model, test_loader, args, atk, 'mifgsm', logger)
elif args.attack_type == 'autoattack':
atk = torchattacks.AutoAttack(model, norm='Linf', eps=8/255, version='standard', n_classes=args.num_classes)
evaluate_attack(model, test_loader, args, atk, 'autoattack', logger)
elif args.attack_type == 'all':
atk = torchattacks.FGSM(model, eps=8 / 255)
evaluate_attack(model, test_loader, args, atk, 'fgsm', logger)
atk = torchattacks.APGD(model, eps=8 / 255, steps=10)
evaluate_attack(model, test_loader, args, atk, 'apgd', logger)
atk = torchattacks.PGD(model, eps=8 / 255, alpha=1.6 / 255, steps=10, random_start=True)
evaluate_attack(model, test_loader, args, atk, 'pgd', logger)
atk = torchattacks.MIFGSM(model, eps=8 / 255, alpha=2 / 255, steps=5, decay=1.0)
evaluate_attack(model, test_loader, args, atk, 'mifgsm', logger)
atk = torchattacks.AutoAttack(model, norm='Linf', eps=8 / 255, version='standard',
n_classes=args.num_classes)
evaluate_attack(model, test_loader, args, atk, 'autoattack', logger)
elif args.attack_type == 'step_test':
for steps in [10,30,50,70,90,110]:
atk = torchattacks.PGD(model, eps=8 / 255, alpha=(16 / steps) / 255, steps=steps, random_start=True)
pgd_loss, pgd_acc = evaluate_attack(model, test_loader, args, atk, f'pgd{steps}', logger)
atk = torchattacks.APGD(model, eps=8 / 255, steps=steps)
apgd_loss, apgd_acc = evaluate_attack(model, test_loader, args, atk, f'apgd{steps}', logger)
elif args.attack_type == 'eot_test':
for steps in [1,10,20,30]:
atk = torchattacks.EOTPGD(model, eps=8 / 255, alpha=(16 / 10) / 255, steps=10, random_start=True, eot_iter=steps)
pgd_loss, pgd_acc = evaluate_attack(model, test_loader, args, atk, f'eot{steps}_pgd', logger)
atk = torchattacks.APGD(model, eps=8 / 255, steps=10, eot_iter=steps)
apgd_loss, apgd_acc = evaluate_attack(model, test_loader, args, atk, f'eot{steps}_apgd', logger)
elif args.attack_type == 'intensity_test':
for intensity in [2, 4, 6, 8, 10, 12, 14, 16]:
atk = torchattacks.APGD(model, eps=intensity / 255, steps=10)
pgd_loss, pgd_acc = evaluate_attack(model, test_loader, args, atk, f'{intensity}_apgd', logger)
logger.info('Testing done.')
if __name__ == "__main__":
main()
================================================
FILE: examples/Snn_safety/RandHet-SNN/my_node.py
================================================
import torch
from braincog.base.node.node import *
class RHLIFNode(BaseNode):
"""
Parametric LIF, 其中的 ```tau``` 会被backward过程影响
Reference:https://arxiv.org/abs/2007.05785
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param dt: 时间步长
:param step: 仿真步
:param tau: 膜电位时间常数, 用于控制膜电位衰减
:param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``
:param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``
:param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``
:param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``
:param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``
:param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组
:param args: 其他的参数
:param kwargs: 其他的参数
"""
def __init__(self, threshold=0.5, tau=0., sigma=1.0, act_fun=AtanGrad, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
init_w = tau
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.act_fun = act_fun(alpha=2., requires_grad=False)
self.sigma = nn.Parameter(torch.as_tensor(sigma), requires_grad=False)
self.w = nn.Parameter(torch.as_tensor(init_w), requires_grad=False)
self.flag = 0
self.rd = 0
def integral(self, inputs):
self.rd = self.sigma * torch.normal(0., 1., size=(inputs.shape[0], inputs.shape[1], inputs.shape[2], inputs.shape[3]), device=inputs.device)
self.mem = self.rd.sigmoid() * self.mem + (1 - self.rd.sigmoid()) * inputs
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.threshold)
# self.mem = self.mem - self.spike.detach() * self.threshold
self.mem = self.mem * (1 - self.spike.detach())
def n_reset(self):
self.mem = self.v_reset
self.spike = 0.
self.feature_map = []
self.mem_collect = []
self.flag = 0
class RHLIFNode2(BaseNode):
"""
Parametric LIF, 其中的 ```tau``` 会被backward过程影响
Reference:https://arxiv.org/abs/2007.05785
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param dt: 时间步长
:param step: 仿真步
:param tau: 膜电位时间常数, 用于控制膜电位衰减
:param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``
:param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``
:param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``
:param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``
:param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``
:param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组
:param args: 其他的参数
:param kwargs: 其他的参数
"""
def __init__(self, threshold=0.5, tau=0., sigma=1.0, act_fun=AtanGrad, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
init_w = tau
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.act_fun = act_fun(alpha=2., requires_grad=False)
self.sigma = nn.Parameter(torch.as_tensor(sigma), requires_grad=False)
self.w = nn.Parameter(torch.as_tensor(init_w), requires_grad=False)
self.flag = 0
self.rd = 0
self.resample = 1
def integral(self, inputs):
if self.flag == 0:
self.rd = self.sigma * torch.normal(0., 1., size=(inputs.shape[0], inputs.shape[1], inputs.shape[2], inputs.shape[3]), device=inputs.device)
self.flag = 1
self.mem = self.rd.sigmoid() * self.mem + (1 - self.rd.sigmoid()) * inputs
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.threshold)
# self.mem = self.mem - self.spike.detach() * self.threshold
self.mem = self.mem * (1 - self.spike.detach())
def n_reset(self):
self.mem = self.v_reset
self.spike = 0.
self.feature_map = []
self.mem_collect = []
if self.resample == 1:
self.flag = 0
else:
self.flag = 1
================================================
FILE: examples/Snn_safety/RandHet-SNN/sew_resnet.py
================================================
import torch
import torch.nn as nn
from copy import deepcopy
import random
try:
from torchvision.models.utils import load_state_dict_from_url
except ImportError:
from torchvision._internally_replaced_utils import load_state_dict_from_url
from braincog.base.node import *
from braincog.model_zoo.base_module import *
from braincog.datasets import is_dvs_data
from timm.models import register_model
def sew_function(x: torch.Tensor, y: torch.Tensor, cnf: str):
if cnf == 'ADD':
return x + y
elif cnf == 'AND':
return x * y
elif cnf == 'IAND':
return x * (1. - y)
else:
raise NotImplementedError
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None, cnf: str = None, node: callable = None, **kwargs):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.node1 = node()
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.node2 = node()
self.downsample = downsample
if downsample is not None:
self.downsample_sn = node()
self.stride = stride
self.cnf = cnf
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.node1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.node2(out)
if self.downsample is not None:
identity = self.downsample_sn(self.downsample(x))
out = sew_function(identity, out, self.cnf)
return out
def extra_repr(self) -> str:
return super().extra_repr() + f'cnf={self.cnf}'
class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None, cnf: str = None, node: callable = None, **kwargs):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.node1 = node()
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.node2 = node()
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.node3 = node()
self.downsample = downsample
if downsample is not None:
self.downsample_sn = node()
self.stride = stride
self.cnf = cnf
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.node1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.node2(out)
out = self.conv3(out)
out = self.bn3(out)
out = self.node3(out)
if self.downsample is not None:
identity = self.downsample_sn(self.downsample(x))
out = sew_function(out, identity, self.cnf)
return out
def extra_repr(self) -> str:
return super().extra_repr() + f'cnf={self.cnf}'
class SEWResNet19(BaseModule):
def __init__(self, block, layers, num_classes=1000, step=8, encode_type="direct", zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None, data_norm=None,
norm_layer=None, cnf: str = None, *args, **kwargs):
super().__init__(
step,
encode_type,
*args,
**kwargs
)
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.num_classes = num_classes
self.normalize = data_norm
self.node = kwargs['node_type']
if issubclass(self.node, BaseNode):
# self.node = partial(self.node, **kwargs, step=step)
self.node1 = partial(self.node, **kwargs, step=step)()
self.node2 = partial(self.node, **kwargs, step=step)
self.node3 = partial(self.node, **kwargs, step=step)
self.node4 = partial(self.node, **kwargs, step=step)
self.once = kwargs["once"] if "once" in kwargs else False
self.sum_output = kwargs["sum_output"] if "sum_output" in kwargs else True
init_channel = 3
self.inplanes = 128
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(init_channel, self.inplanes, kernel_size=3, stride=1, padding=1,
bias=False)
self.bn1 = norm_layer(self.inplanes)
# self.node1 = self.node()
self.layer1 = self._make_layer(block, 128, layers[0], cnf=cnf, node=self.node2, **kwargs)
self.layer2 = self._make_layer(block, 256, layers[1], stride=2,
dilate=replace_stride_with_dilation[0], cnf=cnf, node=self.node3, **kwargs)
self.layer3 = self._make_layer(block, 512, layers[2], stride=2,
dilate=replace_stride_with_dilation[1], cnf=cnf, node=self.node4, **kwargs)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc1 = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 1.0 / float(n))
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, dilate=False, cnf: str = None, node: callable = None,
**kwargs):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer, cnf, node, **kwargs))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer, cnf=cnf, node=node, **kwargs))
return nn.Sequential(*layers)
def _forward_impl(self, inputs):
# See note [TorchScript super()]
if self.normalize is not None:
self.normalize.mean = self.normalize.mean.to(inputs.device)
self.normalize.std = self.normalize.std.to(inputs.device)
inputs = self.normalize(inputs)
self.reset()
if self.layer_by_layer:
inputs = repeat(inputs, 'b c w h -> t b c w h', t=self.step)
inputs = rearrange(inputs, 't b c w h -> (t b) c w h')
x = self.conv1(inputs)
x = self.bn1(x)
x = self.node1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
# x = self.node2(x)
# x = self.fc2(x)
x = rearrange(x, '(t b) c -> t b c', t=self.step)
# print(x)
if self.sum_output: x = x.mean(0)
return x
def _forward_once(self, x):
# inputs = self.encoder(inputs)
# x = inputs[t]
x = self.conv1(x)
x = self.bn1(x)
x = self.node1(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
if self.once: return self._forward_once(x)
return self._forward_impl(x)
================================================
FILE: examples/Snn_safety/RandHet-SNN/train.py
================================================
import argparse
import copy
import logging
import os
import sys
import time
from evaluate import evaluate_attack
import torchattacks
from my_node import RHLIFNode, RHLIFNode2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sew_resnet import SEWResNet19, BasicBlock, Bottleneck
from braincog.base.node.node import *
from utils import (evaluate_standard, cifar10_std, cifar10_mean,
orthogonal_retraction)
from utils import (clamp, get_norm_stat,
get_loaders)
logger = logging.getLogger(__name__)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--data_dir', default='/mnt/data/datasets', type=str)
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--network', default='ResNet18', type=str)
parser.add_argument('--device', default='cuda:3', type=str)
parser.add_argument('--worker', default=4, type=int)
parser.add_argument('--lr_schedule', default='cosine', choices=['cyclic', 'multistep', 'cosine'])
parser.add_argument('--lr_min', default=0., type=float)
parser.add_argument('--lr_max', default=0.1, type=float)
parser.add_argument('--weight_decay', default=1e-4, type=float)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--epsilon', default=4, type=int)
parser.add_argument('--alpha', default=4, type=float, help='Step size')
parser.add_argument('--save_dir', default='ckpt', type=str, help='Output directory')
parser.add_argument('--seed', default=0, type=int, help='Random seed')
parser.add_argument('--attack_iters', default=1, type=int, help='Attack iterations')
parser.add_argument('--pretrain', default=None, type=str, help='path to load the pretrained model')
parser.add_argument('--beta', default=0.004, type=float)
parser.add_argument('--adv_training', action='store_true',
help='if adv training')
parser.add_argument('--time_step', default=8, type=int)
parser.add_argument('--SR', action='store_true')
parser.add_argument('--node_type', default='LIF', type=str)
parser.add_argument('--parseval', action='store_true', help='if use different norm for different layers')
return parser.parse_args()
def main():
args = get_args()
device = args.device
torch.cuda.set_device(device)
if args.dataset == 'cifar10' or args.dataset == 'svhn':
args.num_classes = 10
elif args.dataset == 'cifar100':
args.num_classes = 100
mu, std, upper_limit, lower_limit = get_norm_stat(cifar10_mean, cifar10_std)
path = os.path.join('./ckpt', args.dataset, args.network)
args.save_dir = os.path.join(path, args.save_dir)
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
logfile = os.path.join(args.save_dir, 'output.log')
if os.path.exists(logfile):
os.remove(logfile)
handlers = [logging.FileHandler(logfile, mode='a+'),
logging.StreamHandler()]
logging.basicConfig(
format='[%(asctime)s] - %(message)s',
datefmt='%Y/%m/%d %H:%M:%S',
level=logging.INFO,
handlers=handlers)
logger.info(args)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
# get data loader
train_loader, test_loader, dataset_normalization = get_loaders(args.data_dir, args.batch_size, dataset=args.dataset,
worker=args.worker)
train_loader_e, test_loader_e, dataset_normalization_e = get_loaders(args.data_dir, args.batch_size, dataset=args.dataset,
worker=args.worker, norm=False)
# adv training attack setting
epsilon = ((args.epsilon / 255.) / std).to(device)
alpha = ((args.alpha / 255.) / std).to(device)
node = LIFNode
if args.node_type == 'RHLIF':
node = RHLIFNode
elif args.node_type == 'RHLIF2':
node = RHLIFNode2
# setup network
model = SEWResNet19(BasicBlock, [3, 3, 2], cnf='ADD', node_type=node, step=args.time_step, num_classes=args.num_classes,
layer_by_layer=True, act_fun=AtanGrad)
model.to(device)
# model = torch.nn.DataParallel(model)
# logger.info(model)
# setup optimizer, loss function, LR scheduler
# opt = torch.optim.AdamW(model.parameters(), lr=args.lr_max, weight_decay=args.weight_decay)
if args.parseval:
opt = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=0.9, weight_decay=0)
else:
opt = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=0.9, weight_decay=args.weight_decay)
criterion = nn.CrossEntropyLoss()
if args.lr_schedule == 'cyclic':
lr_steps = args.epochs
scheduler = torch.optim.lr_scheduler.CyclicLR(opt, base_lr=args.lr_min, max_lr=args.lr_max,
step_size_up=lr_steps / 2, step_size_down=lr_steps / 2)
elif args.lr_schedule == 'multistep':
lr_steps = args.epochs
scheduler = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=[lr_steps / 2, lr_steps * 3 / 4], gamma=0.1)
elif args.lr_schedule == 'cosine':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs)
best_pgd_acc = 0
best_clean_acc = 0
test_acc_best_pgd = 0
start_epoch = 0
# Start training
start_train_time = time.time()
for epoch in range(start_epoch, args.epochs):
logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
model.train()
train_loss = 0
train_acc = 0
train_n = 0
for i, (X, y) in enumerate(train_loader):
_iters = epoch * len(train_loader) + i
X, y = X.to(device), y.to(device)
if args.adv_training:
# init delta
delta = torch.zeros_like(X).to(device)
for j in range(len(epsilon)):
delta[:, j, :, :].uniform_((-epsilon[j][0][0] / 10).item(), (epsilon[j][0][0]/10).item())
delta.data = clamp(delta, lower_limit.to(device) - X, upper_limit.to(device) - X)
delta.requires_grad = True
# pgd attack
for _ in range(args.attack_iters):
output = model(X + delta)
# model.random_reset_step = 0
loss = criterion(output, y)
loss.backward()
grad = delta.grad.detach()
delta.data = clamp(delta + alpha * torch.sign(grad), -epsilon, epsilon)
delta.data = clamp(delta, lower_limit.to(device) - X, upper_limit.to(device) - X)
delta.grad.zero_()
delta = delta.detach()
X_adv = X + delta[:X.size(0)]
else:
X_adv = X
if args.SR:
X_adv.requires_grad_(True)
outputs = model(X_adv)
out = outputs.gather(1, y.unsqueeze(1)).squeeze() # choose
batch = []
inds = []
for j in range(len(outputs)):
mm, ind = torch.cat([outputs[j, :y[j]], outputs[j, y[j] + 1:]], dim=0).max(0)
f = torch.exp(out[j]) / (torch.exp(out[j]) + torch.exp(mm))
batch.append(f)
inds.append(ind.item())
f1 = torch.stack(batch, dim=0)
loss1 = criterion(outputs, y)
dx = torch.autograd.grad(f1, X_adv, grad_outputs=torch.ones_like(f1, device=device), retain_graph=True)[0]
X_adv.requires_grad_(False)
v = dx.detach().sign()
x2 = X_adv + 0.01 * v
outputs2 = model(x2)
out = outputs2.gather(1, y.unsqueeze(1)).squeeze() # choose
batch = []
for j in range(len(outputs2)):
mm = torch.cat([outputs2[j, :y[j]], outputs2[j, y[j] + 1:]], dim=0)[inds[j]]
f = torch.exp(out[j]) / (torch.exp(out[j]) + torch.exp(mm))
batch.append(f)
f2 = torch.stack(batch, dim=0)
dl = (f2 - f1) / 0.01
loss2 = dl.pow(2).mean()
loss = loss1 + 0.001 * loss2
loss = loss.mean()
else:
output = model(X_adv)
loss = criterion(output, y)
opt.zero_grad()
loss.backward()
opt.step()
if args.parseval:
orthogonal_retraction(model, args.beta)
# for name, param in model.named_parameters():
# if 'sigma' in name: # 查找包含 'sigma' 的参数
# param.data = torch.clamp(param.data, 0.0, 1.5) # 约束参数范围
# if i==0:
# print(param)
train_loss += loss.item() * y.size(0)
train_acc += (output.max(1)[1] == y).sum().item()
train_n += y.size(0)
if i % 50 == 0:
logger.info("Iter: [{:d}][{:d}/{:d}]\t"
"Loss {:.3f} ({:.3f})\t"
"Prec@1 {:.3f} ({:.3f})\t".format(
epoch,
i,
len(train_loader),
loss.item(),
train_loss / train_n,
(output.max(1)[1] == y).sum().item() / y.size(0),
train_acc / train_n)
)
scheduler.step()
logger.info('Evaluating with standard images...')
test_loss, test_acc = evaluate_standard(test_loader, model, args)
logger.info(
'Test Loss: %.4f \t Test Acc: %.4f',
test_loss, test_acc)
if test_acc > best_clean_acc:
best_clean_acc = (
test_acc)
torch.save(model.state_dict(), os.path.join(args.save_dir, 'weight_c.pth'))
# pgd_loss, pgd_acc = evaluate_pgd(test_loader, model, 5, 1, args)
if epoch > args.epochs - 10:
logger.info('Evaluating with APGD Attack...')
model.normalize = dataset_normalization_e
atk = torchattacks.APGD(model, norm='Linf', eps=8 / 255, steps=10)
pgd_loss, pgd_acc = evaluate_attack(model, test_loader_e, args, atk, 'APGD', logger)
model.normalize = dataset_normalization
if pgd_acc > best_pgd_acc:
best_pgd_acc = pgd_acc
test_acc_best_pgd = test_acc
torch.save(model.state_dict(), os.path.join(args.save_dir, 'weight_r.pth'))
logger.info(
'PGD Loss: %.4f \t PGD Acc: %.4f \n Best PGD Acc: %.4f \t Test Acc of best PGD ckpt: %.4f',
pgd_loss, pgd_acc, best_pgd_acc, test_acc_best_pgd)
train_time = time.time()
logger.info('Total train time: %.4f minutes', (train_time - start_train_time) / 60)
if __name__ == "__main__":
main()
================================================
FILE: examples/Snn_safety/RandHet-SNN/utils.py
================================================
# import apex.amp as amp
import os.path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)
def get_norm_stat(mean, std):
mu = torch.tensor(mean).view(3, 1, 1)
std = torch.tensor(std).view(3, 1, 1)
upper_limit = ((1 - mu) / std)
lower_limit = ((0 - mu) / std)
return mu, std, upper_limit, lower_limit
def clamp(X, lower_limit, upper_limit):
return torch.max(torch.min(X, upper_limit), lower_limit)
def normalize_fn(tensor, mean, std):
"""Differentiable version of torchvision.functional.normalize"""
# here we assume the color channel is in at dim=1
mean = mean[None, :, None, None]
std = std[None, :, None, None]
return tensor.sub(mean).div(std)
class NormalizeByChannelMeanStd(nn.Module):
def __init__(self, mean, std):
super(NormalizeByChannelMeanStd, self).__init__()
if not isinstance(mean, torch.Tensor):
mean = torch.tensor(mean)
if not isinstance(std, torch.Tensor):
std = torch.tensor(std)
self.register_buffer("mean", mean)
self.register_buffer("std", std)
def forward(self, tensor):
return normalize_fn(tensor, self.mean, self.std)
def extra_repr(self):
return 'mean={}, std={}'.format(self.mean, self.std)
def get_loaders(dir_, batch_size, dataset='cifar10', worker=4, norm=True):
if norm:
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(cifar10_mean, cifar10_std),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(cifar10_mean, cifar10_std),
])
dataset_normalization = None
else:
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
])
dataset_normalization = NormalizeByChannelMeanStd(
mean=cifar10_mean, std=cifar10_std)
if dataset == 'cifar10':
train_dataset = datasets.CIFAR10(
dir_, train=True, transform=train_transform, download=True)
test_dataset = datasets.CIFAR10(
dir_, train=False, transform=test_transform, download=True)
elif dataset == 'cifar100':
train_dataset = datasets.CIFAR100(
dir_, train=True, transform=train_transform, download=True)
test_dataset = datasets.CIFAR100(
dir_, train=False, transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=worker,
)
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=worker,
)
return train_loader, test_loader, dataset_normalization
# evaluate on clean images with single norm
def evaluate_standard(test_loader, model, args):
test_loss = 0
test_acc = 0
n = 0
model.eval()
device = args.device
with torch.no_grad():
for i, (X, y) in enumerate(test_loader):
X, y = X.to(device), y.to(device)
output = model(X)
loss = F.cross_entropy(output, y)
test_loss += loss.item() * y.size(0)
test_acc += (output.max(1)[1] == y).sum().item()
n += y.size(0)
return test_loss/n, test_acc/n
def orthogonal_retraction(model, beta=0.002):
with torch.no_grad():
for module in model.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
if isinstance(module, nn.Conv2d):
weight_ = module.weight.data
sz = weight_.shape
weight_ = weight_.reshape(sz[0],-1)
rows = list(range(module.weight.data.shape[0]))
elif isinstance(module, nn.Linear):
if module.weight.data.shape[0] < 200: # set a sample threshold for row number
weight_ = module.weight.data
sz = weight_.shape
weight_ = weight_.reshape(sz[0], -1)
rows = list(range(module.weight.data.shape[0]))
else:
rand_rows = np.random.permutation(module.weight.data.shape[0])
rows = rand_rows[: int(module.weight.data.shape[0] * 0.3)]
weight_ = module.weight.data[rows,:]
sz = weight_.shape
module.weight.data[rows,:] = ((1 + beta) * weight_ - beta * weight_.matmul(weight_.t()).matmul(weight_)).reshape(sz)
================================================
FILE: examples/Social_Cognition/FOToM/algorithms/ToM_class.py
================================================
import torch
import torch.distributions as td
from utils.networks import MLPNetwork, SNNNetwork, LSTMClassifier
from utils.misc import soft_update, average_gradients, onehot_from_logits, gumbel_softmax
class ToM1(object):
"""
tom factory (Simplification of ToM's model)
init ToM0 and ToM1 net
train ToM0 and ToM1 net
"""
def __init__(self, tom_base, alg_types, agent_types, num_lm, device, hidden_dim=64):
self.device = device
self.alg_types = alg_types
self.agent_types = agent_types
self.num_good_agents = len(self._get_index1(self.agent_types, 'agent'))
self.nagents = len(alg_types)
self.num_lm = num_lm
'''
Assume that ToM0 and ToM1 are equivalent
'''
self.tom1 = tom_base
self.other_tom1 = [0] * self.nagents
self._agent_tom1_init()
'''
ToM0_policy
'''
# self.tom_PHI = [] #TODO
self.hidden = None
def _agent_tom1_init(self):
other_alg_types_ = self.alg_types.copy()
other_agent_types_ = self.agent_types.copy()
for agent_i in range(self.nagents):
other_alg_types = other_alg_types_.copy()
other_agent_types = other_agent_types_.copy()
other_alg_types.pop(agent_i)
other_agent_types.pop(agent_i)
adv_indx = self._get_index1(other_agent_types, 'adversary')
good_indx = self._get_index1(other_agent_types, 'agent')
self.other_tom1[agent_i] = [self.tom1['adversary'][self.agent_types[agent_i]]] * len(adv_indx) #TODO
self.other_tom1[agent_i] += [self.tom1['agent'][self.agent_types[agent_i]]] * len(good_indx)
def _get_index1(self, lst=None, item=''):
return [index for (index, value) in enumerate(lst) if value == item]
def c_function(self, tom0_actions_q, tom1_actions_q):
c1 = 0.7
# tom0_actions = torch.stack([gumbel_softmax(action_i, hard=True)
# for action_i in tom0_actions_prob], 0)
# tom1_actions = torch.stack([gumbel_softmax(action_i, hard=True)
# for action_i in tom1_actions_prob], 0)
'''
batch, num_agent, ep, 1
'''
tom0_actions = (tom0_actions_q == tom0_actions_q.max(dim=-1, keepdim=True)[0]).to(dtype=torch.int32)
tom1_actions = (tom1_actions_q.unsqueeze(1) == tom1_actions_q.unsqueeze(0).max(dim=-1, keepdim=True)[0]).to(
dtype=torch.int32)
alig = tom0_actions.long().detach() & tom1_actions.long().detach()
I_belief = tom0_actions_q * (1 - c1) + alig * c1
# I_belief = [prob_i * (1 -c1) + alig[i] * c1 for i, prob_i in enumerate(tom0_actions_prob)]
return I_belief
def tom1_output(self, agent_i, adv_indx, good_indx, obs_, acs_pre_):
"""
ToM1 <--> ToM1
obs_self : obs of self, need to convert
tom0_out : predict other-action (episode_num * self.args.episode_limit * 2, -1), need to convert
tom0_out_q : predict other-action q_value (episode_num * self.args.episode_limit * 2, -1)
device : interact with env (cpu) train (cuda)
ToM0_policy
"""
actions = []
actions += [
# gumbel_softmax(
self.other_tom1[agent_i][j].to(self.device)(
torch.cat((obs_[:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],
acs_pre_[:, :5]), 1))#.detach() #, hard=True
for j in adv_indx
]
actions += [
# gumbel_softmax(
self.other_tom1[agent_i][j].to(self.device)(
torch.cat((obs_[:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],
acs_pre_[:, :5]), 1))#.detach() #, hard=True
for j in good_indx
]
# E_action = torch.cat(actions, 1)
E_action = actions
return E_action
================================================
FILE: examples/Social_Cognition/FOToM/algorithms/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/FOToM/algorithms/maddpg.py
================================================
import torch
from torch.optim import Adam
import torch.nn.functional as F
from gym.spaces import Box, Discrete, MultiDiscrete
from multiagent.multi_discrete import MultiDiscrete
from utils.networks import MLPNetwork, SNNNetwork, LSTMClassifier
from utils.misc import soft_update, average_gradients, onehot_from_logits, gumbel_softmax
from utils.agents import DDPGAgent, DDPGAgent_RNN, DDPGAgent_SNN, DDPGAgent_ToM
# from commom.distributions import make_pdtype
from thop import profile
from thop import clever_format
import time
MSELoss = torch.nn.MSELoss()
# reference:https://github.com/starry-sky6688/MADDPG.git
class MADDPG(object):
def __init__(self, agent_init_params, alg_types, device,
gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64,
discrete_action=False):
"""
Inputs:
agent_init_params (list of dict): List of dicts with parameters to
initialize each agent
num_in_pol (int): Input dimensions to policy
num_out_pol (int): Output dimensions to policy
num_in_critic (int): Input dimensions to critic
alg_types (list of str): Learning algorithm for each agent (DDPG
or MADDPG)
gamma (float): Discount factor
tau (float): Target update rate
lr (float): Learning rate for policy and critic
hidden_dim (int): Number of hidden dimensions for networks
discrete_action (bool): Whether or not to use discrete action space
"""
self.device = device
self.nagents = len(alg_types)
self.alg_types = alg_types
self.agents = [DDPGAgent(lr=lr, discrete_action=discrete_action,
hidden_dim=hidden_dim,
**params)
for params in agent_init_params]
self.agent_init_params = agent_init_params
self.gamma = gamma
self.tau = tau
self.lr = lr
self.discrete_action = discrete_action
self.pol_dev = 'cpu' # device for policies
self.critic_dev = 'cpu' # device for critics
self.trgt_pol_dev = 'cpu' # device for target policies
self.trgt_critic_dev = 'cpu' # device for target critics
self.niter = 0
@property
def policies(self):
return [a.policy for a in self.agents]
@property
def target_policies(self):
return [a.target_policy for a in self.agents]
def scale_noise(self, scale):
"""
Scale noise for each agent
Inputs:
scale (float): scale of noise
"""
for a in self.agents:
a.scale_noise(scale)
def reset_noise(self):
for a in self.agents:
a.reset_noise()
def step(self, observations, explore=False):
"""
Take a step forward in environment with all agents
Inputs:
observations: List of observations for each agent
explore (boolean): Whether or not to add exploration noise
Outputs:
actions: List of actions for each agent
"""
return [a.step(obs, explore=explore) for a, obs in zip(self.agents,
observations)]
def update(self, sample, agent_i, parallel=False, logger=None):
"""
Update parameters of agent model based on sample from replay buffer
Inputs:
sample: tuple of (observations, actions, rewards, next
observations, and episode end masks) sampled randomly from
the replay buffer. Each is a list with entries
corresponding to each agent
agent_i (int): index of agent to update
parallel (bool): If true, will average gradients across threads
logger (SummaryWriter from Tensorboard-Pytorch):
If passed in, important quantities will be logged
"""
obs, acs, rews, next_obs, dones = sample
curr_agent = self.agents[agent_i]
curr_agent.critic_optimizer.zero_grad()
if self.alg_types[agent_i] == 'MADDPG':
if self.discrete_action: # one-hot encode action
all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in
zip(self.target_policies, next_obs)]
else:
all_trgt_acs = [pi(nobs) for pi, nobs in zip(self.target_policies,
next_obs)]
trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)
else: # DDPG
if self.discrete_action:
trgt_vf_in = torch.cat((next_obs[agent_i],
onehot_from_logits(
curr_agent.target_policy(
next_obs[agent_i]))),
dim=1)
else:
trgt_vf_in = torch.cat((next_obs[agent_i],
curr_agent.target_policy(next_obs[agent_i])),
dim=1)
target_value = (rews[agent_i].view(-1, 1) + self.gamma *
curr_agent.target_critic(trgt_vf_in) *
(1 - dones[agent_i].view(-1, 1)))
if self.alg_types[agent_i] == 'MADDPG':
vf_in = torch.cat((*obs, *acs), dim=1)
else: # DDPG
vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1)
actual_value = curr_agent.critic(vf_in)
vf_loss = MSELoss(actual_value, target_value.detach())
vf_loss.backward()
if parallel:
average_gradients(curr_agent.critic)
torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)
curr_agent.critic_optimizer.step()
curr_agent.policy_optimizer.zero_grad()
if self.discrete_action:
# Forward pass as if onehot (hard=True) but backprop through a differentiable
# Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
# through discrete categorical samples, but I'm not sure if that is
# correct since it removes the assumption of a deterministic policy for
# DDPG. Regardless, discrete policies don't seem to learn properly without it.
curr_pol_out = curr_agent.policy(obs[agent_i])
curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
else:
curr_pol_out = curr_agent.policy(obs[agent_i])
curr_pol_vf_in = curr_pol_out
if self.alg_types[agent_i] == 'MADDPG':
all_pol_acs = []
for i, pi, ob in zip(range(self.nagents), self.policies, obs):
if i == agent_i:
all_pol_acs.append(curr_pol_vf_in)
elif self.discrete_action:
all_pol_acs.append(onehot_from_logits(pi(ob)))
else:
all_pol_acs.append(pi(ob))
vf_in = torch.cat((*obs, *all_pol_acs), dim=1)
else: # DDPG
vf_in = torch.cat((obs[agent_i], curr_pol_vf_in),
dim=1)
pol_loss = -curr_agent.critic(vf_in).mean()
pol_loss += (curr_pol_out**2).mean() * 1e-3
pol_loss.backward()
if parallel:
average_gradients(curr_agent.policy)
torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5)
curr_agent.policy_optimizer.step()
if logger is not None:
logger.add_scalars('agent%i/losses' % agent_i,
{'vf_loss': vf_loss,
'pol_loss': pol_loss},
self.niter)
def update_all_targets(self):
"""
Update all target networks (called after normal updates have been
performed for each agent)
"""
for a in self.agents:
soft_update(a.target_critic, a.critic, self.tau)
soft_update(a.target_policy, a.policy, self.tau)
self.niter += 1
def prep_training(self, device='gpu'):
for a in self.agents:
a.policy.train()
a.critic.train()
a.target_policy.train()
a.target_critic.train()
if device == 'gpu':
fn = lambda x: x.to(torch.device(self.device))
else:
fn = lambda x: x.cpu()
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
if not self.critic_dev == device:
for a in self.agents:
a.critic = fn(a.critic)
self.critic_dev = device
if not self.trgt_pol_dev == device:
for a in self.agents:
a.target_policy = fn(a.target_policy)
self.trgt_pol_dev = device
if not self.trgt_critic_dev == device:
for a in self.agents:
a.target_critic = fn(a.target_critic)
self.trgt_critic_dev = device
def prep_rollouts(self, device='cpu'):
for a in self.agents:
a.policy.eval()
if device == 'gpu':
fn = lambda x: x.to(torch.device(self.device))
else:
fn = lambda x: x.cpu()
# only need main policy for rollouts
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
def save(self, filename):
"""
Save trained parameters of all agents into one file
"""
self.prep_training(device='cpu') # move parameters to CPU before saving
save_dict = {'init_dict': self.init_dict,
'agent_params': [a.get_params() for a in self.agents]}
torch.save(save_dict, filename)
@classmethod
def init_from_env(cls, env, device, agent_alg="MADDPG", adversary_alg="MADDPG",
gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64):
"""
Instantiate instance of this class from multi-agent environment
"""
agent_init_params = []
alg_types = [adversary_alg if atype == 'adversary' else agent_alg for
atype in env.agent_types]
for acsp, obsp, algtype in zip(env.action_space, env.observation_space,
alg_types):
num_in_pol = obsp.shape[0]
if isinstance(acsp, Box):
discrete_action = False
get_shape = lambda x: x.shape[0]
elif isinstance(acsp, Discrete): # Discrete
discrete_action = True
get_shape = lambda x: x.n
elif isinstance(acsp, MultiDiscrete):
discrete_action = True
get_shape = lambda x: sum(x.high - x.low + 1)
num_out_pol = get_shape(acsp)
if algtype == "MADDPG":
num_in_critic = 0
for oobsp in env.observation_space:
num_in_critic += oobsp.shape[0]
for oacsp in env.action_space:
if isinstance(oacsp, Box):
discrete_action = False
get_shape = lambda x: x.shape[0]
elif isinstance(oacsp, Discrete): # Discrete
discrete_action = True
get_shape = lambda x: x.n
elif isinstance(oacsp, MultiDiscrete):
discrete_action = True
get_shape = lambda x: sum(x.high - x.low + 1)
num_in_critic += get_shape(oacsp)
else:
num_in_critic = obsp.shape[0] + get_shape(acsp)
agent_init_params.append({'num_in_pol': num_in_pol,
'num_out_pol': num_out_pol,
'num_in_critic': num_in_critic})
init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr,
'hidden_dim': hidden_dim,
'alg_types': alg_types,
'agent_init_params': agent_init_params,
'discrete_action': discrete_action,
'device': device}
instance = cls(**init_dict)
instance.init_dict = init_dict
return instance
@classmethod
def init_from_save(cls, filename):
"""
Instantiate instance of this class from file created by 'save' method
"""
save_dict = torch.load(filename)
instance = cls(**save_dict['init_dict'])
instance.init_dict = save_dict['init_dict']
for a, params in zip(instance.agents, save_dict['agent_params']):
a.load_params(params)
return instance
class MADDPG_RNN(object):
"""
Wrapper class for DDPG-esque (i.e. also MADDPG) agents in multi-agent task
"""
def __init__(self, agent_init_params, alg_types,
gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64,
discrete_action=False):
"""
Inputs:
agent_init_params (list of dict): List of dicts with parameters to
initialize each agent
num_in_pol (int): Input dimensions to policy
num_out_pol (int): Output dimensions to policy
num_in_critic (int): Input dimensions to critic
alg_types (list of str): Learning algorithm for each agent (DDPG
or MADDPG)
gamma (float): Discount factor
tau (float): Target update rate
lr (float): Learning rate for policy and critic
hidden_dim (int): Number of hidden dimensions for networks
discrete_action (bool): Whether or not to use discrete action space
"""
self.nagents = len(alg_types)
self.alg_types = alg_types
self.agents = [DDPGAgent_RNN(lr=lr, discrete_action=discrete_action,
hidden_dim=hidden_dim,
**params)
for params in agent_init_params]
self.agent_init_params = agent_init_params
self.gamma = gamma
self.tau = tau
self.lr = lr
self.discrete_action = discrete_action
self.pol_dev = 'cpu' # device for policies
self.critic_dev = 'cpu' # device for critics
self.trgt_pol_dev = 'cpu' # device for target policies
self.trgt_critic_dev = 'cpu' # device for target critics
self.niter = 0
def _init_agent(self, n_rollout_threads):
for agent_i in self.agents:
agent_i.init_hidden(n_rollout_threads, policy_hidden=True, policy_target_hidden=True, \
critic_hidden=True, critic_target_hidden=True)
# @property
def policies(self, len_ep):
for a in self.agents: a.init_hidden(len_ep, policy_hidden=True, policy_target_hidden=False, \
critic_hidden=False, critic_target_hidden=False)
return [a.policy for a in self.agents], [a.policy_hidden for a in self.agents]
# @property
def target_policies(self, len_ep):
for a in self.agents: a.init_hidden(len_ep, policy_hidden=False, policy_target_hidden=True, \
critic_hidden=False, critic_target_hidden=False)
return [a.target_policy for a in self.agents], [a.policy_target_hidden for a in self.agents]
def scale_noise(self, scale):
"""
Scale noise for each agent
Inputs:
scale (float): scale of noise
"""
for a in self.agents:
a.scale_noise(scale)
def reset_noise(self):
for a in self.agents:
a.reset_noise()
def step(self, observations, explore=False):
"""
Take a step forward in environment with all agents
Inputs:
observations: List of observations for each agent
explore (boolean): Whether or not to add exploration noise
Outputs:
actions: List of actions for each agent
"""
return [a.step(obs, explore=explore) for a, obs in zip(self.agents,
observations)]
def _compute_rnn(self, fn, hidden, inputs, logit):
num_ep = inputs.shape[1]
outputs = []
hidden = hidden.to(torch.device('cuda:4'))
for step_id in range(num_ep):
output, hidden = fn(inputs[:, step_id, :], hidden)
if logit == onehot_from_logits:
outputs.append(logit(output))
elif logit == gumbel_softmax:
outputs.append(logit(output, True))
else:
outputs.append(output)
outputs = torch.stack(outputs,1)
return outputs
def update(self, sample, agent_i, parallel=False, logger=None):
"""
Update parameters of agent model based on sample from replay buffer
Inputs:
sample: tuple of (observations, actions, rewards, next
observations, and episode end masks) sampled randomly from
the replay buffer. Each is a list with entries
corresponding to each agent
agent_i (int): index of agent to update
parallel (bool): If true, will average gradients across threads
logger (SummaryWriter from Tensorboard-Pytorch):
If passed in, important quantities will be logged
"""
obs, acs, rews, next_obs, dones = sample
curr_agent = self.agents[agent_i]
len_ep = obs[0].shape[0]
curr_agent.init_hidden(len_ep, True, True, True, True)
curr_agent.critic_optimizer.zero_grad()
if self.alg_types[agent_i] == 'MADDPG_RNN':
if self.discrete_action: # one-hot encode action
all_trgt_acs = [self._compute_rnn(pi, hidden, nobs, onehot_from_logits) for pi, hidden, nobs in \
zip(self.target_policies(len_ep)[0], self.target_policies(len_ep)[1], next_obs)]
else:
all_trgt_acs = [pi(nobs) for pi, nobs in zip(self.target_policies,
next_obs)]
trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=-1)
target_critic = self._compute_rnn(curr_agent.target_critic, curr_agent.critic_target_hidden, trgt_vf_in, None)
target_value = (rews[agent_i].view(-1, 1) + self.gamma *
target_critic.view(-1, 1) *
(1 - dones[agent_i].view(-1, 1)))
if self.alg_types[agent_i] == 'MADDPG_RNN':
vf_in = torch.cat((*obs, *acs), dim=-1)
actual_value = self._compute_rnn(curr_agent.critic, curr_agent.critic_hidden, vf_in, None)
vf_loss = MSELoss(actual_value.view(-1,1), target_value.detach())
vf_loss.backward()
if parallel:
average_gradients(curr_agent.critic)
torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), 0.5)
curr_agent.critic_optimizer.step()
curr_agent.policy_optimizer.zero_grad()
if self.discrete_action:
# Forward pass as if onehot (hard=True) but backprop through a differentiable
# Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
# through discrete categorical samples, but I'm not sure if that is
# correct since it removes the assumption of a deterministic policy for
# DDPG. Regardless, discrete policies don't seem to learn properly without it.
curr_pol_out = self._compute_rnn(curr_agent.policy, curr_agent.policy_hidden, obs[agent_i], None)
curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
if self.alg_types[agent_i] == 'MADDPG_RNN':
all_pol_acs = []
for i, pi, policy_hidden, ob in zip(range(self.nagents), self.policies(len_ep)[0], self.policies(len_ep)[1], obs):
if i == agent_i:
all_pol_acs.append(curr_pol_vf_in)
elif self.discrete_action:
all_pol_acs.append(self._compute_rnn(pi, policy_hidden, ob, onehot_from_logits))
# all_pol_acs.append(onehot_from_logits(pi(ob)))
else:
all_pol_acs.append(pi(ob))
vf_in = torch.cat((*obs, *all_pol_acs), dim=-1)
curr_agent.init_hidden(len_ep, False, False, True, False)
pol_loss = -self._compute_rnn(curr_agent.critic, policy_hidden, vf_in, None).mean()
# pol_loss = -curr_agent.critic(vf_in).mean()
pol_loss += (curr_pol_out**2).mean() * 1e-3
pol_loss.backward()
if parallel:
average_gradients(curr_agent.policy)
torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), 0.5)
curr_agent.policy_optimizer.step()
if logger is not None:
logger.add_scalars('agent%i/losses' % agent_i,
{'vf_loss': vf_loss,
'pol_loss': pol_loss},
self.niter)
def update_all_targets(self):
"""
Update all target networks (called after normal updates have been
performed for each agent)
"""
for a in self.agents:
soft_update(a.target_critic, a.critic, self.tau)
soft_update(a.target_policy, a.policy, self.tau)
self.niter += 1
def prep_training(self, device='gpu'):
for a in self.agents:
a.policy.train()
a.critic.train()
a.target_policy.train()
a.target_critic.train()
if device == 'gpu':
fn = lambda x: x.to(torch.device('cuda:4'))
else:
fn = lambda x: x.cpu()
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
if not self.critic_dev == device:
for a in self.agents:
a.critic = fn(a.critic)
self.critic_dev = device
if not self.trgt_pol_dev == device:
for a in self.agents:
a.target_policy = fn(a.target_policy)
self.trgt_pol_dev = device
if not self.trgt_critic_dev == device:
for a in self.agents:
a.target_critic = fn(a.target_critic)
self.trgt_critic_dev = device
def prep_rollouts(self, device='cpu'):
for a in self.agents:
a.policy.eval()
if device == 'gpu':
fn = lambda x: x.to(torch.device('cuda:4'))
else:
fn = lambda x: x.cpu()
# only need main policy for rollouts
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
def save(self, filename):
"""
Save trained parameters of all agents into one file
"""
self.prep_training(device='cpu') # move parameters to CPU before saving
save_dict = {'init_dict': self.init_dict,
'agent_params': [a.get_params() for a in self.agents]}
torch.save(save_dict, filename)
@classmethod
def init_from_env(cls, env, agent_alg="MADDPG", adversary_alg="MADDPG_RNN",
gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64):
"""
Instantiate instance of this class from multi-agent environment
"""
agent_init_params = []
alg_types = [adversary_alg if atype == 'adversary' else agent_alg for
atype in env.agent_types]
for acsp, obsp, algtype in zip(env.action_space, env.observation_space,
alg_types):
num_in_pol = obsp.shape[0]
if isinstance(acsp, Box):
discrete_action = False
get_shape = lambda x: x.shape[0]
else: # Discrete
discrete_action = True
get_shape = lambda x: x.n
num_out_pol = get_shape(acsp)
if algtype == "MADDPG_RNN":
num_in_critic = 0
for oobsp in env.observation_space:
num_in_critic += oobsp.shape[0]
for oacsp in env.action_space:
num_in_critic += get_shape(oacsp)
else:
num_in_critic = obsp.shape[0] + get_shape(acsp)
agent_init_params.append({'num_in_pol': num_in_pol,
'num_out_pol': num_out_pol,
'num_in_critic': num_in_critic})
init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr,
'hidden_dim': hidden_dim,
'alg_types': alg_types,
'agent_init_params': agent_init_params,
'discrete_action': discrete_action}
instance = cls(**init_dict)
instance.init_dict = init_dict
return instance
@classmethod
def init_from_save(cls, filename):
"""
Instantiate instance of this class from file created by 'save' method
"""
save_dict = torch.load(filename)
instance = cls(**save_dict['init_dict'])
instance.init_dict = save_dict['init_dict']
for a, params in zip(instance.agents, save_dict['agent_params']):
a.load_params(params)
return instance
================================================
FILE: examples/Social_Cognition/FOToM/algorithms/tom11.py
================================================
import torch
from torch.optim import Adam
import torch.nn.functional as F
from gym.spaces import Box, Discrete, MultiDiscrete
from multiagent.multi_discrete import MultiDiscrete
from utils.networks import MLPNetwork, SNNNetwork, LSTMClassifier
from utils.misc import soft_update, average_gradients, onehot_from_logits, gumbel_softmax
from utils.agents import DDPGAgent
from algorithms.ToM_class import ToM1
# from commom.distributions import make_pdtype
from thop import profile
from thop import clever_format
import time
MSELoss = torch.nn.MSELoss()
KL_criterion = torch.nn.KLDivLoss(reduction='sum')
CE_criterion = torch.nn.CrossEntropyLoss(reduction="sum")
class ToM_decision11(object):
def __init__(self, agent_init_params, alg_types, agent_types, num_lm,
output_style, device, config, gamma=0.95, tau=0.01, lr=0.01,
hidden_dim=64, discrete_action=False):
self.config = config
self.device = device
self.num_lm = num_lm
self.nagents = len(alg_types)
self.alg_types = alg_types
self.agent_types = agent_types
self.num_good_agents = len(self._get_index1(self.agent_types, 'agent'))
self.agents = [DDPGAgent_ToM(lr=lr, discrete_action=discrete_action,
hidden_dim=hidden_dim,
**params, output_style=output_style,
num_agents=self.nagents,
device=self.device)
for params in agent_init_params]
self.agent_init_params = agent_init_params
# tom0
self.mle_base = [MLPNetwork(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2 + 5,
self.agent_init_params[-1]['num_out_pol'],
hidden_dim=5, norm_in=False), # infer good agent
MLPNetwork(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2 + 5,
self.agent_init_params[-1]['num_out_pol'],
hidden_dim=5, norm_in=False), # infer adversary
]
self.tom_base = {
'agent': {'agent': self.mle_base[0],
'adversary': self.mle_base[1]
},
'adversary': {'agent': self.mle_base[0],
'adversary': self.mle_base[1]
}
}
self.tom_PHI = [LSTMClassifier(self.num_good_agents * 2 + self.num_lm * 2 +
(self.nagents - 1) * 2 + 5 * (self.nagents - 1),
self.agent_init_params[-1]['num_out_pol'],
hidden_size=64), # infer good agent
LSTMClassifier(self.num_good_agents * 2 + self.num_lm * 2 +
(self.nagents - 1) * 2 + 5 * (self.nagents - 1),
self.agent_init_params[-1]['num_out_pol'],
hidden_size=64), # infer adversary
] #TODO
self._agent_tom_init() #TODO
self.tom1 = ToM1(self.tom_base, alg_types, agent_types, num_lm, device)
self.actions_tom0 = []
self.next_actions_tom0 = []
self.actions_tom1 = []
self.next_actions_tom1 = []
self.mle_opts = [Adam(i.parameters(), lr=1e-4) for i in self.mle_base]
self.PHI_opts = [Adam(i.parameters(), lr=1e-4) for i in self.tom_PHI]
self.gamma = gamma
self.tau = tau
self.lr = lr
self.discrete_action = discrete_action
self.pol_dev = 'cpu' # device for policies
self.critic_dev = 'cpu' # device for critics
self.trgt_pol_dev = 'cpu' # device for target policies
self.trgt_critic_dev = 'cpu' # device for target critics
self.mle_dev = 'cpu'
self.niter = 0
@property
def policies(self):
return [a.policy for a in self.agents]
@property
def target_policies(self):
return [a.target_policy for a in self.agents]
def scale_noise(self, scale):
"""
Scale noise for each agent
Inputs:
scale (float): scale of noise
"""
for a in self.agents:
a.scale_noise(scale)
def reset_noise(self):
for a in self.agents:
a.reset_noise()
def _get_index1(self, lst=None, item=''):
return [index for (index, value) in enumerate(lst) if value == item]
def _agent_tom_init(self):
# other_alg_types_ = self.alg_types.copy()
other_agent_types_ = self.agent_types.copy()
for agent_i in range(self.nagents):
# other_alg_types = other_alg_types_.copy()
other_agent_types = other_agent_types_.copy()
# other_alg_types.pop(agent_i)
other_agent_types.pop(agent_i)
adv_indx = self._get_index1(other_agent_types, 'adversary')
good_indx = self._get_index1(other_agent_types, 'agent')
self.agents[agent_i].mle += [self.tom_base[self.agent_types[agent_i]]['adversary']] * len(adv_indx) #TODO
self.agents[agent_i].mle += [self.tom_base[self.agent_types[agent_i]]['agent']] * len(good_indx)
def step(self, observations, actions_pre, explore=False): #simple_tag
"""
Take a step forward in environment with all agents
Inputs:
observations: List of observations for each agent
explore (boolean): Whether or not to add exploration noise
Outputs:
actions: List of actions for each agent
"""
# t1 = time.time()
observations_ = observations.copy()
actions_pre_ = actions_pre.copy()
# other_alg_types_ = self.alg_types.copy()
other_agent_types_ = self.agent_types.copy()
adv_agent_indx = self._get_index1(self.agent_types, 'adversary')
good_agent_indx = self._get_index1(self.agent_types, 'agent')
'''
tom0
'''
actions_tom0 = []
actions_tom0 += [
gumbel_softmax(
self.mle_base[1].to(self.device)(
torch.cat((observations[j][:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],
actions_pre[j][:, :5]), 1).to(self.device)).detach(), hard=True
) for j in adv_agent_indx
]
actions_tom0 += [
gumbel_softmax(
self.mle_base[0].to(self.device)(
torch.cat((observations[j][:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],
actions_pre[j][:, :5]), 1).to(self.device)).detach(), hard=True
) for j in good_agent_indx
]
'''
tom1
'''
actions_tom1 = []
for agent_i, obs in enumerate(observations):
obs_ = observations_.copy()
acs_other = actions_tom0.copy()
other_agent_types = other_agent_types_.copy()
obs_.pop(agent_i)
acs_other.pop(agent_i)
other_agent_types.pop(agent_i)
if agent_i in adv_agent_indx:
actions_tom1.append(
gumbel_softmax(
self.tom_PHI[1].to(self.device)(
torch.cat((obs[:, -(self.num_good_agents * 2 + self.num_lm * 2 +
(self.nagents - 1) * 2):].to(self.device), torch.cat(acs_other, 1)), 1)), hard=True
).cpu()
)
elif agent_i in good_agent_indx:
actions_tom1.append(
gumbel_softmax(
self.tom_PHI[0].to(self.device)(
torch.cat((obs[:, -(self.num_good_agents * 2 + self.num_lm * 2 +
(self.nagents - 1) * 2):].to(self.device), torch.cat(acs_other, 1)), 1)), hard=True
).cpu()
)
observations = self._get_obs(observations, actions_tom1)
return [a.step(obs, explore=explore) for a, obs in zip(self.agents,
observations)]
def _get_obs(self, observations, action_tom):
observations_ = []
other_actions_tom_ = action_tom.copy()
for agent_i, obs in enumerate(observations):
other_action_tom = other_actions_tom_.copy()
other_action_tom.pop(agent_i)
actions = other_action_tom
observations_.append(torch.cat((obs, torch.cat(actions, 1)), 1))
return observations_
def train_tom0(self, sample, agent_i):
acs_pre, obs, acs, rews, next_obs, dones = sample
adv_agent_indx = self._get_index1(self.agent_types, 'adversary')
good_agent_indx = self._get_index1(self.agent_types, 'agent')
# self.agent_types[tom_agent_indx]
'''
data
for with_tom
for without_tom
'''
if adv_agent_indx != []:
adv_input = torch.cat([torch.cat((obs[i], acs_pre[i][:, :5]), 1) for i in adv_agent_indx])
label_adv_output = torch.cat([acs[i][:, :5] for i in adv_agent_indx])
self.mle_base[1].zero_grad()
adv_output = self.mle_base[1](adv_input[:,
-(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2 + 5):])
loss_adv = F.mse_loss(adv_output.float(), label_adv_output.float())
loss_adv.backward(retain_graph=True)
torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)
self.mle_opts[1].step()
if good_agent_indx != []:
good_input = torch.cat([torch.cat((obs[i], acs_pre[i][:, :5]), 1) for i in good_agent_indx])
label_good_output = torch.cat([acs[i][:, :5] for i in good_agent_indx])
self.mle_base[0].zero_grad()
good_output = self.mle_base[0](good_input[:,
-(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2 + 5):])
loss_good = F.mse_loss(good_output.float(), label_good_output.float())
loss_good.backward(retain_graph=True)
torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)
self.mle_opts[0].step()
'''
Only train agents with ToM (adversarys)
'''
'''
adv-adv
adv-good
'''
def tom1_infer_other(self, sample):
acs_pre, obs, acs, rews, next_obs, dones = sample
self.actions_tom1 = []
self.next_actions_tom1 = []
actions_tom1 = []
next_actions_tom1 = []
other_actions_tom0_ = self.actions_tom0.copy()
other_next_actions_tom0_ = self.next_actions_tom0.copy()
adv_agent_indx = self._get_index1(self.agent_types, 'adversary')
good_agent_indx = self._get_index1(self.agent_types, 'agent')
good_in = []
adv_in =[]
good_in_next = []
adv_in_next =[]
for agent_i, (obs_i, next_obs_i) in enumerate(zip(obs, next_obs)):
other_action_tom0 = other_actions_tom0_.copy()
other_next_actions_tom0 = other_next_actions_tom0_.copy()
other_action_tom0.pop(agent_i)
other_next_actions_tom0.pop(agent_i)
if agent_i in adv_agent_indx:
adv_in.append(torch.cat(
(obs_i[:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],
torch.cat(other_action_tom0, 1)), 1))
adv_in_next.append(torch.cat(
(next_obs_i[:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],
torch.cat(other_next_actions_tom0, 1)), 1))
elif agent_i in good_agent_indx:
good_in.append(torch.cat(
(obs_i[:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],
torch.cat(other_action_tom0, 1)), 1))
good_in_next.append(torch.cat(
(next_obs_i[:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],
torch.cat(other_next_actions_tom0, 1)), 1))
if adv_agent_indx != []:
adv_in = torch.cat(adv_in, 0)
adv_in_next = torch.cat(adv_in_next, 0)
actions_tom1.append(gumbel_softmax(
self.tom_PHI[1].to(self.device)(
adv_in), hard=True
))
next_actions_tom1.append(gumbel_softmax(
self.tom_PHI[1].to(self.device)(
adv_in_next), hard=True
)) # adv
label_adv_output = torch.cat([self.actions_tom0[i] for i in adv_agent_indx]).detach()
# label_adv_output = torch.cat([acs[i][:, :5] for i in adv_agent_indx])
adv_output = actions_tom1[0]
loss_adv = F.mse_loss(adv_output.float(), label_adv_output.float())
loss_adv.backward(retain_graph=True)
torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)
self.PHI_opts[1].step()
if good_agent_indx != []:
good_in = torch.cat(good_in, 0)
good_in_next = torch.cat(good_in_next, 0)
actions_tom1.append(gumbel_softmax(
self.tom_PHI[0].to(self.device)(
good_in), hard=True
))
next_actions_tom1.append(gumbel_softmax(
self.tom_PHI[0].to(self.device)(
good_in_next), hard=True
)) # agent
label_good_output = torch.cat([self.actions_tom0[i] for i in good_agent_indx]).detach()
# label_good_output = torch.cat([acs[i][:, :5] for i in good_agent_indx])
if self.config.env_id == 'simple_spread' or self.config.env_id == 'hetero_spread':
good_output = actions_tom1[0]
else:
good_output = actions_tom1[1]
loss_good = F.mse_loss(good_output.float(), label_good_output.float())
loss_good.backward(retain_graph=True)
torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)
self.PHI_opts[0].step()
actions_tom1 = torch.cat(actions_tom1)#.detach()
next_actions_tom1 = torch.cat(next_actions_tom1).detach()
for i in range(self.nagents):
self.actions_tom1.append(actions_tom1[i*self.config.batch_size:(i+1)*self.config.batch_size, :])
self.next_actions_tom1.append(next_actions_tom1[i*self.config.batch_size:(i+1)*self.config.batch_size, :])
# print(self.actions_tom1)
def tom0_output(self, sample):
acs_pre, obs, acs, rews, next_obs, dones = sample
self.actions_tom0 = []
self.next_actions_tom0 = []
adv_indx = self._get_index1(self.agent_types, 'adversary')
good_indx = self._get_index1(self.agent_types, 'agent')
self.actions_tom0 += [
gumbel_softmax(
self.mle_base[1].to(self.device)(
torch.cat((obs[j][:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],
acs_pre[j][:, :5]), 1)).detach(), hard=True
) for j in adv_indx
]
self.actions_tom0 += [
gumbel_softmax(
self.mle_base[0].to(self.device)(
torch.cat((obs[j][:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],
acs_pre[j][:, :5]), 1)).detach(), hard=True
) for j in good_indx
]
self.next_actions_tom0 += [
gumbel_softmax(
self.mle_base[1].to(self.device)(
torch.cat((next_obs[j][:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],
acs[j][:, :5]), 1)).detach(), hard=True
) for j in adv_indx
]
self.next_actions_tom0 += [
gumbel_softmax(
self.mle_base[0].to(self.device)(
torch.cat((next_obs[j][:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],
acs[j][:, :5]), 1)).detach(), hard=True
) for j in good_indx
]
# actions_tom0 = torch.cat(actions_tom0, 1)
# return actions_tom0, next_actions_tom0
def update(self, sample, agent_i, parallel=False, logger=None, sample_r=None):
# other_alg_types = self.alg_types.copy()
other_agent_types = self.agent_types.copy()
# other_alg_types.pop(agent_i)
other_agent_types.pop(agent_i)
adv_indx = self._get_index1(other_agent_types, 'adversary')
good_indx = self._get_index1(other_agent_types, 'agent')
agent_i_alg = self.alg_types[agent_i]
acs_pre, obs, acs, rews, next_obs, dones = sample
next_obs_ = self._get_obs(next_obs, self.next_actions_tom1)
obs_ = self._get_obs(obs, self.actions_tom1)
curr_agent = self.agents[agent_i]
'''
distance between self and other
'''
Euclidean_D = []
for i in range(len(other_agent_types)):
Euclidean_D.append(obs[agent_i][:,
-len(other_agent_types)*2:][:, i: i+2].pow(2).sum(1).sqrt())
Euclidean_D_ = torch.stack(Euclidean_D, 1)
'''
distance between self and landmark
'''
Euclidean_L = []
for i in range(self.num_lm):
Euclidean_L.append(obs[agent_i][:, -len(other_agent_types)*2-self.num_lm*2:
-len(other_agent_types)*2][:, i: i+2].pow(2).sum(1).sqrt())
Euclidean_L = torch.stack(Euclidean_L, 1)
close_agent_index = (Euclidean_D_ == Euclidean_D_.min(dim=1, keepdim=True)[0])\
.to(dtype=torch.int32) #run11/run12 self-orgnization
# close_agent_index = torch.ones((self.config.batch_size, len(other_agent_types))) \
# .to(dtype=torch.int32).to(self.device) #run13
if agent_i == 0:
self.train_tom0(sample, agent_i)
E_action = self.tom1.tom1_output(agent_i, adv_indx,
good_indx, obs[agent_i], acs_pre[agent_i])
if agent_i_alg == 'with_tom':
acs_other = acs.copy()
acs_other.pop(agent_i)
# KL loss
# adv_loss = sum([KL_criterion(E_action[j], acs_other[j][:, :5].float()) for j in adv_indx]) #TODO
# good_loss = sum([KL_criterion(E_action[j], acs_other[j][:, :5].float()) for j in good_indx])
# L2 loss
action_loss = torch.norm(acs[agent_i][:, :5] - E_action[0], p=2, dim=1)
loss_other = 0.1 * torch.stack([action_loss]*len(other_agent_types), 1)
if agent_i in adv_indx:
'''
adv_loss : decrease
good_loss : increase
'''
close_agent_index[:, good_indx] *= -1
close_agent_index[:, adv_indx] *= 0.1
intri_rew = close_agent_index.mul(loss_other).mul(Euclidean_D_).sum(1)
# intri_rew = close_agent_index.mul(loss_other).sum(1)
else:
'''
adv_loss : increase
good_loss : decrease
'''
close_agent_index[:, adv_indx] *= 1
close_agent_index[:, good_indx] *= 0.1
# if self.config.env_id == 'simple_adversary':
# intri_rew = close_agent_index.mul(loss_other).mul(Euclidean_D_).sum(1) - \
# obs[agent_i][:, :2].pow(2).sum(1)
# elif self.config.env_id == 'simple_spread_pre':
# intri_rew = close_agent_index.mul(loss_other).mul(Euclidean_D_).sum(1) - \
# Euclidean_L.min(dim=1, keepdim=True)[0][:,0]
# else:
intri_rew = close_agent_index.mul(loss_other).mul(Euclidean_D_).sum(1)
# intri_rew = close_agent_index.mul(loss_other).sum(1)
rews[agent_i] = rews[agent_i] + intri_rew.detach()
# center critic
curr_agent.critic_optimizer.zero_grad()
all_trgt_acs = []
if self.discrete_action: # one-hot encode action
all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in
zip(self.target_policies, next_obs_)]
trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)
target_value = (rews[agent_i].view(-1, 1) + self.gamma *
curr_agent.target_critic(trgt_vf_in) *
(1 - dones[agent_i].view(-1, 1)))
vf_in = torch.cat((*obs, *acs), dim=1)
actual_value = curr_agent.critic(vf_in)
vf_loss = MSELoss(actual_value, target_value.detach())
vf_loss.backward()
if parallel:
average_gradients(curr_agent.critic)
torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)
curr_agent.critic_optimizer.step()
curr_agent.policy_optimizer.zero_grad()
if self.discrete_action:
# Forward pass as if onehot (hard=True) but backprop through a differentiable
# Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
# through discrete categorical samples, but I'm not sure if that is
# correct since it removes the assumption of a deterministic policy for
# DDPG. Regardless, discrete policies don't seem to learn properly without it.
curr_pol_out = curr_agent.policy(obs_[agent_i].detach())
curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
else:
curr_pol_out = curr_agent.policy(obs[agent_i])
curr_pol_vf_in = curr_pol_out
all_pol_acs = []
for i, pi, ob in zip(range(self.nagents), self.policies, obs_):
ob = ob.detach()
if i == agent_i:
all_pol_acs.append(curr_pol_vf_in)
elif self.discrete_action:
all_pol_acs.append(onehot_from_logits(pi(ob)))
else:
all_pol_acs.append(pi(ob))
vf_in = torch.cat((*obs, *all_pol_acs), dim=1)
pol_loss = -curr_agent.critic(vf_in).mean()
pol_loss += (curr_pol_out ** 2).mean() * 1e-3
pol_loss.backward()
if parallel:
average_gradients(curr_agent.policy)
torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5)
# actor
curr_agent.policy_optimizer.step()
# print('c_loss:',vf_loss, 'p_loss:', pol_loss)
if logger is not None:
logger.add_scalars('agent%i/losses' % agent_i,
{'vf_loss': vf_loss,
'pol_loss': pol_loss},
self.niter)
def update_all_targets(self):
"""
Update all target networks (called after normal updates have been
performed for each agent)
"""
for a in self.agents:
soft_update(a.target_critic, a.critic, self.tau)
soft_update(a.target_policy, a.policy, self.tau)
self.niter += 1
def prep_training(self, device='gpu'):
# for i in self.tom_base.values():
# for mle in i.values():
# mle.train()
for mle in self.mle_base:
mle.train()
for a in self.agents:
a.policy.train()
a.critic.train()
a.target_policy.train()
a.target_critic.train()
for mle_i in a.mle:
mle_i.train()
if device == 'gpu':
fn = lambda x: x.to(torch.device(self.device))
else:
fn = lambda x: x.cpu()
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
if not self.critic_dev == device:
for a in self.agents:
a.critic = fn(a.critic)
self.critic_dev = device
if not self.trgt_pol_dev == device:
for a in self.agents:
a.target_policy = fn(a.target_policy)
self.trgt_pol_dev = device
if not self.trgt_critic_dev == device:
for a in self.agents:
a.target_critic = fn(a.target_critic)
self.trgt_critic_dev = device
if not self.mle_dev == device:
# for i in self.tom_base.keys():
# for j in self.tom_base[i].keys():
# self.tom_base[i][j] = fn(mle)
for i, mle in enumerate(self.mle_base):
self.mle_base[i] = fn(mle)
for a in self.agents:
for i, mle_i in enumerate(a.mle):
a.mle[i] = fn(mle_i)
self.mle_dev = device
def prep_rollouts(self, device='cpu'):
for a in self.agents:
a.policy.eval()
if device == 'gpu':
fn = lambda x: x.to(torch.device(self.device))
else:
fn = lambda x: x.cpu()
# only need main policy for rollouts
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
def save(self, filename):
"""
Save trained parameters of all agents into one file
"""
self.prep_training(device='cpu') # move parameters to CPU before saving
save_dict = {'init_dict': self.init_dict,
'agent_params': [a.get_params() for a in self.agents],
'tom_params': [self.get_params()],}
torch.save(save_dict, filename)
@classmethod
def init_from_env(cls, env, config, device, agent_alg, adversary_alg,
gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64, output_style='sum'):
"""
Instantiate instance of this class from multi-agent environment
"""
agent_init_params = []
alg_types = [adversary_alg if atype == 'adversary' else agent_alg for
atype in env.agent_types]
num_lm = env.num_lm
for acsp, obsp, algtype in zip(env.action_space, env.observation_space,
alg_types):
num_in_pol = obsp.shape[0]
num_in_mle = obsp.shape[0]
if isinstance(acsp, Box):
discrete_action = False
get_shape = lambda x: x.shape[0]
elif isinstance(acsp, Discrete): # Discrete
discrete_action = True
get_shape = lambda x: x.n
elif isinstance(acsp, MultiDiscrete):
discrete_action = True
get_shape = lambda x: sum(x.high - x.low + 1)
num_out_pol = get_shape(acsp)
# if algtype == "with_tom":
num_in_critic = 0
num_in_pol += (len(env.agent_types)-1) * 5
for oobsp in env.observation_space:
num_in_critic += oobsp.shape[0]
for oacsp in env.action_space:
if isinstance(oacsp, Box):
discrete_action = False
get_shape = lambda x: x.shape[0]
elif isinstance(oacsp, Discrete): # Discrete
discrete_action = True
get_shape = lambda x: x.n
elif isinstance(oacsp, MultiDiscrete):
discrete_action = True
get_shape = lambda x: sum(x.high - x.low + 1)
num_in_critic += get_shape(oacsp)
# else:
# num_in_critic = obsp.shape[0] + get_shape(acsp)
agent_init_params.append({'num_in_pol': num_in_pol,
'num_out_pol': num_out_pol,
'num_in_critic': num_in_critic,
'num_in_mle': num_in_mle,})
init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr,
'device': device,
'config' : config,
'hidden_dim': hidden_dim,
'alg_types': alg_types,
'agent_types' : env.agent_types,
'num_lm' : num_lm,
'agent_init_params': agent_init_params,
'discrete_action': discrete_action,
'output_style': output_style}
instance = cls(**init_dict)
instance.init_dict = init_dict
return instance
@classmethod
def init_from_save(cls, filename):
"""
Instantiate instance of this class from file created by 'save' method
"""
save_dict = torch.load(filename)
instance = cls(**save_dict['init_dict'])
instance.init_dict = save_dict['init_dict']
for a, params in zip(instance.agents, save_dict['agent_params']):
a.load_params(params)
for a, params in zip([instance], save_dict['tom_params']):
a.load_params(params)
return instance
def get_params(self):
params = {
}
for i in range(len(self.mle_base)):
params['mle%d'%i] = self.mle_base[i].state_dict()
params['mle_optimizer%d'%i] = self.mle_opts[i].state_dict()
params['tom_phi%d'%i] = self.tom_PHI[i].state_dict()
params['phi_opt%d' % i] = self.PHI_opts[i].state_dict()
return params
def load_params(self, params):
for i in range(len(self.mle_base)):
self.mle_base[i].load_state_dict(params['mle%d'%i])
self.mle_opts[i].load_state_dict(params['mle_optimizer%d'%i])
self.tom_PHI[i].load_state_dict(params['tom_phi%d'%i])
self.PHI_opts[i].load_state_dict(params['phi_opt%d' % i])
================================================
FILE: examples/Social_Cognition/FOToM/common/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/FOToM/common/distributions.py
================================================
# import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_eager_execution()
import numpy as np
import maddpg.common.tf_util as U
from tensorflow.python.ops import math_ops
from multiagent.multi_discrete import MultiDiscrete
from tensorflow.python.ops import nn
class Pd(object):
"""
A particular probability distribution
"""
def flatparam(self):
raise NotImplementedError
def mode(self):
raise NotImplementedError
def logp(self, x):
raise NotImplementedError
def kl(self, other):
raise NotImplementedError
def entropy(self):
raise NotImplementedError
def sample(self):
raise NotImplementedError
class PdType(object):
"""
Parametrized family of probability distributions
"""
def pdclass(self):
raise NotImplementedError
def pdfromflat(self, flat):
return self.pdclass()(flat)
def param_shape(self):
raise NotImplementedError
def sample_shape(self):
raise NotImplementedError
def sample_dtype(self):
raise NotImplementedError
def param_placeholder(self, prepend_shape, name=None):
return tf.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name)
def sample_placeholder(self, prepend_shape, name=None):
return tf.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name)
class CategoricalPdType(PdType):
def __init__(self, ncat):
self.ncat = ncat
def pdclass(self):
return CategoricalPd
def param_shape(self):
return [self.ncat]
def sample_shape(self):
return []
def sample_dtype(self):
return tf.int32
class SoftCategoricalPdType(PdType):
def __init__(self, ncat):
self.ncat = ncat
def pdclass(self):
return SoftCategoricalPd
def param_shape(self):
return [self.ncat]
def sample_shape(self):
return [self.ncat]
def sample_dtype(self):
return tf.float32
class MultiCategoricalPdType(PdType):
def __init__(self, low, high):
self.low = low
self.high = high
self.ncats = high - low + 1
def pdclass(self):
return MultiCategoricalPd
def pdfromflat(self, flat):
return MultiCategoricalPd(self.low, self.high, flat)
def param_shape(self):
return [sum(self.ncats)]
def sample_shape(self):
return [len(self.ncats)]
def sample_dtype(self):
return tf.int32
class SoftMultiCategoricalPdType(PdType):
def __init__(self, low, high):
self.low = low
self.high = high
self.ncats = high - low + 1
def pdclass(self):
return SoftMultiCategoricalPd
def pdfromflat(self, flat):
return SoftMultiCategoricalPd(self.low, self.high, flat)
def param_shape(self):
return [sum(self.ncats)]
def sample_shape(self):
return [sum(self.ncats)]
def sample_dtype(self):
return tf.float32
class DiagGaussianPdType(PdType):
def __init__(self, size):
self.size = size
def pdclass(self):
return DiagGaussianPd
def param_shape(self):
return [2*self.size]
def sample_shape(self):
return [self.size]
def sample_dtype(self):
return tf.float32
class BernoulliPdType(PdType):
def __init__(self, size):
self.size = size
def pdclass(self):
return BernoulliPd
def param_shape(self):
return [self.size]
def sample_shape(self):
return [self.size]
def sample_dtype(self):
return tf.int32
# WRONG SECOND DERIVATIVES
# class CategoricalPd(Pd):
# def __init__(self, logits):
# self.logits = logits
# self.ps = tf.nn.softmax(logits)
# @classmethod
# def fromflat(cls, flat):
# return cls(flat)
# def flatparam(self):
# return self.logits
# def mode(self):
# return U.argmax(self.logits, axis=1)
# def logp(self, x):
# return -tf.nn.sparse_softmax_cross_entropy_with_logits(self.logits, x)
# def kl(self, other):
# return tf.nn.softmax_cross_entropy_with_logits(other.logits, self.ps) \
# - tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
# def entropy(self):
# return tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
# def sample(self):
# u = tf.random_uniform(tf.shape(self.logits))
# return U.argmax(self.logits - tf.log(-tf.log(u)), axis=1)
class CategoricalPd(Pd):
def __init__(self, logits):
self.logits = logits
def flatparam(self):
return self.logits
def mode(self):
return U.argmax(self.logits, axis=1)
def logp(self, x):
return -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
def kl(self, other):
a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)
a1 = other.logits - U.max(other.logits, axis=1, keepdims=True)
ea0 = tf.exp(a0)
ea1 = tf.exp(a1)
z0 = U.sum(ea0, axis=1, keepdims=True)
z1 = U.sum(ea1, axis=1, keepdims=True)
p0 = ea0 / z0
return U.sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=1)
def entropy(self):
a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)
ea0 = tf.exp(a0)
z0 = U.sum(ea0, axis=1, keepdims=True)
p0 = ea0 / z0
return U.sum(p0 * (tf.log(z0) - a0), axis=1)
def sample(self):
u = tf.random_uniform(tf.shape(self.logits))
return U.argmax(self.logits - tf.log(-tf.log(u)), axis=1)
@classmethod
def fromflat(cls, flat):
return cls(flat)
class SoftCategoricalPd(Pd):
def __init__(self, logits):
self.logits = logits
def flatparam(self):
return self.logits
def mode(self):
return U.softmax(self.logits, axis=-1)
def logp(self, x):
return -tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
def kl(self, other):
a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)
a1 = other.logits - U.max(other.logits, axis=1, keepdims=True)
ea0 = tf.exp(a0)
ea1 = tf.exp(a1)
z0 = U.sum(ea0, axis=1, keepdims=True)
z1 = U.sum(ea1, axis=1, keepdims=True)
p0 = ea0 / z0
return U.sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=1)
def entropy(self):
a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)
ea0 = tf.exp(a0)
z0 = U.sum(ea0, axis=1, keepdims=True)
p0 = ea0 / z0
return U.sum(p0 * (tf.log(z0) - a0), axis=1)
def sample(self):
u = tf.random_uniform(tf.shape(self.logits))
return U.softmax(self.logits - tf.log(-tf.log(u)), axis=-1)
@classmethod
def fromflat(cls, flat):
return cls(flat)
class MultiCategoricalPd(Pd):
def __init__(self, low, high, flat):
self.flat = flat
self.low = tf.constant(low, dtype=tf.int32)
self.categoricals = list(map(CategoricalPd, tf.split(flat, high - low + 1, axis=len(flat.get_shape()) - 1)))
def flatparam(self):
return self.flat
def mode(self):
return self.low + tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32)
def logp(self, x):
return tf.add_n([p.logp(px) for p, px in zip(self.categoricals, tf.unstack(x - self.low, axis=len(x.get_shape()) - 1))])
def kl(self, other):
return tf.add_n([
p.kl(q) for p, q in zip(self.categoricals, other.categoricals)
])
def entropy(self):
return tf.add_n([p.entropy() for p in self.categoricals])
def sample(self):
return self.low + tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32)
@classmethod
def fromflat(cls, flat):
return cls(flat)
class SoftMultiCategoricalPd(Pd): # doesn't work yet
def __init__(self, low, high, flat):
self.flat = flat
self.low = tf.constant(low, dtype=tf.float32)
self.categoricals = list(map(SoftCategoricalPd, tf.split(flat, high - low + 1, axis=len(flat.get_shape()) - 1)))
def flatparam(self):
return self.flat
def mode(self):
x = []
for i in range(len(self.categoricals)):
x.append(self.low[i] + self.categoricals[i].mode())
return tf.concat(x, axis=-1)
def logp(self, x):
return tf.add_n([p.logp(px) for p, px in zip(self.categoricals, tf.unstack(x - self.low, axis=len(x.get_shape()) - 1))])
def kl(self, other):
return tf.add_n([
p.kl(q) for p, q in zip(self.categoricals, other.categoricals)
])
def entropy(self):
return tf.add_n([p.entropy() for p in self.categoricals])
def sample(self):
x = []
for i in range(len(self.categoricals)):
x.append(self.low[i] + self.categoricals[i].sample())
return tf.concat(x, axis=-1)
@classmethod
def fromflat(cls, flat):
return cls(flat)
class DiagGaussianPd(Pd):
def __init__(self, flat):
self.flat = flat
mean, logstd = tf.split(axis=1, num_or_size_splits=2, value=flat)
self.mean = mean
self.logstd = logstd
self.std = tf.exp(logstd)
def flatparam(self):
return self.flat
def mode(self):
return self.mean
def logp(self, x):
return - 0.5 * U.sum(tf.square((x - self.mean) / self.std), axis=1) \
- 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) \
- U.sum(self.logstd, axis=1)
def kl(self, other):
assert isinstance(other, DiagGaussianPd)
return U.sum(other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, axis=1)
def entropy(self):
return U.sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), 1)
def sample(self):
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
@classmethod
def fromflat(cls, flat):
return cls(flat)
class BernoulliPd(Pd):
def __init__(self, logits):
self.logits = logits
self.ps = tf.sigmoid(logits)
def flatparam(self):
return self.logits
def mode(self):
return tf.round(self.ps)
def logp(self, x):
return - U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.to_float(x)), axis=1)
def kl(self, other):
return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=1) - U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=1)
def entropy(self):
return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=1)
def sample(self):
p = tf.sigmoid(self.logits)
u = tf.random_uniform(tf.shape(p))
return tf.to_float(math_ops.less(u, p))
@classmethod
def fromflat(cls, flat):
return cls(flat)
def make_pdtype(ac_space):
from gym import spaces
if isinstance(ac_space, spaces.Box):
assert len(ac_space.shape) == 1
return DiagGaussianPdType(ac_space.shape[0])
elif isinstance(ac_space, spaces.Discrete):
# return CategoricalPdType(ac_space.n)
return SoftCategoricalPdType(ac_space.n)
elif isinstance(ac_space, MultiDiscrete):
#return MultiCategoricalPdType(ac_space.low, ac_space.high)
return SoftMultiCategoricalPdType(ac_space.low, ac_space.high)
elif isinstance(ac_space, spaces.MultiBinary):
return BernoulliPdType(ac_space.n)
else:
raise NotImplementedError
def shape_el(v, i):
maybe = v.get_shape()[i]
if maybe is not None:
return maybe
else:
return tf.shape(v)[i]
================================================
FILE: examples/Social_Cognition/FOToM/common/tile_images.py
================================================
import numpy as np
def tile_images(img_nhwc):
"""
Tile N images into one big PxQ image
(P,Q) are chosen to be as close as possible, and if N
is square, then P=Q.
input: img_nhwc, list or array of images, ndim=4 once turned into array
n = batch index, h = height, w = width, c = channel
returns:
bigim_HWc, ndarray with ndim=3
"""
img_nhwc = np.asarray(img_nhwc)
N, h, w, c = img_nhwc.shape
H = int(np.ceil(np.sqrt(N)))
W = int(np.ceil(float(N)/H))
img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)])
img_HWhwc = img_nhwc.reshape(H, W, h, w, c)
img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)
img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c)
return img_Hh_Ww_c
================================================
FILE: examples/Social_Cognition/FOToM/common/vec_env/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/FOToM/common/vec_env/vec_env.py
================================================
import contextlib
import os
from abc import ABC, abstractmethod
from common.tile_images import tile_images
class AlreadySteppingError(Exception):
"""
Raised when an asynchronous step is running while
step_async() is called again.
"""
def __init__(self):
msg = 'already running an async step'
Exception.__init__(self, msg)
class NotSteppingError(Exception):
"""
Raised when an asynchronous step is not running but
step_wait() is called.
"""
def __init__(self):
msg = 'not running an async step'
Exception.__init__(self, msg)
class VecEnv(ABC):
"""
An abstract asynchronous, vectorized environment.
Used to batch data from multiple copies of an environment, so that
each observation becomes an batch of observations, and expected action is a batch of actions to
be applied per-environment.
"""
closed = False
viewer = None
metadata = {
'render.modes': ['human', 'rgb_array']
}
def __init__(self, num_envs, observation_space, action_space):
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
@abstractmethod
def reset(self):
"""
Reset all the environments and return an array of
observations, or a dict of observation arrays.
If step_async is still doing work, that work will
be cancelled and step_wait() should not be called
until step_async() is invoked again.
"""
pass
@abstractmethod
def step_async(self, actions):
"""
Tell all the environments to start taking a step
with the given actions.
Call step_wait() to get the results of the step.
You should not call this if a step_async run is
already pending.
"""
pass
@abstractmethod
def step_wait(self):
"""
Wait for the step taken with step_async().
Returns (obs, rews, dones, infos):
- obs: an array of observations, or a dict of
arrays of observations.
- rews: an array of rewards
- dones: an array of "episode done" booleans
- infos: a sequence of info objects
"""
pass
def close_extras(self):
"""
Clean up the extra resources, beyond what's in this base class.
Only runs when not self.closed.
"""
pass
def close(self):
if self.closed:
return
if self.viewer is not None:
self.viewer.close()
self.close_extras()
self.closed = True
def step(self, actions):
"""
Step the environments synchronously.
This is available for backwards compatibility.
"""
self.step_async(actions)
return self.step_wait()
def render(self, mode='human'):
imgs = self.get_images()
bigimg = tile_images(imgs)
if mode == 'human':
self.get_viewer().imshow(bigimg)
return self.get_viewer().isopen
elif mode == 'rgb_array':
return bigimg
else:
raise NotImplementedError
def get_images(self):
"""
Return RGB images from each environment
"""
raise NotImplementedError
@property
def unwrapped(self):
if isinstance(self, VecEnvWrapper):
return self.venv.unwrapped
else:
return self
def get_viewer(self):
if self.viewer is None:
from gym.envs.classic_control import rendering
self.viewer = rendering.SimpleImageViewer()
return self.viewer
class VecEnvWrapper(VecEnv):
"""
An environment wrapper that applies to an entire batch
of environments at once.
"""
def __init__(self, venv, observation_space=None, action_space=None):
self.venv = venv
super().__init__(num_envs=venv.num_envs,
observation_space=observation_space or venv.observation_space,
action_space=action_space or venv.action_space)
def step_async(self, actions):
self.venv.step_async(actions)
@abstractmethod
def reset(self):
pass
@abstractmethod
def step_wait(self):
pass
def close(self):
return self.venv.close()
def render(self, mode='human'):
return self.venv.render(mode=mode)
def get_images(self):
return self.venv.get_images()
def __getattr__(self, name):
if name.startswith('_'):
raise AttributeError("attempted to get missing private attribute '{}'".format(name))
return getattr(self.venv, name)
class VecEnvObservationWrapper(VecEnvWrapper):
@abstractmethod
def process(self, obs):
pass
def reset(self):
obs = self.venv.reset()
return self.process(obs)
def step_wait(self):
obs, rews, dones, infos = self.venv.step_wait()
return self.process(obs), rews, dones, infos
class CloudpickleWrapper(object):
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
"""
def __init__(self, x):
self.x = x
def __getstate__(self):
import cloudpickle
return cloudpickle.dumps(self.x)
def __setstate__(self, ob):
import pickle
self.x = pickle.loads(ob)
@contextlib.contextmanager
def clear_mpi_env_vars():
"""
from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang.
This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing
Processes.
"""
removed_environment = {}
for k, v in list(os.environ.items()):
for prefix in ['OMPI_', 'PMI_']:
if k.startswith(prefix):
removed_environment[k] = v
del os.environ[k]
try:
yield
finally:
os.environ.update(removed_environment)
================================================
FILE: examples/Social_Cognition/FOToM/evaluate.py
================================================
import argparse
import torch
import time
import imageio
import numpy as np
from pathlib import Path
from torch.autograd import Variable
from utils.make_env import make_env
from algorithms.tom11 import ToM_decision11
from algorithms.maddpg import MADDPG
import matplotlib.pyplot as plt
from tqdm import tqdm
from utils.env_wrappers import SubprocVecEnv, DummyVecEnv
def display_frames_as_gif(frames):
patch = plt.imshow(frames[1])
plt.axis('off')
plt.savefig('./images/comm2', bbox_inches='tight')
def make_parallel_env(env_id, n_rollout_threads, discrete_action, num_good_agents, num_adversaries):
def get_env_fn(rank):
def init_env():
env = make_env(env_id, num_good_agents=num_good_agents, num_adversaries=num_adversaries, discrete_action=discrete_action)
# env.seed(seed + rank * 1000)
# np.random.seed(seed + rank * 1000)
return env
return init_env
if n_rollout_threads == 1:
return DummyVecEnv([get_env_fn(0)])
else:
return SubprocVecEnv([get_env_fn(i) for i in range(n_rollout_threads)])
def run(config):
rew_ep = []
for i in range(config.num):
for run_num in config.run_num:
pbar = tqdm(config.n_episodes)
model_path = (Path('./models') / config.env_id / config.model_name /
('run%i' % (run_num)))
if config.incremental is not None:
model_path = model_path / 'incremental' / ('model_ep%i.pt' %
config.incremental)
else:
model_path = model_path / 'model.pt'
# if config.save_gifs:
# gif_path = model_path.parent / 'gifs'
# gif_path.mkdir(exist_ok=True)
if config.alg == 'ToM1':
maddpg = ToM_decision11.init_from_save(model_path)
elif config.alg == 'ToM_SB01' or config.alg== 'ToM_SA01':
maddpg = ToM_decision01.init_from_save(model_path)#.eval()
elif config.alg == 'ToM_SBN1' or config.alg== 'ToM_SAN1':
maddpg = ToM_decisionN1.init_from_save(model_path)#.eval()
elif config.alg == 'MADDPG':
maddpg = MADDPG.init_from_save(model_path)
# env = make_env(config.env_id, num_good_agents=config.num_good_agents,
# num_adversaries=config.num_adversaries, discrete_action=maddpg.discrete_action)
env = make_parallel_env(config.env_id, config.n_rollout_threads,
config.discrete_action, config.num_good_agents, config.num_adversaries)
maddpg.prep_rollouts(device='cpu')
ifi = 1 / config.fps # inter-frame interval
for ep_i in range(0, config.n_episodes, config.n_rollout_threads):
rew = np.zeros((config.n_rollout_threads, config.num_good_agents + config.num_adversaries))
torch_agent_actions = [torch.zeros((config.n_rollout_threads, 5))
for i in range(maddpg.nagents)]
# print("Episode %i of %i" % (ep_i + 1, config.n_episodes))
obs = env.reset()
for t_i in range(config.episode_length):
# calc_start = time.time()
# rearrange observations to be per agent, and convert to torch Variable
torch_obs = [Variable(torch.Tensor(np.vstack(obs[:, i])),
requires_grad=False)
for i in range(maddpg.nagents)]
# get actions as torch Variables
# t1 = time.time()
if config.alg == 'MADDPG':
torch_actions = maddpg.step(torch_obs, explore=False)
else:
torch_actions = maddpg.step(torch_obs, torch_agent_actions, explore=False)
actions = [ac.data.cpu().numpy() for ac in torch_actions]
actions = [[ac[i] for ac in actions] for i in range(config.n_rollout_threads)]
obs, rewards, dones, infos = env.step(actions)
rew += rewards
rew_ep.append(rew)
pbar.update(config.n_rollout_threads)
pbar.close()
rew_ep = np.concatenate(rew_ep, 0)
rew_ep_agent = rew_ep.mean(0)
std_ep_agent = rew_ep.std(0)
print('mean:', rew_ep_agent, 'std:', std_ep_agent)
rew_ep_good = rew_ep[:, -config.num_good_agents:].sum(1).mean()
rew_ep_adv = rew_ep[:, :config.num_adversaries].sum(1).mean()
std_ep_good = rew_ep[:, -config.num_good_agents:].sum(1).std()
std_ep_adv = rew_ep[:, :config.num_adversaries].sum(1).std()
print('good:', rew_ep_good, 'std:', std_ep_good)
print('adv:', rew_ep_adv, 'std:', std_ep_adv)
rew_ep_all = rew_ep.sum(1).mean()
std_ep_all = rew_ep.sum(1).std()
print('all:', rew_ep_all, 'std:', std_ep_all)
env.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--alg",
default="ToM_SAN1", type=str,
choices=['MADDPG', 'ToM_SB01', 'ToM_SA01', 'ToM_SBN1', 'ToM_SAN1',
'ToM1'])
parser.add_argument("--env_id", default='simple_adversary', type=str, help="Name of environment",
choices=['simple_tag', 'simple_world_comm', 'hetero_spread',
'simple_adversary', 'simple_spread'
])
parser.add_argument("--num_good_agents", default=None, type=int,
help="Num of Agent")
parser.add_argument("--num_adversaries", default=None, type=int,
help="Num of Adversary")
parser.add_argument("--model_name", default='4VS2_tomaAN1', type=str,
help="Name of model") #ma2c, maddpg, maddpg_rnn
parser.add_argument("--run_num", default=2, type=int, nargs='+')
parser.add_argument("--save_gifs", default=True, action="store_true",
help="Saves gif of each episode into model directory")
parser.add_argument("--incremental", default= None, type=int,
help="Load incremental policy from given episode " +
"rather than final policy")
parser.add_argument("--n_episodes", default=1000, type=int)
parser.add_argument("--episode_length", default=25, type=int)
parser.add_argument("--fps", default=30, type=int)
parser.add_argument("--eval",
default=True, type=bool,
)
parser.add_argument("--num", default=3, type=int )
parser.add_argument("--n_rollout_threads", default=20, type=int)
parser.add_argument("--discrete_action",
# default=False, type=bool,
action='store_true')
config = parser.parse_args()
if 'ToM_SB' in config.alg:
config.agent_alg = 'without_tom'
config.adversary_alg = 'with_tom'
elif 'ToM_SA' in config.alg:
config.agent_alg = 'with_tom'
config.adversary_alg = 'without_tom'
elif config.alg == 'ToM1':
config.agent_alg = 'with_tom'
config.adversary_alg = 'with_tom'
else:
config.agent_alg = config.alg
config.adversary_alg = config.alg
if config.num_good_agents == None and config.num_adversaries == None:
if config.env_id == 'simple_adversary':
config.num_good_agents = 2
config.num_adversaries = 1
elif config.env_id == 'simple_tag':
config.num_good_agents = 2
config.num_adversaries = 2
elif config.env_id == 'simple_world_comm':
config.num_good_agents = 2
config.num_adversaries = 4
elif config.env_id == 'simple_spread':
config.num_good_agents = 3
config.num_adversaries = 0
run(config)
================================================
FILE: examples/Social_Cognition/FOToM/main.py
================================================
import argparse
import torch
import time
import os
import numpy as np
from gym.spaces import Box, Discrete, MultiDiscrete
from pathlib import Path
from torch.autograd import Variable
from tensorboardX import SummaryWriter
from utils.make_env import make_env
from utils.buffer import ReplayBuffer, ReplayBuffer_pre
from utils.env_wrappers import SubprocVecEnv, DummyVecEnv
from algorithms.tomN1 import ToM_decisionN1
from algorithms.tom01 import ToM_decision01
from algorithms.tom11 import ToM_decision11
from algorithms.maddpg import MADDPG
from tqdm import tqdm
from thop import profile
from thop import clever_format
def get_common_args():
parser = argparse.ArgumentParser()
parser.add_argument("--env_id", default='simple_tag', type=str,
choices=['simple_tag', 'simple_adversary', 'hetero_spread', 'simple_world_comm',
'simple_spread'],
help="Name of environment")
parser.add_argument("--num_good_agents", default=None, type=int,
help="Num of Agent")
parser.add_argument("--num_adversaries", default=None, type=int,
help="Num of Adversary")
parser.add_argument("--model_name", default='ann', type=str,
help="Name of directory to store " +
"model/training contents") #ToM_SA
parser.add_argument("--seed",
default=1, type=int,
help="Random seed")
parser.add_argument("--cuda_num",
default=5, type=int,
help="device")
parser.add_argument("--output_style",
default='sum', type=str,
choices=['sum', 'voltage'])
parser.add_argument("--n_rollout_threads", default=20, type=int)
parser.add_argument("--n_training_threads", default=6, type=int)
parser.add_argument("--buffer_length", default=int(1e6), type=int) #1e6
parser.add_argument("--n_episodes", default=25000, type=int)#25000
parser.add_argument("--episode_length", default=25, type=int)
parser.add_argument("--steps_per_update", default=100, type=int)
parser.add_argument("--load_para", type=bool, default=False)
parser.add_argument("--batch_size",
default=1024, type=int,#4
help="Batch size for model training")
parser.add_argument("--n_exploration_eps", default=25000, type=int)
parser.add_argument("--init_noise_scale", default=0.3, type=float)
parser.add_argument("--final_noise_scale", default=0.0, type=float)
parser.add_argument("--save_interval", default=1000, type=int)#1000
parser.add_argument("--hidden_dim", default=64, type=int)
parser.add_argument("--lr", default=1e-3, type=float) #0.01 #1e-3
parser.add_argument("--tau", default=0.01, type=float)
parser.add_argument("--alg",
default="ToM1", type=str,
choices=['MADDPG',
'ToM1', 'ToM_SB01', 'ToM_SA01', 'ToM_SBN1', 'ToM_SAN1'])
parser.add_argument("--discrete_action",
# default=False, type=bool,
action='store_true')
args = parser.parse_args()
parser.add_argument('--device', type=str, default='cuda:{}'.format(args.cuda_num), help='whether to use the GPU') #'cuda:1'
parser = parser.parse_args()
return parser
USE_CUDA = torch.cuda.is_available()
def make_parallel_env(env_id, n_rollout_threads, seed, discrete_action, num_good_agents, num_adversaries):
def get_env_fn(rank):
def init_env():
env = make_env(env_id, num_good_agents=num_good_agents, num_adversaries=num_adversaries, discrete_action=discrete_action)
env.seed(seed + rank * 1000)
np.random.seed(seed + rank * 1000)
return env
return init_env
if n_rollout_threads == 1:
return DummyVecEnv([get_env_fn(0)])
else:
return SubprocVecEnv([get_env_fn(i) for i in range(n_rollout_threads)])
def run(config):
pbar = tqdm(config.n_episodes)
model_dir = Path('./models') / config.env_id / config.model_name
if not model_dir.exists():
curr_run = 'run1'
else:
exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in
model_dir.iterdir() if
str(folder.name).startswith('run')]
if len(exst_run_nums) == 0:
curr_run = 'run1'
else:
curr_run = 'run%i' % (max(exst_run_nums) + 1)
run_dir = model_dir / curr_run
log_dir = run_dir / 'logs'
os.makedirs(log_dir)
logger = SummaryWriter(str(log_dir))
save_path = log_dir
print(os.path.exists(save_path))
argsDict = config.__dict__
with open(save_path / 'args_{}'.format(max(exst_run_nums) + 1), 'w') as f:
f.writelines('------------------ start ------------------' + '\n')
for eachArg, value in argsDict.items():
f.writelines(eachArg + ' : ' + str(value) + '\n')
f.writelines('------------------- end -------------------')
torch.manual_seed(config.seed)
np.random.seed(config.seed)
if not USE_CUDA:
torch.set_num_threads(config.n_training_threads)
env = make_parallel_env(config.env_id, config.n_rollout_threads, config.seed,
config.discrete_action, config.num_good_agents, config.num_adversaries)
if config.alg == 'ToM1':
print('_______Alg: ' + config.alg + '_______')
if config.load_para == False:
maddpg = ToM_decision11.init_from_env(env, config, agent_alg=config.agent_alg,
adversary_alg=config.adversary_alg,
tau=config.tau,
lr=config.lr,
hidden_dim=config.hidden_dim,
output_style=config.output_style,
device=config.device)
else:
# model_path = (Path('./models') / config.env_id / 'self1' / #'self1' tag, maddpg_self_11 adv
# ('run%i' % 1)) / 'model.pt'
model_path = (Path('./models') / config.env_id / 'rs1' / #'self1' tag, maddpg_self_11 adv
('run%i' % 1)) / 'model.pt'
maddpg = ToM_decision11.init_from_save(model_path)
elif config.alg == 'ToM_SB01' or config.alg== 'ToM_SA01':
print('_______Alg: ' + config.alg + '_______')
if config.load_para == False:
maddpg = ToM_decision01.init_from_env(env, config, agent_alg=config.agent_alg,
adversary_alg=config.adversary_alg,
tau=config.tau,
lr=config.lr,
hidden_dim=config.hidden_dim,
output_style=config.output_style,
device=config.device)
else:
model_path = (Path('./models') / config.env_id / 'am' / #'self1' tag, maddpg_self_11 adv
('run%i' % 3)) / 'model.pt'
maddpg = ToM_decision01.init_from_save(model_path)
elif config.alg == 'ToM_SBN1' or config.alg== 'ToM_SAN1':
print('_______Alg: ' + config.alg + '_______')
if config.load_para == False:
maddpg = ToM_decisionN1.init_from_env(env, config, agent_alg=config.agent_alg,
adversary_alg=config.adversary_alg,
tau=config.tau,
lr=config.lr,
hidden_dim=config.hidden_dim,
output_style=config.output_style,
device=config.device)
else:
model_path = (Path('./models') / config.env_id / 'am' / #'self1' tag, maddpg_self_11 adv
('run%i' % 3)) / 'model.pt'
maddpg = ToM_decisionN1.init_from_save(model_path)
elif config.alg == 'MADDPG':
print('_______Alg: ' + config.alg + '_______')
if config.load_para == False:
maddpg = MADDPG.init_from_env(env, agent_alg=config.agent_alg,
adversary_alg=config.adversary_alg,
tau=config.tau,
lr=config.lr,
hidden_dim=config.hidden_dim,
device=config.device)
if config.alg == 'ToM1' or config.alg == 'ToM_SB01' or config.alg == 'ToM_SA01' \
or config.alg == 'ToM_SBN1' or config.alg == 'ToM_SAN1':
replay_buffer = ReplayBuffer_pre(config.buffer_length, maddpg.nagents,
[obsp.shape[0] for obsp in env.observation_space],
[acsp.n if isinstance(acsp, Discrete) else sum(acsp.high - acsp.low + 1)
for acsp in env.action_space],
device=config.device)
else:
replay_buffer = ReplayBuffer(config.buffer_length, maddpg.nagents,
[obsp.shape[0] for obsp in env.observation_space],
[acsp.n if isinstance(acsp, Discrete) else sum(acsp.high - acsp.low + 1)
for acsp in env.action_space],
device=config.device)
t = 0
total_reward = []
for agent_i in range(maddpg.nagents):
total_reward.append([])
for ep_i in range(0, config.n_episodes, config.n_rollout_threads):
# print("Episodes %i-%i of %i" % (ep_i + 1,
# ep_i + 1 + config.n_rollout_threads,
# config.n_episodes))
obs = env.reset()
# obs.shape = (n_rollout_threads, nagent)(nobs), nobs differs per agent so not tensor
maddpg.prep_rollouts(device='cpu')
torch_agent_actions = [torch.zeros((config.n_rollout_threads, 5)) for i in range(maddpg.nagents)]
explr_pct_remaining = max(0, config.n_exploration_eps - ep_i) / config.n_exploration_eps
maddpg.scale_noise(config.final_noise_scale + (config.init_noise_scale - config.final_noise_scale) * explr_pct_remaining)
if config.alg == 'rnnD' or config.alg == 'rnn':
maddpg._init_agent(config.n_rollout_threads)
maddpg.reset_noise()
obs_ep = []
agent_actions_ep = []
rewards_ep = []
next_obs_ep = []
dones_ep = []
for et_i in range(config.episode_length):
torch_agent_actions_pre = torch_agent_actions
torch_agent_actions_pre = [ac.data.numpy() for ac in torch_agent_actions_pre]
# rearrange observations to be per agent, and convert to torch Variable
torch_obs = [Variable(torch.Tensor(np.vstack(obs[:, i])),
requires_grad=False)
for i in range(maddpg.nagents)] #
# get actions as torch Variables
# t1 = time.time()
if config.alg == 'ToM1' or config.alg == 'ToM_SB01' or config.alg == 'ToM_SA01' \
or config.alg == 'ToM_SBN1' or config.alg == 'ToM_SAN1':
torch_agent_actions = maddpg.step(torch_obs, torch_agent_actions, explore=True)
else:
torch_agent_actions = maddpg.step(torch_obs, explore=True)
# t2 = time.time()
# print('maddpg.step time:', t2-t1)
agent_actions = [ac.data.numpy() for ac in torch_agent_actions] #
# rearrange actions to be per environment
actions = [[ac[i] for ac in agent_actions] for i in range(config.n_rollout_threads)]
# t3 = time.time()
next_obs, rewards, dones, infos = env.step(actions)
# t4 = time.time()
# print('env.step')
obs_ep.append(obs) #episode_id,process, n_agents, dim
agent_actions_ep.append(actions) #episode_id, n_agents, process, dim
rewards_ep.append(rewards) #episode_id,process, n_agents,
next_obs_ep.append(next_obs) #episode_id,process, n_agents, dim
dones_ep.append(dones) #episode_id,process, n_agents,
if config.alg == 'ToM1' or config.alg == 'ToM_SB01' or config.alg == 'ToM_SA01' \
or config.alg == 'ToM_SBN1' or config.alg == 'ToM_SAN1':
replay_buffer.push(torch_agent_actions_pre, obs, agent_actions, rewards, next_obs, dones)
else:
replay_buffer.push(obs, agent_actions, rewards, next_obs, dones)
obs = next_obs
t += config.n_rollout_threads
if (len(replay_buffer) >= config.batch_size and
(t % config.steps_per_update) < config.n_rollout_threads):
if USE_CUDA:
maddpg.prep_training(device='gpu')
else:
maddpg.prep_training(device='cpu')
if config.n_episodes >300:
rollout = 2
else:
rollout = config.n_rollout_threads
for u_i in range(rollout):
sample = replay_buffer.sample(config.batch_size,
to_gpu=USE_CUDA)
if '1' in config.alg:
maddpg.tom0_output(sample)
maddpg.tom1_infer_other(sample)
for a_i in range(maddpg.nagents):
# t1 = time.time()
maddpg.update(sample, a_i, logger=logger)
# t2 = time.time()
# print('trian_time:', t2-t1, u_i, a_i)
maddpg.update_all_targets()
maddpg.prep_rollouts(device='cpu')
if config.alg == 'rnnD' or config.alg == 'rnn':
maddpg._init_agent(config.n_rollout_threads)
ep_rews = replay_buffer.get_average_rewards(
config.episode_length * config.n_rollout_threads)
for a_i, a_ep_rew in enumerate(ep_rews):
logger.add_scalar('agent%i/mean_episode_rewards' % a_i,
a_ep_rew,
ep_i)
logger.add_scalar('agent_mean/mean_episode_rewards',
np.mean(ep_rews),
ep_i)
if ep_i % config.save_interval < config.n_rollout_threads:
os.makedirs(run_dir / 'incremental', exist_ok=True)
maddpg.save(run_dir / 'incremental' / ('model_ep%i.pt' % (ep_i + 1)))
maddpg.save(run_dir / 'model.pt')
pbar.update(config.n_rollout_threads)
pbar.close()
maddpg.save(run_dir / 'model.pt')
env.close()
logger.export_scalars_to_json(str(log_dir / 'summary.json'))
logger.close()
for a_i, reward in enumerate(total_reward):
reward_dir = str(log_dir) + '/agent{}/mean_episode_rewards'.format(a_i) + '/episode_rewards_{}'.format(config.cuda_num)
os.makedirs(reward_dir)
np.save(reward_dir, reward)
if __name__ == '__main__':
config = get_common_args()
# config.env_id = 'simple_world_comm'#'simple_adversary'#'simple_tag'
# # config.model_name = 'ma2c'
if 'ToM_SB' in config.alg:
config.agent_alg = 'without_tom'
config.adversary_alg = 'with_tom'
elif 'ToM_SA' in config.alg:
config.agent_alg = 'with_tom'
config.adversary_alg = 'without_tom'
elif config.alg == 'ToM1':
config.agent_alg = 'with_tom'
config.adversary_alg = 'with_tom'
else:
config.agent_alg = config.alg
config.adversary_alg = config.alg
# debug
if config.num_good_agents == None and config.num_adversaries == None:
if config.env_id == 'simple_adversary':
config.num_good_agents = 2
config.num_adversaries = 1
elif config.env_id == 'simple_tag':
config.num_good_agents = 1
config.num_adversaries = 3
elif config.env_id == 'simple_world_comm':
config.num_good_agents = 2
config.num_adversaries = 4
elif config.env_id == 'hetero_spread':
config.num_good_agents = 4
config.num_adversaries = 0
elif config.env_id == 'simple_spread': #coop
config.num_good_agents = 3
config.num_adversaries = 0
run(config)
================================================
FILE: examples/Social_Cognition/FOToM/multiagent/__init__.py
================================================
from gym.envs.registration import register
# Multiagent envs
# ----------------------------------------
register(
id='MultiagentSimple-v0',
entry_point='multiagent.envs:SimpleEnv',
# FIXME(cathywu) currently has to be exactly max_path_length parameters in
# rllab run script
max_episode_steps=100,
)
register(
id='MultiagentSimpleSpeakerListener-v0',
entry_point='multiagent.envs:SimpleSpeakerListenerEnv',
max_episode_steps=100,
)
================================================
FILE: examples/Social_Cognition/FOToM/multiagent/core.py
================================================
import numpy as np
# physical/external base state of all entites
class EntityState(object):
def __init__(self):
# physical position
self.p_pos = None
# physical velocity
self.p_vel = None
# state of agents (including communication and internal/mental state)
class AgentState(EntityState):
def __init__(self):
super(AgentState, self).__init__()
# communication utterance
self.c = None
# action of the agent
class Action(object):
def __init__(self):
# physical action
self.u = None
# communication action
self.c = None
# properties and state of physical world entity
class Entity(object):
def __init__(self):
# name
self.name = ''
# properties:
self.size = 0.050
# entity can move / be pushed
self.movable = False
# entity collides with others
self.collide = True
# material density (affects mass)
self.density = 25.0
# color
self.color = None
# max speed and accel
self.max_speed = None
self.accel = None
# state
self.state = EntityState()
# mass
self.initial_mass = 1.0
@property
def mass(self):
return self.initial_mass
# properties of landmark entities
class Landmark(Entity):
def __init__(self):
super(Landmark, self).__init__()
# properties of agent entities
class Agent(Entity):
def __init__(self):
super(Agent, self).__init__()
# agents are movable by default
self.movable = True
# cannot send communication signals
self.silent = False
# cannot observe the world
self.blind = False
# physical motor noise amount
self.u_noise = None
# communication noise amount
self.c_noise = None
# control range
self.u_range = 1.0
# state
self.state = AgentState()
# action
self.action = Action()
# script behavior to execute
self.action_callback = None
# multi-agent world
class World(object):
def __init__(self):
# list of agents and entities (can change at execution-time!)
self.agents = []
self.landmarks = []
# communication channel dimensionality
self.dim_c = 0
# position dimensionality
self.dim_p = 2
# color dimensionality
self.dim_color = 3
# simulation timestep
self.dt = 0.1
# physical damping
self.damping = 0.25
# contact response parameters
self.contact_force = 1e+2
self.contact_margin = 1e-3
# return all entities in the world
@property
def entities(self):
return self.agents + self.landmarks
# return all agents controllable by external policies
@property
def policy_agents(self):
return [agent for agent in self.agents if agent.action_callback is None]
# return all agents controlled by world scripts
@property
def scripted_agents(self):
return [agent for agent in self.agents if agent.action_callback is not None]
# update state of the world
def step(self):
# set actions for scripted agents
for agent in self.scripted_agents:
agent.action = agent.action_callback(agent, self)
# gather forces applied to entities
p_force = [None] * len(self.entities)
# apply agent physical controls
p_force = self.apply_action_force(p_force)
# apply environment forces
p_force = self.apply_environment_force(p_force)
# integrate physical state
self.integrate_state(p_force)
# update agent state
for agent in self.agents:
self.update_agent_state(agent)
# gather agent action forces
def apply_action_force(self, p_force):
# set applied forces
for i,agent in enumerate(self.agents):
if agent.movable:
noise = np.random.randn(*agent.action.u.shape) * agent.u_noise if agent.u_noise else 0.0
p_force[i] = agent.action.u + noise
return p_force
# gather physical forces acting on entities
def apply_environment_force(self, p_force):
# simple (but inefficient) collision response
for a,entity_a in enumerate(self.entities):
for b,entity_b in enumerate(self.entities):
if(b <= a): continue
[f_a, f_b] = self.get_collision_force(entity_a, entity_b)
if(f_a is not None):
if(p_force[a] is None): p_force[a] = 0.0
p_force[a] = f_a + p_force[a]
if(f_b is not None):
if(p_force[b] is None): p_force[b] = 0.0
p_force[b] = f_b + p_force[b]
return p_force
# integrate physical state
def integrate_state(self, p_force):
for i,entity in enumerate(self.entities):
if not entity.movable: continue
entity.state.p_vel = entity.state.p_vel * (1 - self.damping)
if (p_force[i] is not None):
entity.state.p_vel += (p_force[i] / entity.mass) * self.dt
if entity.max_speed is not None:
speed = np.sqrt(np.square(entity.state.p_vel[0]) + np.square(entity.state.p_vel[1]))
if speed > entity.max_speed:
entity.state.p_vel = entity.state.p_vel / np.sqrt(np.square(entity.state.p_vel[0]) +
np.square(entity.state.p_vel[1])) * entity.max_speed
entity.state.p_pos += entity.state.p_vel * self.dt
def update_agent_state(self, agent):
# set communication state (directly for now)
if agent.silent:
agent.state.c = np.zeros(self.dim_c)
else:
noise = np.random.randn(*agent.action.c.shape) * agent.c_noise if agent.c_noise else 0.0
agent.state.c = agent.action.c + noise
# get collision forces for any contact between two entities
def get_collision_force(self, entity_a, entity_b):
if (not entity_a.collide) or (not entity_b.collide):
return [None, None] # not a collider
if (entity_a is entity_b):
return [None, None] # don't collide against itself
# compute actual distance between entities
delta_pos = entity_a.state.p_pos - entity_b.state.p_pos
dist = np.sqrt(np.sum(np.square(delta_pos)))
# minimum allowable distance
dist_min = entity_a.size + entity_b.size
# softmax penetration
k = self.contact_margin
penetration = np.logaddexp(0, -(dist - dist_min)/k)*k
force = self.contact_force * delta_pos / dist * penetration
force_a = +force if entity_a.movable else None
force_b = -force if entity_b.movable else None
return [force_a, force_b]
================================================
FILE: examples/Social_Cognition/FOToM/multiagent/environment.py
================================================
import gym
from gym import spaces
from gym.envs.registration import EnvSpec
import numpy as np
from multiagent.multi_discrete import MultiDiscrete
# environment for all agents in the multiagent world
# currently code assumes that no agents will be created/destroyed at runtime!
class MultiAgentEnv(gym.Env):
metadata = {
'render.modes' : ['human', 'rgb_array']
}
def __init__(self, world, reset_callback=None, reward_callback=None,
observation_callback=None, info_callback=None,
done_callback=None, shared_viewer=True):
self.world = world
self.agents = self.world.policy_agents
# set required vectorized gym env property
self.n = len(world.policy_agents)
# scenario callbacks
self.reset_callback = reset_callback
self.reward_callback = reward_callback
self.observation_callback = observation_callback
self.info_callback = info_callback
self.done_callback = done_callback
# environment parameters
self.discrete_action_space = True
# if true, action is a number 0...N, otherwise action is a one-hot N-dimensional vector
self.discrete_action_input = False
# if true, even the action is continuous, action will be performed discretely
self.force_discrete_action = world.discrete_action if hasattr(world, 'discrete_action') else False
# if true, every agent has the same reward
self.shared_reward = world.collaborative if hasattr(world, 'collaborative') else False
self.time = 0
# configure spaces
self.action_space = []
self.observation_space = []
for agent in self.agents:
total_action_space = []
# physical action space
if self.discrete_action_space:
u_action_space = spaces.Discrete(world.dim_p * 2 + 1)
else:
u_action_space = spaces.Box(low=-agent.u_range, high=+agent.u_range, shape=(world.dim_p,), dtype=np.float32)
if agent.movable:
total_action_space.append(u_action_space)
# communication action space
if self.discrete_action_space:
c_action_space = spaces.Discrete(world.dim_c)
else:
c_action_space = spaces.Box(low=0.0, high=1.0, shape=(world.dim_c,), dtype=np.float32)
if not agent.silent:
total_action_space.append(c_action_space)
# total action space
if len(total_action_space) > 1:
# all action spaces are discrete, so simplify to MultiDiscrete action space
if all([isinstance(act_space, spaces.Discrete) for act_space in total_action_space]):
act_space = MultiDiscrete([[0, act_space.n - 1] for act_space in total_action_space])
else:
act_space = spaces.Tuple(total_action_space)
self.action_space.append(act_space)
else:
self.action_space.append(total_action_space[0])
# observation space
obs_dim = len(observation_callback(agent, self.world))
self.observation_space.append(spaces.Box(low=-np.inf, high=+np.inf, shape=(obs_dim,), dtype=np.float32))
agent.action.c = np.zeros(self.world.dim_c)
# rendering
self.shared_viewer = shared_viewer
if self.shared_viewer:
self.viewers = [None]
else:
self.viewers = [None] * self.n
self._reset_render()
def step(self, action_n):
obs_n = []
reward_n = []
done_n = []
info_n = {'n': []}
self.agents = self.world.policy_agents
# set action for each agent
for i, agent in enumerate(self.agents):
self._set_action(action_n[i], agent, self.action_space[i])
# advance world state
self.world.step()
# record observation for each agent
for agent in self.agents:
obs_n.append(self._get_obs(agent))
reward_n.append(self._get_reward(agent))
done_n.append(self._get_done(agent))
info_n['n'].append(self._get_info(agent))
# all agents get total reward in cooperative case
reward = np.sum(reward_n)
if self.shared_reward:
reward_n = [reward] * self.n
return obs_n, reward_n, done_n, info_n
def reset(self):
# reset world
self.reset_callback(self.world)
# reset renderer
self._reset_render()
# record observations for each agent
obs_n = []
self.agents = self.world.policy_agents
for agent in self.agents:
obs_n.append(self._get_obs(agent))
return obs_n
# get info used for benchmarking
def _get_info(self, agent):
if self.info_callback is None:
return {}
return self.info_callback(agent, self.world)
# get observation for a particular agent
def _get_obs(self, agent):
if self.observation_callback is None:
return np.zeros(0)
return self.observation_callback(agent, self.world)
# get dones for a particular agent
# unused right now -- agents are allowed to go beyond the viewing screen
def _get_done(self, agent):
if self.done_callback is None:
return False
return self.done_callback(agent, self.world)
# get reward for a particular agent
def _get_reward(self, agent):
if self.reward_callback is None:
return 0.0
return self.reward_callback(agent, self.world)
# set env action for a particular agent
def _set_action(self, action, agent, action_space, time=None):
agent.action.u = np.zeros(self.world.dim_p)
agent.action.c = np.zeros(self.world.dim_c)
# process action
if isinstance(action_space, MultiDiscrete):
act = []
size = action_space.high - action_space.low + 1
index = 0
for s in size:
act.append(action[index:(index+s)])
index += s
action = act
else:
action = [action]
if agent.movable:
# physical action
if self.discrete_action_input:
agent.action.u = np.zeros(self.world.dim_p)
# process discrete action
if action[0] == 1: agent.action.u[0] = -1.0
if action[0] == 2: agent.action.u[0] = +1.0
if action[0] == 3: agent.action.u[1] = -1.0
if action[0] == 4: agent.action.u[1] = +1.0
else:
if self.force_discrete_action:
d = np.argmax(action[0])
action[0][:] = 0.0
action[0][d] = 1.0
if self.discrete_action_space:
agent.action.u[0] += action[0][1] - action[0][2]
agent.action.u[1] += action[0][3] - action[0][4]
else:
agent.action.u = action[0]
sensitivity = 5.0
if agent.accel is not None:
sensitivity = agent.accel
agent.action.u *= sensitivity
action = action[1:]
if not agent.silent:
# communication action
if self.discrete_action_input:
agent.action.c = np.zeros(self.world.dim_c)
agent.action.c[action[0]] = 1.0
else:
agent.action.c = action[0]
action = action[1:]
# make sure we used all elements of action
assert len(action) == 0
# reset rendering assets
def _reset_render(self):
self.render_geoms = None
self.render_geoms_xform = None
# render environment
def render(self, mode='human'):
if mode == 'human':
alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
message = ''
for agent in self.world.agents:
comm = []
for other in self.world.agents:
if other is agent: continue
if np.all(other.state.c == 0):
word = '_'
else:
word = alphabet[np.argmax(other.state.c)]
message += (other.name + ' to ' + agent.name + ': ' + word + ' ')
print(message)
for i in range(len(self.viewers)):
# create viewers (if necessary)
if self.viewers[i] is None:
# import rendering only if we need it (and don't import for headless machines)
#from gym.envs.classic_control import rendering
from multiagent import rendering
self.viewers[i] = rendering.Viewer(700,700)
# create rendering geometry
if self.render_geoms is None:
# import rendering only if we need it (and don't import for headless machines)
#from gym.envs.classic_control import rendering
from multiagent import rendering
self.render_geoms = []
self.render_geoms_xform = []
for entity in self.world.entities:
geom = rendering.make_circle(entity.size)
xform = rendering.Transform()
if 'agent' in entity.name:
geom.set_color(*entity.color, alpha=0.5)
else:
geom.set_color(*entity.color)
geom.add_attr(xform)
self.render_geoms.append(geom)
self.render_geoms_xform.append(xform)
# add geoms to viewer
for viewer in self.viewers:
viewer.geoms = []
for geom in self.render_geoms:
viewer.add_geom(geom)
results = []
for i in range(len(self.viewers)):
from multiagent import rendering
# update bounds to center around agent
cam_range = 1
if self.shared_viewer:
pos = np.zeros(self.world.dim_p)
else:
pos = self.agents[i].state.p_pos
self.viewers[i].set_bounds(pos[0]-cam_range,pos[0]+cam_range,pos[1]-cam_range,pos[1]+cam_range)
# update geometry positions
for e, entity in enumerate(self.world.entities):
self.render_geoms_xform[e].set_translation(*entity.state.p_pos)
# render to display or array
results.append(self.viewers[i].render(return_rgb_array = mode=='rgb_array'))
return results
# create receptor field locations in local coordinate frame
def _make_receptor_locations(self, agent):
receptor_type = 'polar'
range_min = 0.05 * 2.0
range_max = 1.00
dx = []
# circular receptive field
if receptor_type == 'polar':
for angle in np.linspace(-np.pi, +np.pi, 8, endpoint=False):
for distance in np.linspace(range_min, range_max, 3):
dx.append(distance * np.array([np.cos(angle), np.sin(angle)]))
# add origin
dx.append(np.array([0.0, 0.0]))
# grid receptive field
if receptor_type == 'grid':
for x in np.linspace(-range_max, +range_max, 5):
for y in np.linspace(-range_max, +range_max, 5):
dx.append(np.array([x,y]))
return dx
# vectorized wrapper for a batch of multi-agent environments
# assumes all environments have the same observation and action space
class BatchMultiAgentEnv(gym.Env):
metadata = {
'runtime.vectorized': True,
'render.modes' : ['human', 'rgb_array']
}
def __init__(self, env_batch):
self.env_batch = env_batch
@property
def n(self):
return np.sum([env.n for env in self.env_batch])
@property
def action_space(self):
return self.env_batch[0].action_space
@property
def observation_space(self):
return self.env_batch[0].observation_space
def step(self, action_n, time):
obs_n = []
reward_n = []
done_n = []
info_n = {'n': []}
i = 0
for env in self.env_batch:
obs, reward, done, _ = env.step(action_n[i:(i+env.n)], time)
i += env.n
obs_n += obs
# reward = [r / len(self.env_batch) for r in reward]
reward_n += reward
done_n += done
return obs_n, reward_n, done_n, info_n
def reset(self):
obs_n = []
for env in self.env_batch:
obs_n += env.reset()
return obs_n
# render environment
def render(self, mode='human', close=True):
results_n = []
for env in self.env_batch:
results_n += env.render(mode, close)
return results_n
================================================
FILE: examples/Social_Cognition/FOToM/multiagent/multi_discrete.py
================================================
# An old version of OpenAI Gym's multi_discrete.py. (Was getting affected by Gym updates)
# (https://github.com/openai/gym/blob/1fb81d4e3fb780ccf77fec731287ba07da35eb84/gym/spaces/multi_discrete.py)
import numpy as np
import gym
# from gym.spaces import prng
class MultiDiscrete(gym.Space):
"""
- The multi-discrete action space consists of a series of discrete action spaces with different parameters
- It can be adapted to both a Discrete action space or a continuous (Box) action space
- It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space
- It is parametrized by passing an array of arrays containing [min, max] for each discrete action space
where the discrete action space can take any integers from `min` to `max` (both inclusive)
Note: A value of 0 always need to represent the NOOP action.
e.g. Nintendo Game Controller
- Can be conceptualized as 3 discrete action spaces:
1) Arrow Keys: Discrete 5 - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4] - params: min: 0, max: 4
2) Button A: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1
3) Button B: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1
- Can be initialized as
MultiDiscrete([ [0,4], [0,1], [0,1] ])
"""
def __init__(self, array_of_param_array):
self.low = np.array([x[0] for x in array_of_param_array])
self.high = np.array([x[1] for x in array_of_param_array])
self.num_discrete_space = self.low.shape[0]
def sample(self):
""" Returns a array with one sample from each discrete action space """
# For each row: round(random .* (max - min) + min, 0)
# random_array = prng.np_random.rand(self.num_discrete_space)
random_array = np.random.RandomState().rand(self.num_discrete_space)
return [int(x) for x in np.floor(np.multiply((self.high - self.low + 1.), random_array) + self.low)]
def contains(self, x):
return len(x) == self.num_discrete_space and (np.array(x) >= self.low).all() and (np.array(x) <= self.high).all()
@property
def shape(self):
return self.num_discrete_space
def __repr__(self):
return "MultiDiscrete" + str(self.num_discrete_space)
def __eq__(self, other):
return np.array_equal(self.low, other.low) and np.array_equal(self.high, other.high)
================================================
FILE: examples/Social_Cognition/FOToM/multiagent/policy.py
================================================
import numpy as np
from pyglet.window import key
# individual agent policy
class Policy(object):
def __init__(self):
pass
def action(self, obs):
raise NotImplementedError()
# interactive policy based on keyboard input
# hard-coded to deal only with movement, not communication
class InteractivePolicy(Policy):
def __init__(self, env, agent_index):
super(InteractivePolicy, self).__init__()
self.env = env
# hard-coded keyboard events
self.move = [False for i in range(4)]
self.comm = [False for i in range(env.world.dim_c)]
# register keyboard events with this environment's window
# env.viewers[agent_index].window.on_key_press = self.key_press
# env.viewers[agent_index].window.on_key_release = self.key_release
def action(self, obs):
# ignore observation and just act based on keyboard events
if self.env.discrete_action_input:
u = 0
if self.move[0]: u = 1
if self.move[1]: u = 2
if self.move[2]: u = 4
if self.move[3]: u = 3
else:
u = np.zeros(5) # 5-d because of no-move action
if self.move[0]: u[1] += 1.0
if self.move[1]: u[2] += 1.0
if self.move[3]: u[3] += 1.0
if self.move[2]: u[4] += 1.0
if True not in self.move:
u[0] += 1.0
return np.concatenate([u, np.zeros(self.env.world.dim_c)])
# keyboard event callbacks
def key_press(self, k, mod):
if k==key.LEFT: self.move[0] = True
if k==key.RIGHT: self.move[1] = True
if k==key.UP: self.move[2] = True
if k==key.DOWN: self.move[3] = True
def key_release(self, k, mod):
if k==key.LEFT: self.move[0] = False
if k==key.RIGHT: self.move[1] = False
if k==key.UP: self.move[2] = False
if k==key.DOWN: self.move[3] = False
================================================
FILE: examples/Social_Cognition/FOToM/multiagent/rendering.py
================================================
"""
2D rendering framework
"""
from __future__ import division
import os
import six
import sys
if "Apple" in sys.version:
if 'DYLD_FALLBACK_LIBRARY_PATH' in os.environ:
os.environ['DYLD_FALLBACK_LIBRARY_PATH'] += ':/usr/lib'
# (JDS 2016/04/15): avoid bug on Anaconda 2.3.0 / Yosemite
# from gym.utils import reraise
from gym import error
import pyglet
from pyglet.gl import *
# try:
# import pyglet
# except ImportError as e:
# reraise(suffix="HINT: you can install pyglet directly via 'pip install pyglet'. But if you really just want to install all Gym dependencies and not have to think about it, 'pip install -e .[all]' or 'pip install gym[all]' will do it.")
# try:
# from pyglet.gl import *
# except ImportError as e:
# reraise(prefix="Error occured while running `from pyglet.gl import *`",suffix="HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get install python-opengl'. If you're running on a server, you may need a virtual frame buffer; something like this should work: 'xvfb-run -s \"-screen 0 1400x900x24\" python '")
import math
import numpy as np
RAD2DEG = 57.29577951308232
def get_display(spec):
"""Convert a display specification (such as :0) into an actual Display
object.
Pyglet only supports multiple Displays on Linux.
"""
if spec is None:
return None
elif isinstance(spec, six.string_types):
return pyglet.canvas.Display(spec)
else:
raise error.Error('Invalid display specification: {}. (Must be a string like :0 or None.)'.format(spec))
class Viewer(object):
def __init__(self, width, height, display=None):
display = get_display(display)
self.width = width
self.height = height
self.window = pyglet.window.Window(width=width, height=height, display=display)
self.window.on_close = self.window_closed_by_user
self.geoms = []
self.onetime_geoms = []
self.transform = Transform()
glEnable(GL_BLEND)
# glEnable(GL_MULTISAMPLE)
glEnable(GL_LINE_SMOOTH)
# glHint(GL_LINE_SMOOTH_HINT, GL_DONT_CARE)
glHint(GL_LINE_SMOOTH_HINT, GL_NICEST)
glLineWidth(2.0)
glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)
def close(self):
self.window.close()
def window_closed_by_user(self):
self.close()
def set_bounds(self, left, right, bottom, top):
assert right > left and top > bottom
scalex = self.width/(right-left)
scaley = self.height/(top-bottom)
self.transform = Transform(
translation=(-left*scalex, -bottom*scaley),
scale=(scalex, scaley))
def add_geom(self, geom):
self.geoms.append(geom)
def add_onetime(self, geom):
self.onetime_geoms.append(geom)
def render(self, return_rgb_array=False):
glClearColor(1,1,1,1)
self.window.clear()
self.window.switch_to()
self.window.dispatch_events()
self.transform.enable()
for geom in self.geoms:
geom.render()
for geom in self.onetime_geoms:
geom.render()
self.transform.disable()
arr = None
if return_rgb_array:
buffer = pyglet.image.get_buffer_manager().get_color_buffer()
image_data = buffer.get_image_data()
arr = np.fromstring(image_data._current_data, dtype=np.uint8, sep='')
# In https://github.com/openai/gym-http-api/issues/2, we
# discovered that someone using Xmonad on Arch was having
# a window of size 598 x 398, though a 600 x 400 window
# was requested. (Guess Xmonad was preserving a pixel for
# the boundary.) So we use the buffer height/width rather
# than the requested one.
arr = arr.reshape(buffer.height, buffer.width, 4)
arr = arr[::-1,:,0:3]
self.window.flip()
self.onetime_geoms = []
return arr
# Convenience
def draw_circle(self, radius=10, res=30, filled=True, **attrs):
geom = make_circle(radius=radius, res=res, filled=filled)
_add_attrs(geom, attrs)
self.add_onetime(geom)
return geom
def draw_polygon(self, v, filled=True, **attrs):
geom = make_polygon(v=v, filled=filled)
_add_attrs(geom, attrs)
self.add_onetime(geom)
return geom
def draw_polyline(self, v, **attrs):
geom = make_polyline(v=v)
_add_attrs(geom, attrs)
self.add_onetime(geom)
return geom
def draw_line(self, start, end, **attrs):
geom = Line(start, end)
_add_attrs(geom, attrs)
self.add_onetime(geom)
return geom
def get_array(self):
self.window.flip()
image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
self.window.flip()
arr = np.fromstring(image_data.data, dtype=np.uint8, sep='')
arr = arr.reshape(self.height, self.width, 4)
return arr[::-1,:,0:3]
def _add_attrs(geom, attrs):
if "color" in attrs:
geom.set_color(*attrs["color"])
if "linewidth" in attrs:
geom.set_linewidth(attrs["linewidth"])
class Geom(object):
def __init__(self):
self._color=Color((0, 0, 0, 1.0))
self.attrs = [self._color]
def render(self):
for attr in reversed(self.attrs):
attr.enable()
self.render1()
for attr in self.attrs:
attr.disable()
def render1(self):
raise NotImplementedError
def add_attr(self, attr):
self.attrs.append(attr)
def set_color(self, r, g, b, alpha=1):
self._color.vec4 = (r, g, b, alpha)
class Attr(object):
def enable(self):
raise NotImplementedError
def disable(self):
pass
class Transform(Attr):
def __init__(self, translation=(0.0, 0.0), rotation=0.0, scale=(1,1)):
self.set_translation(*translation)
self.set_rotation(rotation)
self.set_scale(*scale)
def enable(self):
glPushMatrix()
glTranslatef(self.translation[0], self.translation[1], 0) # translate to GL loc ppint
glRotatef(RAD2DEG * self.rotation, 0, 0, 1.0)
glScalef(self.scale[0], self.scale[1], 1)
def disable(self):
glPopMatrix()
def set_translation(self, newx, newy):
self.translation = (float(newx), float(newy))
def set_rotation(self, new):
self.rotation = float(new)
def set_scale(self, newx, newy):
self.scale = (float(newx), float(newy))
class Color(Attr):
def __init__(self, vec4):
self.vec4 = vec4
def enable(self):
glColor4f(*self.vec4)
class LineStyle(Attr):
def __init__(self, style):
self.style = style
def enable(self):
glEnable(GL_LINE_STIPPLE)
glLineStipple(1, self.style)
def disable(self):
glDisable(GL_LINE_STIPPLE)
class LineWidth(Attr):
def __init__(self, stroke):
self.stroke = stroke
def enable(self):
glLineWidth(self.stroke)
class Point(Geom):
def __init__(self):
Geom.__init__(self)
def render1(self):
glBegin(GL_POINTS) # draw point
glVertex3f(0.0, 0.0, 0.0)
glEnd()
class FilledPolygon(Geom):
def __init__(self, v):
Geom.__init__(self)
self.v = v
def render1(self):
if len(self.v) == 4 : glBegin(GL_QUADS)
elif len(self.v) > 4 : glBegin(GL_POLYGON)
else: glBegin(GL_TRIANGLES)
for p in self.v:
glVertex3f(p[0], p[1],0) # draw each vertex
glEnd()
color = (self._color.vec4[0] * 0.5, self._color.vec4[1] * 0.5, self._color.vec4[2] * 0.5, self._color.vec4[3] * 0.5)
glColor4f(*color)
glBegin(GL_LINE_LOOP)
for p in self.v:
glVertex3f(p[0], p[1],0) # draw each vertex
glEnd()
def make_circle(radius=10, res=30, filled=True):
points = []
for i in range(res):
ang = 2*math.pi*i / res
points.append((math.cos(ang)*radius, math.sin(ang)*radius))
if filled:
return FilledPolygon(points)
else:
return PolyLine(points, True)
def make_polygon(v, filled=True):
if filled: return FilledPolygon(v)
else: return PolyLine(v, True)
def make_polyline(v):
return PolyLine(v, False)
def make_capsule(length, width):
l, r, t, b = 0, length, width/2, -width/2
box = make_polygon([(l,b), (l,t), (r,t), (r,b)])
circ0 = make_circle(width/2)
circ1 = make_circle(width/2)
circ1.add_attr(Transform(translation=(length, 0)))
geom = Compound([box, circ0, circ1])
return geom
class Compound(Geom):
def __init__(self, gs):
Geom.__init__(self)
self.gs = gs
for g in self.gs:
g.attrs = [a for a in g.attrs if not isinstance(a, Color)]
def render1(self):
for g in self.gs:
g.render()
class PolyLine(Geom):
def __init__(self, v, close):
Geom.__init__(self)
self.v = v
self.close = close
self.linewidth = LineWidth(1)
self.add_attr(self.linewidth)
def render1(self):
glBegin(GL_LINE_LOOP if self.close else GL_LINE_STRIP)
for p in self.v:
glVertex3f(p[0], p[1],0) # draw each vertex
glEnd()
def set_linewidth(self, x):
self.linewidth.stroke = x
class Line(Geom):
def __init__(self, start=(0.0, 0.0), end=(0.0, 0.0)):
Geom.__init__(self)
self.start = start
self.end = end
self.linewidth = LineWidth(1)
self.add_attr(self.linewidth)
def render1(self):
glBegin(GL_LINES)
glVertex2f(*self.start)
glVertex2f(*self.end)
glEnd()
class Image(Geom):
def __init__(self, fname, width, height):
Geom.__init__(self)
self.width = width
self.height = height
img = pyglet.image.load(fname)
self.img = img
self.flip = False
def render1(self):
self.img.blit(-self.width/2, -self.height/2, width=self.width, height=self.height)
# ================================================================
class SimpleImageViewer(object):
def __init__(self, display=None):
self.window = None
self.isopen = False
self.display = display
def imshow(self, arr):
if self.window is None:
height, width, channels = arr.shape
self.window = pyglet.window.Window(width=width, height=height, display=self.display)
self.width = width
self.height = height
self.isopen = True
assert arr.shape == (self.height, self.width, 3), "You passed in an image with the wrong number shape"
image = pyglet.image.ImageData(self.width, self.height, 'RGB', arr.tobytes(), pitch=self.width * -3)
self.window.clear()
self.window.switch_to()
self.window.dispatch_events()
image.blit(0,0)
self.window.flip()
def close(self):
if self.isopen:
self.window.close()
self.isopen = False
def __del__(self):
self.close()
================================================
FILE: examples/Social_Cognition/FOToM/multiagent/scenario.py
================================================
import numpy as np
# defines scenario upon which the world is built
class BaseScenario(object):
# create elements of the world
def make_world(self):
raise NotImplementedError()
# create initial conditions of the world
def reset_world(self, world):
raise NotImplementedError()
================================================
FILE: examples/Social_Cognition/FOToM/multiagent/scenarios/__init__.py
================================================
import imp
import os.path as osp
def load(name):
pathname = osp.join(osp.dirname(__file__), name)
return imp.load_source('', pathname)
================================================
FILE: examples/Social_Cognition/FOToM/multiagent/scenarios/hetero_spread.py
================================================
import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self, num_good_agents=2, num_adversaries=0):
world = World()
# set any world properties first
world.dim_c = 2
world.max_steps = 25
num_agents = num_good_agents
self.n_agent_a = num_agents // 2 # 2
self.n_agent_b = num_agents // 2 # 2
num_landmarks = num_good_agents
world.collaborative = True
self.agent_size = 0.10
self.n_others = 3
self.n_group = 2
# add agents
world.agents = [Agent() for i in range(num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = True
agent.silent = True
agent.id = i
if i < self.n_agent_a:
agent.size = self.agent_size
agent.accel = 3.0
agent.max_speed = 1.0
else:
agent.size = self.agent_size / 2
agent.accel = 4.0
agent.max_speed = 1.3
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
world.num_steps = 0
self.end_steps = world.max_steps
# random properties for agents
for i, agent in enumerate(world.agents):
if i < self.n_agent_a:
agent.color = np.array([0.35, 0.35, 0.85])
else:
agent.color = np.array([0.35, 0.85, 0.35])
# random properties for landmarks
for i, landmark in enumerate(world.landmarks):
landmark.color = np.array([0.25, 0.25, 0.25])
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
rew = 0
collisions = 0
occupied_landmarks = 0
min_dists = 0
for l in world.landmarks:
dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for a in world.agents]
min_dists += min(dists)
rew -= min(dists)
if min(dists) < 0.1:
occupied_landmarks += 1
if agent.collide:
for a in world.agents:
if self.is_collision(a, agent):
rew -= 1
collisions += 1
return (rew, collisions, min_dists, occupied_landmarks)
def is_collision(self, agent1, agent2):
delta_pos = agent1.state.p_pos - agent2.state.p_pos
dist = np.sqrt(np.sum(np.square(delta_pos)))
dist_min = agent1.size + agent2.size
return True if dist < dist_min else False
def reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark, penalized for collisions
rew = 0
shaped_reward = False
if shaped_reward: # distance-based reward
for l in world.landmarks:
dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for a in world.agents]
rew -= min(dists)
if agent.collide:
for a in world.agents:
if self.is_collision(a, agent):
rew -= 1
return rew
else:
win_agents = []
for land in world.landmarks:
for a in world.agents:
if self.is_collision(a, land):
win_agents.append(a)
break
rew += 2 * len(set(win_agents))
def bound(x):
if x > 1.0:
return min(np.exp(2 * x - 2), 10)
else:
return 0.0
bound_rew = 0.0
for p in range(world.dim_p):
x = abs(agent.state.p_pos[p])
bound_rew -= bound(x)
rew += bound_rew
return rew
def observation(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks: # world.entities:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# entity colors
entity_color = []
other_vel = []
for entity in world.landmarks: # world.entities:
entity_color.append(entity.color)
# communication of all other agents
comm = []
other_pos = []
for other in world.agents:
if other is agent:
other_vel.append([0, 0])
continue
comm.append(other.state.c)
other_pos.append(other.state.p_pos - agent.state.p_pos)
return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + comm +
other_vel + entity_pos + other_pos)
================================================
FILE: examples/Social_Cognition/FOToM/multiagent/scenarios/simple.py
================================================
import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self):
world = World()
# add agents
world.agents = [Agent() for i in range(1)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = False
agent.silent = True
# add landmarks
world.landmarks = [Landmark() for i in range(1)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# random properties for agents
for i, agent in enumerate(world.agents):
agent.color = np.array([0.25,0.25,0.25])
# random properties for landmarks
for i, landmark in enumerate(world.landmarks):
landmark.color = np.array([0.75,0.75,0.75])
world.landmarks[0].color = np.array([0.75,0.25,0.25])
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1,+1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-1,+1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def reward(self, agent, world):
dist2 = np.sum(np.square(agent.state.p_pos - world.landmarks[0].state.p_pos))
return -dist2
def observation(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
return np.concatenate([agent.state.p_vel] + entity_pos)
================================================
FILE: examples/Social_Cognition/FOToM/multiagent/scenarios/simple_adversary.py
================================================
import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self, num_good_agents=2, num_adversaries=1):
world = World()
# set any world properties first
world.dim_c = 2
num_agents = num_adversaries + num_good_agents
num_good_agents = num_good_agents
num_adversaries = 1
world.num_agents = num_agents
num_landmarks = num_agents - 1
# add agents
world.agents = [Agent() for i in range(num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = False
agent.silent = True
agent.adversary = True if i < num_adversaries else False
agent.size = 0.15
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
landmark.size = 0.08
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# random properties for agents
world.agents[0].color = np.array([0.85, 0.35, 0.35])
for i in range(1, world.num_agents):
world.agents[i].color = np.array([0.35, 0.35, 0.85])
# random properties for landmarks
for i, landmark in enumerate(world.landmarks):
landmark.color = np.array([0.15, 0.15, 0.15])
# set goal landmark
goal = np.random.choice(world.landmarks)
goal.color = np.array([0.15, 0.65, 0.15])
for agent in world.agents:
agent.goal_a = goal
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
# returns data for benchmarking purposes
if agent.adversary:
return np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos))
else:
dists = []
for l in world.landmarks:
dists.append(np.sum(np.square(agent.state.p_pos - l.state.p_pos)))
dists.append(np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos)))
return tuple(dists)
# return all agents that are not adversaries
def good_agents(self, world):
return [agent for agent in world.agents if not agent.adversary]
# return all adversarial agents
def adversaries(self, world):
return [agent for agent in world.agents if agent.adversary]
def reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark
return self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)
def agent_reward(self, agent, world):
# Rewarded based on how close any good agent is to the goal landmark, and how far the adversary is from it
shaped_reward = False
shaped_adv_reward = False
# Calculate negative reward for adversary
adversary_agents = self.adversaries(world)
if shaped_adv_reward: # distance-based adversary reward
adv_rew = sum([np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in adversary_agents])
else: # proximity-based adversary reward (binary)
adv_rew = 0
for a in adversary_agents:
if np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) < 2 * a.goal_a.size:
adv_rew -= 5
# Calculate positive reward for agents
good_agents = self.good_agents(world)
if shaped_reward: # distance-based agent reward
pos_rew = -min(
[np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in good_agents])
else: # proximity-based agent reward (binary)
pos_rew = 0
if min([np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in good_agents]) \
< 2 * agent.goal_a.size:
pos_rew += 5
pos_rew -= min(
[np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in good_agents])
return pos_rew + adv_rew
def adversary_reward(self, agent, world):
# Rewarded based on proximity to the goal landmark
shaped_reward = False
if shaped_reward: # distance-based reward
return -np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos))
else: # proximity-based reward (binary)
adv_rew = 0
if np.sqrt(np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos))) < 2 * agent.goal_a.size:
adv_rew += 5
return adv_rew
def observation(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# entity colors
entity_color = []
other_vel = []
for entity in world.landmarks:
entity_color.append(entity.color)
# communication of all other agents
other_pos = []
for other in world.agents:
if other is agent and not other.adversary:
other_vel.append([0, 0])
continue
other_pos.append(other.state.p_pos - agent.state.p_pos)
if not other.adversary:
other_vel.append([0, 0])
if not agent.adversary:
return np.concatenate(
[agent.goal_a.state.p_pos - agent.state.p_pos] +
other_vel + entity_pos + other_pos)
else:
return np.concatenate([np.zeros(2)] +
other_vel + entity_pos + other_pos)
================================================
FILE: examples/Social_Cognition/FOToM/multiagent/scenarios/simple_crypto.py
================================================
"""
Scenario:
1 speaker, 2 listeners (one of which is an adversary). Good agents rewarded for proximity to goal, and distance from
adversary to goal. Adversary is rewarded for its distance to the goal.
"""
import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
import random
class CryptoAgent(Agent):
def __init__(self):
super(CryptoAgent, self).__init__()
self.key = None
class Scenario(BaseScenario):
def make_world(self):
world = World()
# set any world properties first
num_agents = 3
num_adversaries = 1
num_landmarks = 2
world.dim_c = 4
# add agents
world.agents = [CryptoAgent() for i in range(num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = False
agent.adversary = True if i < num_adversaries else False
agent.speaker = True if i == 2 else False
agent.movable = False
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# random properties for agents
for i, agent in enumerate(world.agents):
agent.color = np.array([0.25, 0.25, 0.25])
if agent.adversary:
agent.color = np.array([0.75, 0.25, 0.25])
agent.key = None
# random properties for landmarks
color_list = [np.zeros(world.dim_c) for i in world.landmarks]
for i, color in enumerate(color_list):
color[i] += 1
for color, landmark in zip(color_list, world.landmarks):
landmark.color = color
# set goal landmark
goal = np.random.choice(world.landmarks)
world.agents[1].color = goal.color
world.agents[2].key = np.random.choice(world.landmarks).color
for agent in world.agents:
agent.goal_a = goal
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
# returns data for benchmarking purposes
return (agent.state.c, agent.goal_a.color)
# return all agents that are not adversaries
def good_listeners(self, world):
return [agent for agent in world.agents if not agent.adversary and not agent.speaker]
# return all agents that are not adversaries
def good_agents(self, world):
return [agent for agent in world.agents if not agent.adversary]
# return all adversarial agents
def adversaries(self, world):
return [agent for agent in world.agents if agent.adversary]
def reward(self, agent, world):
return self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)
def agent_reward(self, agent, world):
# Agents rewarded if Bob can reconstruct message, but adversary (Eve) cannot
good_listeners = self.good_listeners(world)
adversaries = self.adversaries(world)
good_rew = 0
adv_rew = 0
for a in good_listeners:
if (a.state.c == np.zeros(world.dim_c)).all():
continue
else:
good_rew -= np.sum(np.square(a.state.c - agent.goal_a.color))
for a in adversaries:
if (a.state.c == np.zeros(world.dim_c)).all():
continue
else:
adv_l1 = np.sum(np.square(a.state.c - agent.goal_a.color))
adv_rew += adv_l1
return adv_rew + good_rew
def adversary_reward(self, agent, world):
# Adversary (Eve) is rewarded if it can reconstruct original goal
rew = 0
if not (agent.state.c == np.zeros(world.dim_c)).all():
rew -= np.sum(np.square(agent.state.c - agent.goal_a.color))
return rew
def observation(self, agent, world):
# goal color
goal_color = np.zeros(world.dim_color)
if agent.goal_a is not None:
goal_color = agent.goal_a.color
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# communication of all other agents
comm = []
for other in world.agents:
if other is agent or (other.state.c is None) or not other.speaker: continue
comm.append(other.state.c)
confer = np.array([0])
if world.agents[2].key is None:
confer = np.array([1])
key = np.zeros(world.dim_c)
goal_color = np.zeros(world.dim_c)
else:
key = world.agents[2].key
prnt = False
# speaker
if agent.speaker:
if prnt:
print('speaker')
print(agent.state.c)
print(np.concatenate([goal_color] + [key] + [confer] + [np.random.randn(1)]))
return np.concatenate([goal_color] + [key])
# listener
if not agent.speaker and not agent.adversary:
if prnt:
print('listener')
print(agent.state.c)
print(np.concatenate([key] + comm + [confer]))
return np.concatenate([key] + comm)
if not agent.speaker and agent.adversary:
if prnt:
print('adversary')
print(agent.state.c)
print(np.concatenate(comm + [confer]))
return np.concatenate(comm)
================================================
FILE: examples/Social_Cognition/FOToM/multiagent/scenarios/simple_push.py
================================================
import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self, num_good_agents=2, num_adversaries=2):
world = World()
# set any world properties first
world.dim_c = 2
num_agents = num_good_agents + num_adversaries
num_adversaries = num_adversaries
num_landmarks = 2
# add agents
world.agents = [Agent() for i in range(num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = True
agent.silent = True
if i < num_adversaries:
agent.adversary = True
else:
agent.adversary = False
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# random properties for landmarks
for i, landmark in enumerate(world.landmarks):
landmark.color = np.array([0.1, 0.1, 0.1])
landmark.color[i + 1] += 0.8
landmark.index = i
# set goal landmark
goal = np.random.choice(world.landmarks)
for i, agent in enumerate(world.agents):
agent.goal_a = goal
agent.color = np.array([0.25, 0.25, 0.25])
if agent.adversary:
agent.color = np.array([0.75, 0.25, 0.25])
else:
j = goal.index
agent.color[j + 1] += 0.5
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark
return self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)
def good_agents(self, world):
return [agent for agent in world.agents if not agent.adversary]
# return all adversarial agents
def adversaries(self, world):
return [agent for agent in world.agents if agent.adversary]
def is_collision(self, agent1, agent2):
delta_pos = agent1.state.p_pos - agent2.state.p_pos
dist = np.sqrt(np.sum(np.square(delta_pos)))
dist_min = agent1.size + agent2.size
return dist < dist_min
def agent_reward(self, agent, world):
'''
Rewrite
'''
shaped_reward = False
if shaped_reward: # distance-based reward
# the distance to the goal
return -np.sqrt(np.sum(np.square(
agent.state.p_pos - agent.goal_a.state.p_pos)))
else:
pos_rew, adv_rew = 0.0, 0.0
for a in self.adversaries(world):
if self.is_collision(a, agent):
adv_rew -= 5.0
if self.is_collision(agent, agent.goal_a):
pos_rew += 5.0
rew = pos_rew + adv_rew
def bound(x):
if x > 1.0:
return min(np.exp(2 * x - 2), 10)
else:
return 0.0
bound_rew = 0.0
for p in range(world.dim_p):
x = abs(agent.state.p_pos[p])
bound_rew -= bound(x)
rew += bound_rew
return rew
def adversary_reward(self, agent, world):
'''
Rewrite
'''
shaped_reward = False
if shaped_reward: # distance-based reward
# keep the nearest good agents away from the goal
agent_dist = [np.sqrt(np.sum(np.square(a.state.p_pos -
a.goal_a.state.p_pos))) for a in world.agents if not a.adversary]
pos_rew = min(agent_dist)
#nearest_agent = world.good_agents[np.argmin(agent_dist)]
#neg_rew = np.sqrt(np.sum(np.square(nearest_agent.state.p_pos - agent.state.p_pos)))
neg_rew = np.sqrt(np.sum(np.square(agent.goal_a.state.p_pos - agent.state.p_pos)))
#neg_rew = sum([np.sqrt(np.sum(np.square(a.state.p_pos - agent.state.p_pos))) for a in world.good_agents])
return pos_rew - neg_rew
else:
rew = 0.0
for a in self.good_agents(world):
if self.is_collision(a, a.goal_a):
rew -= 5.0
# if self.is_collision(a, agent):
# rew += 5.0
if self.is_collision(agent, agent.goal_a):
rew += 5.0
return rew
def observation(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks: # world.entities:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# entity colors
entity_color = []
other_vel = []
for entity in world.landmarks: # world.entities:
entity_color.append(entity.color)
# communication of all other agents
comm = []
other_pos = []
for other in world.agents:
if other is agent and not other.adversary:
other_vel.append([0, 0])
continue
comm.append(other.state.c)
other_pos.append(other.state.p_pos - agent.state.p_pos)
if not other.adversary:
other_vel.append([0, 0])
if not agent.adversary:
return np.concatenate([agent.state.p_vel] +
[agent.goal_a.state.p_pos - agent.state.p_pos] +
[agent.color] + entity_color +
other_vel + entity_pos + other_pos)
else:
#other_pos = list(reversed(other_pos)) if random.uniform(0,1) > 0.5 else other_pos # randomize position of other agents in adversary network
return np.concatenate([agent.state.p_vel] +
other_vel + entity_pos + other_pos)
================================================
FILE: examples/Social_Cognition/FOToM/multiagent/scenarios/simple_reference.py
================================================
import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self, num_good_agents=2, num_adversaries=0):
world = World()
# set any world properties first
world.dim_c = 10
world.collaborative = True # whether agents share rewards
# add agents
world.agents = [Agent() for i in range(num_good_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = False
# add landmarks
world.landmarks = [Landmark() for i in range(num_good_agents+1)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# assign goals to agents
for agent in world.agents:
agent.goal_a = None
agent.goal_b = None
# want other agent to go to the goal landmark
world.agents[0].goal_a = world.agents[1]
world.agents[0].goal_b = np.random.choice(world.landmarks)
world.agents[1].goal_a = world.agents[0]
world.agents[1].goal_b = np.random.choice(world.landmarks)
# random properties for agents
for i, agent in enumerate(world.agents):
agent.color = np.array([0.25,0.25,0.25])
# random properties for landmarks
world.landmarks[0].color = np.array([0.75,0.25,0.25])
world.landmarks[1].color = np.array([0.25,0.75,0.25])
world.landmarks[2].color = np.array([0.25,0.25,0.75])
# special colors for goals
world.agents[0].goal_a.color = world.agents[0].goal_b.color
world.agents[1].goal_a.color = world.agents[1].goal_b.color
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1,+1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-1,+1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def reward(self, agent, world):
if agent.goal_a is None or agent.goal_b is None:
return 0.0
dist2 = np.sum(np.square(agent.goal_a.state.p_pos - agent.goal_b.state.p_pos))
return -dist2
def observation(self, agent, world):
# goal color
goal_color = [np.zeros(world.dim_color), np.zeros(world.dim_color)]
if agent.goal_b is not None:
goal_color[1] = agent.goal_b.color
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# entity colors
entity_color = []
for entity in world.landmarks:
entity_color.append(entity.color)
# communication of all other agents
comm = []
for other in world.agents:
if other is agent: continue
comm.append(other.state.c)
return np.concatenate([agent.state.p_vel] + entity_pos + [goal_color[1]] + comm)
================================================
FILE: examples/Social_Cognition/FOToM/multiagent/scenarios/simple_speaker_listener.py
================================================
import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self):
world = World()
# set any world properties first
world.dim_c = 3
num_landmarks = 3
world.collaborative = True
# add agents
world.agents = [Agent() for i in range(2)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = False
agent.size = 0.075
# speaker
world.agents[0].movable = False
# listener
world.agents[1].silent = True
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
landmark.size = 0.04
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# assign goals to agents
for agent in world.agents:
agent.goal_a = None
agent.goal_b = None
# want listener to go to the goal landmark
world.agents[0].goal_a = world.agents[1]
world.agents[0].goal_b = np.random.choice(world.landmarks)
# random properties for agents
for i, agent in enumerate(world.agents):
agent.color = np.array([0.25,0.25,0.25])
# random properties for landmarks
world.landmarks[0].color = np.array([0.65,0.15,0.15])
world.landmarks[1].color = np.array([0.15,0.65,0.15])
world.landmarks[2].color = np.array([0.15,0.15,0.65])
# special colors for goals
world.agents[0].goal_a.color = world.agents[0].goal_b.color + np.array([0.45, 0.45, 0.45])
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1,+1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-1,+1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
# returns data for benchmarking purposes
return self.reward(agent, reward)
def reward(self, agent, world):
# squared distance from listener to landmark
a = world.agents[0]
dist2 = np.sum(np.square(a.goal_a.state.p_pos - a.goal_b.state.p_pos))
return -dist2
def observation(self, agent, world):
# goal color
goal_color = np.zeros(world.dim_color)
if agent.goal_b is not None:
goal_color = agent.goal_b.color
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# communication of all other agents
comm = []
for other in world.agents:
if other is agent or (other.state.c is None): continue
comm.append(other.state.c)
# speaker
if not agent.movable:
return np.concatenate([goal_color])
# listener
if agent.silent:
return np.concatenate([agent.state.p_vel] + entity_pos + comm)
================================================
FILE: examples/Social_Cognition/FOToM/multiagent/scenarios/simple_spread.py
================================================
import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self, num_good_agents=2, num_adversaries=2):
world = World()
# set any world properties first
world.dim_c = 2
num_agents = num_good_agents
num_landmarks = num_good_agents
world.collaborative = True
# add agents
world.agents = [Agent() for i in range(num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = True
agent.silent = True
agent.size = 0.15
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# random properties for agents
for i, agent in enumerate(world.agents):
agent.color = np.array([0.35, 0.35, 0.85])
# random properties for landmarks
for i, landmark in enumerate(world.landmarks):
landmark.color = np.array([0.25, 0.25, 0.25])
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
rew = 0
collisions = 0
occupied_landmarks = 0
min_dists = 0
for l in world.landmarks:
dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for a in world.agents]
min_dists += min(dists)
rew -= min(dists)
if min(dists) < 0.1:
occupied_landmarks += 1
if agent.collide:
for a in world.agents:
if self.is_collision(a, agent):
rew -= 1
collisions += 1
return (rew, collisions, min_dists, occupied_landmarks)
def is_collision(self, agent1, agent2):
delta_pos = agent1.state.p_pos - agent2.state.p_pos
dist = np.sqrt(np.sum(np.square(delta_pos)))
dist_min = agent1.size + agent2.size
return True if dist < dist_min else False
def reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark, penalized for collisions
rew = 0
shaped_reward = False
if shaped_reward: # distance-based reward
for l in world.landmarks:
dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for a in world.agents]
rew -= min(dists)
if agent.collide:
for a in world.agents:
if self.is_collision(a, agent):
rew -= 1
return rew
else:
win_agents = []
for land in world.landmarks:
for a in world.agents:
if self.is_collision(a, land):
win_agents.append(a)
break
rew += 2 * len(set(win_agents))
def bound(x):
if x > 1.0:
return min(np.exp(2 * x - 2), 10)
else:
return 0.0
bound_rew = 0.0
for p in range(world.dim_p):
x = abs(agent.state.p_pos[p])
bound_rew -= bound(x)
rew += bound_rew
return rew
def observation(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks: # world.entities:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# entity colors
entity_color = []
other_vel = []
for entity in world.landmarks: # world.entities:
entity_color.append(entity.color)
# communication of all other agents
comm = []
other_pos = []
for other in world.agents:
if other is agent:
other_vel.append([0, 0])
continue
comm.append(other.state.c)
other_pos.append(other.state.p_pos - agent.state.p_pos)
return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + comm +
other_vel + entity_pos + other_pos)
================================================
FILE: examples/Social_Cognition/FOToM/multiagent/scenarios/simple_tag.py
================================================
import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self, num_good_agents=1, num_adversaries=3):
world = World()
# set any world properties first
world.dim_c = 2
num_good_agents = num_good_agents
num_adversaries = num_adversaries
num_agents = num_adversaries + num_good_agents
num_landmarks = 2
# add agents
world.agents = [Agent() for i in range(num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = True
agent.silent = True
agent.adversary = True if i < num_adversaries else False
agent.size = 0.075 if agent.adversary else 0.05
agent.accel = 3.0 if agent.adversary else 4.0
#agent.accel = 20.0 if agent.adversary else 25.0
agent.max_speed = 1.0 if agent.adversary else 1.3
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = True
landmark.movable = False
landmark.size = 0.2
landmark.boundary = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# random properties for agents
for i, agent in enumerate(world.agents):
agent.color = np.array([0.35, 0.85, 0.35]) if not agent.adversary else np.array([0.85, 0.35, 0.35])
# random properties for landmarks
for i, landmark in enumerate(world.landmarks):
landmark.color = np.array([0.25, 0.25, 0.25])
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
if not landmark.boundary:
landmark.state.p_pos = np.random.uniform(-0.9, +0.9, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
# returns data for benchmarking purposes
if agent.adversary:
collisions = 0
for a in self.good_agents(world):
if self.is_collision(a, agent):
collisions += 1
return collisions
else:
return 0
def is_collision(self, agent1, agent2):
delta_pos = agent1.state.p_pos - agent2.state.p_pos
dist = np.sqrt(np.sum(np.square(delta_pos)))
dist_min = agent1.size + agent2.size
return True if dist < dist_min else False
# return all agents that are not adversaries
def good_agents(self, world):
return [agent for agent in world.agents if not agent.adversary]
# return all adversarial agents
def adversaries(self, world):
return [agent for agent in world.agents if agent.adversary]
def reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark
main_reward = self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)
return main_reward
def agent_reward(self, agent, world):
# Agents are negatively rewarded if caught by adversaries
rew = 0
shape = False
adversaries = self.adversaries(world)
if shape: # reward can optionally be shaped (increased reward for increased distance from adversary)
for adv in adversaries:
rew += 0.1 * np.sqrt(np.sum(np.square(agent.state.p_pos - adv.state.p_pos)))
if agent.collide:
for a in adversaries:
if self.is_collision(a, agent):
rew -= 10
# agents are penalized for exiting the screen, so that they can be caught by the adversaries
def bound(x):
if x < 0.9:
return 0
if x < 1.0:
return (x - 0.9) * 10
return min(np.exp(2 * x - 2), 10)
for p in range(world.dim_p):
x = abs(agent.state.p_pos[p])
rew -= bound(x)
return rew
def adversary_reward(self, agent, world):
# Adversaries are rewarded for collisions with agents
rew = 0
shape = False
agents = self.good_agents(world)
adversaries = self.adversaries(world)
if shape: # reward can optionally be shaped (decreased reward for increased distance from agents)
for adv in adversaries:
rew -= 0.1 * min([np.sqrt(np.sum(np.square(a.state.p_pos - adv.state.p_pos))) for a in agents])
if agent.collide:
for ag in agents:
for adv in adversaries:
if self.is_collision(ag, adv):
rew += 10
return rew
def observation(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
if not entity.boundary:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# communication of all other agents
comm = []
other_pos = []
other_vel = []
for other in world.agents:
if other is agent and not other.adversary:
other_vel.append(other.state.p_vel)
continue
comm.append(other.state.c)
other_pos.append(other.state.p_pos - agent.state.p_pos)
if not other.adversary:
other_vel.append(other.state.p_vel)
return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + other_vel +
entity_pos + other_pos)
================================================
FILE: examples/Social_Cognition/FOToM/multiagent/scenarios/simple_world_comm.py
================================================
import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self, num_good_agents=2, num_adversaries=4):
world = World()
# set any world properties first
world.dim_c = 4
#world.damping = 1
num_good_agents = num_good_agents
num_adversaries = num_adversaries
num_agents = num_adversaries + num_good_agents
num_landmarks = 1
num_food = 2
num_forests = 2
# add agents
world.agents = [Agent() for i in range(num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = True
agent.leader = True if i == 0 else False
agent.silent = True if i > 0 else False
agent.adversary = True if i < num_adversaries else False
agent.size = 0.075 if agent.adversary else 0.045
agent.accel = 3.0 if agent.adversary else 4.0
#agent.accel = 20.0 if agent.adversary else 25.0
agent.max_speed = 1.0 if agent.adversary else 1.3
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = True
landmark.movable = False
landmark.size = 0.2
landmark.boundary = False
world.food = [Landmark() for i in range(num_food)]
for i, landmark in enumerate(world.food):
landmark.name = 'food %d' % i
landmark.collide = False
landmark.movable = False
landmark.size = 0.03
landmark.boundary = False
world.forests = [Landmark() for i in range(num_forests)]
for i, landmark in enumerate(world.forests):
landmark.name = 'forest %d' % i
landmark.collide = False
landmark.movable = False
landmark.size = 0.3
landmark.boundary = False
world.landmarks += world.food
world.landmarks += world.forests
#world.landmarks += self.set_boundaries(world) # world boundaries now penalized with negative reward
# make initial conditions
self.reset_world(world)
return world
def set_boundaries(self, world):
boundary_list = []
landmark_size = 1
edge = 1 + landmark_size
num_landmarks = int(edge * 2 / landmark_size)
for x_pos in [-edge, edge]:
for i in range(num_landmarks):
l = Landmark()
l.state.p_pos = np.array([x_pos, -1 + i * landmark_size])
boundary_list.append(l)
for y_pos in [-edge, edge]:
for i in range(num_landmarks):
l = Landmark()
l.state.p_pos = np.array([-1 + i * landmark_size, y_pos])
boundary_list.append(l)
for i, l in enumerate(boundary_list):
l.name = 'boundary %d' % i
l.collide = True
l.movable = False
l.boundary = True
l.color = np.array([0.75, 0.75, 0.75])
l.size = landmark_size
l.state.p_vel = np.zeros(world.dim_p)
return boundary_list
def reset_world(self, world):
# random properties for agents
for i, agent in enumerate(world.agents):
agent.color = np.array([0.45, 0.95, 0.45]) if not agent.adversary else np.array([0.95, 0.45, 0.45])
agent.color -= np.array([0.3, 0.3, 0.3]) if agent.leader else np.array([0, 0, 0])
# random properties for landmarks
for i, landmark in enumerate(world.landmarks):
landmark.color = np.array([0.25, 0.25, 0.25])
for i, landmark in enumerate(world.food):
landmark.color = np.array([0.15, 0.15, 0.65])
for i, landmark in enumerate(world.forests):
landmark.color = np.array([0.6, 0.9, 0.6])
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-0.9, +0.9, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
for i, landmark in enumerate(world.food):
landmark.state.p_pos = np.random.uniform(-0.9, +0.9, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
for i, landmark in enumerate(world.forests):
landmark.state.p_pos = np.random.uniform(-0.9, +0.9, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
if agent.adversary:
collisions = 0
for a in self.good_agents(world):
if self.is_collision(a, agent):
collisions += 1
return collisions
else:
return 0
def is_collision(self, agent1, agent2):
delta_pos = agent1.state.p_pos - agent2.state.p_pos
dist = np.sqrt(np.sum(np.square(delta_pos)))
dist_min = agent1.size + agent2.size
return True if dist < dist_min else False
# return all agents that are not adversaries
def good_agents(self, world):
return [agent for agent in world.agents if not agent.adversary]
# return all adversarial agents
def adversaries(self, world):
return [agent for agent in world.agents if agent.adversary]
def reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark
#boundary_reward = -10 if self.outside_boundary(agent) else 0
main_reward = self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)
return main_reward
def outside_boundary(self, agent):
if agent.state.p_pos[0] > 1 or agent.state.p_pos[0] < -1 or agent.state.p_pos[1] > 1 or agent.state.p_pos[1] < -1:
return True
else:
return False
def agent_reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark
rew = 0
shape = False
adversaries = self.adversaries(world)
if shape:
for adv in adversaries:
rew += 0.1 * np.sqrt(np.sum(np.square(agent.state.p_pos - adv.state.p_pos)))
if agent.collide:
for a in adversaries:
if self.is_collision(a, agent):
rew -= 5
def bound(x):
if x < 0.9:
return 0
if x < 1.0:
return (x - 0.9) * 10
return min(np.exp(2 * x - 2), 10) # 1 + (x - 1) * (x - 1)
for p in range(world.dim_p):
x = abs(agent.state.p_pos[p])
rew -= 2 * bound(x)
for food in world.food:
if self.is_collision(agent, food):
rew += 2
rew += 0.05 * min([np.sqrt(np.sum(np.square(food.state.p_pos - agent.state.p_pos))) for food in world.food])
return rew
def adversary_reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark
rew = 0
shape = True
agents = self.good_agents(world)
adversaries = self.adversaries(world)
if shape:
rew -= 0.1 * min([np.sqrt(np.sum(np.square(a.state.p_pos - agent.state.p_pos))) for a in agents])
if agent.collide:
for ag in agents:
for adv in adversaries:
if self.is_collision(ag, adv):
rew += 5
return rew
def observation2(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
if not entity.boundary:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
food_pos = []
for entity in world.food:
if not entity.boundary:
food_pos.append(entity.state.p_pos - agent.state.p_pos)
# communication of all other agents
comm = []
other_pos = []
other_vel = []
for other in world.agents:
if other is agent: continue
comm.append(other.state.c)
other_pos.append(other.state.p_pos - agent.state.p_pos)
if not other.adversary:
other_vel.append(other.state.p_vel)
return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel)
def observation(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
if not entity.boundary:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
in_forest = [np.array([-1]), np.array([-1])]
inf1 = False
inf2 = False
if self.is_collision(agent, world.forests[0]):
in_forest[0] = np.array([1])
inf1= True
if self.is_collision(agent, world.forests[1]):
in_forest[1] = np.array([1])
inf2 = True
food_pos = []
for entity in world.food:
if not entity.boundary:
food_pos.append(entity.state.p_pos - agent.state.p_pos)
# communication of all other agents
comm = []
other_pos = []
other_vel = []
for other in world.agents:
if other is agent and not other.adversary:
other_vel.append(other.state.p_vel) #
continue
comm.append(other.state.c)
oth_f1 = self.is_collision(other, world.forests[0])
oth_f2 = self.is_collision(other, world.forests[1])
if (inf1 and oth_f1) or (inf2 and oth_f2) or \
(not inf1 and not oth_f1 and not inf2 and not oth_f2) or \
agent.leader: #without forest vis
other_pos.append(other.state.p_pos - agent.state.p_pos)
if not other.adversary:
other_vel.append(other.state.p_vel)
else:
other_pos.append([0, 0])
if not other.adversary:
other_vel.append([0, 0])
# to tell the pred when the prey are in the forest
prey_forest = []
ga = self.good_agents(world)
for a in ga:
if any([self.is_collision(a, f) for f in world.forests]):
prey_forest.append(np.array([1]))
else:
prey_forest.append(np.array([-1]))
# to tell leader when pred are in forest
prey_forest_lead = []
for f in world.forests:
if any([self.is_collision(a, f) for a in ga]):
prey_forest_lead.append(np.array([1]))
else:
prey_forest_lead.append(np.array([-1]))
comm = [world.agents[0].state.c]
if agent.adversary and not agent.leader:
return np.concatenate(in_forest + comm +
[agent.state.p_vel] + [agent.state.p_pos] + other_vel +
entity_pos + other_pos)
if agent.leader:
return np.concatenate(in_forest + comm +
[agent.state.p_vel] + [agent.state.p_pos] + other_vel +
entity_pos + other_pos)
else:
return np.concatenate(in_forest + [np.zeros_like(world.agents[0].state.c)] +
[agent.state.p_vel] + [agent.state.p_pos] + other_vel +
entity_pos + other_pos)
================================================
FILE: examples/Social_Cognition/FOToM/readme.md
================================================
# Readme
This project is for FOToM.
================================================
FILE: examples/Social_Cognition/FOToM/utils/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/FOToM/utils/agents.py
================================================
import torch
from torch import Tensor
from torch.autograd import Variable
from torch.optim import Adam
from .networks import MLPNetwork, RNN, SNNNetwork, LSTMClassifier
from .misc import hard_update, gumbel_softmax, onehot_from_logits
from .noise import OUNoise
import time
class DDPGAgent(object):
"""
General class for DDPG agents (policy, critic, target policy, target
critic, exploration noise)
"""
def __init__(self, num_in_pol, num_out_pol, num_in_critic, hidden_dim=64,
lr=0.01, discrete_action=True):
"""
Inputs:
num_in_pol (int): number of dimensions for policy input
num_out_pol (int): number of dimensions for policy output
num_in_critic (int): number of dimensions for critic input
"""
self.policy = LSTMClassifier(num_in_pol, num_out_pol,#MLPNetwork
hidden_dim,)
# constrain_out=True,
# discrete_action=discrete_action)
self.critic = LSTMClassifier(num_in_critic, 1,
hidden_dim,)
# constrain_out=False)
self.target_policy = LSTMClassifier(num_in_pol, num_out_pol,
hidden_dim,)
# constrain_out=True,
# discrete_action=discrete_action)
self.target_critic = LSTMClassifier(num_in_critic, 1,
hidden_dim,)
# constrain_out=False)
hard_update(self.target_policy, self.policy)
hard_update(self.target_critic, self.critic)
self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)
self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)
if not discrete_action:
self.exploration = OUNoise(num_out_pol)
else:
self.exploration = 0.3 # epsilon for eps-greedy
self.discrete_action = discrete_action
def reset_noise(self):
if not self.discrete_action:
self.exploration.reset()
def scale_noise(self, scale):
if self.discrete_action:
self.exploration = scale
else:
self.exploration.scale = scale
def step(self, obs, explore=False):
"""
Take a step forward in environment for a minibatch of observations
Inputs:
obs (PyTorch Variable): Observations for this agent
explore (boolean): Whether or not to add exploration noise
Outputs:
action (PyTorch Variable): Actions for this agent
"""
action = self.policy(obs)
if self.discrete_action:
if explore:
if action.shape[1] == 9:
action = torch.cat(
(gumbel_softmax(action[:, :5], hard=True), gumbel_softmax(action[:, 5:], hard=True)), 1)
else:
action = gumbel_softmax(action, hard=True)
else:
if action.shape[1] == 9:
action = torch.cat(
(onehot_from_logits(action[:, :5]), onehot_from_logits(action[:, 5:])), 1)
else:
action = onehot_from_logits(action)
else: # continuous action
if explore:
action += Variable(Tensor(self.exploration.noise()),
requires_grad=False)
action = action.clamp(-1, 1)
return action
def get_params(self):
return {'policy': self.policy.state_dict(),
'critic': self.critic.state_dict(),
'target_policy': self.target_policy.state_dict(),
'target_critic': self.target_critic.state_dict(),
'policy_optimizer': self.policy_optimizer.state_dict(),
'critic_optimizer': self.critic_optimizer.state_dict()}
def load_params(self, params):
self.policy.load_state_dict(params['policy'])
self.critic.load_state_dict(params['critic'])
self.target_policy.load_state_dict(params['target_policy'])
self.target_critic.load_state_dict(params['target_critic'])
self.policy_optimizer.load_state_dict(params['policy_optimizer'])
self.critic_optimizer.load_state_dict(params['critic_optimizer'])
class DDPGAgent_RNN(object):
"""
General class for DDPG agents (policy, critic, target policy, target
critic, exploration noise)
"""
def __init__(self, num_in_pol, num_out_pol, num_in_critic, hidden_dim=64,
lr=0.01, discrete_action=True):
"""
Inputs:
num_in_pol (int): number of dimensions for policy input
num_out_pol (int): number of dimensions for policy output
num_in_critic (int): number of dimensions for critic input
"""
self.policy = RNN(num_in_pol, num_out_pol,
hidden_dim=hidden_dim,
constrain_out=True,
discrete_action=discrete_action)
self.critic = RNN(num_in_critic, 1,
hidden_dim=hidden_dim,
constrain_out=False)
self.target_policy = RNN(num_in_pol, num_out_pol,
hidden_dim=hidden_dim,
constrain_out=True,
discrete_action=discrete_action)
self.target_critic = RNN(num_in_critic, 1,
hidden_dim=hidden_dim,
constrain_out=False)
self.policy_hidden = None
self.policy_target_hidden = None
self.critic_hidden = None
self.critic_target_hidden = None
self.num_in_pol = num_in_pol
self.num_out_pol = num_out_pol
self.hidden_dim = hidden_dim
hard_update(self.target_policy, self.policy)
hard_update(self.target_critic, self.critic)
self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)
self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)
if not discrete_action:
self.exploration = OUNoise(num_out_pol)
else:
self.exploration = 0.3 # epsilon for eps-greedy
self.discrete_action = discrete_action
def reset_noise(self):
if not self.discrete_action:
self.exploration.reset()
def scale_noise(self, scale):
if self.discrete_action:
self.exploration = scale
else:
self.exploration.scale = scale
def step(self, obs, explore=False):
"""
Take a step forward in environment for a minibatch of observations
Inputs:
obs (PyTorch Variable): Observations for this agent
explore (boolean): Whether or not to add exploration noise
Outputs:
action (PyTorch Variable): Actions for this agent
"""
action, self.policy_hidden = self.policy(obs, self.policy_hidden)
if self.discrete_action:
if explore:
action = gumbel_softmax(action, hard=True)
else:
action = onehot_from_logits(action)
else: # continuous action
if explore:
action += Variable(Tensor(self.exploration.noise()),
requires_grad=False)
action = action.clamp(-1, 1)
return action
def get_params(self):
return {'policy': self.policy.state_dict(),
'critic': self.critic.state_dict(),
'target_policy': self.target_policy.state_dict(),
'target_critic': self.target_critic.state_dict(),
'policy_optimizer': self.policy_optimizer.state_dict(),
'critic_optimizer': self.critic_optimizer.state_dict()}
def load_params(self, params):
self.policy.load_state_dict(params['policy'])
self.critic.load_state_dict(params['critic'])
self.target_policy.load_state_dict(params['target_policy'])
self.target_critic.load_state_dict(params['target_critic'])
self.policy_optimizer.load_state_dict(params['policy_optimizer'])
self.critic_optimizer.load_state_dict(params['critic_optimizer'])
def init_hidden(self, len_ep, policy_hidden=False, policy_target_hidden=False, \
critic_hidden=False, critic_target_hidden=False):
# 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden
if policy_hidden == True:
self.policy_hidden = torch.zeros((len_ep, self.hidden_dim))
if policy_target_hidden == True:
self.policy_target_hidden = torch.zeros((len_ep, self.hidden_dim))
if critic_hidden == True:
self.critic_hidden = torch.zeros((len_ep, self.hidden_dim))
if critic_target_hidden == True:
self.critic_target_hidden = torch.zeros((len_ep, self.hidden_dim))
class DDPGAgent_SNN(object):
"""
General class for DDPG agents (policy, critic, target policy, target
critic, exploration noise)
"""
def __init__(self, num_in_pol, num_out_pol, num_in_critic, output_style, hidden_dim=64,
lr=0.01, discrete_action=True):
"""
Inputs:
num_in_pol (int): number of dimensions for policy input
num_out_pol (int): number of dimensions for policy output
num_in_critic (int): number of dimensions for critic input
"""
self.policy = SNNNetwork(num_in_pol, num_out_pol,
hidden_dim=hidden_dim,
output_style=output_style)
self.critic = SNNNetwork(num_in_critic, 1,
hidden_dim=hidden_dim,
output_style=output_style)
self.target_policy = SNNNetwork(num_in_pol, num_out_pol,
hidden_dim=hidden_dim,
output_style=output_style)
self.target_critic = SNNNetwork(num_in_critic, 1,
hidden_dim=hidden_dim,
output_style=output_style)
hard_update(self.target_policy, self.policy)
hard_update(self.target_critic, self.critic)
self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)
self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)
if not discrete_action:
self.exploration = OUNoise(num_out_pol)
else:
self.exploration = 0.3 # epsilon for eps-greedy
self.discrete_action = discrete_action
def reset_noise(self):
if not self.discrete_action:
self.exploration.reset()
def scale_noise(self, scale):
if self.discrete_action:
self.exploration = scale
else:
self.exploration.scale = scale
def step(self, obs, explore=False):
"""
Take a step forward in environment for a minibatch of observations
Inputs:
obs (PyTorch Variable): Observations for this agent
explore (boolean): Whether or not to add exploration noise
Outputs:
action (PyTorch Variable): Actions for this agent
"""
# t1 = time.time()
action = self.policy(obs)
# t2 = time.time()
# print('time_interaction:', t2 - t1)
if self.discrete_action:
if explore:
if action.shape[1] == 9:
action = torch.cat(
(gumbel_softmax(action[:, :5], hard=True), gumbel_softmax(action[:, 5:], hard=True)), 1)
else:
action = gumbel_softmax(action, hard=True)
else:
if action.shape[1] == 9:
action = torch.cat(
(onehot_from_logits(action[:, :5]), onehot_from_logits(action[:, 5:])), 1)
else:
action = onehot_from_logits(action)
# if explore:
#
# action = gumbel_softmax(action, hard=True)
#
# else:
# action = onehot_from_logits(action)
else: # continuous action
if explore:
action += Variable(Tensor(self.exploration.noise()),
requires_grad=False)
action = action.clamp(-1, 1)
return action
def get_params(self):
return {'policy': self.policy.state_dict(),
'critic': self.critic.state_dict(),
'target_policy': self.target_policy.state_dict(),
'target_critic': self.target_critic.state_dict(),
'policy_optimizer': self.policy_optimizer.state_dict(),
'critic_optimizer': self.critic_optimizer.state_dict()}
def load_params(self, params):
self.policy.load_state_dict(params['policy'])
self.critic.load_state_dict(params['critic'])
self.target_policy.load_state_dict(params['target_policy'])
self.target_critic.load_state_dict(params['target_critic'])
self.policy_optimizer.load_state_dict(params['policy_optimizer'])
self.critic_optimizer.load_state_dict(params['critic_optimizer'])
class DDPGAgent_ToM(object):
"""
General class for DDPG agents (policy, critic, target policy, target
critic, exploration noise)
"""
def __init__(self, num_in_pol, num_out_pol, num_in_critic, num_in_mle, output_style,
num_agents, device, hidden_dim=64, lr=0.01, discrete_action=True):
"""
Inputs:
num_in_pol (int): number of dimensions for policy input
num_out_pol (int): number of dimensions for policy output
num_in_critic (int): number of dimensions for critic input
"""
self.device = device
self.policy = LSTMClassifier(num_in_pol, num_out_pol,hidden_dim) #SNNNetwork
# hidden_dim=hidden_dim,
# output_style=output_style)
self.critic = LSTMClassifier(num_in_critic, 1,hidden_dim)
# hidden_dim=hidden_dim,
# output_style=output_style)
self.target_policy = LSTMClassifier(num_in_pol, num_out_pol,hidden_dim)
# hidden_dim=hidden_dim,
# output_style=output_style)
self.target_critic = LSTMClassifier(num_in_critic, 1,hidden_dim)
# hidden_dim=hidden_dim,
# output_style=output_style)
# self.policy = SNNNetwork(num_in_pol, num_out_pol,
# hidden_dim=hidden_dim,
# output_style=output_style)
# self.critic = SNNNetwork(num_in_critic, 1,
# hidden_dim=hidden_dim,
# output_style=output_style)
# self.target_policy = SNNNetwork(num_in_pol, num_out_pol,
# hidden_dim=hidden_dim,
# output_style=output_style)
# self.target_critic = SNNNetwork(num_in_critic, 1,
# hidden_dim=hidden_dim,
# output_style=output_style)
# self.mle = [SNNNetwork(num_in_mle, num_out_pol,
# hidden_dim=hidden_dim,
# output_style=output_style)] * (num_agents - 1)
self.mle = []
hard_update(self.target_policy, self.policy)
hard_update(self.target_critic, self.critic)
self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)
self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)
self.mle_optimizer = []
if not discrete_action:
self.exploration = OUNoise(num_out_pol)
else:
self.exploration = 0.3 # epsilon for eps-greedy
self.discrete_action = discrete_action
def reset_noise(self):
if not self.discrete_action:
self.exploration.reset()
def scale_noise(self, scale):
if self.discrete_action:
self.exploration = scale
else:
self.exploration.scale = scale
def step(self, obs, explore=False):
"""
Take a step forward in environment for a minibatch of observations
Inputs:
obs (PyTorch Variable): Observations for this agent
explore (boolean): Whether or not to add exploration noise
Outputs:
action (PyTorch Variable): Actions for this agent
"""
action = self.policy.to(self.device)(obs.to(self.device))
if self.discrete_action:
if explore:
if action.shape[1] == 9:
action = torch.cat(
(gumbel_softmax(action[:, :5], hard=True), gumbel_softmax(action[:, 5:], hard=True)), 1).cpu()
else:
action = gumbel_softmax(action, hard=True).cpu()
else:
if action.shape[1] == 9:
action = torch.cat(
(onehot_from_logits(action[:, :5], hard=True), onehot_from_logits(action[:, 5:], hard=True)), 1)
else:
action = onehot_from_logits(action).cpu()
# if explore:
# action = gumbel_softmax(action, hard=True).cpu()
# else:
# action = onehot_from_logits(action).cpu()
else: # continuous action
if explore:
action += Variable(Tensor(self.exploration.noise()),
requires_grad=False)
action = action.clamp(-1, 1)
return action
def get_params(self):
params = {'policy': self.policy.state_dict(),
'critic': self.critic.state_dict(),
'target_policy': self.target_policy.state_dict(),
'target_critic': self.target_critic.state_dict(),
'policy_optimizer': self.policy_optimizer.state_dict(),
'critic_optimizer': self.critic_optimizer.state_dict(),
}
# for i in range(len(self.mle)):
# params['mle%d'%i] = self.mle[i].state_dict()
# params['mle_optimizer%d'%i] = self.mle_optimizer[i].state_dict()
return params
def load_params(self, params):
self.policy.load_state_dict(params['policy'])
self.critic.load_state_dict(params['critic'])
self.target_policy.load_state_dict(params['target_policy'])
self.target_critic.load_state_dict(params['target_critic'])
self.policy_optimizer.load_state_dict(params['policy_optimizer'])
self.critic_optimizer.load_state_dict(params['critic_optimizer'])
# for i in range(len(self.mle)):
# self.mle[i].load_state_dict(params['mle%d'%i])
# self.mle_optimizer[i].load_state_dict(params['mle_optimizer%d'%i])
class rDDPGAgent_ToM(object):
"""
General class for DDPG agents (policy, critic, target policy, target
critic, exploration noise)
"""
def __init__(self, num_in_pol, num_out_pol, num_in_critic, num_in_mle, output_style,
num_agents, device, hidden_dim=64, lr=0.01, discrete_action=True):
"""
Inputs:
num_in_pol (int): number of dimensions for policy input
num_out_pol (int): number of dimensions for policy output
num_in_critic (int): number of dimensions for critic input
"""
self.device = device
self.policy = RNN(num_in_pol, num_out_pol,
hidden_dim=hidden_dim,
constrain_out=True,
discrete_action=discrete_action)
self.critic = RNN(num_in_critic, 1,
hidden_dim=hidden_dim,
constrain_out=True,
discrete_action=discrete_action)
self.target_policy = RNN(num_in_pol, num_out_pol,
hidden_dim=hidden_dim,
constrain_out=True,
discrete_action=discrete_action)
self.target_critic = RNN(num_in_critic, 1,
hidden_dim=hidden_dim,
constrain_out=True,
discrete_action=discrete_action)
# self.mle = [SNNNetwork(num_in_mle, num_out_pol,
# hidden_dim=hidden_dim,
# output_style=output_style)] * (num_agents - 1)
self.mle = []
self.policy_hidden = None
self.policy_target_hidden = None
self.critic_hidden = None
self.critic_target_hidden = None
self.hidden_dim = hidden_dim
hard_update(self.target_policy, self.policy)
hard_update(self.target_critic, self.critic)
self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)
self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)
self.mle_optimizer = []
if not discrete_action:
self.exploration = OUNoise(num_out_pol)
else:
self.exploration = 0.3 # epsilon for eps-greedy
self.discrete_action = discrete_action
def reset_noise(self):
if not self.discrete_action:
self.exploration.reset()
def scale_noise(self, scale):
if self.discrete_action:
self.exploration = scale
else:
self.exploration.scale = scale
def step(self, obs, explore=False):
"""
Take a step forward in environment for a minibatch of observations
Inputs:
obs (PyTorch Variable): Observations for this agent
explore (boolean): Whether or not to add exploration noise
Outputs:
action (PyTorch Variable): Actions for this agent
"""
action, self.policy_hidden = self.policy(obs, self.policy_hidden)
if self.discrete_action:
if explore:
if action.shape[1] == 9:
action = torch.cat(
(gumbel_softmax(action[:, :5], hard=True), gumbel_softmax(action[:, 5:], hard=True)), 1).cpu()
else:
action = gumbel_softmax(action, hard=True).cpu()
else:
if action.shape[1] == 9:
action = torch.cat(
(onehot_from_logits(action[:, :5], hard=True), onehot_from_logits(action[:, 5:], hard=True)), 1)
else:
action = onehot_from_logits(action).cpu()
# if explore:
# action = gumbel_softmax(action, hard=True).cpu()
# else:
# action = onehot_from_logits(action).cpu()
else: # continuous action
if explore:
action += Variable(Tensor(self.exploration.noise()),
requires_grad=False)
action = action.clamp(-1, 1)
return action
def get_params(self):
params = {'policy': self.policy.state_dict(),
'critic': self.critic.state_dict(),
'target_policy': self.target_policy.state_dict(),
'target_critic': self.target_critic.state_dict(),
'policy_optimizer': self.policy_optimizer.state_dict(),
'critic_optimizer': self.critic_optimizer.state_dict(),
}
# for i in range(len(self.mle)):
# params['mle%d'%i] = self.mle[i].state_dict()
# params['mle_optimizer%d'%i] = self.mle_optimizer[i].state_dict()
return params
def load_params(self, params):
self.policy.load_state_dict(params['policy'])
self.critic.load_state_dict(params['critic'])
self.target_policy.load_state_dict(params['target_policy'])
self.target_critic.load_state_dict(params['target_critic'])
self.policy_optimizer.load_state_dict(params['policy_optimizer'])
self.critic_optimizer.load_state_dict(params['critic_optimizer'])
# for i in range(len(self.mle)):
# self.mle[i].load_state_dict(params['mle%d'%i])
# self.mle_optimizer[i].load_state_dict(params['mle_optimizer%d'%i])
def init_hidden(self, len_ep, policy_hidden=False, policy_target_hidden=False, \
critic_hidden=False, critic_target_hidden=False):
# 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden
if policy_hidden == True:
self.policy_hidden = torch.zeros((len_ep, self.hidden_dim))
if policy_target_hidden == True:
self.policy_target_hidden = torch.zeros((len_ep, self.hidden_dim))
if critic_hidden == True:
self.critic_hidden = torch.zeros((len_ep, self.hidden_dim))
if critic_target_hidden == True:
self.critic_target_hidden = torch.zeros((len_ep, self.hidden_dim))
class lDDPGAgent(object):
"""
General class for DDPG agents (policy, critic, target policy, target
critic, exploration noise)
"""
def __init__(self, num_in_pol, num_out_pol, num_in_critic, hidden_dim=64,
lr=0.01, discrete_action=True):
"""
Inputs:
num_in_pol (int): number of dimensions for policy input
num_out_pol (int): number of dimensions for policy output
num_in_critic (int): number of dimensions for critic input
"""
self.policy = LSTMClassifier(num_in_pol, num_out_pol,
hidden_dim=hidden_dim,
constrain_out=True,
discrete_action=discrete_action)
self.critic = LSTMClassifier(num_in_critic, 1,
hidden_dim=hidden_dim,
constrain_out=False)
self.target_policy = LSTMClassifier(num_in_pol, num_out_pol,
hidden_dim=hidden_dim,
constrain_out=True,
discrete_action=discrete_action)
self.target_critic = LSTMClassifier(num_in_critic, 1,
hidden_dim=hidden_dim,
constrain_out=False)
hard_update(self.target_policy, self.policy)
hard_update(self.target_critic, self.critic)
self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)
self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)
if not discrete_action:
self.exploration = OUNoise(num_out_pol)
else:
self.exploration = 0.3 # epsilon for eps-greedy
self.discrete_action = discrete_action
def reset_noise(self):
if not self.discrete_action:
self.exploration.reset()
def scale_noise(self, scale):
if self.discrete_action:
self.exploration = scale
else:
self.exploration.scale = scale
def step(self, obs, explore=False):
"""
Take a step forward in environment for a minibatch of observations
Inputs:
obs (PyTorch Variable): Observations for this agent
explore (boolean): Whether or not to add exploration noise
Outputs:
action (PyTorch Variable): Actions for this agent
"""
action = self.policy(obs)
if self.discrete_action:
if explore:
if action.shape[1] == 9:
action = torch.cat(
(gumbel_softmax(action[:, :5], hard=True), gumbel_softmax(action[:, 5:], hard=True)), 1)
else:
action = gumbel_softmax(action, hard=True)
else:
if action.shape[1] == 9:
action = torch.cat(
(onehot_from_logits(action[:, :5]), onehot_from_logits(action[:, 5:])), 1)
else:
action = onehot_from_logits(action)
else: # continuous action
if explore:
action += Variable(Tensor(self.exploration.noise()),
requires_grad=False)
action = action.clamp(-1, 1)
return action
def get_params(self):
return {'policy': self.policy.state_dict(),
'critic': self.critic.state_dict(),
'target_policy': self.target_policy.state_dict(),
'target_critic': self.target_critic.state_dict(),
'policy_optimizer': self.policy_optimizer.state_dict(),
'critic_optimizer': self.critic_optimizer.state_dict()}
def load_params(self, params):
self.policy.load_state_dict(params['policy'])
self.critic.load_state_dict(params['critic'])
self.target_policy.load_state_dict(params['target_policy'])
self.target_critic.load_state_dict(params['target_critic'])
self.policy_optimizer.load_state_dict(params['policy_optimizer'])
self.critic_optimizer.load_state_dict(params['critic_optimizer'])
================================================
FILE: examples/Social_Cognition/FOToM/utils/buffer.py
================================================
import numpy as np
import torch
from torch import Tensor
from torch.autograd import Variable
class ReplayBuffer(object):
"""
Replay Buffer for multi-agent RL with parallel rollouts
"""
def __init__(self, max_steps, num_agents, obs_dims, ac_dims, device):
"""
Inputs:
max_steps (int): Maximum number of timepoints to store in buffer
num_agents (int): Number of agents in environment
obs_dims (list of ints): number of obervation dimensions for each
agent
ac_dims (list of ints): number of action dimensions for each agent
"""
self.device = device
self.max_steps = max_steps
self.num_agents = num_agents
self.obs_buffs = []
self.ac_buffs = []
self.rew_buffs = []
self.next_obs_buffs = []
self.done_buffs = []
for odim, adim in zip(obs_dims, ac_dims):
self.obs_buffs.append(np.zeros((max_steps, odim)))
self.ac_buffs.append(np.zeros((max_steps, adim)))
self.rew_buffs.append(np.zeros(max_steps))
self.next_obs_buffs.append(np.zeros((max_steps, odim)))
self.done_buffs.append(np.zeros(max_steps))
self.filled_i = 0 # index of first empty location in buffer (last index when full)
self.curr_i = 0 # current index to write to (ovewrite oldest data)
def __len__(self):
return self.filled_i
def push(self, observations, actions, rewards, next_observations, dones):
nentries = observations.shape[0] # handle multiple parallel environments
if self.curr_i + nentries > self.max_steps:
rollover = self.max_steps - self.curr_i # num of indices to roll over
for agent_i in range(self.num_agents):
self.obs_buffs[agent_i] = np.roll(self.obs_buffs[agent_i],
rollover, axis=0)
self.ac_buffs[agent_i] = np.roll(self.ac_buffs[agent_i],
rollover, axis=0)
self.rew_buffs[agent_i] = np.roll(self.rew_buffs[agent_i],
rollover)
self.next_obs_buffs[agent_i] = np.roll(
self.next_obs_buffs[agent_i], rollover, axis=0)
self.done_buffs[agent_i] = np.roll(self.done_buffs[agent_i],
rollover)
self.curr_i = 0
self.filled_i = self.max_steps
for agent_i in range(self.num_agents):
self.obs_buffs[agent_i][self.curr_i:self.curr_i + nentries] = np.vstack(
observations[:, agent_i])
# actions are already batched by agent, so they are indexed differently
self.ac_buffs[agent_i][self.curr_i:self.curr_i + nentries] = actions[agent_i]
self.rew_buffs[agent_i][self.curr_i:self.curr_i + nentries] = rewards[:, agent_i]
self.next_obs_buffs[agent_i][self.curr_i:self.curr_i + nentries] = np.vstack(
next_observations[:, agent_i])
self.done_buffs[agent_i][self.curr_i:self.curr_i + nentries] = dones[:, agent_i]
self.curr_i += nentries
if self.filled_i < self.max_steps:
self.filled_i += nentries
if self.curr_i == self.max_steps:
self.curr_i = 0
def sample(self, N, to_gpu=False, norm_rews=True):
inds = np.random.choice(np.arange(self.filled_i), size=N,
replace=False)
if to_gpu:
cast = lambda x: Variable(Tensor(x), requires_grad=False).to(torch.device(self.device))
else:
cast = lambda x: Variable(Tensor(x), requires_grad=False)
if norm_rews:
ret_rews = [cast((self.rew_buffs[i][inds] -
self.rew_buffs[i][:self.filled_i].mean()) /
self.rew_buffs[i][:self.filled_i].std())
for i in range(self.num_agents)]
else:
ret_rews = [cast(self.rew_buffs[i][inds]) for i in range(self.num_agents)]
return ([cast(self.obs_buffs[i][inds]) for i in range(self.num_agents)],
[cast(self.ac_buffs[i][inds]) for i in range(self.num_agents)],
ret_rews,
[cast(self.next_obs_buffs[i][inds]) for i in range(self.num_agents)],
[cast(self.done_buffs[i][inds]) for i in range(self.num_agents)])
def get_average_rewards(self, N):
if self.filled_i == self.max_steps:
inds = np.arange(self.curr_i - N, self.curr_i) # allow for negative indexing
else:
inds = np.arange(max(0, self.curr_i - N), self.curr_i)
return [self.rew_buffs[i][inds].mean() for i in range(self.num_agents)]
class ReplayBuffer_pre(object):
"""
Replay Buffer for multi-agent RL with parallel rollouts
"""
def __init__(self, max_steps, num_agents, obs_dims, ac_dims, device):
"""
Inputs:
max_steps (int): Maximum number of timepoints to store in buffer
num_agents (int): Number of agents in environment
obs_dims (list of ints): number of obervation dimensions for each
agent
ac_dims (list of ints): number of action dimensions for each agent
"""
self.device = device
self.max_steps = max_steps
self.num_agents = num_agents
self.ac_pre_buffs = []
self.obs_buffs = []
self.ac_buffs = []
self.rew_buffs = []
self.next_obs_buffs = []
self.done_buffs = []
for odim, adim in zip(obs_dims, ac_dims):
self.ac_pre_buffs.append(np.zeros((max_steps, 5)))
self.obs_buffs.append(np.zeros((max_steps, odim)))
self.ac_buffs.append(np.zeros((max_steps, adim)))
self.rew_buffs.append(np.zeros(max_steps))
self.next_obs_buffs.append(np.zeros((max_steps, odim)))
self.done_buffs.append(np.zeros(max_steps))
self.filled_i = 0 # index of first empty location in buffer (last index when full)
self.curr_i = 0 # current index to write to (ovewrite oldest data)
def __len__(self):
return self.filled_i
def push(self, actions_pre, observations, actions, rewards, next_observations, dones):
nentries = observations.shape[0] # handle multiple parallel environments
if self.curr_i + nentries > self.max_steps:
rollover = self.max_steps - self.curr_i # num of indices to roll over
for agent_i in range(self.num_agents):
self.ac_pre_buffs[agent_i] = np.roll(self.ac_pre_buffs[agent_i][:,:5],
rollover, axis=0)
self.obs_buffs[agent_i] = np.roll(self.obs_buffs[agent_i],
rollover, axis=0)
self.ac_buffs[agent_i] = np.roll(self.ac_buffs[agent_i],
rollover, axis=0)
self.rew_buffs[agent_i] = np.roll(self.rew_buffs[agent_i],
rollover)
self.next_obs_buffs[agent_i] = np.roll(
self.next_obs_buffs[agent_i], rollover, axis=0)
self.done_buffs[agent_i] = np.roll(self.done_buffs[agent_i],
rollover)
self.curr_i = 0
self.filled_i = self.max_steps
for agent_i in range(self.num_agents):
self.ac_pre_buffs[agent_i][self.curr_i:self.curr_i + nentries] = actions_pre[agent_i][:,:5]
self.obs_buffs[agent_i][self.curr_i:self.curr_i + nentries] = np.vstack(
observations[:, agent_i])
# actions are already batched by agent, so they are indexed differently
self.ac_buffs[agent_i][self.curr_i:self.curr_i + nentries] = actions[agent_i]
self.rew_buffs[agent_i][self.curr_i:self.curr_i + nentries] = rewards[:, agent_i]
self.next_obs_buffs[agent_i][self.curr_i:self.curr_i + nentries] = np.vstack(
next_observations[:, agent_i])
self.done_buffs[agent_i][self.curr_i:self.curr_i + nentries] = dones[:, agent_i]
self.curr_i += nentries
if self.filled_i < self.max_steps:
self.filled_i += nentries
if self.curr_i == self.max_steps:
self.curr_i = 0
def sample(self, N, to_gpu=False, norm_rews=True):
inds = np.random.choice(np.arange(self.filled_i), size=N,
replace=False)
# inds = np.arange(self.filled_i)[0:-1:self.filled_i//N]
if to_gpu:
cast = lambda x: Variable(Tensor(x), requires_grad=False).to(torch.device(self.device))
else:
cast = lambda x: Variable(Tensor(x), requires_grad=False)
if self.rew_buffs[0].sum() == False:
norm_rews = False
if norm_rews:
ret_rews = [cast((self.rew_buffs[i][inds] -
self.rew_buffs[i][:self.filled_i].mean()) /
self.rew_buffs[i][:self.filled_i].std())
for i in range(self.num_agents)]
else:
ret_rews = [cast(self.rew_buffs[i][inds]) for i in range(self.num_agents)]
return ([cast(self.ac_pre_buffs[i][inds]) for i in range(self.num_agents)],
[cast(self.obs_buffs[i][inds]) for i in range(self.num_agents)],
[cast(self.ac_buffs[i][inds]) for i in range(self.num_agents)],
ret_rews,
[cast(self.next_obs_buffs[i][inds]) for i in range(self.num_agents)],
[cast(self.done_buffs[i][inds]) for i in range(self.num_agents)])
def get_average_rewards(self, N):
if self.filled_i == self.max_steps:
inds = np.arange(self.curr_i - N, self.curr_i) # allow for negative indexing
else:
inds = np.arange(max(0, self.curr_i - N), self.curr_i)
return [self.rew_buffs[i][inds].mean() for i in range(self.num_agents)]
class ReplayBuffer_RNN(object):
"""
Replay Buffer for multi-agent RL with parallel rollouts
"""
def __init__(self, max_steps, num_agents, obs_dims, ac_dims, ep_dims, device):
"""
Inputs:
max_steps (int): Maximum number of timepoints to store in buffer
num_agents (int): Number of agents in environment
obs_dims (list of ints): number of obervation dimensions for each
agent
ac_dims (list of ints): number of action dimensions for each agent
ep_dims (int): Number of steps in each episode
"""
self.device = device
self.max_steps = max_steps
self.num_agents = num_agents
self.obs_buffs = []
self.ac_buffs = []
self.rew_buffs = []
self.next_obs_buffs = []
self.done_buffs = []
for odim, adim in zip(obs_dims, ac_dims):
self.obs_buffs.append(np.zeros((max_steps, ep_dims, odim)))
self.ac_buffs.append(np.zeros((max_steps, ep_dims, adim)))
self.rew_buffs.append(np.zeros((max_steps, ep_dims)))
self.next_obs_buffs.append(np.zeros((max_steps, ep_dims, odim)))
self.done_buffs.append(np.zeros((max_steps, ep_dims)))
self.filled_i = 0 # index of first empty location in buffer (last index when full)
self.curr_i = 0 # current index to write to (ovewrite oldest data)
def __len__(self):
return self.filled_i
def push(self, observations_ep, actions_ep, rewards_ep, next_observations_ep, dones_ep):
nentries = observations_ep[0].shape[0] # handle multiple parallel environments
observations_ep, actions_ep, rewards_ep, next_observations_ep, dones_ep = \
np.array(observations_ep), np.array(actions_ep), np.array(rewards_ep),\
np.array(next_observations_ep), np.array(dones_ep)
if self.curr_i + nentries > self.max_steps:
rollover = self.max_steps - self.curr_i # num of indices to roll over
for agent_i in range(self.num_agents):
self.obs_buffs[agent_i] = np.roll(self.obs_buffs[agent_i],
rollover, axis=0)
self.ac_buffs[agent_i] = np.roll(self.ac_buffs[agent_i],
rollover, axis=0)
self.rew_buffs[agent_i] = np.roll(self.rew_buffs[agent_i],
rollover)
self.next_obs_buffs[agent_i] = np.roll(
self.next_obs_buffs[agent_i], rollover, axis=0)
self.done_buffs[agent_i] = np.roll(self.done_buffs[agent_i],
rollover)
self.curr_i = 0
self.filled_i = self.max_steps
for agent_i in range(self.num_agents):
for i in range(observations_ep[:,:,agent_i].shape[0]):
if i == 0:
ob_ep = np.expand_dims(np.vstack(observations_ep[:,:,agent_i][i]), 0)
ob_next_ep = np.expand_dims(np.vstack(next_observations_ep[:,:,agent_i][i]), 0)
else:
ob_ep = np.vstack((ob_ep, np.expand_dims(np.vstack(observations_ep[:,:,agent_i][i]), 0)))
ob_next_ep = np.vstack((ob_next_ep, np.expand_dims(np.vstack(next_observations_ep[:,:,agent_i][i]), 0)))
self.obs_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = ob_ep.transpose(1, 0, 2)
# actions are already batched by agent, so they are indexed differently
self.ac_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = actions_ep[:,:,0,:].transpose(1, 0, 2)
self.rew_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = rewards_ep[:, :, agent_i].transpose(1, 0)
self.next_obs_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = ob_next_ep.transpose(1, 0, 2)
self.done_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = dones_ep[:, :, agent_i].transpose(1, 0)
self.curr_i += nentries
if self.filled_i < self.max_steps:
self.filled_i += nentries
if self.curr_i == self.max_steps:
self.curr_i = 0
def sample(self, N, to_gpu=False, norm_rews=True):
inds = np.random.choice(np.arange(self.filled_i), size=N,
replace=False)
if to_gpu:
cast = lambda x: Variable(Tensor(x), requires_grad=False).to(torch.device(self.device))
else:
cast = lambda x: Variable(Tensor(x), requires_grad=False)
if norm_rews:
ret_rews = [cast((self.rew_buffs[i][inds] -
self.rew_buffs[i][:self.filled_i].mean()) /
self.rew_buffs[i][:self.filled_i].std())
for i in range(self.num_agents)]
else:
ret_rews = [cast(self.rew_buffs[i][inds]) for i in range(self.num_agents)]
return ([cast(self.obs_buffs[i][inds]) for i in range(self.num_agents)],
[cast(self.ac_buffs[i][inds]) for i in range(self.num_agents)],
ret_rews,
[cast(self.next_obs_buffs[i][inds]) for i in range(self.num_agents)],
[cast(self.done_buffs[i][inds]) for i in range(self.num_agents)])
def get_average_rewards(self, N):
if self.filled_i == self.max_steps:
inds = np.arange(self.curr_i - N, self.curr_i) # allow for negative indexing
else:
inds = np.arange(max(0, self.curr_i - N), self.curr_i)
return [self.rew_buffs[i][inds].mean() for i in range(self.num_agents)]
================================================
FILE: examples/Social_Cognition/FOToM/utils/env_wrappers.py
================================================
"""
Modified from OpenAI Baselines code to work with multi-agent envs
"""
import numpy as np
from multiprocessing import Process, Pipe
from common.vec_env.vec_env import VecEnv, CloudpickleWrapper
def worker(remote, parent_remote, env_fn_wrapper):
parent_remote.close()
env = env_fn_wrapper.x()
while True:
cmd, data = remote.recv()
if cmd == 'step':
ob, reward, done, info = env.step(data)
if all(done):
ob = env.reset()
remote.send((ob, reward, done, info))
elif cmd == 'reset':
ob = env.reset()
remote.send(ob)
elif cmd == 'reset_task':
ob = env.reset_task()
remote.send(ob)
elif cmd == 'close':
remote.close()
break
elif cmd == 'get_spaces':
remote.send((env.observation_space, env.action_space))
elif cmd == 'get_agent_types':
if all([hasattr(a, 'adversary') for a in env.agents]):
remote.send(['adversary' if a.adversary else 'agent' for a in
env.agents])
else:
remote.send(['agent' for _ in env.agents])
elif cmd == 'get_num_landmarks':
remote.send(len(env.world.landmarks))
else:
raise NotImplementedError
class SubprocVecEnv(VecEnv):
def __init__(self, env_fns, spaces=None):
"""
envs: list of gym environments to run in subprocesses
"""
self.waiting = False
self.closed = False
nenvs = len(env_fns)
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
for p in self.ps:
p.daemon = True # if the main process crashes, we should not cause things to hang
p.start()
for remote in self.work_remotes:
remote.close()
self.remotes[0].send(('get_spaces', None))
observation_space, action_space = self.remotes[0].recv()
self.remotes[0].send(('get_agent_types', None))
self.agent_types = self.remotes[0].recv()
self.remotes[0].send(('get_num_landmarks', None))
self.num_lm = self.remotes[0].recv()
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
def step_async(self, actions):
for remote, action in zip(self.remotes, actions):
remote.send(('step', action))
self.waiting = True
def step_wait(self):
results = [remote.recv() for remote in self.remotes]
self.waiting = False
obs, rews, dones, infos = zip(*results)
return np.stack(obs), np.stack(rews), np.stack(dones), infos
def reset(self):
for remote in self.remotes:
remote.send(('reset', None))
return np.stack([remote.recv() for remote in self.remotes])
def reset_task(self):
for remote in self.remotes:
remote.send(('reset_task', None))
return np.stack([remote.recv() for remote in self.remotes])
def close(self):
if self.closed:
return
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(('close', None))
for p in self.ps:
p.join()
self.closed = True
class DummyVecEnv(VecEnv):
def __init__(self, env_fns):
self.envs = [fn() for fn in env_fns]
env = self.envs[0]
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
if all([hasattr(a, 'adversary') for a in env.agents]):
self.agent_types = ['adversary' if a.adversary else 'agent' for a in
env.agents]
else:
self.agent_types = ['agent' for _ in env.agents]
self.ts = np.zeros(len(self.envs), dtype='int')
self.actions = None
def step_async(self, actions):
self.actions = actions
def step_wait(self):
results = [env.step(a) for (a,env) in zip(self.actions, self.envs)]
obs, rews, dones, infos = map(np.array, zip(*results))
self.ts += 1
for (i, done) in enumerate(dones):
if all(done):
obs[i] = self.envs[i].reset()
self.ts[i] = 0
self.actions = None
return np.array(obs), np.array(rews), np.array(dones), infos
def reset(self):
results = [env.reset() for env in self.envs]
return np.array(results)
def close(self):
return
================================================
FILE: examples/Social_Cognition/FOToM/utils/make_env.py
================================================
"""
Code for creating a multiagent environment with one of the scenarios listed
in ./scenarios/.
Can be called by using, for example:
env = make_env('simple_speaker_listener')
After producing the env object, can be used similarly to an OpenAI gym
environment.
A policy using this environment must output actions in the form of a list
for all agents. Each element of the list should be a numpy array,
of size (env.world.dim_p + env.world.dim_c, 1). Physical actions precede
communication actions in this array. See environment.py for more details.
"""
def make_env(scenario_name, num_good_agents, num_adversaries, benchmark=False, discrete_action=False):
'''
Creates a MultiAgentEnv object as env. This can be used similar to a gym
environment by calling env.reset() and env.step().
Use env.render() to view the environment on the screen.
Input:
scenario_name : name of the scenario from ./scenarios/ to be Returns
(without the .py extension)
benchmark : whether you want to produce benchmarking data
(usually only done during evaluation)
Some useful env properties (see environment.py):
.observation_space : Returns the observation space for each agent
.action_space : Returns the action space for each agent
.n : Returns the number of Agents
'''
from multiagent.environment import MultiAgentEnv
import multiagent.scenarios as scenarios
# load scenario from script
scenario = scenarios.load(scenario_name + ".py").Scenario()
# create world
world = scenario.make_world(num_good_agents, num_adversaries)
# create multiagent environment
if benchmark:
env = MultiAgentEnv(world, scenario.reset_world, scenario.reward,
scenario.observation, scenario.benchmark_data)
else:
env = MultiAgentEnv(world, scenario.reset_world, scenario.reward,
scenario.observation)
return env
================================================
FILE: examples/Social_Cognition/FOToM/utils/misc.py
================================================
import os
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.autograd import Variable
import numpy as np
# https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py#L11
def soft_update(target, source, tau):
"""
Perform DDPG soft update (move target params toward source based on weight
factor tau)
Inputs:
target (torch.nn.Module): Net to copy parameters to
source (torch.nn.Module): Net whose parameters to copy
tau (float, 0 < x < 1): Weight factor for update
"""
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
# https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py#L15
def hard_update(target, source):
"""
Copy network parameters from source to target
Inputs:
target (torch.nn.Module): Net to copy parameters to
source (torch.nn.Module): Net whose parameters to copy
"""
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)
# https://github.com/seba-1511/dist_tuto.pth/blob/gh-pages/train_dist.py
def average_gradients(model):
""" Gradient averaging. """
size = float(dist.get_world_size())
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, group=0)
param.grad.data /= size
# https://github.com/seba-1511/dist_tuto.pth/blob/gh-pages/train_dist.py
def init_processes(rank, size, fn, backend='gloo'):
""" Initialize the distributed environment. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size)
def onehot_from_logits(logits, eps=0.0):
"""
Given batch of logits, return one-hot sample using epsilon greedy strategy
(based on given epsilon)
"""
# get best (according to current policy) actions in one-hot form
argmax_acs = (logits == logits.max(1, keepdim=True)[0]).float()
if eps == 0.0:
return argmax_acs
# get random actions in one-hot form
rand_acs = Variable(torch.eye(logits.shape[1])[[np.random.choice(
range(logits.shape[1]), size=logits.shape[0])]], requires_grad=False)
# chooses between best and random actions using epsilon greedy
return torch.stack([argmax_acs[i] if r > eps else rand_acs[i] for i, r in
enumerate(torch.rand(logits.shape[0]))])
# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb
def sample_gumbel(shape, eps=1e-20, tens_type=torch.FloatTensor):
"""Sample from Gumbel(0, 1)"""
U = Variable(tens_type(*shape).uniform_(), requires_grad=False)
return -torch.log(-torch.log(U + eps) + eps)
# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb
def gumbel_softmax_sample(logits, temperature):
""" Draw a sample from the Gumbel-Softmax distribution"""
y = logits + sample_gumbel(logits.shape, tens_type=type(logits.data)).to(logits.device)
return F.softmax(y / temperature, dim=1)
# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb
def gumbel_softmax(logits, temperature=1.0, hard=False):
"""Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits: [batch_size, n_class] unnormalized log-probs
temperature: non-negative scalar
hard: if True, take argmax, but differentiate w.r.t. soft sample y
Returns:
[batch_size, n_class] sample from the Gumbel-Softmax distribution.
If hard=True, then the returned sample will be one-hot, otherwise it will
be a probabilitiy distribution that sums to 1 across classes
"""
y = gumbel_softmax_sample(logits, temperature)
if hard:
y_hard = onehot_from_logits(y)
y = (y_hard - y).detach() + y
return y
================================================
FILE: examples/Social_Cognition/FOToM/utils/multiprocessing.py
================================================
# This code is from openai baseline
# https://github.com/openai/baselines/tree/master/baselines/common/vec_env
import time
import matplotlib.pyplot as plt
import numpy as np
from multiprocessing import Process, Pipe
def _flatten_list(l):
assert isinstance(l, (list, tuple))
assert len(l) > 0
assert all([len(l_) > 0 for l_ in l])
return [l__ for l_ in l for l__ in l_]
def worker(remote, parent_remote, env_fn_wrapper):
parent_remote.close()
env = env_fn_wrapper.x()
while True:
cmd, data = remote.recv()
if cmd == 'step':
ob, reward, done, info = env.step(data)
if done:
ob = env.reset()
remote.send((ob, reward, done, info))
elif cmd == 'reset':
ob = env.reset()
remote.send(ob)
elif cmd == 'reset_task':
ob = env.reset_task()
remote.send(ob)
elif cmd == 'world':
remote.send(env.world)
elif cmd == 'render':
ob = env.render(mode='rgb_array')
# print(len(ob), 'len(frames)')
# print(len(ob[0]), 'len(frames[0])')
# print(len(ob[0][0]), 'len(frames[0][0])')
remote.send(ob) # rgb_array
elif cmd == 'observe':
ob = env.observe(data)
remote.send(ob)
elif cmd == 'agents':
remote.send(env.agents)
elif cmd == 'spec':
remote.send(env.spec)
elif cmd == 'get_spaces':
remote.send((env.observation_space, env.action_space))
elif cmd == 'close':
remote.close()
break
else:
raise NotImplementedError
class VecEnv(object):
"""
An abstract asynchronous, vectorized environment.
"""
closed = False
viewer = None
metadata = {
'render.modes': ['human', 'rgb_array']
}
def __init__(self, num_envs, observation_space, action_space):
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
def observe(self, agent):
pass
def reset(self):
"""
Reset all the environments and return an array of
observations, or a tuple of observation arrays.
If step_async is still doing work, that work will
be cancelled and step_wait() should not be called
until step_async() is invoked again.
"""
pass
def step_async(self, actions):
"""
Tell all the environments to start taking a step
with the given actions.
Call step_wait() to get the results of the step.
You should not call this if a step_async run is
already pending.
"""
pass
def step_wait(self):
"""
Wait for the step taken with step_async().
Returns (obs, rews, dones, infos):
- obs: an array of observations, or a tuple of
arrays of observations.
- rews: an array of rewards
- dones: an array of "episode done" booleans
- infos: a sequence of info objects
"""
pass
def close(self):
"""
Clean up the environments' resources.
"""
pass
def step(self, actions):
self.step_async(actions)
return self.step_wait()
def render(self, mode='human'):
imgs = self.get_images()
bigimg = self.tile_images(imgs)
if mode == 'human':
self.get_viewer().imshow(bigimg) #
return self.get_viewer().isopen
elif mode == 'rgb_array':
return bigimg
else:
raise NotImplementedError
def get_images(self):
"""
Return RGB images from each environment
"""
raise NotImplementedError
def get_viewer(self):
if self.viewer is None:
from common import rendering
self.viewer = rendering.SimpleImageViewer()
return self.viewer
def tile_images(self, img_nhwc):
"""
Tile N images into one big PxQ image
(P,Q) are chosen to be as close as possible, and if N
is square, then P=Q.
input: img_nhwc, list or array of images, ndim=4 once turned into array
n = batch index, h = height, w = width, c = channel
returns:
bigim_HWc, ndarray with ndim=3
"""
img_nhwc = np.asarray(img_nhwc)
N, h, w, c = img_nhwc.shape
H = int(np.ceil(np.sqrt(N)))
W = int(np.ceil(float(N) / H))
img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(N, H * W)])
img_HWhwc = img_nhwc.reshape(H, W, h, w, c)
img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)
img_Hh_Ww_c = img_HhWwc.reshape(H * h, W * w, c)
return img_Hh_Ww_c
class CloudpickleWrapper(object):
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
"""
def __init__(self, x):
self.x = x
def __getstate__(self):
import cloudpickle
return cloudpickle.dumps(self.x)
def __setstate__(self, ob):
import pickle
self.x = pickle.loads(ob)
class SubprocVecEnv(VecEnv):
def __init__(self, env_fns, spaces=None):
"""
envs_sc: list of gym environments to run in subprocesses
"""
# self.venv = venv
self.waiting = False
self.closed = False
nenvs = len(env_fns)
self.nenvs = nenvs
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
for p in self.ps:
p.daemon = True # if the main process crashes, we should not cause things to hang
p.start()
for remote in self.work_remotes:
remote.close()
self.remotes[0].send(('get_spaces', None))
observation_space, action_space = self.remotes[0].recv()
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
def step_async(self, actions):
for remote, action in zip(self.remotes, actions): # the input of step() : action
remote.send(('step', action))
self.waiting = True
def step_wait(self):
results = [remote.recv() for remote in self.remotes] # the output of step() : zip(*results)
self.waiting = False
obs, rews, dones, infos = zip(*results)
return np.stack(obs), np.stack(rews), np.stack(dones), infos
def step_wait_2(self):
results = [remote.recv() for remote in self.remotes]
self.waiting = False
reward, done, _cumulative_rewards = zip(*results)
return reward, done, _cumulative_rewards
def step_wait_3(self):
results = [remote.recv() for remote in self.remotes] # the output of step() : zip(*results)
self.waiting = False
obs, rews, dones, infos = zip(*results)
return np.stack(obs), np.stack(rews), np.stack(dones), infos
def reset(self):
for remote in self.remotes:
remote.send(('reset', None))
return np.stack([remote.recv() for remote in self.remotes])
def agents(self):
for remote in self.remotes:
remote.send(('agents', None))
return np.stack([remote.recv() for remote in self.remotes])
def world(self):
for remote in self.remotes:
remote.send(('world', None))
return np.stack([remote.recv() for remote in self.remotes])
def reset_task(self):
for remote in self.remotes:
remote.send(('reset_task', None))
return np.stack([remote.recv() for remote in self.remotes])
def spec(self):
for remote in self.remotes:
remote.send(('spec', None))
return np.stack([remote.recv() for remote in self.remotes])
def get_images(self):
# self._assert_not_closed()
for pipe in self.remotes:
pipe.send(('render', None))
imgs = [pipe.recv() for pipe in self.remotes]
# imgs = _flatten_list(imgs)
return imgs
def observe(self, agent):
for remote, agent in zip(self.remotes, agent):
remote.send(('observe', agent))
return np.stack([remote.recv() for remote in self.remotes])
# def render(self, mode='human'):
# return self.venv.render(mode=mode)
def close(self):
if self.closed:
return
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(('close', None))
for p in self.ps:
p.join()
self.closed = True
def __len__(self):
return self.nenvs
def _flatten_list(l):
assert isinstance(l, (list, tuple))
assert len(l) > 0
assert all([len(l_) > 0 for l_ in l])
return [l__ for l_ in l for l__ in l_]
class DummyVecEnv(VecEnv):
def __init__(self, env_fns):
self.envs = [fn() for fn in env_fns]
env = self.envs[0]
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
if all([hasattr(a, 'adversary') for a in env.agents]):
self.agent_types = ['adversary' if a.adversary else 'agent' for a in
env.agents]
else:
self.agent_types = ['agent' for _ in env.agents]
self.ts = np.zeros(len(self.envs), dtype='int')
self.actions = None
def step_async(self, actions):
self.actions = actions
def step_wait(self):
results = [env.step(a) for (a,env) in zip(self.actions, self.envs)]
obs, rews, dones, infos = map(np.array, zip(*results))
self.ts += 1
for (i, done) in enumerate(dones):
if all(done):
obs[i] = self.envs[i].reset()
self.ts[i] = 0
self.actions = None
return np.array(obs), np.array(rews), np.array(dones), infos
def reset(self):
results = [env.reset() for env in self.envs]
return np.array(results)
def close(self):
return
================================================
FILE: examples/Social_Cognition/FOToM/utils/networks.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
class MLPNetwork(nn.Module):
"""
MLP network (can be used as value or policy)
"""
def __init__(self, input_dim, out_dim, hidden_dim=64, nonlin=F.relu,
constrain_out=False, norm_in=True, discrete_action=True):
"""
Inputs:
input_dim (int): Number of dimensions in input
out_dim (int): Number of dimensions in output
hidden_dim (int): Number of hidden dimensions
nonlin (PyTorch function): Nonlinearity to apply to hidden layers
"""
super(MLPNetwork, self).__init__()
if norm_in: # normalize inputs
self.in_fn = nn.BatchNorm1d(input_dim) #train
# self.in_fn = input_dim #test
self.in_fn.weight.data.fill_(1)
self.in_fn.bias.data.fill_(0)
else:
self.in_fn = lambda x: x
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, out_dim)
self.nonlin = nonlin
if constrain_out and not discrete_action:
# initialize small to prevent saturation
self.fc3.weight.data.uniform_(-3e-3, 3e-3)
self.out_fn = F.tanh
else: # logits for discrete action (will softmax later)
self.out_fn = lambda x: x
def forward(self, X):
"""
Inputs:
X (PyTorch Matrix): Batch of observations
Outputs:
out (PyTorch Matrix): Output of network (actions, values, etc)
"""
h1 = self.nonlin(self.fc1(self.in_fn(X)))
h2 = self.nonlin(self.fc2(h1))
out = self.out_fn(self.fc3(h2))
return out
class RNN(nn.Module):
# Because all the agents_sc share the same network_sc, input_shape=obs_shape+n_actions+n_agents
def __init__(self, input_dim, out_dim, hidden_dim=64, nonlin=F.relu,
constrain_out=False, norm_in=True, discrete_action=True):
super(RNN, self).__init__()
self.rnn_hidden_dim = hidden_dim
if norm_in: # normalize inputs
self.in_fn = nn.BatchNorm1d(input_dim)
self.in_fn.weight.data.fill_(1)
self.in_fn.bias.data.fill_(0)
else:
self.in_fn = lambda x: x
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.rnn = nn.GRUCell(hidden_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, out_dim)
self.nonlin = nonlin
if constrain_out and not discrete_action:
# initialize small to prevent saturation
self.fc3.weight.data.uniform_(-3e-3, 3e-3)
self.out_fn = F.tanh
else: # logits for discrete action (will softmax later)
self.out_fn = lambda x: x
def forward(self, obs, hidden_state):
x = self.nonlin(self.fc1(obs))
# x = x.reshape(-1, self.rnn_hidden_dim)
# h_in = hidden_state.reshape(-1, self.rnn_hidden_dim)
h = self.rnn(x, hidden_state)
q = self.fc2(h)
return q, h
class BCNoSpikingLIFNode(LIFNode):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, dv: torch.Tensor):
self.integral(dv)
return self.mem
class SNNNetwork(nn.Module):
"""
SNN network (can be used as value or policy or MLE)
"""
def __init__(self, input_dim, out_dim, hidden_dim=64, node=LIFNode, time_window=16,
norm_in=True, output_style='sum'):
"""
Inputs:
input_dim (int): Number of dimensions in input
out_dim (int): Number of dimensions in output
hidden_dim (int): Number of hidden dimensions
nonlin (PyTorch function): Nonlinearity to apply to hidden layers
"""
super(SNNNetwork, self).__init__()
self._threshold = 0.5
self.v_reset = 0.0
self._time_window = time_window
self.output_style = output_style
self._node1 = node(threshold=self._threshold, v_reset=self.v_reset)
self._node2 = node(threshold=self._threshold, v_reset=self.v_reset)
if norm_in: # normalize inputs
self.in_fn = nn.BatchNorm1d(input_dim) #train
self.in_fn.weight.data.fill_(1)
self.in_fn.bias.data.fill_(0)
else:
self.in_fn = lambda x: x
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, out_dim)
if self.output_style == 'sum':
self._out_node = lambda x: x
elif self.output_style == 'voltage':
self._out_node = BCNoSpikingLIFNode()
def reset(self):
for mod in self.modules():
if hasattr(mod, 'n_reset'):
mod.n_reset()
def forward(self, X):
qs = []
self.reset()
for t in range(self._time_window):
x = self.fc1((self.in_fn(X)+0.5)) #train
# x = self.fc1((X + 0.5)) #test
x = self._node1(x)
x = self.fc2(x)
x = self._node2(x)
x = self.fc3(x)
x = self._out_node(x)
qs.append(x)
if self.output_style == 'sum':
outputs = sum(qs) / self._time_window
return outputs
elif self.output_style == 'voltage':
outputs = x
return outputs
class LSTMClassifier(nn.Module):
def __init__(self, input_size, output_size, hidden_size=256):
super(LSTMClassifier, self).__init__()
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=1, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = x.unsqueeze(1)
output, (h, c) = self.lstm(x)
output = output.squeeze(1)
out = self.fc(output)
return out
================================================
FILE: examples/Social_Cognition/FOToM/utils/noise.py
================================================
import numpy as np
# from https://github.com/songrotek/DDPG/blob/master/ou_noise.py
class OUNoise:
def __init__(self, action_dimension, scale=0.1, mu=0, theta=0.15, sigma=0.2):
self.action_dimension = action_dimension
self.scale = scale
self.mu = mu
self.theta = theta
self.sigma = sigma
self.state = np.ones(self.action_dimension) * self.mu
self.reset()
def reset(self):
self.state = np.ones(self.action_dimension) * self.mu
def noise(self):
x = self.state
dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(len(x))
self.state = x + dx
return self.state * self.scale
================================================
FILE: examples/Social_Cognition/Intention_Prediction/Intention_Prediction.py
================================================
import numpy as np
import torch,os,sys
from torch import nn
from torch.nn import Parameter
import abc
import math
from abc import ABC
import numpy as np
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
import matplotlib.pyplot as plt
from BrainCog.base.strategy.surrogate import *
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import random
from BrainCog.base.node.node import *
from BrainCog.base.learningrule.STDP import MutliInputSTDP
class CustomLinear(nn.Module):
def __init__(self, weight,mask=None):
super().__init__()
self.weight = nn.Parameter(weight, requires_grad=True)
self.mask=mask
def forward(self, x: torch.Tensor):
#
# ret.shape = [C]
return x.mul(self.weight) # Changed
def update(self, dw):
with torch.no_grad():
if self.mask is not None:
dw *= self.mask
self.weight.data+= dw
class DLPFCNet(nn.Module):
def __init__(self,connection):
super().__init__()
# DLPFC, BG
self.node = []
self.node.append(IzhNodeMU(threshold=30., a=0.02, b=0.60, c=-65., d=8., mem=-70.))
self.node.append(IzhNodeMU(threshold=30., a=0.02, b=0.60, c=-65., d=8., mem=-70.))
self.learning_rule = []
self.connection = connection
self.out_DLPFC=torch.zeros((self.connection[0].weight.shape[1]), dtype=torch.float) # Input-DLPFC
self.out_BG=torch.zeros((self.connection[1].weight.shape[1]), dtype=torch.float) # DLPFC-BG
def forward(self, input):
self.out_DLPFC=self.node[0](self.connection[0](input))
self.out_BG=self.node[1](self.connection[1](self.out_DLPFC))
BG_Spike = self.node[1].spike
if sum(sum(BG_Spike)).item() > 1:
num_neuron = len(BG_Spike)
BG_Spike_index = torch.argmax(BG_Spike)
BG_Spike_index_x = torch.floor(BG_Spike_index/num_neuron)
BG_Spike_index_y = BG_Spike_index - BG_Spike_index_x*num_neuron
BG_Spike = torch.zeros([num_neuron, num_neuron], dtype=torch.float)
BG_Spike[BG_Spike_index_x.long()][BG_Spike_index_y.long()] = 1
return BG_Spike
def reset(self):
for i in range(len(self.node)):
self.node[i].n_reset()
for i in range(len(self.learning_rule)):
self.learning_rule[i].reset()
def UpdateWeight(self, i, W):
self.connection[i].weight.data = W
class OFCNet(nn.Module):
def __init__(self,connection):
super().__init__()
# OFC, MOFC, LOFC
self.node = []
self.node.append(IzhNodeMU(threshold=30., a=0.02, b=0.60, c=-65., d=8., mem=-70.)) # OFC_1
self.node.append(IzhNodeMU(threshold=30., a=0.02, b=0.60, c=-65., d=8., mem=-70.)) # OFC_2
self.node.append(IzhNodeMU(threshold=30., a=0.02, b=0.60, c=-65., d=8., mem=-70.)) # MOFC
self.node.append(IzhNodeMU(threshold=30., a=0.02, b=0.60, c=-65., d=8., mem=-70.)) # LOFC
self.connection = connection
self.learning_rule = []
self.learning_rule.append(MutliInputSTDP(self.node[3], [self.connection[3],self.connection[4]])) # OFC_2-LOFC, MOFC-LOFC
self.learning_rule.append(MutliInputSTDP(self.node[3], [self.connection[3],self.connection[5]])) # OFC_2-LOFC, OFC_1-LOFC
self.out_OFC_1=torch.zeros((self.connection[0].weight.shape[1]), dtype=torch.float)
self.out_OFC_2=torch.zeros((self.connection[1].weight.shape[1]), dtype=torch.float)
self.out_MOFC=torch.zeros((self.connection[2].weight.shape[1]), dtype=torch.float)
self.out_LOFC=torch.zeros((self.connection[5].weight.shape[1]), dtype=torch.float)
def forward(self, Input_Tha, Input_SNc, Reward):
self.out_OFC_1 = self.node[0](self.connection[0](Input_Tha))
self.out_OFC_2 = self.node[1](self.connection[1](Input_SNc))
if Reward == 1:
self.out_MOFC = self.node[2](self.connection[2](self.out_OFC_1))
self.out_LOFC, dw_lofc = self.learning_rule[0](self.out_OFC_2, self.out_MOFC)
else:
self.out_MOFC = self.node[2](self.connection[2](self.out_OFC_1*0))
self.out_LOFC, dw_lofc = self.learning_rule[1](self.out_OFC_2, self.out_OFC_1)
MOFC_Spike = self.node[2].spike
LOFC_Spike = self.node[3].spike
return MOFC_Spike, LOFC_Spike
def reset(self):
for i in range(len(self.node)):
self.node[i].n_reset()
for i in range(len(self.learning_rule)):
self.learning_rule[i].reset()
class BGNet(nn.Module):
def __init__(self,connection):
super().__init__()
# DLPFC, StrD1, StrD2
self.node = []
self.node.append(IzhNodeMU(threshold=30., a=0.02, b=0.60, c=-65., d=8., mem=-70.)) # DLPFC
self.node.append(IzhNodeMU(threshold=30., a=0.01, b=0.01, c=-65., d=8., mem=-70.)) # StrD1
self.node.append(IzhNodeMU(threshold=30., a=0.1, b=0.5, c=-65., d=8., mem=-70.)) # StrD2
self.connection = connection
self.learning_rule = []
self.out_DLPFC=torch.zeros((self.connection[0].weight.shape[1]), dtype=torch.float)
self.out_StrD1=torch.zeros((self.connection[1].weight.shape[1]), dtype=torch.float)
self.out_StrD2=torch.zeros((self.connection[2].weight.shape[1]), dtype=torch.float)
def forward(self, input1, input2, input3):
self.out_DLPFC=self.node[0](self.connection[0](input1))
self.out_StrD1=self.node[1](self.connection[1](input2))
self.out_StrD2=self.node[2](self.connection[2](input3))
DLPFC_out = self.node[0].spike
BG_out = self.node[1].spike + self.node[2].spike
return DLPFC_out, BG_out
def reset(self):
for i in range(len(self.node)):
self.node[i].n_reset()
for i in range(len(self.learning_rule)):
self.learning_rule[i].reset()
def UpdateWeight(self, i, W):
self.connection[i].weight.data = W
def STDP(Pre_mat, Post_mat, W):
T_Pre = 0
T_Post = 0
for i in range(len(Pre_mat)):
C_Pre = Pre_mat[i]
C_Post = Post_mat[i]
if sum(sum(C_Pre)) > 0:
T_Pre = i
Spike_Pre = Pre_mat[T_Pre]
if sum(sum(C_Post)) > 0:
T_Post = i
Spike_Post = Post_mat[T_Post]
if T_Pre*T_Post > 0:
dT = T_Pre - T_Post
A_up = 0.777
A_down = -0.237
tau_up = 16.8
tau_down = -33.7
if dT < 0:
dW = A_up * math.exp(dT/tau_up)
else:
dW = A_down * math.exp(dT/tau_down)
T_Post = 0
dW_mat = torch.mul(Spike_Pre, Spike_Post)*dW
W = W + torch.mul(dW_mat, W)
return W
if __name__=="__main__":
# number of neurons
num_neuron = 6
num_DLPFC = num_neuron
num_BG = num_neuron
num_StrD1 = num_neuron
num_StrD2 = num_neuron
num_Thalamus = num_neuron
num_OFC = num_neuron
num_SNc = num_neuron
num_PMC = num_neuron
##############################
# DLPFC
##############################
WeightAdd = 20
# DLPFC-BG
DLPFC_BG_connection = []
# Input-DLPFC
con_matrix0 = torch.ones([num_DLPFC, num_DLPFC], dtype=torch.float)*WeightAdd
DLPFC_BG_connection.append(CustomLinear(con_matrix0))
# DLPFC-BG
W = torch.ones([num_DLPFC, num_BG], dtype=torch.float)*WeightAdd
DLPFC_BG_connection.append(CustomLinear(W))
DLPFC = DLPFCNet(DLPFC_BG_connection)
##############################
# OFC
##############################
WeightAdd = 20
OFC_connection = []
# Tha-OFC_1 (Input1)
con_matrix0 = torch.ones([num_Thalamus, num_OFC], dtype=torch.float)*WeightAdd
OFC_connection.append(CustomLinear(con_matrix0))
# SNc/VTA-OFC_2 (Input2)
con_matrix1 = torch.ones([num_SNc, num_OFC], dtype=torch.float)*WeightAdd
OFC_connection.append(CustomLinear(con_matrix1))
# OFC_1-MOFC
con_matrix2 = torch.ones([num_OFC, num_OFC], dtype=torch.float)*WeightAdd*5
OFC_connection.append(CustomLinear(con_matrix2))
# OFC_2-LOFC
con_matrix3 = torch.ones([num_OFC, num_OFC], dtype=torch.float)*WeightAdd*5
OFC_connection.append(CustomLinear(con_matrix3))
# MOFC-LOFC
con_matrix4 = torch.ones([num_OFC, num_OFC], dtype=torch.float)*WeightAdd*-10
OFC_connection.append(CustomLinear(con_matrix4))
# OFC_1-LOFC
con_matrix5 = torch.ones([num_OFC, num_OFC], dtype=torch.float)*WeightAdd*5
OFC_connection.append(CustomLinear(con_matrix5))
OFC = OFCNet(OFC_connection)
##############################
# BGNet
##############################
BG_connection = []
WeightAdd = 20
# Input1-DLPFC
con_matrix0 = torch.ones([num_DLPFC, num_DLPFC], dtype=torch.float)*WeightAdd
BG_connection.append(CustomLinear(con_matrix0))
# Input2-StrD1
con_matrix1 = torch.ones([num_StrD1,num_StrD1], dtype=torch.float)*WeightAdd
BG_connection.append(CustomLinear(con_matrix1))
# Input3-StrD2
con_matrix2 = torch.ones([num_StrD2,num_StrD2], dtype=torch.float)*WeightAdd
BG_connection.append(CustomLinear(con_matrix2))
BG_connection.append(CustomLinear(W))
# StrD1-BG
con_matrix4 = torch.ones([num_StrD1,num_BG], dtype=torch.float)*WeightAdd
BG_connection.append(CustomLinear(con_matrix4))
# StrD2-BG
con_matrix5 = torch.ones([num_StrD2,num_BG], dtype=torch.float)*WeightAdd
BG_connection.append(CustomLinear(con_matrix5))
BG = BGNet(BG_connection)
##############################
# Train
##############################
# Intention-action corresponding rules
Intention_mat = range(num_neuron)
Action_mat = range(num_neuron)
TrainNum = 0
for k in range(len(Intention_mat)):
Intention = Intention_mat[k]
Intention_Action = Action_mat[k]
for j in range (len(Intention_mat)+1):
TrainNum = TrainNum + 1
# Intention prediction
for i in range(10):
DLPFC_Input = torch.zeros([num_DLPFC, num_DLPFC], dtype=torch.float)
DLPFC_Input[Intention,:] = 10
BG_Spike = DLPFC(DLPFC_Input)
if sum(sum(BG_Spike)).item() > 0:
Action = torch.nonzero(BG_Spike).numpy()[0][1]
break
DLPFC.reset()
PMC = torch.zeros([1, num_PMC], dtype=torch.float)
PMC[0][Action] = 1
Thalamus = torch.zeros([num_Thalamus, num_Thalamus], dtype=torch.float)
Thalamus[Intention][Action] = 10
if Intention_Action == Action:
# Positive reward
# Tha-OFC_1-MOFC, SNc-OFC_2&MOFC-LOFC
Reward = 1
Input_Tha = Thalamus
Input_SNc_Reward = torch.ones([num_OFC, num_OFC], dtype=torch.float)
Input_SNc = torch.mul(Input_SNc_Reward, PMC)*10
for t in range(10):
MOFC_Spike, LOFC_Spike = OFC(Input_Tha, Input_SNc, Reward)
else:
# Negative reward
# Tha-OFC_1-LOFC, SNc-OFC_2-LOFC (SNC is zeros)
Reward = -1
Input_Tha = Thalamus
Input_SNc = torch.zeros([num_OFC, num_OFC], dtype=torch.float)
for t in range(10):
MOFC_Spike, LOFC_Spike = OFC(Input_Tha, Input_SNc, Reward)
OFC.reset()
for i in range(1):
DLPFC_out_mat = []
BG_out_mat = []
State = 0
for t in range(10):
DLPFC_Input = MOFC_Spike + LOFC_Spike
StrD1_Input = MOFC_Spike
StrD2_Input = LOFC_Spike
DLPFC_out, BG_out = BG(DLPFC_Input, StrD1_Input, StrD2_Input)
DLPFC_out_mat.append(DLPFC_out)
BG_out_mat.append(BG_out)
W = STDP(DLPFC_out_mat, BG_out_mat, W)
BG.reset()
DLPFC.UpdateWeight(1, W)
BG.UpdateWeight(3, W)
if Reward == 1:
break
print("Train End")
print("W is: \n", W)
print("TrainNum is: \n", TrainNum)
print("*****************************")
================================================
FILE: examples/Social_Cognition/MAToM-SNN/LICENSE
================================================
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc.
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
Copyright (C)
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
Copyright (C)
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
.
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/agents/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/agents/agents.py
================================================
import torch
from torch import Tensor
from torch.autograd import Variable
from torch.optim import Adam
from MPE.utils.networks import MLPNetwork, RNN, SNNNetwork
from MPE.utils.misc import hard_update, gumbel_softmax, onehot_from_logits
from MPE.utils.noise import OUNoise
import time
class DDPGAgent(object):
"""
General class for DDPG agents (policy, critic, target policy, target
critic, exploration noise)
"""
def __init__(self, num_in_pol, num_out_pol, num_in_critic, hidden_dim=64,
lr=0.01, discrete_action=True):
"""
Inputs:
num_in_pol (int): number of dimensions for policy input
num_out_pol (int): number of dimensions for policy output
num_in_critic (int): number of dimensions for critic input
"""
self.policy = MLPNetwork(num_in_pol, num_out_pol,
hidden_dim=hidden_dim,
constrain_out=True,
discrete_action=discrete_action)
self.critic = MLPNetwork(num_in_critic, 1,
hidden_dim=hidden_dim,
constrain_out=False)
self.target_policy = MLPNetwork(num_in_pol, num_out_pol,
hidden_dim=hidden_dim,
constrain_out=True,
discrete_action=discrete_action)
self.target_critic = MLPNetwork(num_in_critic, 1,
hidden_dim=hidden_dim,
constrain_out=False)
hard_update(self.target_policy, self.policy)
hard_update(self.target_critic, self.critic)
self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)
self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)
if not discrete_action:
self.exploration = OUNoise(num_out_pol)
else:
self.exploration = 0.3 # epsilon for eps-greedy
self.discrete_action = discrete_action
def reset_noise(self):
if not self.discrete_action:
self.exploration.reset()
def scale_noise(self, scale):
if self.discrete_action:
self.exploration = scale
else:
self.exploration.scale = scale
def step(self, obs, explore=False):
"""
Take a step forward in environment for a minibatch of observations
Inputs:
obs (PyTorch Variable): Observations for this agent
explore (boolean): Whether or not to add exploration noise
Outputs:
action (PyTorch Variable): Actions for this agent
"""
action = self.policy(obs)
if self.discrete_action:
if explore:
if action.shape[1] == 9:
action = torch.cat(
(gumbel_softmax(action[:, :5], hard=True), gumbel_softmax(action[:, 5:], hard=True)), 1)
else:
action = gumbel_softmax(action, hard=True)
else:
if action.shape[1] == 9:
action = torch.cat(
(onehot_from_logits(action[:, :5]), onehot_from_logits(action[:, 5:])), 1)
else:
action = onehot_from_logits(action)
else: # continuous action
if explore:
action += Variable(Tensor(self.exploration.noise()),
requires_grad=False)
action = action.clamp(-1, 1)
return action
def get_params(self):
return {'policy': self.policy.state_dict(),
'critic': self.critic.state_dict(),
'target_policy': self.target_policy.state_dict(),
'target_critic': self.target_critic.state_dict(),
'policy_optimizer': self.policy_optimizer.state_dict(),
'critic_optimizer': self.critic_optimizer.state_dict()}
def load_params(self, params):
self.policy.load_state_dict(params['policy'])
self.critic.load_state_dict(params['critic'])
self.target_policy.load_state_dict(params['target_policy'])
self.target_critic.load_state_dict(params['target_critic'])
self.policy_optimizer.load_state_dict(params['policy_optimizer'])
self.critic_optimizer.load_state_dict(params['critic_optimizer'])
class DDPGAgent_RNN(object):
"""
General class for DDPG agents (policy, critic, target policy, target
critic, exploration noise)
"""
def __init__(self, num_in_pol, num_out_pol, num_in_critic, hidden_dim=64,
lr=0.01, discrete_action=True):
"""
Inputs:
num_in_pol (int): number of dimensions for policy input
num_out_pol (int): number of dimensions for policy output
num_in_critic (int): number of dimensions for critic input
"""
self.policy = RNN(num_in_pol, num_out_pol,
hidden_dim=hidden_dim,
constrain_out=True,
discrete_action=discrete_action)
self.critic = RNN(num_in_critic, 1,
hidden_dim=hidden_dim,
constrain_out=False)
self.target_policy = RNN(num_in_pol, num_out_pol,
hidden_dim=hidden_dim,
constrain_out=True,
discrete_action=discrete_action)
self.target_critic = RNN(num_in_critic, 1,
hidden_dim=hidden_dim,
constrain_out=False)
self.policy_hidden = None
self.policy_target_hidden = None
self.critic_hidden = None
self.critic_target_hidden = None
self.num_in_pol = num_in_pol
self.num_out_pol = num_out_pol
self.hidden_dim = hidden_dim
hard_update(self.target_policy, self.policy)
hard_update(self.target_critic, self.critic)
self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)
self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)
if not discrete_action:
self.exploration = OUNoise(num_out_pol)
else:
self.exploration = 0.3 # epsilon for eps-greedy
self.discrete_action = discrete_action
def reset_noise(self):
if not self.discrete_action:
self.exploration.reset()
def scale_noise(self, scale):
if self.discrete_action:
self.exploration = scale
else:
self.exploration.scale = scale
def step(self, obs, explore=False):
"""
Take a step forward in environment for a minibatch of observations
Inputs:
obs (PyTorch Variable): Observations for this agent
explore (boolean): Whether or not to add exploration noise
Outputs:
action (PyTorch Variable): Actions for this agent
"""
action, self.policy_hidden = self.policy(obs, self.policy_hidden)
if self.discrete_action:
if explore:
action = gumbel_softmax(action, hard=True)
else:
action = onehot_from_logits(action)
else: # continuous action
if explore:
action += Variable(Tensor(self.exploration.noise()),
requires_grad=False)
action = action.clamp(-1, 1)
return action
def get_params(self):
return {'policy': self.policy.state_dict(),
'critic': self.critic.state_dict(),
'target_policy': self.target_policy.state_dict(),
'target_critic': self.target_critic.state_dict(),
'policy_optimizer': self.policy_optimizer.state_dict(),
'critic_optimizer': self.critic_optimizer.state_dict()}
def load_params(self, params):
self.policy.load_state_dict(params['policy'])
self.critic.load_state_dict(params['critic'])
self.target_policy.load_state_dict(params['target_policy'])
self.target_critic.load_state_dict(params['target_critic'])
self.policy_optimizer.load_state_dict(params['policy_optimizer'])
self.critic_optimizer.load_state_dict(params['critic_optimizer'])
def init_hidden(self, len_ep, policy_hidden=False, policy_target_hidden=False, \
critic_hidden=False, critic_target_hidden=False):
# 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden
if policy_hidden == True:
self.policy_hidden = torch.zeros((len_ep, self.hidden_dim))
if policy_target_hidden == True:
self.policy_target_hidden = torch.zeros((len_ep, self.hidden_dim))
if critic_hidden == True:
self.critic_hidden = torch.zeros((len_ep, self.hidden_dim))
if critic_target_hidden == True:
self.critic_target_hidden = torch.zeros((len_ep, self.hidden_dim))
class DDPGAgent_SNN(object):
"""
General class for DDPG agents (policy, critic, target policy, target
critic, exploration noise)
"""
def __init__(self, num_in_pol, num_out_pol, num_in_critic, output_style, hidden_dim=64,
lr=0.01, discrete_action=True):
"""
Inputs:
num_in_pol (int): number of dimensions for policy input
num_out_pol (int): number of dimensions for policy output
num_in_critic (int): number of dimensions for critic input
"""
self.policy = SNNNetwork(num_in_pol, num_out_pol,
hidden_dim=hidden_dim,
output_style=output_style)
self.critic = SNNNetwork(num_in_critic, 1,
hidden_dim=hidden_dim,
output_style=output_style)
self.target_policy = SNNNetwork(num_in_pol, num_out_pol,
hidden_dim=hidden_dim,
output_style=output_style)
self.target_critic = SNNNetwork(num_in_critic, 1,
hidden_dim=hidden_dim,
output_style=output_style)
hard_update(self.target_policy, self.policy)
hard_update(self.target_critic, self.critic)
self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)
self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)
if not discrete_action:
self.exploration = OUNoise(num_out_pol)
else:
self.exploration = 0.3 # epsilon for eps-greedy
self.discrete_action = discrete_action
def reset_noise(self):
if not self.discrete_action:
self.exploration.reset()
def scale_noise(self, scale):
if self.discrete_action:
self.exploration = scale
else:
self.exploration.scale = scale
def step(self, obs, explore=False):
"""
Take a step forward in environment for a minibatch of observations
Inputs:
obs (PyTorch Variable): Observations for this agent
explore (boolean): Whether or not to add exploration noise
Outputs:
action (PyTorch Variable): Actions for this agent
"""
# t1 = time.time()
action = self.policy(obs)
# t2 = time.time()
# print('time_interaction:', t2 - t1)
if self.discrete_action:
if explore:
if action.shape[1] == 9:
action = torch.cat(
(gumbel_softmax(action[:, :5], hard=True), gumbel_softmax(action[:, 5:], hard=True)), 1)
else:
action = gumbel_softmax(action, hard=True)
else:
if action.shape[1] == 9:
action = torch.cat(
(onehot_from_logits(action[:, :5]), onehot_from_logits(action[:, 5:])), 1)
else:
action = onehot_from_logits(action)
# if explore:
#
# action = gumbel_softmax(action, hard=True)
#
# else:
# action = onehot_from_logits(action)
else: # continuous action
if explore:
action += Variable(Tensor(self.exploration.noise()),
requires_grad=False)
action = action.clamp(-1, 1)
return action
def get_params(self):
return {'policy': self.policy.state_dict(),
'critic': self.critic.state_dict(),
'target_policy': self.target_policy.state_dict(),
'target_critic': self.target_critic.state_dict(),
'policy_optimizer': self.policy_optimizer.state_dict(),
'critic_optimizer': self.critic_optimizer.state_dict()}
def load_params(self, params):
self.policy.load_state_dict(params['policy'])
self.critic.load_state_dict(params['critic'])
self.target_policy.load_state_dict(params['target_policy'])
self.target_critic.load_state_dict(params['target_critic'])
self.policy_optimizer.load_state_dict(params['policy_optimizer'])
self.critic_optimizer.load_state_dict(params['critic_optimizer'])
class DDPGAgent_ToM(object):
"""
General class for DDPG agents (policy, critic, target policy, target
critic, exploration noise)
"""
def __init__(self, num_in_pol, num_out_pol, num_in_critic, num_in_mle, output_style,
num_agents, device, hidden_dim=64, lr=0.01, discrete_action=True):
"""
Inputs:
num_in_pol (int): number of dimensions for policy input
num_out_pol (int): number of dimensions for policy output
num_in_critic (int): number of dimensions for critic input
"""
self.device = device
self.policy = SNNNetwork(num_in_pol, num_out_pol,
hidden_dim=hidden_dim,
output_style=output_style)
self.critic = SNNNetwork(num_in_critic, 1,
hidden_dim=hidden_dim,
output_style=output_style)
self.target_policy = SNNNetwork(num_in_pol, num_out_pol,
hidden_dim=hidden_dim,
output_style=output_style)
self.target_critic = SNNNetwork(num_in_critic, 1,
hidden_dim=hidden_dim,
output_style=output_style)
# self.mle = [SNNNetwork(num_in_mle, num_out_pol,
# hidden_dim=hidden_dim,
# output_style=output_style)] * (num_agents - 1)
self.mle = []
hard_update(self.target_policy, self.policy)
hard_update(self.target_critic, self.critic)
self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)
self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)
self.mle_optimizer = []
if not discrete_action:
self.exploration = OUNoise(num_out_pol)
else:
self.exploration = 0.3 # epsilon for eps-greedy
self.discrete_action = discrete_action
def reset_noise(self):
if not self.discrete_action:
self.exploration.reset()
def scale_noise(self, scale):
if self.discrete_action:
self.exploration = scale
else:
self.exploration.scale = scale
def step(self, obs, explore=False):
"""
Take a step forward in environment for a minibatch of observations
Inputs:
obs (PyTorch Variable): Observations for this agent
explore (boolean): Whether or not to add exploration noise
Outputs:
action (PyTorch Variable): Actions for this agent
"""
action = self.policy.to(self.device)(obs.to(self.device))
if self.discrete_action:
if explore:
if action.shape[1] == 9:
action = torch.cat(
(gumbel_softmax(action[:, :5], hard=True), gumbel_softmax(action[:, 5:], hard=True)), 1).cpu()
else:
action = gumbel_softmax(action, hard=True).cpu()
else:
if action.shape[1] == 9:
action = torch.cat(
(onehot_from_logits(action[:, :5], hard=True), onehot_from_logits(action[:, 5:], hard=True)), 1)
else:
action = onehot_from_logits(action).cpu()
# if explore:
# action = gumbel_softmax(action, hard=True).cpu()
# else:
# action = onehot_from_logits(action).cpu()
else: # continuous action
if explore:
action += Variable(Tensor(self.exploration.noise()),
requires_grad=False)
action = action.clamp(-1, 1)
return action
def get_params(self):
params = {'policy': self.policy.state_dict(),
'critic': self.critic.state_dict(),
'target_policy': self.target_policy.state_dict(),
'target_critic': self.target_critic.state_dict(),
'policy_optimizer': self.policy_optimizer.state_dict(),
'critic_optimizer': self.critic_optimizer.state_dict(),
}
# for i in range(len(self.mle)):
# params['mle%d'%i] = self.mle[i].state_dict()
# params['mle_optimizer%d'%i] = self.mle_optimizer[i].state_dict()
return params
def load_params(self, params):
self.policy.load_state_dict(params['policy'])
self.critic.load_state_dict(params['critic'])
self.target_policy.load_state_dict(params['target_policy'])
self.target_critic.load_state_dict(params['target_critic'])
self.policy_optimizer.load_state_dict(params['policy_optimizer'])
self.critic_optimizer.load_state_dict(params['critic_optimizer'])
# for i in range(len(self.mle)):
# self.mle[i].load_state_dict(params['mle%d'%i])
# self.mle_optimizer[i].load_state_dict(params['mle_optimizer%d'%i])
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/common/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/common/distributions.py
================================================
# import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_eager_execution()
import numpy as np
import maddpg.common.tf_util as U
from tensorflow.python.ops import math_ops
from multiagent.multi_discrete import MultiDiscrete
from tensorflow.python.ops import nn
class Pd(object):
"""
A particular probability distribution
"""
def flatparam(self):
raise NotImplementedError
def mode(self):
raise NotImplementedError
def logp(self, x):
raise NotImplementedError
def kl(self, other):
raise NotImplementedError
def entropy(self):
raise NotImplementedError
def sample(self):
raise NotImplementedError
class PdType(object):
"""
Parametrized family of probability distributions
"""
def pdclass(self):
raise NotImplementedError
def pdfromflat(self, flat):
return self.pdclass()(flat)
def param_shape(self):
raise NotImplementedError
def sample_shape(self):
raise NotImplementedError
def sample_dtype(self):
raise NotImplementedError
def param_placeholder(self, prepend_shape, name=None):
return tf.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name)
def sample_placeholder(self, prepend_shape, name=None):
return tf.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name)
class CategoricalPdType(PdType):
def __init__(self, ncat):
self.ncat = ncat
def pdclass(self):
return CategoricalPd
def param_shape(self):
return [self.ncat]
def sample_shape(self):
return []
def sample_dtype(self):
return tf.int32
class SoftCategoricalPdType(PdType):
def __init__(self, ncat):
self.ncat = ncat
def pdclass(self):
return SoftCategoricalPd
def param_shape(self):
return [self.ncat]
def sample_shape(self):
return [self.ncat]
def sample_dtype(self):
return tf.float32
class MultiCategoricalPdType(PdType):
def __init__(self, low, high):
self.low = low
self.high = high
self.ncats = high - low + 1
def pdclass(self):
return MultiCategoricalPd
def pdfromflat(self, flat):
return MultiCategoricalPd(self.low, self.high, flat)
def param_shape(self):
return [sum(self.ncats)]
def sample_shape(self):
return [len(self.ncats)]
def sample_dtype(self):
return tf.int32
class SoftMultiCategoricalPdType(PdType):
def __init__(self, low, high):
self.low = low
self.high = high
self.ncats = high - low + 1
def pdclass(self):
return SoftMultiCategoricalPd
def pdfromflat(self, flat):
return SoftMultiCategoricalPd(self.low, self.high, flat)
def param_shape(self):
return [sum(self.ncats)]
def sample_shape(self):
return [sum(self.ncats)]
def sample_dtype(self):
return tf.float32
class DiagGaussianPdType(PdType):
def __init__(self, size):
self.size = size
def pdclass(self):
return DiagGaussianPd
def param_shape(self):
return [2*self.size]
def sample_shape(self):
return [self.size]
def sample_dtype(self):
return tf.float32
class BernoulliPdType(PdType):
def __init__(self, size):
self.size = size
def pdclass(self):
return BernoulliPd
def param_shape(self):
return [self.size]
def sample_shape(self):
return [self.size]
def sample_dtype(self):
return tf.int32
# WRONG SECOND DERIVATIVES
# class CategoricalPd(Pd):
# def __init__(self, logits):
# self.logits = logits
# self.ps = tf.nn.softmax(logits)
# @classmethod
# def fromflat(cls, flat):
# return cls(flat)
# def flatparam(self):
# return self.logits
# def mode(self):
# return U.argmax(self.logits, axis=1)
# def logp(self, x):
# return -tf.nn.sparse_softmax_cross_entropy_with_logits(self.logits, x)
# def kl(self, other):
# return tf.nn.softmax_cross_entropy_with_logits(other.logits, self.ps) \
# - tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
# def entropy(self):
# return tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
# def sample(self):
# u = tf.random_uniform(tf.shape(self.logits))
# return U.argmax(self.logits - tf.log(-tf.log(u)), axis=1)
class CategoricalPd(Pd):
def __init__(self, logits):
self.logits = logits
def flatparam(self):
return self.logits
def mode(self):
return U.argmax(self.logits, axis=1)
def logp(self, x):
return -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
def kl(self, other):
a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)
a1 = other.logits - U.max(other.logits, axis=1, keepdims=True)
ea0 = tf.exp(a0)
ea1 = tf.exp(a1)
z0 = U.sum(ea0, axis=1, keepdims=True)
z1 = U.sum(ea1, axis=1, keepdims=True)
p0 = ea0 / z0
return U.sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=1)
def entropy(self):
a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)
ea0 = tf.exp(a0)
z0 = U.sum(ea0, axis=1, keepdims=True)
p0 = ea0 / z0
return U.sum(p0 * (tf.log(z0) - a0), axis=1)
def sample(self):
u = tf.random_uniform(tf.shape(self.logits))
return U.argmax(self.logits - tf.log(-tf.log(u)), axis=1)
@classmethod
def fromflat(cls, flat):
return cls(flat)
class SoftCategoricalPd(Pd):
def __init__(self, logits):
self.logits = logits
def flatparam(self):
return self.logits
def mode(self):
return U.softmax(self.logits, axis=-1)
def logp(self, x):
return -tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
def kl(self, other):
a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)
a1 = other.logits - U.max(other.logits, axis=1, keepdims=True)
ea0 = tf.exp(a0)
ea1 = tf.exp(a1)
z0 = U.sum(ea0, axis=1, keepdims=True)
z1 = U.sum(ea1, axis=1, keepdims=True)
p0 = ea0 / z0
return U.sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=1)
def entropy(self):
a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)
ea0 = tf.exp(a0)
z0 = U.sum(ea0, axis=1, keepdims=True)
p0 = ea0 / z0
return U.sum(p0 * (tf.log(z0) - a0), axis=1)
def sample(self):
u = tf.random_uniform(tf.shape(self.logits))
return U.softmax(self.logits - tf.log(-tf.log(u)), axis=-1)
@classmethod
def fromflat(cls, flat):
return cls(flat)
class MultiCategoricalPd(Pd):
def __init__(self, low, high, flat):
self.flat = flat
self.low = tf.constant(low, dtype=tf.int32)
self.categoricals = list(map(CategoricalPd, tf.split(flat, high - low + 1, axis=len(flat.get_shape()) - 1)))
def flatparam(self):
return self.flat
def mode(self):
return self.low + tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32)
def logp(self, x):
return tf.add_n([p.logp(px) for p, px in zip(self.categoricals, tf.unstack(x - self.low, axis=len(x.get_shape()) - 1))])
def kl(self, other):
return tf.add_n([
p.kl(q) for p, q in zip(self.categoricals, other.categoricals)
])
def entropy(self):
return tf.add_n([p.entropy() for p in self.categoricals])
def sample(self):
return self.low + tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32)
@classmethod
def fromflat(cls, flat):
return cls(flat)
class SoftMultiCategoricalPd(Pd): # doesn't work yet
def __init__(self, low, high, flat):
self.flat = flat
self.low = tf.constant(low, dtype=tf.float32)
self.categoricals = list(map(SoftCategoricalPd, tf.split(flat, high - low + 1, axis=len(flat.get_shape()) - 1)))
def flatparam(self):
return self.flat
def mode(self):
x = []
for i in range(len(self.categoricals)):
x.append(self.low[i] + self.categoricals[i].mode())
return tf.concat(x, axis=-1)
def logp(self, x):
return tf.add_n([p.logp(px) for p, px in zip(self.categoricals, tf.unstack(x - self.low, axis=len(x.get_shape()) - 1))])
def kl(self, other):
return tf.add_n([
p.kl(q) for p, q in zip(self.categoricals, other.categoricals)
])
def entropy(self):
return tf.add_n([p.entropy() for p in self.categoricals])
def sample(self):
x = []
for i in range(len(self.categoricals)):
x.append(self.low[i] + self.categoricals[i].sample())
return tf.concat(x, axis=-1)
@classmethod
def fromflat(cls, flat):
return cls(flat)
class DiagGaussianPd(Pd):
def __init__(self, flat):
self.flat = flat
mean, logstd = tf.split(axis=1, num_or_size_splits=2, value=flat)
self.mean = mean
self.logstd = logstd
self.std = tf.exp(logstd)
def flatparam(self):
return self.flat
def mode(self):
return self.mean
def logp(self, x):
return - 0.5 * U.sum(tf.square((x - self.mean) / self.std), axis=1) \
- 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) \
- U.sum(self.logstd, axis=1)
def kl(self, other):
assert isinstance(other, DiagGaussianPd)
return U.sum(other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, axis=1)
def entropy(self):
return U.sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), 1)
def sample(self):
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
@classmethod
def fromflat(cls, flat):
return cls(flat)
class BernoulliPd(Pd):
def __init__(self, logits):
self.logits = logits
self.ps = tf.sigmoid(logits)
def flatparam(self):
return self.logits
def mode(self):
return tf.round(self.ps)
def logp(self, x):
return - U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.to_float(x)), axis=1)
def kl(self, other):
return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=1) - U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=1)
def entropy(self):
return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=1)
def sample(self):
p = tf.sigmoid(self.logits)
u = tf.random_uniform(tf.shape(p))
return tf.to_float(math_ops.less(u, p))
@classmethod
def fromflat(cls, flat):
return cls(flat)
def make_pdtype(ac_space):
from gym import spaces
if isinstance(ac_space, spaces.Box):
assert len(ac_space.shape) == 1
return DiagGaussianPdType(ac_space.shape[0])
elif isinstance(ac_space, spaces.Discrete):
# return CategoricalPdType(ac_space.n)
return SoftCategoricalPdType(ac_space.n)
elif isinstance(ac_space, MultiDiscrete):
#return MultiCategoricalPdType(ac_space.low, ac_space.high)
return SoftMultiCategoricalPdType(ac_space.low, ac_space.high)
elif isinstance(ac_space, spaces.MultiBinary):
return BernoulliPdType(ac_space.n)
else:
raise NotImplementedError
def shape_el(v, i):
maybe = v.get_shape()[i]
if maybe is not None:
return maybe
else:
return tf.shape(v)[i]
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/common/tile_images.py
================================================
import numpy as np
def tile_images(img_nhwc):
"""
Tile N images into one big PxQ image
(P,Q) are chosen to be as close as possible, and if N
is square, then P=Q.
input: img_nhwc, list or array of images, ndim=4 once turned into array
n = batch index, h = height, w = width, c = channel
returns:
bigim_HWc, ndarray with ndim=3
"""
img_nhwc = np.asarray(img_nhwc)
N, h, w, c = img_nhwc.shape
H = int(np.ceil(np.sqrt(N)))
W = int(np.ceil(float(N)/H))
img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)])
img_HWhwc = img_nhwc.reshape(H, W, h, w, c)
img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)
img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c)
return img_Hh_Ww_c
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/common/vec_env/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/common/vec_env/vec_env.py
================================================
import contextlib
import os
from abc import ABC, abstractmethod
from common.tile_images import tile_images
class AlreadySteppingError(Exception):
"""
Raised when an asynchronous step is running while
step_async() is called again.
"""
def __init__(self):
msg = 'already running an async step'
Exception.__init__(self, msg)
class NotSteppingError(Exception):
"""
Raised when an asynchronous step is not running but
step_wait() is called.
"""
def __init__(self):
msg = 'not running an async step'
Exception.__init__(self, msg)
class VecEnv(ABC):
"""
An abstract asynchronous, vectorized environment.
Used to batch data from multiple copies of an environment, so that
each observation becomes an batch of observations, and expected action is a batch of actions to
be applied per-environment.
"""
closed = False
viewer = None
metadata = {
'render.modes': ['human', 'rgb_array']
}
def __init__(self, num_envs, observation_space, action_space):
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
@abstractmethod
def reset(self):
"""
Reset all the environments and return an array of
observations, or a dict of observation arrays.
If step_async is still doing work, that work will
be cancelled and step_wait() should not be called
until step_async() is invoked again.
"""
pass
@abstractmethod
def step_async(self, actions):
"""
Tell all the environments to start taking a step
with the given actions.
Call step_wait() to get the results of the step.
You should not call this if a step_async run is
already pending.
"""
pass
@abstractmethod
def step_wait(self):
"""
Wait for the step taken with step_async().
Returns (obs, rews, dones, infos):
- obs: an array of observations, or a dict of
arrays of observations.
- rews: an array of rewards
- dones: an array of "episode done" booleans
- infos: a sequence of info objects
"""
pass
def close_extras(self):
"""
Clean up the extra resources, beyond what's in this base class.
Only runs when not self.closed.
"""
pass
def close(self):
if self.closed:
return
if self.viewer is not None:
self.viewer.close()
self.close_extras()
self.closed = True
def step(self, actions):
"""
Step the environments synchronously.
This is available for backwards compatibility.
"""
self.step_async(actions)
return self.step_wait()
def render(self, mode='human'):
imgs = self.get_images()
bigimg = tile_images(imgs)
if mode == 'human':
self.get_viewer().imshow(bigimg)
return self.get_viewer().isopen
elif mode == 'rgb_array':
return bigimg
else:
raise NotImplementedError
def get_images(self):
"""
Return RGB images from each environment
"""
raise NotImplementedError
@property
def unwrapped(self):
if isinstance(self, VecEnvWrapper):
return self.venv.unwrapped
else:
return self
def get_viewer(self):
if self.viewer is None:
from gym.envs.classic_control import rendering
self.viewer = rendering.SimpleImageViewer()
return self.viewer
class VecEnvWrapper(VecEnv):
"""
An environment wrapper that applies to an entire batch
of environments at once.
"""
def __init__(self, venv, observation_space=None, action_space=None):
self.venv = venv
super().__init__(num_envs=venv.num_envs,
observation_space=observation_space or venv.observation_space,
action_space=action_space or venv.action_space)
def step_async(self, actions):
self.venv.step_async(actions)
@abstractmethod
def reset(self):
pass
@abstractmethod
def step_wait(self):
pass
def close(self):
return self.venv.close()
def render(self, mode='human'):
return self.venv.render(mode=mode)
def get_images(self):
return self.venv.get_images()
def __getattr__(self, name):
if name.startswith('_'):
raise AttributeError("attempted to get missing private attribute '{}'".format(name))
return getattr(self.venv, name)
class VecEnvObservationWrapper(VecEnvWrapper):
@abstractmethod
def process(self, obs):
pass
def reset(self):
obs = self.venv.reset()
return self.process(obs)
def step_wait(self):
obs, rews, dones, infos = self.venv.step_wait()
return self.process(obs), rews, dones, infos
class CloudpickleWrapper(object):
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
"""
def __init__(self, x):
self.x = x
def __getstate__(self):
import cloudpickle
return cloudpickle.dumps(self.x)
def __setstate__(self, ob):
import pickle
self.x = pickle.loads(ob)
@contextlib.contextmanager
def clear_mpi_env_vars():
"""
from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang.
This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing
Processes.
"""
removed_environment = {}
for k, v in list(os.environ.items()):
for prefix in ['OMPI_', 'PMI_']:
if k.startswith(prefix):
removed_environment[k] = v
del os.environ[k]
try:
yield
finally:
os.environ.update(removed_environment)
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/main.py
================================================
import argparse
import torch
import time
import os
import numpy as np
from gym.spaces import Box, Discrete, MultiDiscrete
from pathlib import Path
from torch.autograd import Variable
from tensorboardX import SummaryWriter
from utils.make_env import make_env
from utils.buffer import ReplayBuffer, ReplayBuffer_pre
from utils.env_wrappers import SubprocVecEnv, DummyVecEnv
from policy.maddpg import MADDPG, MADDPG_SNN, MADDPG_ToM, ToM_SA, ToM_S, ToM_self
from tqdm import tqdm
def get_common_args():
parser = argparse.ArgumentParser()
parser.add_argument("--env_id", default='simple_world_comm', type=str,
choices=['simple_tag', 'simple_adversary', 'simple_push', 'simple_world_comm'],
help="Name of environment")
parser.add_argument("--model_name", default='ann', type=str,
help="Name of directory to store " +
"model/training contents") #ToM_SA
parser.add_argument("--seed",
default=1, type=int,
help="Random seed")
parser.add_argument("--cuda_num",
default=7, type=int,
help="device")
parser.add_argument("--output_style",
default='sum', type=str,
choices=['sum', 'voltage'])
parser.add_argument("--n_rollout_threads", default=20, type=int)
parser.add_argument("--n_training_threads", default=6, type=int)
parser.add_argument("--buffer_length", default=int(1e6), type=int)
parser.add_argument("--n_episodes", default=15000, type=int)#
parser.add_argument("--episode_length", default=25, type=int)
parser.add_argument("--steps_per_update", default=100, type=int)
parser.add_argument("--batch_size",
default=1024, type=int,#4
help="Batch size for model training")
parser.add_argument("--n_exploration_eps", default=25000, type=int)
parser.add_argument("--init_noise_scale", default=0.3, type=float)
parser.add_argument("--final_noise_scale", default=0.0, type=float)
parser.add_argument("--save_interval", default=1000, type=int)
parser.add_argument("--hidden_dim", default=64, type=int)
parser.add_argument("--lr", default=0.01, type=float)
parser.add_argument("--tau", default=0.01, type=float)
parser.add_argument("--agent_alg",
default="MADDPG_ToM", type=str,
choices=['MADDPG', 'DDPG', 'MADDPG_SNN', 'MADDPG_ToM', 'ToM_SA', 'ToM_S', 'ToM_self'])
parser.add_argument("--adversary_alg",
default="MADDPG_ToM", type=str,
choices=['MADDPG', 'DDPG', 'MADDPG_SNN', 'MADDPG_ToM', 'ToM_SA', 'ToM_S', 'ToM_self'])
parser.add_argument("--discrete_action",
# default=False, type=bool,
action='store_true')
args = parser.parse_args()
parser.add_argument('--device', type=str, default='cuda:{}'.format(args.cuda_num), help='whether to use the GPU') #'cuda:1'
parser = parser.parse_args()
return parser
USE_CUDA = torch.cuda.is_available()
def make_parallel_env(env_id, n_rollout_threads, seed, discrete_action):
def get_env_fn(rank):
def init_env():
env = make_env(env_id, discrete_action=discrete_action)
env.seed(seed + rank * 1000)
np.random.seed(seed + rank * 1000)
return env
return init_env
if n_rollout_threads == 1:
return DummyVecEnv([get_env_fn(0)])
else:
return SubprocVecEnv([get_env_fn(i) for i in range(n_rollout_threads)])
def run(config):
pbar = tqdm(config.n_episodes)
model_dir = Path('./models') / config.env_id / config.model_name
if not model_dir.exists():
curr_run = 'run1'
else:
exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in
model_dir.iterdir() if
str(folder.name).startswith('run')]
if len(exst_run_nums) == 0:
curr_run = 'run1'
else:
curr_run = 'run%i' % (max(exst_run_nums) + 1)
run_dir = model_dir / curr_run
log_dir = run_dir / 'logs'
os.makedirs(log_dir)
logger = SummaryWriter(str(log_dir))
torch.manual_seed(config.seed)
np.random.seed(config.seed)
if not USE_CUDA:
torch.set_num_threads(config.n_training_threads)
env = make_parallel_env(config.env_id, config.n_rollout_threads, config.seed,
config.discrete_action)
if config.agent_alg == 'MADDPG' or config.agent_alg == 'DDPG':
print('_____MADDPG_____')
maddpg = MADDPG.init_from_env(env, agent_alg=config.agent_alg,
adversary_alg=config.adversary_alg,
tau=config.tau,
lr=config.lr,
hidden_dim=config.hidden_dim,
device=config.device)
elif config.agent_alg == 'MADDPG_SNN':
print('_____MADDPG_SNN_____')
maddpg = MADDPG_SNN.init_from_env(env, agent_alg=config.agent_alg,
adversary_alg=config.adversary_alg,
tau=config.tau,
lr=config.lr,
hidden_dim=config.hidden_dim,
output_style=config.output_style,
device=config.device)
elif config.agent_alg == 'MADDPG_ToM':
print('_____MADDPG_ToM_____')
maddpg = MADDPG_ToM.init_from_env(env, agent_alg=config.agent_alg,
adversary_alg=config.adversary_alg,
tau=config.tau,
lr=config.lr,
hidden_dim=config.hidden_dim,
output_style=config.output_style,
device=config.device)
elif config.agent_alg == 'ToM_SA':
print('_______ToM_SA_______')
maddpg = ToM_SA.init_from_env(env, agent_alg=config.agent_alg,
adversary_alg=config.adversary_alg,
tau=config.tau,
lr=config.lr,
hidden_dim=config.hidden_dim,
output_style=config.output_style,
device=config.device)
elif config.agent_alg == 'ToM_S':
print('_______ToM_S_______')
maddpg = ToM_S.init_from_env(env, agent_alg=config.agent_alg,
adversary_alg=config.adversary_alg,
tau=config.tau,
lr=config.lr,
hidden_dim=config.hidden_dim,
output_style=config.output_style,
device=config.device)
print('_______ToM_self_______')
maddpg = ToM_self.init_from_env(env, agent_alg=config.agent_alg,
adversary_alg=config.adversary_alg,
tau=config.tau,
lr=config.lr,
hidden_dim=config.hidden_dim,
output_style=config.output_style,
device=config.device)
if config.agent_alg == 'ToM_SA'or config.agent_alg == 'ToM_S' or config.agent_alg == 'ToM_self' or config.agent_alg == 'ToM_SB':
replay_buffer = ReplayBuffer_pre(config.buffer_length, maddpg.nagents,
[obsp.shape[0] for obsp in env.observation_space],
[acsp.n if isinstance(acsp, Discrete) else sum(acsp.high - acsp.low + 1)
for acsp in env.action_space],
device=config.device)
else:
replay_buffer = ReplayBuffer(config.buffer_length, maddpg.nagents,
[obsp.shape[0] for obsp in env.observation_space],
[acsp.n if isinstance(acsp, Discrete) else sum(acsp.high - acsp.low + 1)
for acsp in env.action_space],
device=config.device)
t = 0
total_reward = []
for agent_i in range(maddpg.nagents):
total_reward.append([])
for ep_i in range(0, config.n_episodes, config.n_rollout_threads):
# print("Episodes %i-%i of %i" % (ep_i + 1,
# ep_i + 1 + config.n_rollout_threads,
# config.n_episodes))
obs = env.reset()
# obs.shape = (n_rollout_threads, nagent)(nobs), nobs differs per agent so not tensor
maddpg.prep_rollouts(device='cpu')
torch_agent_actions = [torch.zeros((config.n_rollout_threads, 5)) for i in range(maddpg.nagents)]
explr_pct_remaining = max(0, config.n_exploration_eps - ep_i) / config.n_exploration_eps
maddpg.scale_noise(config.final_noise_scale + (config.init_noise_scale - config.final_noise_scale) * explr_pct_remaining)
maddpg.reset_noise()
obs_ep = []
agent_actions_ep = []
rewards_ep = []
next_obs_ep = []
dones_ep = []
for et_i in range(config.episode_length):
torch_agent_actions_pre = torch_agent_actions
torch_agent_actions_pre = [ac.data.numpy() for ac in torch_agent_actions_pre]
# rearrange observations to be per agent, and convert to torch Variable
torch_obs = [Variable(torch.Tensor(np.vstack(obs[:, i])),
requires_grad=False)
for i in range(maddpg.nagents)] #
# get actions as torch Variables
# t1 = time.time()
if config.agent_alg == 'ToM_SA' or config.agent_alg == 'ToM_S' or config.agent_alg == 'ToM_self' or config.agent_alg == 'ToM_SB':
torch_agent_actions = maddpg.step(torch_obs, torch_agent_actions, explore=True)
else:
torch_agent_actions = maddpg.step(torch_obs, explore=True)
# t2 = time.time()
# print('time_step:', t2-t1)
# convert actions to numpy arrays
agent_actions = [ac.data.numpy() for ac in torch_agent_actions] #
# rearrange actions to be per environment
actions = [[ac[i] for ac in agent_actions] for i in range(config.n_rollout_threads)]
next_obs, rewards, dones, infos = env.step(actions)
obs_ep.append(obs) #episode_id,process, n_agents, dim
agent_actions_ep.append(actions) #episode_id, n_agents, process, dim
rewards_ep.append(rewards) #episode_id,process, n_agents,
next_obs_ep.append(next_obs) #episode_id,process, n_agents, dim
dones_ep.append(dones) #episode_id,process, n_agents,
if config.agent_alg == 'ToM_SA' or config.agent_alg == 'ToM_S' or config.agent_alg == 'ToM_self'or config.agent_alg == 'ToM_SB':
replay_buffer.push(torch_agent_actions_pre, obs, agent_actions, rewards, next_obs, dones)
else:
replay_buffer.push(obs, agent_actions, rewards, next_obs, dones)
obs = next_obs
t += config.n_rollout_threads
if (len(replay_buffer) >= config.batch_size and
(t % config.steps_per_update) < config.n_rollout_threads):
if USE_CUDA:
maddpg.prep_training(device='gpu')
else:
maddpg.prep_training(device='cpu')
if config.n_episodes >300:
rollout = 2
else:
rollout = config.n_rollout_threads
for u_i in range(rollout):
for a_i in range(maddpg.nagents):
sample = replay_buffer.sample(config.batch_size,
to_gpu=USE_CUDA)
# t1 = time.time()
maddpg.update(sample, a_i, logger=logger)
# t2 = time.time()
# print('trian_time:', t2-t1, u_i, a_i)
maddpg.update_all_targets()
maddpg.prep_rollouts(device='cpu')
ep_rews = replay_buffer.get_average_rewards(
config.episode_length * config.n_rollout_threads)
for a_i, a_ep_rew in enumerate(ep_rews):
logger.add_scalar('agent%i/mean_episode_rewards' % a_i,
a_ep_rew,
ep_i)
logger.add_scalar('agent_mean/mean_episode_rewards',
np.mean(ep_rews),
ep_i)
if ep_i % config.save_interval < config.n_rollout_threads:
os.makedirs(run_dir / 'incremental', exist_ok=True)
maddpg.save(run_dir / 'incremental' / ('model_ep%i.pt' % (ep_i + 1)))
maddpg.save(run_dir / 'model.pt')
pbar.update(config.n_rollout_threads)
pbar.close()
maddpg.save(run_dir / 'model.pt')
env.close()
logger.export_scalars_to_json(str(log_dir / 'summary.json'))
logger.close()
for a_i, reward in enumerate(total_reward):
reward_dir = str(log_dir) + '/agent{}/mean_episode_rewards'.format(a_i) + '/episode_rewards_{}'.format(config.cuda_num)
os.makedirs(reward_dir)
np.save(reward_dir, reward)
if __name__ == '__main__':
config = get_common_args()
# config.env_id = 'simple_tag'
# # config.model_name = 'ma2c'
config.agent_alg = 'ToM_SB'#
config.adversary_alg = 'ToM_SB'
#
run(config)
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/multiagent/__init__.py
================================================
from gym.envs.registration import register
# Multiagent envs
# ----------------------------------------
register(
id='MultiagentSimple-v0',
entry_point='multiagent.envs:SimpleEnv',
# FIXME(cathywu) currently has to be exactly max_path_length parameters in
# rllab run script
max_episode_steps=100,
)
register(
id='MultiagentSimpleSpeakerListener-v0',
entry_point='multiagent.envs:SimpleSpeakerListenerEnv',
max_episode_steps=100,
)
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/multiagent/scenarios/__init__.py
================================================
import imp
import os.path as osp
def load(name):
pathname = osp.join(osp.dirname(__file__), name)
return imp.load_source('', pathname)
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/multiagent/scenarios/simple.py
================================================
import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self):
world = World()
# add agents
world.agents = [Agent() for i in range(1)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = False
agent.silent = True
# add landmarks
world.landmarks = [Landmark() for i in range(1)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# random properties for agents
for i, agent in enumerate(world.agents):
agent.color = np.array([0.25,0.25,0.25])
# random properties for landmarks
for i, landmark in enumerate(world.landmarks):
landmark.color = np.array([0.75,0.75,0.75])
world.landmarks[0].color = np.array([0.75,0.25,0.25])
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1,+1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-1,+1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def reward(self, agent, world):
dist2 = np.sum(np.square(agent.state.p_pos - world.landmarks[0].state.p_pos))
return -dist2
def observation(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
return np.concatenate([agent.state.p_vel] + entity_pos)
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/multiagent/scenarios/simple_crypto.py
================================================
"""
Scenario:
1 speaker, 2 listeners (one of which is an adversary). Good agents rewarded for proximity to goal, and distance from
adversary to goal. Adversary is rewarded for its distance to the goal.
"""
import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
import random
class CryptoAgent(Agent):
def __init__(self):
super(CryptoAgent, self).__init__()
self.key = None
class Scenario(BaseScenario):
def make_world(self):
world = World()
# set any world properties first
num_agents = 3
num_adversaries = 1
num_landmarks = 2
world.dim_c = 4
# add agents
world.agents = [CryptoAgent() for i in range(num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = False
agent.adversary = True if i < num_adversaries else False
agent.speaker = True if i == 2 else False
agent.movable = False
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# random properties for agents
for i, agent in enumerate(world.agents):
agent.color = np.array([0.25, 0.25, 0.25])
if agent.adversary:
agent.color = np.array([0.75, 0.25, 0.25])
agent.key = None
# random properties for landmarks
color_list = [np.zeros(world.dim_c) for i in world.landmarks]
for i, color in enumerate(color_list):
color[i] += 1
for color, landmark in zip(color_list, world.landmarks):
landmark.color = color
# set goal landmark
goal = np.random.choice(world.landmarks)
world.agents[1].color = goal.color
world.agents[2].key = np.random.choice(world.landmarks).color
for agent in world.agents:
agent.goal_a = goal
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
# returns data for benchmarking purposes
return (agent.state.c, agent.goal_a.color)
# return all agents that are not adversaries
def good_listeners(self, world):
return [agent for agent in world.agents if not agent.adversary and not agent.speaker]
# return all agents that are not adversaries
def good_agents(self, world):
return [agent for agent in world.agents if not agent.adversary]
# return all adversarial agents
def adversaries(self, world):
return [agent for agent in world.agents if agent.adversary]
def reward(self, agent, world):
return self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)
def agent_reward(self, agent, world):
# Agents rewarded if Bob can reconstruct message, but adversary (Eve) cannot
good_listeners = self.good_listeners(world)
adversaries = self.adversaries(world)
good_rew = 0
adv_rew = 0
for a in good_listeners:
if (a.state.c == np.zeros(world.dim_c)).all():
continue
else:
good_rew -= np.sum(np.square(a.state.c - agent.goal_a.color))
for a in adversaries:
if (a.state.c == np.zeros(world.dim_c)).all():
continue
else:
adv_l1 = np.sum(np.square(a.state.c - agent.goal_a.color))
adv_rew += adv_l1
return adv_rew + good_rew
def adversary_reward(self, agent, world):
# Adversary (Eve) is rewarded if it can reconstruct original goal
rew = 0
if not (agent.state.c == np.zeros(world.dim_c)).all():
rew -= np.sum(np.square(agent.state.c - agent.goal_a.color))
return rew
def observation(self, agent, world):
# goal color
goal_color = np.zeros(world.dim_color)
if agent.goal_a is not None:
goal_color = agent.goal_a.color
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# communication of all other agents
comm = []
for other in world.agents:
if other is agent or (other.state.c is None) or not other.speaker: continue
comm.append(other.state.c)
confer = np.array([0])
if world.agents[2].key is None:
confer = np.array([1])
key = np.zeros(world.dim_c)
goal_color = np.zeros(world.dim_c)
else:
key = world.agents[2].key
prnt = False
# speaker
if agent.speaker:
if prnt:
print('speaker')
print(agent.state.c)
print(np.concatenate([goal_color] + [key] + [confer] + [np.random.randn(1)]))
return np.concatenate([goal_color] + [key])
# listener
if not agent.speaker and not agent.adversary:
if prnt:
print('listener')
print(agent.state.c)
print(np.concatenate([key] + comm + [confer]))
return np.concatenate([key] + comm)
if not agent.speaker and agent.adversary:
if prnt:
print('adversary')
print(agent.state.c)
print(np.concatenate(comm + [confer]))
return np.concatenate(comm)
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/multiagent/scenarios/simple_push.py
================================================
import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self):
world = World()
# set any world properties first
world.dim_c = 2
num_agents = 2
num_adversaries = 1
num_landmarks = 2
# add agents
world.agents = [Agent() for i in range(num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = True
agent.silent = True
if i < num_adversaries:
agent.adversary = True
else:
agent.adversary = False
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# random properties for landmarks
for i, landmark in enumerate(world.landmarks):
landmark.color = np.array([0.1, 0.1, 0.1])
landmark.color[i + 1] += 0.8
landmark.index = i
# set goal landmark
goal = np.random.choice(world.landmarks)
for i, agent in enumerate(world.agents):
agent.goal_a = goal
agent.color = np.array([0.25, 0.25, 0.25])
if agent.adversary:
agent.color = np.array([0.75, 0.25, 0.25])
else:
j = goal.index
agent.color[j + 1] += 0.5
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark
return self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)
def agent_reward(self, agent, world):
# the distance to the goal
return -np.sqrt(np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos)))
def adversary_reward(self, agent, world):
# keep the nearest good agents away from the goal
agent_dist = [np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in world.agents if not a.adversary]
pos_rew = min(agent_dist)
#nearest_agent = world.good_agents[np.argmin(agent_dist)]
#neg_rew = np.sqrt(np.sum(np.square(nearest_agent.state.p_pos - agent.state.p_pos)))
neg_rew = np.sqrt(np.sum(np.square(agent.goal_a.state.p_pos - agent.state.p_pos)))
#neg_rew = sum([np.sqrt(np.sum(np.square(a.state.p_pos - agent.state.p_pos))) for a in world.good_agents])
return pos_rew - neg_rew
def observation(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks: # world.entities:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# entity colors
entity_color = []
for entity in world.landmarks: # world.entities:
entity_color.append(entity.color)
# communication of all other agents
comm = []
other_pos = []
for other in world.agents:
if other is agent: continue
comm.append(other.state.c)
other_pos.append(other.state.p_pos - agent.state.p_pos)
if not agent.adversary:
return np.concatenate([agent.state.p_vel] + [agent.goal_a.state.p_pos - agent.state.p_pos] + [agent.color] + entity_pos + entity_color + other_pos)
else:
#other_pos = list(reversed(other_pos)) if random.uniform(0,1) > 0.5 else other_pos # randomize position of other agents in adversary network
return np.concatenate([agent.state.p_vel] + entity_pos + other_pos)
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/multiagent/scenarios/simple_reference.py
================================================
import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self):
world = World()
# set any world properties first
world.dim_c = 10
world.collaborative = True # whether agents share rewards
# add agents
world.agents = [Agent() for i in range(2)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = False
# add landmarks
world.landmarks = [Landmark() for i in range(3)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# assign goals to agents
for agent in world.agents:
agent.goal_a = None
agent.goal_b = None
# want other agent to go to the goal landmark
world.agents[0].goal_a = world.agents[1]
world.agents[0].goal_b = np.random.choice(world.landmarks)
world.agents[1].goal_a = world.agents[0]
world.agents[1].goal_b = np.random.choice(world.landmarks)
# random properties for agents
for i, agent in enumerate(world.agents):
agent.color = np.array([0.25,0.25,0.25])
# random properties for landmarks
world.landmarks[0].color = np.array([0.75,0.25,0.25])
world.landmarks[1].color = np.array([0.25,0.75,0.25])
world.landmarks[2].color = np.array([0.25,0.25,0.75])
# special colors for goals
world.agents[0].goal_a.color = world.agents[0].goal_b.color
world.agents[1].goal_a.color = world.agents[1].goal_b.color
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1,+1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-1,+1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def reward(self, agent, world):
if agent.goal_a is None or agent.goal_b is None:
return 0.0
dist2 = np.sum(np.square(agent.goal_a.state.p_pos - agent.goal_b.state.p_pos))
return -dist2
def observation(self, agent, world):
# goal color
goal_color = [np.zeros(world.dim_color), np.zeros(world.dim_color)]
if agent.goal_b is not None:
goal_color[1] = agent.goal_b.color
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# entity colors
entity_color = []
for entity in world.landmarks:
entity_color.append(entity.color)
# communication of all other agents
comm = []
for other in world.agents:
if other is agent: continue
comm.append(other.state.c)
return np.concatenate([agent.state.p_vel] + entity_pos + [goal_color[1]] + comm)
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/multiagent/scenarios/simple_speaker_listener.py
================================================
import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self):
world = World()
# set any world properties first
world.dim_c = 3
num_landmarks = 3
world.collaborative = True
# add agents
world.agents = [Agent() for i in range(2)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = False
agent.size = 0.075
# speaker
world.agents[0].movable = False
# listener
world.agents[1].silent = True
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
landmark.size = 0.04
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# assign goals to agents
for agent in world.agents:
agent.goal_a = None
agent.goal_b = None
# want listener to go to the goal landmark
world.agents[0].goal_a = world.agents[1]
world.agents[0].goal_b = np.random.choice(world.landmarks)
# random properties for agents
for i, agent in enumerate(world.agents):
agent.color = np.array([0.25,0.25,0.25])
# random properties for landmarks
world.landmarks[0].color = np.array([0.65,0.15,0.15])
world.landmarks[1].color = np.array([0.15,0.65,0.15])
world.landmarks[2].color = np.array([0.15,0.15,0.65])
# special colors for goals
world.agents[0].goal_a.color = world.agents[0].goal_b.color + np.array([0.45, 0.45, 0.45])
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1,+1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-1,+1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
# returns data for benchmarking purposes
return self.reward(agent, reward)
def reward(self, agent, world):
# squared distance from listener to landmark
a = world.agents[0]
dist2 = np.sum(np.square(a.goal_a.state.p_pos - a.goal_b.state.p_pos))
return -dist2
def observation(self, agent, world):
# goal color
goal_color = np.zeros(world.dim_color)
if agent.goal_b is not None:
goal_color = agent.goal_b.color
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# communication of all other agents
comm = []
for other in world.agents:
if other is agent or (other.state.c is None): continue
comm.append(other.state.c)
# speaker
if not agent.movable:
return np.concatenate([goal_color])
# listener
if agent.silent:
return np.concatenate([agent.state.p_vel] + entity_pos + comm)
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/multiagent/scenarios/simple_spread.py
================================================
import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self):
world = World()
# set any world properties first
world.dim_c = 2
num_agents = 3
num_landmarks = 3
world.collaborative = True
# add agents
world.agents = [Agent() for i in range(num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = True
agent.silent = True
agent.size = 0.15
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# random properties for agents
for i, agent in enumerate(world.agents):
agent.color = np.array([0.35, 0.35, 0.85])
# random properties for landmarks
for i, landmark in enumerate(world.landmarks):
landmark.color = np.array([0.25, 0.25, 0.25])
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
rew = 0
collisions = 0
occupied_landmarks = 0
min_dists = 0
for l in world.landmarks:
dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for a in world.agents]
min_dists += min(dists)
rew -= min(dists)
if min(dists) < 0.1:
occupied_landmarks += 1
if agent.collide:
for a in world.agents:
if self.is_collision(a, agent):
rew -= 1
collisions += 1
return (rew, collisions, min_dists, occupied_landmarks)
def is_collision(self, agent1, agent2):
delta_pos = agent1.state.p_pos - agent2.state.p_pos
dist = np.sqrt(np.sum(np.square(delta_pos)))
dist_min = agent1.size + agent2.size
return True if dist < dist_min else False
def reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark, penalized for collisions
rew = 0
for l in world.landmarks:
dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for a in world.agents]
rew -= min(dists)
if agent.collide:
for a in world.agents:
if self.is_collision(a, agent):
rew -= 1
return rew
def observation(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks: # world.entities:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# entity colors
entity_color = []
for entity in world.landmarks: # world.entities:
entity_color.append(entity.color)
# communication of all other agents
comm = []
other_pos = []
for other in world.agents:
if other is agent: continue
comm.append(other.state.c)
other_pos.append(other.state.p_pos - agent.state.p_pos)
return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + comm)
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/multiagent/scenarios/simple_world_comm.py
================================================
import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self):
world = World()
# set any world properties first
world.dim_c = 4
#world.damping = 1
num_good_agents = 2
num_adversaries = 4
num_agents = num_adversaries + num_good_agents
num_landmarks = 1
num_food = 2
num_forests = 2
# add agents
world.agents = [Agent() for i in range(num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = True
agent.leader = True if i == 0 else False
agent.silent = True if i > 0 else False
agent.adversary = True if i < num_adversaries else False
agent.size = 0.075 if agent.adversary else 0.045
agent.accel = 3.0 if agent.adversary else 4.0
#agent.accel = 20.0 if agent.adversary else 25.0
agent.max_speed = 1.0 if agent.adversary else 1.3
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = True
landmark.movable = False
landmark.size = 0.2
landmark.boundary = False
world.food = [Landmark() for i in range(num_food)]
for i, landmark in enumerate(world.food):
landmark.name = 'food %d' % i
landmark.collide = False
landmark.movable = False
landmark.size = 0.03
landmark.boundary = False
world.forests = [Landmark() for i in range(num_forests)]
for i, landmark in enumerate(world.forests):
landmark.name = 'forest %d' % i
landmark.collide = False
landmark.movable = False
landmark.size = 0.3
landmark.boundary = False
world.landmarks += world.food
world.landmarks += world.forests
#world.landmarks += self.set_boundaries(world) # world boundaries now penalized with negative reward
# make initial conditions
self.reset_world(world)
return world
def set_boundaries(self, world):
boundary_list = []
landmark_size = 1
edge = 1 + landmark_size
num_landmarks = int(edge * 2 / landmark_size)
for x_pos in [-edge, edge]:
for i in range(num_landmarks):
l = Landmark()
l.state.p_pos = np.array([x_pos, -1 + i * landmark_size])
boundary_list.append(l)
for y_pos in [-edge, edge]:
for i in range(num_landmarks):
l = Landmark()
l.state.p_pos = np.array([-1 + i * landmark_size, y_pos])
boundary_list.append(l)
for i, l in enumerate(boundary_list):
l.name = 'boundary %d' % i
l.collide = True
l.movable = False
l.boundary = True
l.color = np.array([0.75, 0.75, 0.75])
l.size = landmark_size
l.state.p_vel = np.zeros(world.dim_p)
return boundary_list
def reset_world(self, world):
# random properties for agents
for i, agent in enumerate(world.agents):
agent.color = np.array([0.45, 0.95, 0.45]) if not agent.adversary else np.array([0.95, 0.45, 0.45])
agent.color -= np.array([0.3, 0.3, 0.3]) if agent.leader else np.array([0, 0, 0])
# random properties for landmarks
for i, landmark in enumerate(world.landmarks):
landmark.color = np.array([0.25, 0.25, 0.25])
for i, landmark in enumerate(world.food):
landmark.color = np.array([0.15, 0.15, 0.65])
for i, landmark in enumerate(world.forests):
landmark.color = np.array([0.6, 0.9, 0.6])
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-0.9, +0.9, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
for i, landmark in enumerate(world.food):
landmark.state.p_pos = np.random.uniform(-0.9, +0.9, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
for i, landmark in enumerate(world.forests):
landmark.state.p_pos = np.random.uniform(-0.9, +0.9, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
if agent.adversary:
collisions = 0
for a in self.good_agents(world):
if self.is_collision(a, agent):
collisions += 1
return collisions
else:
return 0
def is_collision(self, agent1, agent2):
delta_pos = agent1.state.p_pos - agent2.state.p_pos
dist = np.sqrt(np.sum(np.square(delta_pos)))
dist_min = agent1.size + agent2.size
return True if dist < dist_min else False
# return all agents that are not adversaries
def good_agents(self, world):
return [agent for agent in world.agents if not agent.adversary]
# return all adversarial agents
def adversaries(self, world):
return [agent for agent in world.agents if agent.adversary]
def reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark
#boundary_reward = -10 if self.outside_boundary(agent) else 0
main_reward = self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)
return main_reward
def outside_boundary(self, agent):
if agent.state.p_pos[0] > 1 or agent.state.p_pos[0] < -1 or agent.state.p_pos[1] > 1 or agent.state.p_pos[1] < -1:
return True
else:
return False
def agent_reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark
rew = 0
shape = False
adversaries = self.adversaries(world)
if shape:
for adv in adversaries:
rew += 0.1 * np.sqrt(np.sum(np.square(agent.state.p_pos - adv.state.p_pos)))
if agent.collide:
for a in adversaries:
if self.is_collision(a, agent):
rew -= 5
def bound(x):
if x < 0.9:
return 0
if x < 1.0:
return (x - 0.9) * 10
return min(np.exp(2 * x - 2), 10) # 1 + (x - 1) * (x - 1)
for p in range(world.dim_p):
x = abs(agent.state.p_pos[p])
rew -= 2 * bound(x)
for food in world.food:
if self.is_collision(agent, food):
rew += 2
rew += 0.05 * min([np.sqrt(np.sum(np.square(food.state.p_pos - agent.state.p_pos))) for food in world.food])
return rew
def adversary_reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark
rew = 0
shape = True
agents = self.good_agents(world)
adversaries = self.adversaries(world)
if shape:
rew -= 0.1 * min([np.sqrt(np.sum(np.square(a.state.p_pos - agent.state.p_pos))) for a in agents])
if agent.collide:
for ag in agents:
for adv in adversaries:
if self.is_collision(ag, adv):
rew += 5
return rew
def observation2(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
if not entity.boundary:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
food_pos = []
for entity in world.food:
if not entity.boundary:
food_pos.append(entity.state.p_pos - agent.state.p_pos)
# communication of all other agents
comm = []
other_pos = []
other_vel = []
for other in world.agents:
if other is agent: continue
comm.append(other.state.c)
other_pos.append(other.state.p_pos - agent.state.p_pos)
if not other.adversary:
other_vel.append(other.state.p_vel)
return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel)
def observation(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
if not entity.boundary:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
in_forest = [np.array([-1]), np.array([-1])]
inf1 = False
inf2 = False
if self.is_collision(agent, world.forests[0]):
in_forest[0] = np.array([1])
inf1= True
if self.is_collision(agent, world.forests[1]):
in_forest[1] = np.array([1])
inf2 = True
food_pos = []
for entity in world.food:
if not entity.boundary:
food_pos.append(entity.state.p_pos - agent.state.p_pos)
# communication of all other agents
comm = []
other_pos = []
other_vel = []
for other in world.agents:
if other is agent: continue
comm.append(other.state.c)
oth_f1 = self.is_collision(other, world.forests[0])
oth_f2 = self.is_collision(other, world.forests[1])
if (inf1 and oth_f1) or (inf2 and oth_f2) or (not inf1 and not oth_f1 and not inf2 and not oth_f2) or agent.leader: #without forest vis
other_pos.append(other.state.p_pos - agent.state.p_pos)
if not other.adversary:
other_vel.append(other.state.p_vel)
else:
other_pos.append([0, 0])
if not other.adversary:
other_vel.append([0, 0])
# to tell the pred when the prey are in the forest
prey_forest = []
ga = self.good_agents(world)
for a in ga:
if any([self.is_collision(a, f) for f in world.forests]):
prey_forest.append(np.array([1]))
else:
prey_forest.append(np.array([-1]))
# to tell leader when pred are in forest
prey_forest_lead = []
for f in world.forests:
if any([self.is_collision(a, f) for a in ga]):
prey_forest_lead.append(np.array([1]))
else:
prey_forest_lead.append(np.array([-1]))
comm = [world.agents[0].state.c]
if agent.adversary and not agent.leader:
return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + in_forest + comm)
if agent.leader:
return np.concatenate(
[agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + in_forest + comm)
else:
return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + in_forest + other_vel)
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/policy/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/policy/maddpg.py
================================================
import torch
from torch.optim import Adam
import torch.nn.functional as F
from gym.spaces import Box, Discrete, MultiDiscrete
from multiagent.multi_discrete import MultiDiscrete
from utils.networks import MLPNetwork, SNNNetwork
from utils.misc import soft_update, average_gradients, onehot_from_logits, gumbel_softmax
from agents.agents import DDPGAgent, DDPGAgent_RNN, DDPGAgent_SNN, DDPGAgent_ToM
# from commom.distributions import make_pdtype
import time
MSELoss = torch.nn.MSELoss()
class MADDPG(object):
"""
Wrapper class for DDPG-esque (i.e. also MADDPG) agents in multi-agent task
"""
def __init__(self, agent_init_params, alg_types, device,
gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64,
discrete_action=False):
"""
Inputs:
agent_init_params (list of dict): List of dicts with parameters to
initialize each agent
num_in_pol (int): Input dimensions to policy
num_out_pol (int): Output dimensions to policy
num_in_critic (int): Input dimensions to critic
alg_types (list of str): Learning algorithm for each agent (DDPG
or MADDPG)
gamma (float): Discount factor
tau (float): Target update rate
lr (float): Learning rate for policy and critic
hidden_dim (int): Number of hidden dimensions for networks
discrete_action (bool): Whether or not to use discrete action space
"""
self.device = device
self.nagents = len(alg_types)
self.alg_types = alg_types
self.agents = [DDPGAgent(lr=lr, discrete_action=discrete_action,
hidden_dim=hidden_dim,
**params)
for params in agent_init_params]
self.agent_init_params = agent_init_params
self.gamma = gamma
self.tau = tau
self.lr = lr
self.discrete_action = discrete_action
self.pol_dev = 'cpu' # device for policies
self.critic_dev = 'cpu' # device for critics
self.trgt_pol_dev = 'cpu' # device for target policies
self.trgt_critic_dev = 'cpu' # device for target critics
self.niter = 0
@property
def policies(self):
return [a.policy for a in self.agents]
@property
def target_policies(self):
return [a.target_policy for a in self.agents]
def scale_noise(self, scale):
"""
Scale noise for each agent
Inputs:
scale (float): scale of noise
"""
for a in self.agents:
a.scale_noise(scale)
def reset_noise(self):
for a in self.agents:
a.reset_noise()
def step(self, observations, explore=False):
"""
Take a step forward in environment with all agents
Inputs:
observations: List of observations for each agent
explore (boolean): Whether or not to add exploration noise
Outputs:
actions: List of actions for each agent
"""
return [a.step(obs, explore=explore) for a, obs in zip(self.agents,
observations)]
def update(self, sample, agent_i, parallel=False, logger=None):
"""
Update parameters of agent model based on sample from replay buffer
Inputs:
sample: tuple of (observations, actions, rewards, next
observations, and episode end masks) sampled randomly from
the replay buffer. Each is a list with entries
corresponding to each agent
agent_i (int): index of agent to update
parallel (bool): If true, will average gradients across threads
logger (SummaryWriter from Tensorboard-Pytorch):
If passed in, important quantities will be logged
"""
obs, acs, rews, next_obs, dones = sample
curr_agent = self.agents[agent_i]
curr_agent.critic_optimizer.zero_grad()
if self.alg_types[agent_i] == 'MADDPG':
if self.discrete_action: # one-hot encode action
all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in
zip(self.target_policies, next_obs)]
else:
all_trgt_acs = [pi(nobs) for pi, nobs in zip(self.target_policies,
next_obs)]
trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)
else: # DDPG
if self.discrete_action:
trgt_vf_in = torch.cat((next_obs[agent_i],
onehot_from_logits(
curr_agent.target_policy(
next_obs[agent_i]))),
dim=1)
else:
trgt_vf_in = torch.cat((next_obs[agent_i],
curr_agent.target_policy(next_obs[agent_i])),
dim=1)
target_value = (rews[agent_i].view(-1, 1) + self.gamma *
curr_agent.target_critic(trgt_vf_in) *
(1 - dones[agent_i].view(-1, 1)))
if self.alg_types[agent_i] == 'MADDPG':
vf_in = torch.cat((*obs, *acs), dim=1)
else: # DDPG
vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1)
actual_value = curr_agent.critic(vf_in)
vf_loss = MSELoss(actual_value, target_value.detach())
vf_loss.backward()
if parallel:
average_gradients(curr_agent.critic)
torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)
curr_agent.critic_optimizer.step()
curr_agent.policy_optimizer.zero_grad()
if self.discrete_action:
# Forward pass as if onehot (hard=True) but backprop through a differentiable
# Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
# through discrete categorical samples, but I'm not sure if that is
# correct since it removes the assumption of a deterministic policy for
# DDPG. Regardless, discrete policies don't seem to learn properly without it.
curr_pol_out = curr_agent.policy(obs[agent_i])
curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
else:
curr_pol_out = curr_agent.policy(obs[agent_i])
curr_pol_vf_in = curr_pol_out
if self.alg_types[agent_i] == 'MADDPG':
all_pol_acs = []
for i, pi, ob in zip(range(self.nagents), self.policies, obs):
if i == agent_i:
all_pol_acs.append(curr_pol_vf_in)
elif self.discrete_action:
all_pol_acs.append(onehot_from_logits(pi(ob)))
else:
all_pol_acs.append(pi(ob))
vf_in = torch.cat((*obs, *all_pol_acs), dim=1)
else: # DDPG
vf_in = torch.cat((obs[agent_i], curr_pol_vf_in),
dim=1)
pol_loss = -curr_agent.critic(vf_in).mean()
pol_loss += (curr_pol_out**2).mean() * 1e-3
pol_loss.backward()
if parallel:
average_gradients(curr_agent.policy)
torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5)
curr_agent.policy_optimizer.step()
if logger is not None:
logger.add_scalars('agent%i/losses' % agent_i,
{'vf_loss': vf_loss,
'pol_loss': pol_loss},
self.niter)
def update_all_targets(self):
"""
Update all target networks (called after normal updates have been
performed for each agent)
"""
for a in self.agents:
soft_update(a.target_critic, a.critic, self.tau)
soft_update(a.target_policy, a.policy, self.tau)
self.niter += 1
def prep_training(self, device='gpu'):
for a in self.agents:
a.policy.train()
a.critic.train()
a.target_policy.train()
a.target_critic.train()
if device == 'gpu':
fn = lambda x: x.to(torch.device(self.device))
else:
fn = lambda x: x.cpu()
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
if not self.critic_dev == device:
for a in self.agents:
a.critic = fn(a.critic)
self.critic_dev = device
if not self.trgt_pol_dev == device:
for a in self.agents:
a.target_policy = fn(a.target_policy)
self.trgt_pol_dev = device
if not self.trgt_critic_dev == device:
for a in self.agents:
a.target_critic = fn(a.target_critic)
self.trgt_critic_dev = device
def prep_rollouts(self, device='cpu'):
for a in self.agents:
a.policy.eval()
if device == 'gpu':
fn = lambda x: x.to(torch.device(self.device))
else:
fn = lambda x: x.cpu()
# only need main policy for rollouts
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
def save(self, filename):
"""
Save trained parameters of all agents into one file
"""
self.prep_training(device='cpu') # move parameters to CPU before saving
save_dict = {'init_dict': self.init_dict,
'agent_params': [a.get_params() for a in self.agents]}
torch.save(save_dict, filename)
@classmethod
def init_from_env(cls, env, device, agent_alg="MADDPG", adversary_alg="MADDPG",
gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64):
"""
Instantiate instance of this class from multi-agent environment
"""
agent_init_params = []
alg_types = [adversary_alg if atype == 'adversary' else agent_alg for
atype in env.agent_types]
for acsp, obsp, algtype in zip(env.action_space, env.observation_space,
alg_types):
num_in_pol = obsp.shape[0]
if isinstance(acsp, Box):
discrete_action = False
get_shape = lambda x: x.shape[0]
elif isinstance(acsp, Discrete): # Discrete
discrete_action = True
get_shape = lambda x: x.n
elif isinstance(acsp, MultiDiscrete):
discrete_action = True
get_shape = lambda x: sum(x.high - x.low + 1)
num_out_pol = get_shape(acsp)
if algtype == "MADDPG":
num_in_critic = 0
for oobsp in env.observation_space:
num_in_critic += oobsp.shape[0]
for oacsp in env.action_space:
if isinstance(oacsp, Box):
discrete_action = False
get_shape = lambda x: x.shape[0]
elif isinstance(oacsp, Discrete): # Discrete
discrete_action = True
get_shape = lambda x: x.n
elif isinstance(oacsp, MultiDiscrete):
discrete_action = True
get_shape = lambda x: sum(x.high - x.low + 1)
num_in_critic += get_shape(oacsp)
else:
num_in_critic = obsp.shape[0] + get_shape(acsp)
agent_init_params.append({'num_in_pol': num_in_pol,
'num_out_pol': num_out_pol,
'num_in_critic': num_in_critic})
init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr,
'hidden_dim': hidden_dim,
'alg_types': alg_types,
'agent_init_params': agent_init_params,
'discrete_action': discrete_action,
'device': device}
instance = cls(**init_dict)
instance.init_dict = init_dict
return instance
@classmethod
def init_from_save(cls, filename):
"""
Instantiate instance of this class from file created by 'save' method
"""
save_dict = torch.load(filename)
instance = cls(**save_dict['init_dict'])
instance.init_dict = save_dict['init_dict']
for a, params in zip(instance.agents, save_dict['agent_params']):
a.load_params(params)
return instance
class MADDPG_SNN(object):
"""
Wrapper class for DDPG-esque (i.e. also MADDPG) agents in multi-agent task
"""
def __init__(self, agent_init_params, alg_types,output_style, device,
gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64,
discrete_action=False):
"""
Inputs:
agent_init_params (list of dict): List of dicts with parameters to
initialize each agent
num_in_pol (int): Input dimensions to policy
num_out_pol (int): Output dimensions to policy
num_in_critic (int): Input dimensions to critic
alg_types (list of str): Learning algorithm for each agent (DDPG
or MADDPG)
gamma (float): Discount factor
tau (float): Target update rate
lr (float): Learning rate for policy and critic
hidden_dim (int): Number of hidden dimensions for networks
discrete_action (bool): Whether or not to use discrete action space
"""
self.device = device
self.nagents = len(alg_types)
self.alg_types = alg_types
self.agents = [DDPGAgent_SNN(lr=lr, discrete_action=discrete_action,
hidden_dim=hidden_dim,
**params, output_style=output_style)
for params in agent_init_params]
self.agent_init_params = agent_init_params
self.gamma = gamma
self.tau = tau
self.lr = lr
self.discrete_action = discrete_action
self.pol_dev = 'cpu' # device for policies
self.critic_dev = 'cpu' # device for critics
self.trgt_pol_dev = 'cpu' # device for target policies
self.trgt_critic_dev = 'cpu' # device for target critics
self.niter = 0
@property
def policies(self):
return [a.policy for a in self.agents]
@property
def target_policies(self):
return [a.target_policy for a in self.agents]
def scale_noise(self, scale):
"""
Scale noise for each agent
Inputs:
scale (float): scale of noise
"""
for a in self.agents:
a.scale_noise(scale)
def reset_noise(self):
for a in self.agents:
a.reset_noise()
def step(self, observations, explore=False):
"""
Take a step forward in environment with all agents
Inputs:
observations: List of observations for each agent
explore (boolean): Whether or not to add exploration noise
Outputs:
actions: List of actions for each agent
"""
return [a.step(obs, explore=explore) for a, obs in zip(self.agents,
observations)]
def update(self, sample, agent_i, parallel=False, logger=None):
"""
Update parameters of agent model based on sample from replay buffer
Inputs:
sample: tuple of (observations, actions, rewards, next
observations, and episode end masks) sampled randomly from
the replay buffer. Each is a list with entries
corresponding to each agent
agent_i (int): index of agent to update
parallel (bool): If true, will average gradients across threads
logger (SummaryWriter from Tensorboard-Pytorch):
If passed in, important quantities will be logged
"""
obs, acs, rews, next_obs, dones = sample
curr_agent = self.agents[agent_i]
curr_agent.critic_optimizer.zero_grad()
if self.alg_types[agent_i] == 'MADDPG_SNN':
all_trgt_acs = []
if self.discrete_action: # one-hot encode action
all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in
zip(self.target_policies, next_obs)]
# for nobs in next_obs:
# if nobs.shape[1] == next_obs[agent_i].shape[1]:
# all_trgt_acs.append(onehot_from_logits(self.target_policies[agent_i](nobs)))
# else:
# if next_obs[agent_i].shape[1] - nobs[:][3].shape[0] > 0 :
# a = torch.zeros((nobs.shape[0], next_obs[agent_i].shape[1] - nobs[:][3].shape[0]))
# a = a.to(torch.device(self.device))
# obs_good = torch.cat((nobs, a), 1)
# all_trgt_acs.append(onehot_from_logits(self.target_policies[agent_i](obs_good)))
# else:
# all_trgt_acs.append(onehot_from_logits(self.target_policies[agent_i](nobs[:, :next_obs[agent_i].shape[1]])))
# all_trgt_acs = [onehot_from_logits(self.target_policies[agent_i](nobs)) for nobs in
# next_obs]
else:
# all_trgt_acs = [pi(nobs) for pi, nobs in zip(self.target_policies,
# next_obs)]
all_trgt_acs = [self.target_policies[agent_i](nobs) for nobs in next_obs] #self-experience
trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)
else: # DDPG
if self.discrete_action:
trgt_vf_in = torch.cat((next_obs[agent_i],
onehot_from_logits(
curr_agent.target_policy(
next_obs[agent_i]))),
dim=1)
else:
trgt_vf_in = torch.cat((next_obs[agent_i],
curr_agent.target_policy(next_obs[agent_i])),
dim=1)
target_value = (rews[agent_i].view(-1, 1) + self.gamma *
curr_agent.target_critic(trgt_vf_in) *
(1 - dones[agent_i].view(-1, 1)))
if self.alg_types[agent_i] == 'MADDPG_SNN':
vf_in = torch.cat((*obs, *acs), dim=1)
else: # DDPG
vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1)
actual_value = curr_agent.critic(vf_in)
vf_loss = MSELoss(actual_value, target_value.detach())
vf_loss.backward()
if parallel:
average_gradients(curr_agent.critic)
torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)
curr_agent.critic_optimizer.step()
curr_agent.policy_optimizer.zero_grad()
if self.discrete_action:
# Forward pass as if onehot (hard=True) but backprop through a differentiable
# Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
# through discrete categorical samples, but I'm not sure if that is
# correct since it removes the assumption of a deterministic policy for
# DDPG. Regardless, discrete policies don't seem to learn properly without it.
curr_pol_out = curr_agent.policy(obs[agent_i])
curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
else:
curr_pol_out = curr_agent.policy(obs[agent_i])
curr_pol_vf_in = curr_pol_out
if self.alg_types[agent_i] == 'MADDPG_SNN':
all_pol_acs = []
for i, pi, ob in zip(range(self.nagents), self.policies, obs):
if i == agent_i:
all_pol_acs.append(curr_pol_vf_in)
elif self.discrete_action:
all_pol_acs.append(onehot_from_logits(pi(ob)))
else:
all_pol_acs.append(pi(ob))
vf_in = torch.cat((*obs, *all_pol_acs), dim=1)
else: # DDPG
vf_in = torch.cat((obs[agent_i], curr_pol_vf_in),
dim=1)
pol_loss = -curr_agent.critic(vf_in).mean()
pol_loss += (curr_pol_out**2).mean() * 1e-3
pol_loss.backward()
if parallel:
average_gradients(curr_agent.policy)
torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5)
curr_agent.policy_optimizer.step()
if logger is not None:
logger.add_scalars('agent%i/losses' % agent_i,
{'vf_loss': vf_loss,
'pol_loss': pol_loss},
self.niter)
def update_all_targets(self):
"""
Update all target networks (called after normal updates have been
performed for each agent)
"""
for a in self.agents:
soft_update(a.target_critic, a.critic, self.tau)
soft_update(a.target_policy, a.policy, self.tau)
self.niter += 1
def prep_training(self, device='gpu'):
for a in self.agents:
a.policy.train()
a.critic.train()
a.target_policy.train()
a.target_critic.train()
if device == 'gpu':
fn = lambda x: x.to(torch.device(self.device))
else:
fn = lambda x: x.cpu()
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
if not self.critic_dev == device:
for a in self.agents:
a.critic = fn(a.critic)
self.critic_dev = device
if not self.trgt_pol_dev == device:
for a in self.agents:
a.target_policy = fn(a.target_policy)
self.trgt_pol_dev = device
if not self.trgt_critic_dev == device:
for a in self.agents:
a.target_critic = fn(a.target_critic)
self.trgt_critic_dev = device
def prep_rollouts(self, device='cpu'):
for a in self.agents:
a.policy.eval()
if device == 'gpu':
fn = lambda x: x.to(torch.device(self.device))
else:
fn = lambda x: x.cpu()
# only need main policy for rollouts
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
def save(self, filename):
"""
Save trained parameters of all agents into one file
"""
self.prep_training(device='cpu') # move parameters to CPU before saving
save_dict = {'init_dict': self.init_dict,
'agent_params': [a.get_params() for a in self.agents]}
torch.save(save_dict, filename)
@classmethod
def init_from_env(cls, env, device, agent_alg="MADDPG_SNN", adversary_alg="MADDPG_SNN",
gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64, output_style='sum'):
# def init_from_env(cls, env, agent_alg="MADDPG_SNN", adversary_alg="MADDPG_SNN",
# gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64): #eval
"""
Instantiate instance of this class from multi-agent environment
"""
agent_init_params = []
alg_types = [adversary_alg if atype == 'adversary' else agent_alg for
atype in env.agent_types]
for acsp, obsp, algtype in zip(env.action_space, env.observation_space,
alg_types):
num_in_pol = obsp.shape[0]
if isinstance(acsp, Box):
discrete_action = False
get_shape = lambda x: x.shape[0]
elif isinstance(acsp, Discrete): # Discrete
discrete_action = True
get_shape = lambda x: x.n
elif isinstance(acsp, MultiDiscrete):
discrete_action = True
get_shape = lambda x: sum(x.high - x.low + 1)
num_out_pol = get_shape(acsp)
if algtype == "MADDPG_SNN":
num_in_critic = 0
for oobsp in env.observation_space:
num_in_critic += oobsp.shape[0]
for oacsp in env.action_space:
if isinstance(oacsp, Box):
discrete_action = False
get_shape = lambda x: x.shape[0]
elif isinstance(oacsp, Discrete): # Discrete
discrete_action = True
get_shape = lambda x: x.n
elif isinstance(oacsp, MultiDiscrete):
discrete_action = True
get_shape = lambda x: sum(x.high - x.low + 1)
num_in_critic += get_shape(oacsp)
else:
num_in_critic = obsp.shape[0] + get_shape(acsp)
agent_init_params.append({'num_in_pol': num_in_pol,
'num_out_pol': num_out_pol,
'num_in_critic': num_in_critic})
init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr,
'hidden_dim': hidden_dim,
'alg_types': alg_types,
'agent_init_params': agent_init_params,
'discrete_action': discrete_action,
'output_style': output_style,
'device': device}
instance = cls(**init_dict)
instance.init_dict = init_dict
return instance
@classmethod
def init_from_save(cls, filename):
"""
Instantiate instance of this class from file created by 'save' method
"""
save_dict = torch.load(filename)
instance = cls(**save_dict['init_dict'])
instance.init_dict = save_dict['init_dict']
for a, params in zip(instance.agents, save_dict['agent_params']):
a.load_params(params)
return instance
class MADDPG_ToM(object):
"""
Wrapper class for DDPG-esque (i.e. also MADDPG) agents in multi-agent task
"""
def __init__(self, agent_init_params, alg_types, output_style, device,
gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64,
discrete_action=False):
"""
Inputs:
agent_init_params (list of dict): List of dicts with parameters to
initialize each agent
num_in_pol (int): Input dimensions to policy
num_out_pol (int): Output dimensions to policy
num_in_critic (int): Input dimensions to critic
alg_types (list of str): Learning algorithm for each agent (DDPG
or MADDPG)
gamma (float): Discount factor
tau (float): Target update rate
lr (float): Learning rate for policy and critic
hidden_dim (int): Number of hidden dimensions for networks
discrete_action (bool): Whether or not to use discrete action space
"""
self.device = device
self.nagents = len(alg_types)
self.alg_types = alg_types
self.agents = [DDPGAgent_ToM(lr=lr, discrete_action=discrete_action,
hidden_dim=hidden_dim,
**params, output_style=output_style,
num_agents=self.nagents,
device=self.device)
for params in agent_init_params]
self.agent_init_params = agent_init_params
if self.nagents == 6:
self.mle_base = [SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14, #simple_com
self.agent_init_params[3]['num_out_pol'], #adv self-self
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14,
self.agent_init_params[3]['num_out_pol'], #adv self-other
hidden_dim=hidden_dim, output_style=output_style),
# SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14,
# self.agent_init_params[3]['num_out_pol'], # adv self-other
# hidden_dim=hidden_dim, output_style=output_style),
# SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14,
# self.agent_init_params[3]['num_out_pol'],
# hidden_dim=hidden_dim, output_style=output_style), ##agent self-other
]
if self.nagents == 4:
self.mle_base = [SNNNetwork(self.agent_init_params[0]['num_in_mle'] - 2, #simple_tag
self.agent_init_params[0]['num_out_pol'], #adv self-self
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 2,
self.agent_init_params[3]['num_out_pol'], #adv self-other
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 2,
self.agent_init_params[3]['num_out_pol'],
hidden_dim=hidden_dim, output_style=output_style), ##agent self-other
]
elif self.nagents == 3:
self.mle_base = [SNNNetwork(self.agent_init_params[1]['num_in_mle'], #simple_adv
self.agent_init_params[1]['num_out_pol'], #adv self-other
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[1]['num_in_mle'],
self.agent_init_params[1]['num_out_pol'], #agent self-self
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[1]['num_in_mle'],
self.agent_init_params[1]['num_out_pol'],
hidden_dim=hidden_dim, output_style=output_style), ##agent self-other
]
elif self.nagents == 2:
self.mle_base = [SNNNetwork(self.agent_init_params[0]['num_in_mle']-2, #simple_push
self.agent_init_params[0]['num_out_pol'], #adv self-other
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[1]['num_in_mle']-2,
self.agent_init_params[1]['num_out_pol'],
hidden_dim=hidden_dim, output_style=output_style), ##agent self-other
]
self.mle_opts = [Adam(self.mle_base[i].parameters(), lr=lr) for i in range(len(self.mle_base))]
self.gamma = gamma
self.tau = tau
self.lr = lr
self.discrete_action = discrete_action
self.pol_dev = 'cpu' # device for policies
self.critic_dev = 'cpu' # device for critics
self.trgt_pol_dev = 'cpu' # device for target policies
self.trgt_critic_dev = 'cpu' # device for target critics
self.mle_dev = 'cpu'
self.niter = 0
@property
def policies(self):
return [a.policy for a in self.agents]
@property
def target_policies(self):
return [a.target_policy for a in self.agents]
def scale_noise(self, scale):
"""
Scale noise for each agent
Inputs:
scale (float): scale of noise
"""
for a in self.agents:
a.scale_noise(scale)
def reset_noise(self):
for a in self.agents:
a.reset_noise()
def step(self, observations, explore=False): #simple_tag
"""
Take a step forward in environment with all agents
Inputs:
observations: List of observations for each agent
explore (boolean): Whether or not to add exploration noise
Outputs:
actions: List of actions for each agent
"""
# t1 = time.time()
observations_ = observations.copy()
for agent_i, obs in enumerate(observations):
obs_ = observations_.copy()
obs_.pop(agent_i)
# actions = [self.agents[agent_i].mle[j].cpu()(observations[agent_i]) for j, obs_j in enumerate(obs_)]
# observations[agent_i] = torch.cat((observations[agent_i], torch.cat(actions, 1)), 1)
if self.nagents == 6:
if agent_i < 4:
self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0],self.mle_base[0], self.mle_base[1], self.mle_base[1]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(obs_j[:,2:]), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(obs_j[:, 4:24].to(self.device)),
hard=True).cpu()
for j, obs_j in enumerate(obs_)]
else:
self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[1]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(obs_j[:,2:-2]), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(obs_j[:, 4:24].to(self.device)),
hard=True).cpu()
for j, obs_j in enumerate(obs_)]
if self.nagents == 4:
if agent_i < 3:
self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0], self.mle_base[1]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(obs_j[:,2:]), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(obs_j[:, 2:].to(self.device)),
hard=True).cpu()
for j, obs_j in enumerate(obs_)]
elif agent_i == 3:
self.agents[agent_i].mle = [self.mle_base[2],self.mle_base[2], self.mle_base[2]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(obs_j[:,2:-2]), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(obs_j[:,2:-2].to(self.device)),
hard=True).cpu()
for j, obs_j in enumerate(obs_)]
elif self.nagents == 3: #simple_adv
if agent_i < 1:
self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0]]
# actions = [gumbel_softmax(
# self.agents[agent_i].mle[j].cpu()(torch.cat((obs_j[:, :2], observations_[agent_i]), 1)),
# hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((obs_j[:, :2],
observations_[agent_i]), 1).to(self.device)), hard=True).cpu()
for j, obs_j in enumerate(obs_)]
elif agent_i >= 1:
self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i]), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(observations_[agent_i].to(self.device)),
hard=True).cpu() for j, obs_j in enumerate(obs_)]
elif self.nagents == 2:
if agent_i < 1:
self.agents[agent_i].mle = [self.mle_base[0]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i][:,2:]), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(observations_[agent_i][:,2:].to(self.device)),
hard=True).cpu() for j, obs_j in enumerate(obs_)]
elif agent_i == 1:
self.agents[agent_i].mle = [self.mle_base[1]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i][:, 2:]), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(observations_[agent_i][:,2:].to(self.device)),
hard=True).cpu() for j, obs_j in enumerate(obs_)]
observations[agent_i] = torch.cat((observations[agent_i], torch.cat(actions, 1)), 1)
# t2 = time.time()
# print('step+time:', t2 - t1)
return [a.step(obs, explore=explore) for a, obs in zip(self.agents,
observations)]
def _get_obs(self, observations):
observations_ = []
for agent_i, obs in enumerate(observations):
obs_ = observations.copy()
obs_.pop(agent_i)
if self.nagents == 6:
if agent_i < 4: #simple_tag
self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[1], self.mle_base[1]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(obs_j[:, 4:24]).detach(), hard=True)
for j, obs_j in enumerate(obs_)]
elif agent_i > 4:
self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[1]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(obs_j[:, 4:24]).detach(), hard=True)
for j, obs_j in enumerate(obs_)]
if self.nagents == 4:
if agent_i < 3: #simple_tag
self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0], self.mle_base[1]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(obs_j[:, 2:]).detach(), hard=True)
for j, obs_j in enumerate(obs_)]
elif agent_i == 3:
self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[2], self.mle_base[2]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(obs_j[:, 2:-2]).detach(), hard=True)
for j, obs_j in enumerate(obs_)]
elif self.nagents == 3:
if agent_i < 1: #simple_adv
self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:,:2],observations[agent_i]),1)).detach(), hard=True)
for j, obs_j in enumerate(obs_)]
elif agent_i >= 1:
self.agents[agent_i].mle = [self.mle_base[2],self.mle_base[1]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(observations[agent_i]).detach(), hard=True)
for j, obs_j in enumerate(obs_)]
elif self.nagents == 2:
if agent_i < 1: #simple_push
self.agents[agent_i].mle = [self.mle_base[0]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(observations[agent_i][:,2:]).detach(), hard=True)
for j, obs_j in enumerate(obs_)]
elif agent_i == 1:
self.agents[agent_i].mle = [self.mle_base[1]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(observations[agent_i][:,2:]).detach(), hard=True)
for j, obs_j in enumerate(obs_)]
observations_.append(torch.cat((observations[agent_i], torch.cat(actions, 1)), 1))
return observations_
def trian_tag(self, agent_i, KL_criterion, obs, parallel, acs):
if agent_i == 0:
self.mle_opts[0].zero_grad()
action_i = self.mle_base[0](obs[0][:, 2:])
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[0].float())
loss.backward(retain_graph=True)
if parallel:
average_gradients(self.mle_base[0])
torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)
self.mle_opts[0].step()
self.mle_opts[1].zero_grad()
action_i = self.mle_base[1](obs[3][:, 2:])
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[3].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[1])
torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)
self.mle_opts[1].step()
self.mle_opts[2].zero_grad()
action_i = self.mle_base[2](obs[0][:, 2:-2])
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[0].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[2])
torch.nn.utils.clip_grad_norm_(self.mle_base[2].parameters(), 20)
self.mle_opts[2].step()
def trian_adv(self, agent_i, KL_criterion, obs, parallel, acs):
if agent_i == 0:
self.mle_opts[0].zero_grad()
action_i = self.mle_base[0](torch.cat((obs[1][:,:2],obs[agent_i]), 1))
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[1].float())
loss.backward(retain_graph=True)
if parallel:
average_gradients(self.mle_base[0])
torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)
self.mle_opts[0].step()
self.mle_opts[1].zero_grad()
action_i = self.mle_base[1](obs[1])
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[1].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[1])
torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)
self.mle_opts[1].step()
self.mle_opts[2].zero_grad()
action_i = self.mle_base[2](obs[1])
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[0].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[2])
torch.nn.utils.clip_grad_norm_(self.mle_base[2].parameters(), 20)
self.mle_opts[2].step()
def trian_push(self, agent_i, KL_criterion, obs, parallel, acs):
if agent_i == 0:
self.mle_opts[0].zero_grad()
action_i = self.mle_base[0](obs[0][:, 2:])
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[1].float())
loss.backward(retain_graph=True)
if parallel:
average_gradients(self.mle_base[0])
torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)
self.mle_opts[0].step()
self.mle_opts[1].zero_grad()
action_i = self.mle_base[1](obs[1][:, 2:])
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[0].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[1])
torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)
self.mle_opts[1].step()
def trian_com(self, agent_i, KL_criterion, obs, parallel, acs):
if agent_i == 0:
self.mle_opts[0].zero_grad()
action_i = self.mle_base[0](obs[1][:, 4:24])
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[1].float())
loss.backward(retain_graph=True)
if parallel:
average_gradients(self.mle_base[0])
torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)
self.mle_opts[0].step()
self.mle_opts[1].zero_grad()
action_i = self.mle_base[1](obs[4][:, 4:24])
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[4].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[1])
torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)
self.mle_opts[1].step()
def update(self, sample, agent_i, parallel=False, logger=None, sample_r=None):
"""
Update parameters of agent model based on sample from replay buffer
Inputs:
sample: tuple of (observations, actions, rewards, next
observations, and episode end masks) sampled randomly from
the replay buffer. Each is a list with entries
corresponding to each agent
agent_i (int): index of agent to update
parallel (bool): If true, will average gradients across threads
logger (SummaryWriter from Tensorboard-Pytorch):
If passed in, important quantities will be logged
"""
# print('___update___')
obs, acs, rews, next_obs, dones = sample
next_obs_ = self._get_obs(next_obs)
obs_ = self._get_obs(obs)
curr_agent = self.agents[agent_i]
# mle
KL_criterion = torch.nn.KLDivLoss(reduction='sum')
# for i in range(len(curr_agent.mle)):
# curr_agent.mle_optimizer[i].zero_grad()
# action_i = curr_agent.mle[i](obs[agent_i]obs[agent_i])
# action_pre = gumbel_softmax(action_i, hard=True)
# loss = KL_criterion(action_pre.float(), acs[i].float())
# loss.backward()
# if parallel:
# average_gradients(curr_agent.mle[i])
# torch.nn.utils.clip_grad_norm_(curr_agent.mle[i].parameters(), 20)
# curr_agent.policy_optimizer.step()
if self.nagents == 6:
self.trian_com(agent_i, KL_criterion, obs, parallel, acs)
if self.nagents == 4:
self.trian_tag(agent_i, KL_criterion, obs, parallel, acs)
elif self.nagents == 3:
self.trian_adv(agent_i, KL_criterion, obs, parallel, acs)
elif self.nagents == 2:
self.trian_push(agent_i, KL_criterion, obs, parallel, acs)
# center critic
curr_agent.critic_optimizer.zero_grad()
if self.alg_types[agent_i] == 'MADDPG_ToM':
all_trgt_acs = []
if self.discrete_action: # one-hot encode action
all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in
zip(self.target_policies, next_obs_)]
trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)
target_value = (rews[agent_i].view(-1, 1) + self.gamma *
curr_agent.target_critic(trgt_vf_in) *
(1 - dones[agent_i].view(-1, 1)))
if self.alg_types[agent_i] == 'MADDPG_ToM':
vf_in = torch.cat((*obs, *acs), dim=1)
actual_value = curr_agent.critic(vf_in)
vf_loss = MSELoss(actual_value, target_value.detach())
vf_loss.backward()
if parallel:
average_gradients(curr_agent.critic)
torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)
curr_agent.critic_optimizer.step()
curr_agent.policy_optimizer.zero_grad()
if self.discrete_action:
# Forward pass as if onehot (hard=True) but backprop through a differentiable
# Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
# through discrete categorical samples, but I'm not sure if that is
# correct since it removes the assumption of a deterministic policy for
# DDPG. Regardless, discrete policies don't seem to learn properly without it.
curr_pol_out = curr_agent.policy(obs_[agent_i])
curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
else:
curr_pol_out = curr_agent.policy(obs[agent_i])
curr_pol_vf_in = curr_pol_out
if self.alg_types[agent_i] == 'MADDPG_ToM':
all_pol_acs = []
for i, pi, ob in zip(range(self.nagents), self.policies, obs_):
if i == agent_i:
all_pol_acs.append(curr_pol_vf_in)
elif self.discrete_action:
all_pol_acs.append(onehot_from_logits(pi(ob)))
else:
all_pol_acs.append(pi(ob))
vf_in = torch.cat((*obs, *all_pol_acs), dim=1)
pol_loss = -curr_agent.critic(vf_in).mean()
pol_loss += (curr_pol_out ** 2).mean() * 1e-3
pol_loss.backward()
if parallel:
average_gradients(curr_agent.policy)
torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5)
# actor
curr_agent.policy_optimizer.step()
if logger is not None:
logger.add_scalars('agent%i/losses' % agent_i,
{'vf_loss': vf_loss,
'pol_loss': pol_loss},
self.niter)
def update_all_targets(self):
"""
Update all target networks (called after normal updates have been
performed for each agent)
"""
for a in self.agents:
soft_update(a.target_critic, a.critic, self.tau)
soft_update(a.target_policy, a.policy, self.tau)
self.niter += 1
def prep_training(self, device='gpu'):
for mle in self.mle_base:
mle.train()
for a in self.agents:
a.policy.train()
a.critic.train()
a.target_policy.train()
a.target_critic.train()
for mle_i in a.mle:
mle_i.train()
if device == 'gpu':
fn = lambda x: x.to(torch.device(self.device))
else:
fn = lambda x: x.cpu()
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
if not self.critic_dev == device:
for a in self.agents:
a.critic = fn(a.critic)
self.critic_dev = device
if not self.trgt_pol_dev == device:
for a in self.agents:
a.target_policy = fn(a.target_policy)
self.trgt_pol_dev = device
if not self.trgt_critic_dev == device:
for a in self.agents:
a.target_critic = fn(a.target_critic)
self.trgt_critic_dev = device
if not self.mle_dev == device:
for i, mle in enumerate(self.mle_base):
self.mle_base[i] = fn(mle)
for a in self.agents:
for i, mle_i in enumerate(a.mle):
a.mle[i] = fn(mle_i)
self.mle_dev = device
def prep_rollouts(self, device='cpu'):
for a in self.agents:
a.policy.eval()
if device == 'gpu':
fn = lambda x: x.to(torch.device(self.device))
else:
fn = lambda x: x.cpu()
# only need main policy for rollouts
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
def save(self, filename):
"""
Save trained parameters of all agents into one file
"""
self.prep_training(device='cpu') # move parameters to CPU before saving
save_dict = {'init_dict': self.init_dict,
'agent_params': [a.get_params() for a in self.agents],
'mle_params': [self.get_params()],}
torch.save(save_dict, filename)
@classmethod
def init_from_env(cls, env, device, agent_alg="MADDPG_ToM", adversary_alg="MADDPG_ToM",
gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64, output_style='sum'):
"""
Instantiate instance of this class from multi-agent environment
"""
agent_init_params = []
alg_types = [adversary_alg if atype == 'adversary' else agent_alg for
atype in env.agent_types]
for acsp, obsp, algtype in zip(env.action_space, env.observation_space,
alg_types):
num_in_pol = obsp.shape[0]
num_in_mle = obsp.shape[0]
if isinstance(acsp, Box):
discrete_action = False
get_shape = lambda x: x.shape[0]
elif isinstance(acsp, Discrete): # Discrete
discrete_action = True
get_shape = lambda x: x.n
elif isinstance(acsp, MultiDiscrete):
discrete_action = True
get_shape = lambda x: sum(x.high - x.low + 1)
num_out_pol = get_shape(acsp)
if algtype == "MADDPG_ToM":
num_in_critic = 0
num_in_pol += (len(env.agent_types)-1) * 5
for oobsp in env.observation_space:
num_in_critic += oobsp.shape[0]
for oacsp in env.action_space:
if isinstance(oacsp, Box):
discrete_action = False
get_shape = lambda x: x.shape[0]
elif isinstance(oacsp, Discrete): # Discrete
discrete_action = True
get_shape = lambda x: x.n
elif isinstance(oacsp, MultiDiscrete):
discrete_action = True
get_shape = lambda x: sum(x.high - x.low + 1)
num_in_critic += get_shape(oacsp)
else:
num_in_critic = obsp.shape[0] + get_shape(acsp)
agent_init_params.append({'num_in_pol': num_in_pol,
'num_out_pol': num_out_pol,
'num_in_critic': num_in_critic,
'num_in_mle': num_in_mle,})
init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr,
'device': device,
'hidden_dim': hidden_dim,
'alg_types': alg_types,
'agent_init_params': agent_init_params,
'discrete_action': discrete_action,
'output_style': output_style}
instance = cls(**init_dict)
instance.init_dict = init_dict
return instance
@classmethod
def init_from_save(cls, filename):
"""
Instantiate instance of this class from file created by 'save' method
"""
save_dict = torch.load(filename)
instance = cls(**save_dict['init_dict'])
instance.init_dict = save_dict['init_dict']
for a, params in zip(instance.agents, save_dict['agent_params']):
a.load_params(params)
for a, params in zip([instance], save_dict['mle_params']):
a.load_params(params)
return instance
def get_params(self):
params = {
}
for i in range(len(self.mle_base)):
params['mle%d'%i] = self.mle_base[i].state_dict()
params['mle_optimizer%d'%i] = self.mle_opts[i].state_dict()
return params
def load_params(self, params):
for i in range(len(self.mle_base)):
self.mle_base[i].load_state_dict(params['mle%d'%i])
self.mle_opts[i].load_state_dict(params['mle_optimizer%d'%i])
class ToM_SA(object):
"""
Wrapper class for DDPG-esque (i.e. also MADDPG) agents in multi-agent task
"""
def __init__(self, agent_init_params, alg_types, output_style, device,
gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64,
discrete_action=False):
"""
Inputs:
agent_init_params (list of dict): List of dicts with parameters to
initialize each agent
num_in_pol (int): Input dimensions to policy
num_out_pol (int): Output dimensions to policy
num_in_critic (int): Input dimensions to critic
alg_types (list of str): Learning algorithm for each agent (DDPG
or MADDPG)
gamma (float): Discount factor
tau (float): Target update rate
lr (float): Learning rate for policy and critic
hidden_dim (int): Number of hidden dimensions for networks
discrete_action (bool): Whether or not to use discrete action space
"""
self.device = device
self.nagents = len(alg_types)
self.alg_types = alg_types
self.agents = [DDPGAgent_ToM(lr=lr, discrete_action=discrete_action,
hidden_dim=hidden_dim,
**params, output_style=output_style,
num_agents=self.nagents,
device=self.device)
for params in agent_init_params]
self.agent_init_params = agent_init_params
if self.nagents == 6:
self.mle_base = [SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5, #simple_com
self.agent_init_params[3]['num_out_pol'], #adv self-self
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,
self.agent_init_params[3]['num_out_pol'], #adv self-other
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,
self.agent_init_params[3]['num_out_pol'], # adv self-other
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,
self.agent_init_params[3]['num_out_pol'],
hidden_dim=hidden_dim, output_style=output_style), ##agent self-other
]
if self.nagents == 4:
self.mle_base = [SNNNetwork(self.agent_init_params[0]['num_in_mle'] - 2 + 5, #simple_tag
self.agent_init_params[0]['num_out_pol'], #adv self-self
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 2 + 5,
self.agent_init_params[3]['num_out_pol'], #adv self-other
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 2 + 5,
self.agent_init_params[3]['num_out_pol'],
hidden_dim=hidden_dim, output_style=output_style), ##agent self-other
]
elif self.nagents == 3:
self.mle_base = [SNNNetwork(self.agent_init_params[1]['num_in_mle'] + 5, #simple_adv
self.agent_init_params[1]['num_out_pol'], #adv self-other
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[1]['num_in_mle'] + 5,
self.agent_init_params[1]['num_out_pol'], #agent self-self
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[1]['num_in_mle'] + 5,
self.agent_init_params[1]['num_out_pol'],
hidden_dim=hidden_dim, output_style=output_style), ##agent self-other
]
elif self.nagents == 2:
self.mle_base = [SNNNetwork(self.agent_init_params[0]['num_in_mle']-2 + 5, #simple_push
self.agent_init_params[0]['num_out_pol'], #adv self-other
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[1]['num_in_mle']-2 + 5,
self.agent_init_params[1]['num_out_pol'],
hidden_dim=hidden_dim, output_style=output_style), ##agent self-other
]
self.mle_opts = [Adam(self.mle_base[i].parameters(), lr=lr) for i in range(len(self.mle_base))]
self.gamma = gamma
self.tau = tau
self.lr = lr
self.discrete_action = discrete_action
self.pol_dev = 'cpu' # device for policies
self.critic_dev = 'cpu' # device for critics
self.trgt_pol_dev = 'cpu' # device for target policies
self.trgt_critic_dev = 'cpu' # device for target critics
self.mle_dev = 'cpu'
self.niter = 0
@property
def policies(self):
return [a.policy for a in self.agents]
@property
def target_policies(self):
return [a.target_policy for a in self.agents]
def scale_noise(self, scale):
"""
Scale noise for each agent
Inputs:
scale (float): scale of noise
"""
for a in self.agents:
a.scale_noise(scale)
def reset_noise(self):
for a in self.agents:
a.reset_noise()
def step(self, observations, actions_pre, explore=False): #simple_tag
"""
Take a step forward in environment with all agents
Inputs:
observations: List of observations for each agent
explore (boolean): Whether or not to add exploration noise
Outputs:
actions: List of actions for each agent
"""
# t1 = time.time()
observations_ = observations.copy()
actions_pre_ = actions_pre.copy()
for agent_i, obs in enumerate(observations):
obs_ = observations_.copy()
acs_pre_ = actions_pre_.copy()
obs_.pop(agent_i)
acs_pre_.pop(agent_i)
# actions = [self.agents[agent_i].mle[j].cpu()(observations[agent_i]) for j, obs_j in enumerate(obs_)]
# observations[agent_i] = torch.cat((observations[agent_i], torch.cat(actions, 1)), 1)
if self.nagents == 6:
if agent_i < 4:
self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0],self.mle_base[0], self.mle_base[1], self.mle_base[1]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(obs_j[:,2:]), hard=True)
# for j, obs_j in enumerate(obs_)]
# t1 = time.time()
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((obs_j[:, 4:24], acs_pre_[j][:,:5]),1).to(self.device)),
# hard=True).cpu()
# for j, obs_j in enumerate(obs_)]
# print(t1 - time.time())
# t1 = time.time()
actions = [torch.cat((obs_j[:, 4:24], acs_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]
b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)), hard=True).cpu()
b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)), hard=True).cpu()
actions = torch.cat((b1[:20], b1[20:40], b1[40:60], b2[:20], b2[20:40]), 1)
# print(t1 - time.time())
# print()
else:
self.agents[agent_i].mle = [self.mle_base[3],self.mle_base[3], self.mle_base[3], self.mle_base[3], self.mle_base[2]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(obs_j[:,2:-2]), hard=True)
# for j, obs_j in enumerate(obs_)]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((obs_j[:, 4:24], acs_pre_[j]),1).to(self.device)),
# hard=True).cpu()
# for j, obs_j in enumerate(obs_)]
actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[1]['num_out_pol']))
for j, obs_j in enumerate(obs_)]
actions = torch.cat(actions,1)
# print()
if self.nagents == 4:
if agent_i < 3:
self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0], self.mle_base[1]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((obs_j[:, 2:], acs_pre_[j]),1).to(self.device)),
hard=True).cpu()
for j, obs_j in enumerate(obs_)]
elif agent_i == 3:
self.agents[agent_i].mle = [self.mle_base[2],self.mle_base[2], self.mle_base[2]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(obs_j[:,2:-2]), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol']))
for j, obs_j in enumerate(obs_)]
elif self.nagents == 3: #simple_adv
actions = []
if agent_i < 1:
self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0]]
actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol']))
for j, obs_j in enumerate(obs_)]
elif agent_i == 1:
self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]
actions.append(gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations_[agent_i],
actions_pre[(0)]), 1).to(self.device)),
hard=True).cpu() )
actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations_[agent_i],
actions_pre[(2)]), 1).to(self.device)),
hard=True).cpu() )
elif agent_i == 2:
self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i]), hard=True)
# for j, obs_j in enumerate(obs_)]
actions.append(gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations_[agent_i],
actions_pre[(0)]), 1).to(self.device)),
hard=True).cpu() )
actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations_[agent_i],
actions_pre[(1)]), 1).to(self.device)),
hard=True).cpu() )
elif self.nagents == 2:
if agent_i < 1:
self.agents[agent_i].mle = [self.mle_base[0]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i][:,2:]), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [torch.zeros((obs_j.shape[0], self.agent_init_params[0]['num_out_pol'])) for j, obs_j in enumerate(obs_)]
elif agent_i == 1:
self.agents[agent_i].mle = [self.mle_base[1]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i][:, 2:]), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((observations_[agent_i][:,2:],
actions_pre[(self.nagents -1 - agent_i)]), 1).to(self.device)),
hard=True).cpu() for j, obs_j in enumerate(obs_)]
if self.nagents == 6:
observations[agent_i] = torch.cat((observations[agent_i], actions), 1)
else:
observations[agent_i] = torch.cat((observations[agent_i], torch.cat(actions, 1)), 1)
# t2 = time.time()
# print('step+time:', t2 - t1)
return [a.step(obs, explore=explore) for a, obs in zip(self.agents,
observations)]
def _get_obs(self, observations, actions_pre):
observations_ = []
actions_pre_ = []
for agent_i, obs in enumerate(observations):
obs_ = observations.copy()
obs_.pop(agent_i)
actions_pre_ = actions_pre.copy()
actions_pre_.pop(agent_i)
if self.nagents == 6:
if agent_i < 4: #simple_comm
self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[1], self.mle_base[1]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:, 4:24], actions_pre_[j][:,:5]),1)).detach(), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [torch.cat((obs_j[:, 4:24], actions_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]
b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)).detach(), hard=True)
b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)).detach(), hard=True)
actions = torch.cat((b1[:1024], b1[1024:2048], b1[2048:3072], b2[:1024], b2[1024:2048]), 1)
# print()
elif agent_i > 4:
self.agents[agent_i].mle = [self.mle_base[3], self.mle_base[3], self.mle_base[3], self.mle_base[3], self.mle_base[2]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(obs_j[:, 4:24]).detach(), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [torch.zeros((obs_j.shape[0], self.agent_init_params[1]['num_out_pol'])).to(
torch.device(self.device)).detach() for j, obs_j in enumerate(obs_)]
actions = torch.cat(actions, 1)
# print()
if self.nagents == 4:
if agent_i < 3: #simple_tag
self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0], self.mle_base[1]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:, 2:], actions_pre_[j]),1)).detach(), hard=True)
for j, obs_j in enumerate(obs_)]
elif agent_i == 3:
self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[2], self.mle_base[2]]
actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol'])).to(torch.device(self.device)).detach()
for j, obs_j in enumerate(obs_)]
elif self.nagents == 3:
actions = []
if agent_i < 1: #simple_adv
# self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:,:2],observations[agent_i]),1)).detach(), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol'])).to(torch.device(self.device)).detach()
for j, obs_j in enumerate(obs_)]
elif agent_i == 1:
self.agents[agent_i].mle = [self.mle_base[2],self.mle_base[1]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(observations[agent_i]).detach(), hard=True)
# for j, obs_j in enumerate(obs_)]
actions.append(
gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations[agent_i],
actions_pre[(0)]), 1).to(self.device)).detach(), hard=True))
actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations[agent_i],
actions_pre[(2)]), 1).to(self.device)).detach(), hard=True))
elif agent_i == 2:
self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]
actions.append(
gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations[agent_i],
actions_pre[(0)]), 1).to(self.device)).detach(), hard=True))
actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations[agent_i],
actions_pre[(1)]), 1).to(self.device)).detach(), hard=True))
elif self.nagents == 2:
if agent_i < 1: #simple_push
self.agents[agent_i].mle = [self.mle_base[0]]
actions = [torch.zeros((obs_j.shape[0], self.agent_init_params[0]['num_out_pol'])).to(torch.device(self.device)).detach() for j, obs_j in
enumerate(obs_)]
elif agent_i == 1:
self.agents[agent_i].mle = [self.mle_base[1]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((observations[agent_i][:,2:],
actions_pre[(self.nagents -1 - agent_i)]), 1)).detach(), hard=True)
for j, obs_j in enumerate(obs_)]
if self.nagents == 6:
observations_.append(torch.cat((observations[agent_i], actions), 1))
else:
observations_.append(torch.cat((observations[agent_i], torch.cat(actions, 1)), 1))
return observations_
def trian_tag(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):
if agent_i == 0:
self.mle_opts[0].zero_grad()
action_i = self.mle_base[0](torch.cat((obs[0][:, 2:], acs_pre[0]),1))#
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[0].float())
loss.backward(retain_graph=True)
if parallel:
average_gradients(self.mle_base[0])
torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)
self.mle_opts[0].step()
self.mle_opts[1].zero_grad()
action_i = self.mle_base[1](torch.cat((obs[3][:, 2:], acs_pre[3]),1))
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[3].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[1])
torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)
self.mle_opts[1].step()
# self.mle_opts[2].zero_grad()
# action_i = self.mle_base[2](obs[0][:, 2:-2])
# action_pre = gumbel_softmax(action_i, hard=True)
# loss = KL_criterion(action_pre.float(), acs[0].float())
# loss.backward()
# if parallel:
# average_gradients(self.mle_base[2])
# torch.nn.utils.clip_grad_norm_(self.mle_base[2].parameters(), 20)
# self.mle_opts[2].step()
def trian_adv(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):
if agent_i == 0:
# self.mle_opts[0].zero_grad()
# action_i = self.mle_base[0](torch.cat((obs[1][:,:2],obs[agent_i]), 1))
# action_pre = gumbel_softmax(action_i, hard=True)
# loss = KL_criterion(action_pre.float(), acs[1].float())
# loss.backward(retain_graph=True)
# if parallel:
# average_gradients(self.mle_base[0])
# torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)
# self.mle_opts[0].step()
self.mle_opts[1].zero_grad()
action_i = self.mle_base[1](torch.cat((obs[1], acs_pre[2]), 1)) #torch.cat((obs[1], acs_pre[2]), 1))
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[1].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[1])
torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)
self.mle_opts[1].step()
self.mle_opts[2].zero_grad()
action_i = self.mle_base[2](torch.cat((obs[1], acs_pre[0]), 1))
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[0].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[2])
torch.nn.utils.clip_grad_norm_(self.mle_base[2].parameters(), 20)
self.mle_opts[2].step()
def trian_push(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):
if agent_i == 0:
# self.mle_opts[0].zero_grad()
# action_i = self.mle_base[0](obs[0][:, 2:]) #torch.cat((obs[agent_i][:,2:], actions[(self.nagents -1 - agent_i)]), 1)
# action_pre = gumbel_softmax(action_i, hard=True)
# loss = KL_criterion(action_pre.float(), acs[1].float())
# loss.backward(retain_graph=True)
# if parallel:
# average_gradients(self.mle_base[0])
# torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)
# self.mle_opts[0].step()
self.mle_opts[1].zero_grad()
action_i = self.mle_base[1](torch.cat((obs[1][:,2:], acs_pre[(0)]), 1)) #obs[1][:, 2:]
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[0].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[1])
torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)
self.mle_opts[1].step()
def trian_com(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):
if agent_i == 0:
self.mle_opts[0].zero_grad()
action_i = self.mle_base[0](torch.cat((obs[1][:, 4:24], acs_pre[(1)]), 1))
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[1].float())
loss.backward(retain_graph=True)
if parallel:
average_gradients(self.mle_base[0])
torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)
self.mle_opts[0].step()
self.mle_opts[1].zero_grad()
action_i = self.mle_base[1](torch.cat((obs[4][:, 4:24], acs_pre[(4)]), 1))
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[4].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[1])
torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)
self.mle_opts[1].step()
def update(self, sample, agent_i, parallel=False, logger=None, sample_r=None):
"""
Update parameters of agent model based on sample from replay buffer
Inputs:
sample: tuple of (observations, actions, rewards, next
observations, and episode end masks) sampled randomly from
the replay buffer. Each is a list with entries
corresponding to each agent
agent_i (int): index of agent to update
parallel (bool): If true, will average gradients across threads
logger (SummaryWriter from Tensorboard-Pytorch):
If passed in, important quantities will be logged
"""
# print('___update___')
acs_pre, obs, acs, rews, next_obs, dones = sample
next_obs_ = self._get_obs(next_obs, acs)
obs_ = self._get_obs(obs, acs_pre)
curr_agent = self.agents[agent_i]
# mle
KL_criterion = torch.nn.KLDivLoss(reduction='sum')
# for i in range(len(curr_agent.mle)):
# curr_agent.mle_optimizer[i].zero_grad()
# action_i = curr_agent.mle[i](obs[agent_i]obs[agent_i])
# action_pre = gumbel_softmax(action_i, hard=True)
# loss = KL_criterion(action_pre.float(), acs[i].float())
# loss.backward()
# if parallel:
# average_gradients(curr_agent.mle[i])
# torch.nn.utils.clip_grad_norm_(curr_agent.mle[i].parameters(), 20)
# curr_agent.policy_optimizer.step()
if self.nagents == 6:
self.trian_com(agent_i, KL_criterion, obs, acs_pre, parallel, acs)
elif self.nagents == 4:
self.trian_tag(agent_i, KL_criterion, obs, acs_pre, parallel, acs)
elif self.nagents == 3:
self.trian_adv(agent_i, KL_criterion, obs, acs_pre, parallel, acs)
elif self.nagents == 2:
self.trian_push(agent_i, KL_criterion, obs, acs_pre, parallel, acs)
# center critic
curr_agent.critic_optimizer.zero_grad()
all_trgt_acs = []
if self.discrete_action: # one-hot encode action
all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in
zip(self.target_policies, next_obs_)]
trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)
target_value = (rews[agent_i].view(-1, 1) + self.gamma *
curr_agent.target_critic(trgt_vf_in) *
(1 - dones[agent_i].view(-1, 1)))
vf_in = torch.cat((*obs, *acs), dim=1)
actual_value = curr_agent.critic(vf_in)
vf_loss = MSELoss(actual_value, target_value.detach())
vf_loss.backward()
if parallel:
average_gradients(curr_agent.critic)
torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)
curr_agent.critic_optimizer.step()
curr_agent.policy_optimizer.zero_grad()
if self.discrete_action:
# Forward pass as if onehot (hard=True) but backprop through a differentiable
# Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
# through discrete categorical samples, but I'm not sure if that is
# correct since it removes the assumption of a deterministic policy for
# DDPG. Regardless, discrete policies don't seem to learn properly without it.
curr_pol_out = curr_agent.policy(obs_[agent_i])
curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
else:
curr_pol_out = curr_agent.policy(obs[agent_i])
curr_pol_vf_in = curr_pol_out
all_pol_acs = []
for i, pi, ob in zip(range(self.nagents), self.policies, obs_):
if i == agent_i:
all_pol_acs.append(curr_pol_vf_in)
elif self.discrete_action:
all_pol_acs.append(onehot_from_logits(pi(ob)))
else:
all_pol_acs.append(pi(ob))
vf_in = torch.cat((*obs, *all_pol_acs), dim=1)
pol_loss = -curr_agent.critic(vf_in).mean()
pol_loss += (curr_pol_out ** 2).mean() * 1e-3
pol_loss.backward()
if parallel:
average_gradients(curr_agent.policy)
torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5)
# actor
curr_agent.policy_optimizer.step()
if logger is not None:
logger.add_scalars('agent%i/losses' % agent_i,
{'vf_loss': vf_loss,
'pol_loss': pol_loss},
self.niter)
def update_all_targets(self):
"""
Update all target networks (called after normal updates have been
performed for each agent)
"""
for a in self.agents:
soft_update(a.target_critic, a.critic, self.tau)
soft_update(a.target_policy, a.policy, self.tau)
self.niter += 1
def prep_training(self, device='gpu'):
for mle in self.mle_base:
mle.train()
for a in self.agents:
a.policy.train()
a.critic.train()
a.target_policy.train()
a.target_critic.train()
for mle_i in a.mle:
mle_i.train()
if device == 'gpu':
fn = lambda x: x.to(torch.device(self.device))
else:
fn = lambda x: x.cpu()
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
if not self.critic_dev == device:
for a in self.agents:
a.critic = fn(a.critic)
self.critic_dev = device
if not self.trgt_pol_dev == device:
for a in self.agents:
a.target_policy = fn(a.target_policy)
self.trgt_pol_dev = device
if not self.trgt_critic_dev == device:
for a in self.agents:
a.target_critic = fn(a.target_critic)
self.trgt_critic_dev = device
if not self.mle_dev == device:
for i, mle in enumerate(self.mle_base):
self.mle_base[i] = fn(mle)
for a in self.agents:
for i, mle_i in enumerate(a.mle):
a.mle[i] = fn(mle_i)
self.mle_dev = device
def prep_rollouts(self, device='cpu'):
for a in self.agents:
a.policy.eval()
if device == 'gpu':
fn = lambda x: x.to(torch.device(self.device))
else:
fn = lambda x: x.cpu()
# only need main policy for rollouts
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
def save(self, filename):
"""
Save trained parameters of all agents into one file
"""
self.prep_training(device='cpu') # move parameters to CPU before saving
save_dict = {'init_dict': self.init_dict,
'agent_params': [a.get_params() for a in self.agents],
'mle_params': [self.get_params()],}
torch.save(save_dict, filename)
@classmethod
def init_from_env(cls, env, device, agent_alg="ToM_SA", adversary_alg="ToM_SA",
gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64, output_style='sum'):
"""
Instantiate instance of this class from multi-agent environment
"""
agent_init_params = []
alg_types = [adversary_alg if atype == 'adversary' else agent_alg for
atype in env.agent_types]
for acsp, obsp, algtype in zip(env.action_space, env.observation_space,
alg_types):
num_in_pol = obsp.shape[0]
num_in_mle = obsp.shape[0]
if isinstance(acsp, Box):
discrete_action = False
get_shape = lambda x: x.shape[0]
elif isinstance(acsp, Discrete): # Discrete
discrete_action = True
get_shape = lambda x: x.n
elif isinstance(acsp, MultiDiscrete):
discrete_action = True
get_shape = lambda x: sum(x.high - x.low + 1)
num_out_pol = get_shape(acsp)
if algtype == "ToM_SA":
num_in_critic = 0
num_in_pol += (len(env.agent_types)-1) * 5
for oobsp in env.observation_space:
num_in_critic += oobsp.shape[0]
for oacsp in env.action_space:
if isinstance(oacsp, Box):
discrete_action = False
get_shape = lambda x: x.shape[0]
elif isinstance(oacsp, Discrete): # Discrete
discrete_action = True
get_shape = lambda x: x.n
elif isinstance(oacsp, MultiDiscrete):
discrete_action = True
get_shape = lambda x: sum(x.high - x.low + 1)
num_in_critic += get_shape(oacsp)
else:
num_in_critic = obsp.shape[0] + get_shape(acsp)
agent_init_params.append({'num_in_pol': num_in_pol,
'num_out_pol': num_out_pol,
'num_in_critic': num_in_critic,
'num_in_mle': num_in_mle,})
init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr,
'device': device,
'hidden_dim': hidden_dim,
'alg_types': alg_types,
'agent_init_params': agent_init_params,
'discrete_action': discrete_action,
'output_style': output_style}
instance = cls(**init_dict)
instance.init_dict = init_dict
return instance
@classmethod
def init_from_save(cls, filename):
"""
Instantiate instance of this class from file created by 'save' method
"""
save_dict = torch.load(filename)
instance = cls(**save_dict['init_dict'])
instance.init_dict = save_dict['init_dict']
for a, params in zip(instance.agents, save_dict['agent_params']):
a.load_params(params)
for a, params in zip([instance], save_dict['mle_params']):
a.load_params(params)
return instance
def get_params(self):
params = {
}
for i in range(len(self.mle_base)):
params['mle%d'%i] = self.mle_base[i].state_dict()
params['mle_optimizer%d'%i] = self.mle_opts[i].state_dict()
return params
def load_params(self, params):
for i in range(len(self.mle_base)):
self.mle_base[i].load_state_dict(params['mle%d'%i])
self.mle_opts[i].load_state_dict(params['mle_optimizer%d'%i])
class ToM_S(object):
"""
Wrapper class for DDPG-esque (i.e. also MADDPG) agents in multi-agent task
"""
def __init__(self, agent_init_params, alg_types, output_style, device,
gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64,
discrete_action=False):
"""
Inputs:
agent_init_params (list of dict): List of dicts with parameters to
initialize each agent
num_in_pol (int): Input dimensions to policy
num_out_pol (int): Output dimensions to policy
num_in_critic (int): Input dimensions to critic
alg_types (list of str): Learning algorithm for each agent (DDPG
or MADDPG)
gamma (float): Discount factor
tau (float): Target update rate
lr (float): Learning rate for policy and critic
hidden_dim (int): Number of hidden dimensions for networks
discrete_action (bool): Whether or not to use discrete action space
"""
self.device = device
self.nagents = len(alg_types)
self.alg_types = alg_types
self.agents = [DDPGAgent_ToM(lr=lr, discrete_action=discrete_action,
hidden_dim=hidden_dim,
**params, output_style=output_style,
num_agents=self.nagents,
device=self.device)
for params in agent_init_params]
self.agent_init_params = agent_init_params
if self.nagents == 6:
self.mle_base = [SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5, #simple_com
self.agent_init_params[3]['num_out_pol'], #adv self-self
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,
self.agent_init_params[3]['num_out_pol'], #adv self-other
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,
self.agent_init_params[3]['num_out_pol'], # adv self-other
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,
self.agent_init_params[3]['num_out_pol'],
hidden_dim=hidden_dim, output_style=output_style), ##agent self-other
]
if self.nagents == 4:
self.mle_base = [SNNNetwork(self.agent_init_params[0]['num_in_mle'] - 2 + 5, #simple_tag
self.agent_init_params[0]['num_out_pol'], #adv self-self
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 2 + 5,
self.agent_init_params[3]['num_out_pol'], #adv self-other
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 2 + 5,
self.agent_init_params[3]['num_out_pol'],
hidden_dim=hidden_dim, output_style=output_style), ##agent self-other
]
elif self.nagents == 3:
self.mle_base = [SNNNetwork(self.agent_init_params[1]['num_in_mle'] + 5, #simple_adv
self.agent_init_params[1]['num_out_pol'], #adv self-other
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[1]['num_in_mle'] + 5,
self.agent_init_params[1]['num_out_pol'], #agent self-self
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[1]['num_in_mle'] + 5,
self.agent_init_params[1]['num_out_pol'],
hidden_dim=hidden_dim, output_style=output_style), ##agent self-other
]
elif self.nagents == 2:
self.mle_base = [SNNNetwork(self.agent_init_params[0]['num_in_mle']-2 + 5, #simple_push
self.agent_init_params[0]['num_out_pol'], #adv self-other
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[1]['num_in_mle']-2 + 5,
self.agent_init_params[1]['num_out_pol'],
hidden_dim=hidden_dim, output_style=output_style), ##agent self-other
]
self.mle_opts = [Adam(self.mle_base[i].parameters(), lr=lr) for i in range(len(self.mle_base))]
self.gamma = gamma
self.tau = tau
self.lr = lr
self.discrete_action = discrete_action
self.pol_dev = 'cpu' # device for policies
self.critic_dev = 'cpu' # device for critics
self.trgt_pol_dev = 'cpu' # device for target policies
self.trgt_critic_dev = 'cpu' # device for target critics
self.mle_dev = 'cpu'
self.niter = 0
@property
def policies(self):
return [a.policy for a in self.agents]
@property
def target_policies(self):
return [a.target_policy for a in self.agents]
def scale_noise(self, scale):
"""
Scale noise for each agent
Inputs:
scale (float): scale of noise
"""
for a in self.agents:
a.scale_noise(scale)
def reset_noise(self):
for a in self.agents:
a.reset_noise()
def step(self, observations, actions_pre, explore=False): #simple_tag
"""
Take a step forward in environment with all agents
Inputs:
observations: List of observations for each agent
explore (boolean): Whether or not to add exploration noise
Outputs:
actions: List of actions for each agent
"""
# t1 = time.time()
observations_ = observations.copy()
actions_pre_ = actions_pre.copy()
for agent_i, obs in enumerate(observations):
obs_ = observations_.copy()
acs_pre_ = actions_pre_.copy()
obs_.pop(agent_i)
acs_pre_.pop(agent_i)
# actions = [self.agents[agent_i].mle[j].cpu()(observations[agent_i]) for j, obs_j in enumerate(obs_)]
# observations[agent_i] = torch.cat((observations[agent_i], torch.cat(actions, 1)), 1)
if self.nagents == 6:
if agent_i < 4:
self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0],self.mle_base[0], self.mle_base[1], self.mle_base[1]]
actions = [torch.cat((obs_j[:, 4:24], acs_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]
b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)), hard=True).cpu()
b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)), hard=True).cpu()
actions = torch.cat((b1[:20], b1[20:40], b1[40:60], b2[:20], b2[20:40]), 1)
# print(t1 - time.time())
# print()
else:
self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[1]]
actions = [torch.cat((obs_j[:, 4:24], acs_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]
b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)), hard=True).cpu()
b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)), hard=True).cpu()
actions = torch.cat((b1[:20], b1[20:40], b1[40:60], b2[:20], b2[20:40]), 1)
# actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[1]['num_out_pol']))
# for j, obs_j in enumerate(obs_)]
# actions = torch.cat(actions,1)
# print()
if self.nagents == 4:
if agent_i < 3:
self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0], self.mle_base[1]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((obs_j[:, 2:], acs_pre_[j]),1).to(self.device)),
hard=True).cpu()
for j, obs_j in enumerate(obs_)]
elif agent_i == 3:
self.agents[agent_i].mle = [self.mle_base[2],self.mle_base[2], self.mle_base[2]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((obs_j[:, 2:-2], acs_pre_[j]),1).to(self.device)),
hard=True).cpu()
for j, obs_j in enumerate(obs_)]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(obs_j[:,2:-2]), hard=True)
# for j, obs_j in enumerate(obs_)]
# actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol']))
# for j, obs_j in enumerate(obs_)]
elif self.nagents == 3: #simple_adv
actions = []
if agent_i < 1:
self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0]]
actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol']))
for j, obs_j in enumerate(obs_)]
elif agent_i == 1:
self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]
actions.append(gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations_[agent_i],
actions_pre[(0)]), 1).to(self.device)),
hard=True).cpu() )
actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations_[agent_i],
actions_pre[(2)]), 1).to(self.device)),
hard=True).cpu() )
elif agent_i == 2:
self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i]), hard=True)
# for j, obs_j in enumerate(obs_)]
actions.append(gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations_[agent_i],
actions_pre[(0)]), 1).to(self.device)),
hard=True).cpu() )
actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations_[agent_i],
actions_pre[(1)]), 1).to(self.device)),
hard=True).cpu() )
elif self.nagents == 2:
if agent_i < 1:
self.agents[agent_i].mle = [self.mle_base[0]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i][:,2:]), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [torch.zeros((obs_j.shape[0], self.agent_init_params[0]['num_out_pol'])) for j, obs_j in enumerate(obs_)]
elif agent_i == 1:
self.agents[agent_i].mle = [self.mle_base[1]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i][:, 2:]), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((observations_[agent_i][:,2:],
actions_pre[(self.nagents -1 - agent_i)]), 1).to(self.device)),
hard=True).cpu() for j, obs_j in enumerate(obs_)]
if self.nagents == 6:
observations[agent_i] = torch.cat((observations[agent_i], actions), 1)
else:
observations[agent_i] = torch.cat((observations[agent_i], torch.cat(actions, 1)), 1)
# t2 = time.time()
# print('step+time:', t2 - t1)
return [a.step(obs, explore=explore) for a, obs in zip(self.agents,
observations)]
def _get_obs(self, observations, actions_pre):
observations_ = []
actions_pre_ = []
for agent_i, obs in enumerate(observations):
obs_ = observations.copy()
obs_.pop(agent_i)
actions_pre_ = actions_pre.copy()
actions_pre_.pop(agent_i)
if self.nagents == 6:
if agent_i < 4: #simple_comm
self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[1], self.mle_base[1]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:, 4:24], actions_pre_[j][:,:5]),1)).detach(), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [torch.cat((obs_j[:, 4:24], actions_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]
b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)).detach(), hard=True)
b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)).detach(), hard=True)
actions = torch.cat((b1[:1024], b1[1024:2048], b1[2048:3072], b2[:1024], b2[1024:2048]), 1)
# print()
elif agent_i > 4:
self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[1]]
actions = [torch.cat((obs_j[:, 4:24], actions_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]
b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)).detach(), hard=True)
b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)).detach(), hard=True)
actions = torch.cat((b1[:1024], b1[1024:2048], b1[2048:3072], b2[:1024], b2[1024:2048]), 1)
# print()
if self.nagents == 4:
if agent_i < 3: #simple_tag
self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0], self.mle_base[1]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:, 2:], actions_pre_[j]),1)).detach(), hard=True)
for j, obs_j in enumerate(obs_)]
elif agent_i == 3:
self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[2], self.mle_base[2]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:, 2:-2], actions_pre_[j]),1)).detach(), hard=True)
for j, obs_j in enumerate(obs_)]
elif self.nagents == 3:
actions = []
if agent_i < 1: #simple_adv
# self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:,:2],observations[agent_i]),1)).detach(), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol'])).to(torch.device(self.device)).detach()
for j, obs_j in enumerate(obs_)]
elif agent_i == 1:
self.agents[agent_i].mle = [self.mle_base[2],self.mle_base[1]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(observations[agent_i]).detach(), hard=True)
# for j, obs_j in enumerate(obs_)]
actions.append(
gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations[agent_i],
actions_pre[(0)]), 1).to(self.device)).detach(), hard=True))
actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations[agent_i],
actions_pre[(2)]), 1).to(self.device)).detach(), hard=True))
elif agent_i == 2:
self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]
actions.append(
gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations[agent_i],
actions_pre[(0)]), 1).to(self.device)).detach(), hard=True))
actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations[agent_i],
actions_pre[(1)]), 1).to(self.device)).detach(), hard=True))
elif self.nagents == 2:
if agent_i < 1: #simple_push
self.agents[agent_i].mle = [self.mle_base[0]]
actions = [torch.zeros((obs_j.shape[0], self.agent_init_params[0]['num_out_pol'])).to(torch.device(self.device)).detach() for j, obs_j in
enumerate(obs_)]
elif agent_i == 1:
self.agents[agent_i].mle = [self.mle_base[1]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((observations[agent_i][:,2:],
actions_pre[(self.nagents -1 - agent_i)]), 1)).detach(), hard=True)
for j, obs_j in enumerate(obs_)]
if self.nagents == 6:
observations_.append(torch.cat((observations[agent_i], actions), 1))
else:
observations_.append(torch.cat((observations[agent_i], torch.cat(actions, 1)), 1))
return observations_
def trian_tag(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):
if agent_i == 0:
self.mle_opts[0].zero_grad()
action_i = self.mle_base[0](torch.cat((obs[0][:, 2:], acs_pre[0]),1))#
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[0].float())
loss.backward(retain_graph=True)
if parallel:
average_gradients(self.mle_base[0])
torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)
self.mle_opts[0].step()
self.mle_opts[1].zero_grad()
action_i = self.mle_base[1](torch.cat((obs[3][:, 2:], acs_pre[3]),1))
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[3].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[1])
torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)
self.mle_opts[1].step()
self.mle_opts[2].zero_grad()
action_i = self.mle_base[2](torch.cat((obs[0][:, 2:-2], acs_pre[0]),1))
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[0].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[2])
torch.nn.utils.clip_grad_norm_(self.mle_base[2].parameters(), 20)
self.mle_opts[2].step()
def trian_adv(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):
if agent_i == 0:
# self.mle_opts[0].zero_grad()
# action_i = self.mle_base[0](torch.cat((obs[1][:,:2],obs[agent_i]), 1))
# action_pre = gumbel_softmax(action_i, hard=True)
# loss = KL_criterion(action_pre.float(), acs[1].float())
# loss.backward(retain_graph=True)
# if parallel:
# average_gradients(self.mle_base[0])
# torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)
# self.mle_opts[0].step()
self.mle_opts[1].zero_grad()
action_i = self.mle_base[1](torch.cat((obs[1], acs_pre[2]), 1)) #torch.cat((obs[1], acs_pre[2]), 1))
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[1].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[1])
torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)
self.mle_opts[1].step()
self.mle_opts[2].zero_grad()
action_i = self.mle_base[2](torch.cat((obs[1], acs_pre[0]), 1))
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[0].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[2])
torch.nn.utils.clip_grad_norm_(self.mle_base[2].parameters(), 20)
self.mle_opts[2].step()
def trian_push(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):
if agent_i == 0:
# self.mle_opts[0].zero_grad()
# action_i = self.mle_base[0](obs[0][:, 2:]) #torch.cat((obs[agent_i][:,2:], actions[(self.nagents -1 - agent_i)]), 1)
# action_pre = gumbel_softmax(action_i, hard=True)
# loss = KL_criterion(action_pre.float(), acs[1].float())
# loss.backward(retain_graph=True)
# if parallel:
# average_gradients(self.mle_base[0])
# torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)
# self.mle_opts[0].step()
self.mle_opts[1].zero_grad()
action_i = self.mle_base[1](torch.cat((obs[1][:,2:], acs_pre[(0)]), 1)) #obs[1][:, 2:]
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[0].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[1])
torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)
self.mle_opts[1].step()
def trian_com(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):
if agent_i == 0:
self.mle_opts[0].zero_grad()
action_i = self.mle_base[0](torch.cat((obs[1][:, 4:24], acs_pre[(1)]), 1))
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[1].float())
loss.backward(retain_graph=True)
if parallel:
average_gradients(self.mle_base[0])
torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)
self.mle_opts[0].step()
self.mle_opts[1].zero_grad()
action_i = self.mle_base[1](torch.cat((obs[4][:, 4:24], acs_pre[(4)]), 1))
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[4].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[1])
torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)
self.mle_opts[1].step()
def update(self, sample, agent_i, parallel=False, logger=None, sample_r=None):
"""
Update parameters of agent model based on sample from replay buffer
Inputs:
sample: tuple of (observations, actions, rewards, next
observations, and episode end masks) sampled randomly from
the replay buffer. Each is a list with entries
corresponding to each agent
agent_i (int): index of agent to update
parallel (bool): If true, will average gradients across threads
logger (SummaryWriter from Tensorboard-Pytorch):
If passed in, important quantities will be logged
"""
# print('___update___')
acs_pre, obs, acs, rews, next_obs, dones = sample
next_obs_ = self._get_obs(next_obs, acs)
obs_ = self._get_obs(obs, acs_pre)
curr_agent = self.agents[agent_i]
# mle
KL_criterion = torch.nn.KLDivLoss(reduction='sum')
# for i in range(len(curr_agent.mle)):
# curr_agent.mle_optimizer[i].zero_grad()
# action_i = curr_agent.mle[i](obs[agent_i]obs[agent_i])
# action_pre = gumbel_softmax(action_i, hard=True)
# loss = KL_criterion(action_pre.float(), acs[i].float())
# loss.backward()
# if parallel:
# average_gradients(curr_agent.mle[i])
# torch.nn.utils.clip_grad_norm_(curr_agent.mle[i].parameters(), 20)
# curr_agent.policy_optimizer.step()
if self.nagents == 6:
self.trian_com(agent_i, KL_criterion, obs, acs_pre, parallel, acs)
elif self.nagents == 4:
self.trian_tag(agent_i, KL_criterion, obs, acs_pre, parallel, acs)
elif self.nagents == 3:
self.trian_adv(agent_i, KL_criterion, obs, acs_pre, parallel, acs)
elif self.nagents == 2:
self.trian_push(agent_i, KL_criterion, obs, acs_pre, parallel, acs)
# center critic
curr_agent.critic_optimizer.zero_grad()
all_trgt_acs = []
if self.discrete_action: # one-hot encode action
all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in
zip(self.target_policies, next_obs_)]
trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)
target_value = (rews[agent_i].view(-1, 1) + self.gamma *
curr_agent.target_critic(trgt_vf_in) *
(1 - dones[agent_i].view(-1, 1)))
vf_in = torch.cat((*obs, *acs), dim=1)
actual_value = curr_agent.critic(vf_in)
vf_loss = MSELoss(actual_value, target_value.detach())
vf_loss.backward()
if parallel:
average_gradients(curr_agent.critic)
torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)
curr_agent.critic_optimizer.step()
curr_agent.policy_optimizer.zero_grad()
if self.discrete_action:
# Forward pass as if onehot (hard=True) but backprop through a differentiable
# Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
# through discrete categorical samples, but I'm not sure if that is
# correct since it removes the assumption of a deterministic policy for
# DDPG. Regardless, discrete policies don't seem to learn properly without it.
curr_pol_out = curr_agent.policy(obs_[agent_i])
curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
else:
curr_pol_out = curr_agent.policy(obs[agent_i])
curr_pol_vf_in = curr_pol_out
all_pol_acs = []
for i, pi, ob in zip(range(self.nagents), self.policies, obs_):
if i == agent_i:
all_pol_acs.append(curr_pol_vf_in)
elif self.discrete_action:
all_pol_acs.append(onehot_from_logits(pi(ob)))
else:
all_pol_acs.append(pi(ob))
vf_in = torch.cat((*obs, *all_pol_acs), dim=1)
pol_loss = -curr_agent.critic(vf_in).mean()
pol_loss += (curr_pol_out ** 2).mean() * 1e-3
pol_loss.backward()
if parallel:
average_gradients(curr_agent.policy)
torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5)
# actor
curr_agent.policy_optimizer.step()
if logger is not None:
logger.add_scalars('agent%i/losses' % agent_i,
{'vf_loss': vf_loss,
'pol_loss': pol_loss},
self.niter)
def update_all_targets(self):
"""
Update all target networks (called after normal updates have been
performed for each agent)
"""
for a in self.agents:
soft_update(a.target_critic, a.critic, self.tau)
soft_update(a.target_policy, a.policy, self.tau)
self.niter += 1
def prep_training(self, device='gpu'):
for mle in self.mle_base:
mle.train()
for a in self.agents:
a.policy.train()
a.critic.train()
a.target_policy.train()
a.target_critic.train()
for mle_i in a.mle:
mle_i.train()
if device == 'gpu':
fn = lambda x: x.to(torch.device(self.device))
else:
fn = lambda x: x.cpu()
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
if not self.critic_dev == device:
for a in self.agents:
a.critic = fn(a.critic)
self.critic_dev = device
if not self.trgt_pol_dev == device:
for a in self.agents:
a.target_policy = fn(a.target_policy)
self.trgt_pol_dev = device
if not self.trgt_critic_dev == device:
for a in self.agents:
a.target_critic = fn(a.target_critic)
self.trgt_critic_dev = device
if not self.mle_dev == device:
for i, mle in enumerate(self.mle_base):
self.mle_base[i] = fn(mle)
for a in self.agents:
for i, mle_i in enumerate(a.mle):
a.mle[i] = fn(mle_i)
self.mle_dev = device
def prep_rollouts(self, device='cpu'):
for a in self.agents:
a.policy.eval()
if device == 'gpu':
fn = lambda x: x.to(torch.device(self.device))
else:
fn = lambda x: x.cpu()
# only need main policy for rollouts
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
def save(self, filename):
"""
Save trained parameters of all agents into one file
"""
self.prep_training(device='cpu') # move parameters to CPU before saving
save_dict = {'init_dict': self.init_dict,
'agent_params': [a.get_params() for a in self.agents],
'mle_params': [self.get_params()],}
torch.save(save_dict, filename)
@classmethod
def init_from_env(cls, env, device, agent_alg="ToM_S", adversary_alg="ToM_S",
gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64, output_style='sum'):
"""
Instantiate instance of this class from multi-agent environment
"""
agent_init_params = []
alg_types = [adversary_alg if atype == 'adversary' else agent_alg for
atype in env.agent_types]
for acsp, obsp, algtype in zip(env.action_space, env.observation_space,
alg_types):
num_in_pol = obsp.shape[0]
num_in_mle = obsp.shape[0]
if isinstance(acsp, Box):
discrete_action = False
get_shape = lambda x: x.shape[0]
elif isinstance(acsp, Discrete): # Discrete
discrete_action = True
get_shape = lambda x: x.n
elif isinstance(acsp, MultiDiscrete):
discrete_action = True
get_shape = lambda x: sum(x.high - x.low + 1)
num_out_pol = get_shape(acsp)
if algtype == "ToM_S":
num_in_critic = 0
num_in_pol += (len(env.agent_types)-1) * 5
for oobsp in env.observation_space:
num_in_critic += oobsp.shape[0]
for oacsp in env.action_space:
if isinstance(oacsp, Box):
discrete_action = False
get_shape = lambda x: x.shape[0]
elif isinstance(oacsp, Discrete): # Discrete
discrete_action = True
get_shape = lambda x: x.n
elif isinstance(oacsp, MultiDiscrete):
discrete_action = True
get_shape = lambda x: sum(x.high - x.low + 1)
num_in_critic += get_shape(oacsp)
else:
num_in_critic = obsp.shape[0] + get_shape(acsp)
agent_init_params.append({'num_in_pol': num_in_pol,
'num_out_pol': num_out_pol,
'num_in_critic': num_in_critic,
'num_in_mle': num_in_mle,})
init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr,
'device': device,
'hidden_dim': hidden_dim,
'alg_types': alg_types,
'agent_init_params': agent_init_params,
'discrete_action': discrete_action,
'output_style': output_style}
instance = cls(**init_dict)
instance.init_dict = init_dict
return instance
@classmethod
def init_from_save(cls, filename):
"""
Instantiate instance of this class from file created by 'save' method
"""
save_dict = torch.load(filename)
instance = cls(**save_dict['init_dict'])
instance.init_dict = save_dict['init_dict']
for a, params in zip(instance.agents, save_dict['agent_params']):
a.load_params(params)
for a, params in zip([instance], save_dict['mle_params']):
a.load_params(params)
return instance
def get_params(self):
params = {
}
for i in range(len(self.mle_base)):
params['mle%d'%i] = self.mle_base[i].state_dict()
params['mle_optimizer%d'%i] = self.mle_opts[i].state_dict()
return params
def load_params(self, params):
for i in range(len(self.mle_base)):
self.mle_base[i].load_state_dict(params['mle%d'%i])
self.mle_opts[i].load_state_dict(params['mle_optimizer%d'%i])
class ToM_self(object):
"""
Wrapper class for DDPG-esque (i.e. also MADDPG) agents in multi-agent task
"""
def __init__(self, agent_init_params, alg_types, output_style, device,
gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64,
discrete_action=False):
"""
Inputs:
agent_init_params (list of dict): List of dicts with parameters to
initialize each agent
num_in_pol (int): Input dimensions to policy
num_out_pol (int): Output dimensions to policy
num_in_critic (int): Input dimensions to critic
alg_types (list of str): Learning algorithm for each agent (DDPG
or MADDPG)
gamma (float): Discount factor
tau (float): Target update rate
lr (float): Learning rate for policy and critic
hidden_dim (int): Number of hidden dimensions for networks
discrete_action (bool): Whether or not to use discrete action space
"""
self.device = device
self.nagents = len(alg_types)
self.alg_types = alg_types
self.agents = [DDPGAgent_ToM(lr=lr, discrete_action=discrete_action,
hidden_dim=hidden_dim,
**params, output_style=output_style,
num_agents=self.nagents,
device=self.device)
for params in agent_init_params]
self.agent_init_params = agent_init_params
if self.nagents == 6:
self.mle_base = [SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5, #simple_com
self.agent_init_params[3]['num_out_pol'], #adv self-self
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,
self.agent_init_params[3]['num_out_pol'], #adv self-other
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,
self.agent_init_params[3]['num_out_pol'], # adv self-other
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,
self.agent_init_params[3]['num_out_pol'],
hidden_dim=hidden_dim, output_style=output_style), ##agent self-other
]
if self.nagents == 4:
self.mle_base = [SNNNetwork(self.agent_init_params[0]['num_in_mle'] - 2 + 5, #simple_tag
self.agent_init_params[0]['num_out_pol'], #adv self-self
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 2 + 5,
self.agent_init_params[3]['num_out_pol'], #adv self-other
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 2 + 5,
self.agent_init_params[3]['num_out_pol'],
hidden_dim=hidden_dim, output_style=output_style), ##agent self-other
]
elif self.nagents == 3:
self.mle_base = [SNNNetwork(self.agent_init_params[1]['num_in_mle'] + 5, #simple_adv
self.agent_init_params[1]['num_out_pol'], #adv self-other
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[1]['num_in_mle'] + 5,
self.agent_init_params[1]['num_out_pol'], #agent self-self
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[1]['num_in_mle'] + 5,
self.agent_init_params[1]['num_out_pol'],
hidden_dim=hidden_dim, output_style=output_style), ##agent self-other
]
elif self.nagents == 2:
self.mle_base = [SNNNetwork(self.agent_init_params[0]['num_in_mle']-2 + 5, #simple_push
self.agent_init_params[0]['num_out_pol'], #adv self-other
hidden_dim=hidden_dim, output_style=output_style),
SNNNetwork(self.agent_init_params[1]['num_in_mle']-2 + 5,
self.agent_init_params[1]['num_out_pol'],
hidden_dim=hidden_dim, output_style=output_style), ##agent self-other
]
self.mle_opts = [Adam(self.mle_base[i].parameters(), lr=lr) for i in range(len(self.mle_base))]
self.gamma = gamma
self.tau = tau
self.lr = lr
self.discrete_action = discrete_action
self.pol_dev = 'cpu' # device for policies
self.critic_dev = 'cpu' # device for critics
self.trgt_pol_dev = 'cpu' # device for target policies
self.trgt_critic_dev = 'cpu' # device for target critics
self.mle_dev = 'cpu'
self.niter = 0
@property
def policies(self):
return [a.policy for a in self.agents]
@property
def target_policies(self):
return [a.target_policy for a in self.agents]
def scale_noise(self, scale):
"""
Scale noise for each agent
Inputs:
scale (float): scale of noise
"""
for a in self.agents:
a.scale_noise(scale)
def reset_noise(self):
for a in self.agents:
a.reset_noise()
def step(self, observations, actions_pre, explore=False): #simple_tag
"""
Take a step forward in environment with all agents
Inputs:
observations: List of observations for each agent
explore (boolean): Whether or not to add exploration noise
Outputs:
actions: List of actions for each agent
"""
# t1 = time.time()
observations_ = observations.copy()
actions_pre_ = actions_pre.copy()
for agent_i, obs in enumerate(observations):
obs_ = observations_.copy()
acs_pre_ = actions_pre_.copy()
obs_.pop(agent_i)
acs_pre_.pop(agent_i)
# actions = [self.agents[agent_i].mle[j].cpu()(observations[agent_i]) for j, obs_j in enumerate(obs_)]
# observations[agent_i] = torch.cat((observations[agent_i], torch.cat(actions, 1)), 1)
if self.nagents == 6:
if agent_i < 4:
self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0],self.mle_base[0], self.mle_base[0], self.mle_base[0]]
actions = [torch.cat((obs_j[:, 4:24], acs_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]
b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)), hard=True).cpu()
b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)), hard=True).cpu()
actions = torch.cat((b1[:20], b1[20:40], b1[40:60], b2[:20], b2[20:40]), 1)
# print(t1 - time.time())
# print()
else:
self.agents[agent_i].mle = [self.mle_base[1],self.mle_base[1], self.mle_base[1], self.mle_base[1], self.mle_base[1]]
actions = [torch.cat((obs_j[:, 4:24], acs_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]
b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)), hard=True).cpu()
b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)), hard=True).cpu()
actions = torch.cat((b1[:20], b1[20:40], b1[40:60], b2[:20], b2[20:40]), 1)
# actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[1]['num_out_pol']))
# for j, obs_j in enumerate(obs_)]
# actions = torch.cat(actions,1)
# print()
if self.nagents == 4:
if agent_i < 3:
self.agents[agent_i].mle = [self.mle_base[1],self.mle_base[1], self.mle_base[1]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((obs_j[:, 2:14], acs_pre_[j]),1).to(self.device)),
hard=True).cpu()
for j, obs_j in enumerate(obs_)]
elif agent_i == 3:
self.agents[agent_i].mle = [self.mle_base[2],self.mle_base[2], self.mle_base[2]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((obs_j[:,2:14], acs_pre_[j]),1).to(self.device)),
hard=True).cpu() for j, obs_j in enumerate(obs_)]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(obs_j[:,2:-2]), hard=True)
# for j, obs_j in enumerate(obs_)]
# actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol']))
# for j, obs_j in enumerate(obs_)]
elif self.nagents == 3: #simple_adv
actions = []
if agent_i < 1:
self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0]]
actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol']))
for j, obs_j in enumerate(obs_)]
elif agent_i == 1:
self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]
actions.append(gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations_[agent_i],
actions_pre[(0)]), 1).to(self.device)),
hard=True).cpu() )
actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations_[agent_i],
actions_pre[(2)]), 1).to(self.device)),
hard=True).cpu() )
elif agent_i == 2:
self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i]), hard=True)
# for j, obs_j in enumerate(obs_)]
actions.append(gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations_[agent_i],
actions_pre[(0)]), 1).to(self.device)),
hard=True).cpu() )
actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations_[agent_i],
actions_pre[(1)]), 1).to(self.device)),
hard=True).cpu() )
elif self.nagents == 2:
if agent_i < 1:
self.agents[agent_i].mle = [self.mle_base[0]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i][:,2:]), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [torch.zeros((obs_j.shape[0], self.agent_init_params[0]['num_out_pol'])) for j, obs_j in enumerate(obs_)]
elif agent_i == 1:
self.agents[agent_i].mle = [self.mle_base[1]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i][:, 2:]), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((observations_[agent_i][:,2:],
actions_pre[(self.nagents -1 - agent_i)]), 1).to(self.device)),
hard=True).cpu() for j, obs_j in enumerate(obs_)]
if self.nagents == 6:
observations[agent_i] = torch.cat((observations[agent_i], actions), 1)
else:
observations[agent_i] = torch.cat((observations[agent_i], torch.cat(actions, 1)), 1)
# t2 = time.time()
# print('step+time:', t2 - t1)
return [a.step(obs, explore=explore) for a, obs in zip(self.agents,
observations)]
def _get_obs(self, observations, actions_pre):
observations_ = []
actions_pre_ = []
for agent_i, obs in enumerate(observations):
obs_ = observations.copy()
obs_.pop(agent_i)
actions_pre_ = actions_pre.copy()
actions_pre_.pop(agent_i)
if self.nagents == 6:
if agent_i < 4: #simple_comm
self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[0]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:, 4:24], actions_pre_[j][:,:5]),1)).detach(), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [torch.cat((obs_j[:, 4:24], actions_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]
b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)).detach(), hard=True)
b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)).detach(), hard=True)
actions = torch.cat((b1[:1024], b1[1024:2048], b1[2048:3072], b2[:1024], b2[1024:2048]), 1)
# print()
elif agent_i > 4:
self.agents[agent_i].mle = [self.mle_base[1], self.mle_base[1], self.mle_base[1], self.mle_base[1], self.mle_base[1]]
actions = [torch.cat((obs_j[:, 4:24], actions_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]
b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)).detach(), hard=True)
b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)).detach(), hard=True)
actions = torch.cat((b1[:1024], b1[1024:2048], b1[2048:3072], b2[:1024], b2[1024:2048]), 1)
# print()
if self.nagents == 4:
if agent_i < 3: #simple_tag
self.agents[agent_i].mle = [self.mle_base[1], self.mle_base[1], self.mle_base[1]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:, 2:14], actions_pre_[j]),1)).detach(), hard=True)
for j, obs_j in enumerate(obs_)]
elif agent_i == 3:
self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[2], self.mle_base[2]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:, 2:14], actions_pre_[j]),1)).detach(), hard=True)
for j, obs_j in enumerate(obs_)]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((obs_j[:,2:-2], acs_pre_[j]),1).to(self.device)),
# hard=True).cpu() for j, obs_j in enumerate(obs_)]
# actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol'])).to(torch.device(self.device)).detach()
# for j, obs_j in enumerate(obs_)]
elif self.nagents == 3:
actions = []
if agent_i < 1: #simple_adv
# self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:,:2],observations[agent_i]),1)).detach(), hard=True)
# for j, obs_j in enumerate(obs_)]
actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol'])).to(torch.device(self.device)).detach()
for j, obs_j in enumerate(obs_)]
elif agent_i == 1:
self.agents[agent_i].mle = [self.mle_base[2],self.mle_base[1]]
# actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(observations[agent_i]).detach(), hard=True)
# for j, obs_j in enumerate(obs_)]
actions.append(
gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations[agent_i],
actions_pre[(0)]), 1).to(self.device)).detach(), hard=True))
actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations[agent_i],
actions_pre[(2)]), 1).to(self.device)).detach(), hard=True))
elif agent_i == 2:
self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]
actions.append(
gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations[agent_i],
actions_pre[(0)]), 1).to(self.device)).detach(), hard=True))
actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations[agent_i],
actions_pre[(1)]), 1).to(self.device)).detach(), hard=True))
elif self.nagents == 2:
if agent_i < 1: #simple_push
self.agents[agent_i].mle = [self.mle_base[0]]
actions = [torch.zeros((obs_j.shape[0], self.agent_init_params[0]['num_out_pol'])).to(torch.device(self.device)).detach() for j, obs_j in
enumerate(obs_)]
elif agent_i == 1:
self.agents[agent_i].mle = [self.mle_base[1]]
actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((observations[agent_i][:,2:],
actions_pre[(self.nagents -1 - agent_i)]), 1)).detach(), hard=True)
for j, obs_j in enumerate(obs_)]
if self.nagents == 6:
observations_.append(torch.cat((observations[agent_i], actions), 1))
else:
observations_.append(torch.cat((observations[agent_i], torch.cat(actions, 1)), 1))
return observations_
def trian_tag(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):
if agent_i == 0:
self.mle_opts[1].zero_grad()
action_i = self.mle_base[1](torch.cat((obs[0][:, 2:14], acs_pre[0]),1))#
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[0].float())
loss.backward(retain_graph=True)
if parallel:
average_gradients(self.mle_base[1])
torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)
self.mle_opts[1].step()
self.mle_opts[2].zero_grad()
action_i = self.mle_base[2](torch.cat((obs[3][:, 2:14], acs_pre[3]),1))
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[3].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[2])
torch.nn.utils.clip_grad_norm_(self.mle_base[2].parameters(), 20)
self.mle_opts[2].step()
def trian_adv(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):
if agent_i == 0:
self.mle_opts[1].zero_grad()
action_i = self.mle_base[1](torch.cat((obs[1], acs_pre[2]), 1)) #torch.cat((obs[1], acs_pre[2]), 1))
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[1].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[1])
torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)
self.mle_opts[1].step()
self.mle_opts[2].zero_grad()
action_i = self.mle_base[2](torch.cat((obs[1], acs_pre[0]), 1))
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[0].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[2])
torch.nn.utils.clip_grad_norm_(self.mle_base[2].parameters(), 20)
self.mle_opts[2].step()
def trian_push(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):
if agent_i == 0:
self.mle_opts[1].zero_grad()
action_i = self.mle_base[1](torch.cat((obs[1][:,2:], acs_pre[(0)]), 1)) #obs[1][:, 2:]
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[0].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[1])
torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)
self.mle_opts[1].step()
def trian_com(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):
if agent_i == 0:
self.mle_opts[0].zero_grad()
action_i = self.mle_base[0](torch.cat((obs[1][:, 4:24], acs_pre[(1)]), 1))
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[1].float())
loss.backward(retain_graph=True)
if parallel:
average_gradients(self.mle_base[0])
torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)
self.mle_opts[0].step()
self.mle_opts[1].zero_grad()
action_i = self.mle_base[1](torch.cat((obs[4][:, 4:24], acs_pre[(4)]), 1))
action_pre = gumbel_softmax(action_i, hard=True)
loss = KL_criterion(action_pre.float(), acs[4].float())
loss.backward()
if parallel:
average_gradients(self.mle_base[1])
torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)
self.mle_opts[1].step()
def update(self, sample, agent_i, parallel=False, logger=None, sample_r=None):
"""
Update parameters of agent model based on sample from replay buffer
Inputs:
sample: tuple of (observations, actions, rewards, next
observations, and episode end masks) sampled randomly from
the replay buffer. Each is a list with entries
corresponding to each agent
agent_i (int): index of agent to update
parallel (bool): If true, will average gradients across threads
logger (SummaryWriter from Tensorboard-Pytorch):
If passed in, important quantities will be logged
"""
# print('___update___')
acs_pre, obs, acs, rews, next_obs, dones = sample
next_obs_ = self._get_obs(next_obs, acs)
obs_ = self._get_obs(obs, acs_pre)
curr_agent = self.agents[agent_i]
# mle
KL_criterion = torch.nn.KLDivLoss(reduction='sum')
# for i in range(len(curr_agent.mle)):
# curr_agent.mle_optimizer[i].zero_grad()
# action_i = curr_agent.mle[i](obs[agent_i]obs[agent_i])
# action_pre = gumbel_softmax(action_i, hard=True)
# loss = KL_criterion(action_pre.float(), acs[i].float())
# loss.backward()
# if parallel:
# average_gradients(curr_agent.mle[i])
# torch.nn.utils.clip_grad_norm_(curr_agent.mle[i].parameters(), 20)
# curr_agent.policy_optimizer.step()
if self.nagents == 6:
self.trian_com(agent_i, KL_criterion, obs, acs_pre, parallel, acs)
elif self.nagents == 4:
self.trian_tag(agent_i, KL_criterion, obs, acs_pre, parallel, acs)
elif self.nagents == 3:
self.trian_adv(agent_i, KL_criterion, obs, acs_pre, parallel, acs)
elif self.nagents == 2:
self.trian_push(agent_i, KL_criterion, obs, acs_pre, parallel, acs)
# center critic
curr_agent.critic_optimizer.zero_grad()
all_trgt_acs = []
if self.discrete_action: # one-hot encode action
all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in
zip(self.target_policies, next_obs_)]
trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)
target_value = (rews[agent_i].view(-1, 1) + self.gamma *
curr_agent.target_critic(trgt_vf_in) *
(1 - dones[agent_i].view(-1, 1)))
vf_in = torch.cat((*obs, *acs), dim=1)
actual_value = curr_agent.critic(vf_in)
vf_loss = MSELoss(actual_value, target_value.detach())
vf_loss.backward()
if parallel:
average_gradients(curr_agent.critic)
torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)
curr_agent.critic_optimizer.step()
curr_agent.policy_optimizer.zero_grad()
if self.discrete_action:
# Forward pass as if onehot (hard=True) but backprop through a differentiable
# Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
# through discrete categorical samples, but I'm not sure if that is
# correct since it removes the assumption of a deterministic policy for
# DDPG. Regardless, discrete policies don't seem to learn properly without it.
curr_pol_out = curr_agent.policy(obs_[agent_i])
curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
else:
curr_pol_out = curr_agent.policy(obs[agent_i])
curr_pol_vf_in = curr_pol_out
all_pol_acs = []
for i, pi, ob in zip(range(self.nagents), self.policies, obs_):
if i == agent_i:
all_pol_acs.append(curr_pol_vf_in)
elif self.discrete_action:
all_pol_acs.append(onehot_from_logits(pi(ob)))
else:
all_pol_acs.append(pi(ob))
vf_in = torch.cat((*obs, *all_pol_acs), dim=1)
pol_loss = -curr_agent.critic(vf_in).mean()
pol_loss += (curr_pol_out ** 2).mean() * 1e-3
pol_loss.backward()
if parallel:
average_gradients(curr_agent.policy)
torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5)
# actor
curr_agent.policy_optimizer.step()
if logger is not None:
logger.add_scalars('agent%i/losses' % agent_i,
{'vf_loss': vf_loss,
'pol_loss': pol_loss},
self.niter)
def update_all_targets(self):
"""
Update all target networks (called after normal updates have been
performed for each agent)
"""
for a in self.agents:
soft_update(a.target_critic, a.critic, self.tau)
soft_update(a.target_policy, a.policy, self.tau)
self.niter += 1
def prep_training(self, device='gpu'):
for mle in self.mle_base:
mle.train()
for a in self.agents:
a.policy.train()
a.critic.train()
a.target_policy.train()
a.target_critic.train()
for mle_i in a.mle:
mle_i.train()
if device == 'gpu':
fn = lambda x: x.to(torch.device(self.device))
else:
fn = lambda x: x.cpu()
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
if not self.critic_dev == device:
for a in self.agents:
a.critic = fn(a.critic)
self.critic_dev = device
if not self.trgt_pol_dev == device:
for a in self.agents:
a.target_policy = fn(a.target_policy)
self.trgt_pol_dev = device
if not self.trgt_critic_dev == device:
for a in self.agents:
a.target_critic = fn(a.target_critic)
self.trgt_critic_dev = device
if not self.mle_dev == device:
for i, mle in enumerate(self.mle_base):
self.mle_base[i] = fn(mle)
for a in self.agents:
for i, mle_i in enumerate(a.mle):
a.mle[i] = fn(mle_i)
self.mle_dev = device
def prep_rollouts(self, device='cpu'):
for a in self.agents:
a.policy.eval()
if device == 'gpu':
fn = lambda x: x.to(torch.device(self.device))
else:
fn = lambda x: x.cpu()
# only need main policy for rollouts
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
def save(self, filename):
"""
Save trained parameters of all agents into one file
"""
self.prep_training(device='cpu') # move parameters to CPU before saving
save_dict = {'init_dict': self.init_dict,
'agent_params': [a.get_params() for a in self.agents],
'mle_params': [self.get_params()],}
torch.save(save_dict, filename)
@classmethod
def init_from_env(cls, env, device, agent_alg="ToM_self", adversary_alg="ToM_self",
gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64, output_style='sum'):
"""
Instantiate instance of this class from multi-agent environment
"""
agent_init_params = []
alg_types = [adversary_alg if atype == 'adversary' else agent_alg for
atype in env.agent_types]
for acsp, obsp, algtype in zip(env.action_space, env.observation_space,
alg_types):
num_in_pol = obsp.shape[0]
num_in_mle = obsp.shape[0]
if isinstance(acsp, Box):
discrete_action = False
get_shape = lambda x: x.shape[0]
elif isinstance(acsp, Discrete): # Discrete
discrete_action = True
get_shape = lambda x: x.n
elif isinstance(acsp, MultiDiscrete):
discrete_action = True
get_shape = lambda x: sum(x.high - x.low + 1)
num_out_pol = get_shape(acsp)
if algtype == "ToM_self":
num_in_critic = 0
num_in_pol += (len(env.agent_types)-1) * 5
for oobsp in env.observation_space:
num_in_critic += oobsp.shape[0]
for oacsp in env.action_space:
if isinstance(oacsp, Box):
discrete_action = False
get_shape = lambda x: x.shape[0]
elif isinstance(oacsp, Discrete): # Discrete
discrete_action = True
get_shape = lambda x: x.n
elif isinstance(oacsp, MultiDiscrete):
discrete_action = True
get_shape = lambda x: sum(x.high - x.low + 1)
num_in_critic += get_shape(oacsp)
else:
num_in_critic = obsp.shape[0] + get_shape(acsp)
agent_init_params.append({'num_in_pol': num_in_pol,
'num_out_pol': num_out_pol,
'num_in_critic': num_in_critic,
'num_in_mle': num_in_mle,})
init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr,
'device': device,
'hidden_dim': hidden_dim,
'alg_types': alg_types,
'agent_init_params': agent_init_params,
'discrete_action': discrete_action,
'output_style': output_style}
instance = cls(**init_dict)
instance.init_dict = init_dict
return instance
@classmethod
def init_from_save(cls, filename):
"""
Instantiate instance of this class from file created by 'save' method
"""
save_dict = torch.load(filename)
instance = cls(**save_dict['init_dict'])
instance.init_dict = save_dict['init_dict']
for a, params in zip(instance.agents, save_dict['agent_params']):
a.load_params(params)
for a, params in zip([instance], save_dict['mle_params']):
a.load_params(params)
return instance
def get_params(self):
params = {
}
for i in range(len(self.mle_base)):
params['mle%d'%i] = self.mle_base[i].state_dict()
params['mle_optimizer%d'%i] = self.mle_opts[i].state_dict()
return params
def load_params(self, params):
for i in range(len(self.mle_base)):
self.mle_base[i].load_state_dict(params['mle%d'%i])
self.mle_opts[i].load_state_dict(params['mle_optimizer%d'%i])
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/utils/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/utils/buffer.py
================================================
import numpy as np
import torch
from torch import Tensor
from torch.autograd import Variable
class ReplayBuffer(object):
"""
Replay Buffer for multi-agent RL with parallel rollouts
"""
def __init__(self, max_steps, num_agents, obs_dims, ac_dims, device):
"""
Inputs:
max_steps (int): Maximum number of timepoints to store in buffer
num_agents (int): Number of agents in environment
obs_dims (list of ints): number of obervation dimensions for each
agent
ac_dims (list of ints): number of action dimensions for each agent
"""
self.device = device
self.max_steps = max_steps
self.num_agents = num_agents
self.obs_buffs = []
self.ac_buffs = []
self.rew_buffs = []
self.next_obs_buffs = []
self.done_buffs = []
for odim, adim in zip(obs_dims, ac_dims):
self.obs_buffs.append(np.zeros((max_steps, odim)))
self.ac_buffs.append(np.zeros((max_steps, adim)))
self.rew_buffs.append(np.zeros(max_steps))
self.next_obs_buffs.append(np.zeros((max_steps, odim)))
self.done_buffs.append(np.zeros(max_steps))
self.filled_i = 0 # index of first empty location in buffer (last index when full)
self.curr_i = 0 # current index to write to (ovewrite oldest data)
def __len__(self):
return self.filled_i
def push(self, observations, actions, rewards, next_observations, dones):
nentries = observations.shape[0] # handle multiple parallel environments
if self.curr_i + nentries > self.max_steps:
rollover = self.max_steps - self.curr_i # num of indices to roll over
for agent_i in range(self.num_agents):
self.obs_buffs[agent_i] = np.roll(self.obs_buffs[agent_i],
rollover, axis=0)
self.ac_buffs[agent_i] = np.roll(self.ac_buffs[agent_i],
rollover, axis=0)
self.rew_buffs[agent_i] = np.roll(self.rew_buffs[agent_i],
rollover)
self.next_obs_buffs[agent_i] = np.roll(
self.next_obs_buffs[agent_i], rollover, axis=0)
self.done_buffs[agent_i] = np.roll(self.done_buffs[agent_i],
rollover)
self.curr_i = 0
self.filled_i = self.max_steps
for agent_i in range(self.num_agents):
self.obs_buffs[agent_i][self.curr_i:self.curr_i + nentries] = np.vstack(
observations[:, agent_i])
# actions are already batched by agent, so they are indexed differently
self.ac_buffs[agent_i][self.curr_i:self.curr_i + nentries] = actions[agent_i]
self.rew_buffs[agent_i][self.curr_i:self.curr_i + nentries] = rewards[:, agent_i]
self.next_obs_buffs[agent_i][self.curr_i:self.curr_i + nentries] = np.vstack(
next_observations[:, agent_i])
self.done_buffs[agent_i][self.curr_i:self.curr_i + nentries] = dones[:, agent_i]
self.curr_i += nentries
if self.filled_i < self.max_steps:
self.filled_i += nentries
if self.curr_i == self.max_steps:
self.curr_i = 0
def sample(self, N, to_gpu=False, norm_rews=True):
inds = np.random.choice(np.arange(self.filled_i), size=N,
replace=False)
if to_gpu:
cast = lambda x: Variable(Tensor(x), requires_grad=False).to(torch.device(self.device))
else:
cast = lambda x: Variable(Tensor(x), requires_grad=False)
if norm_rews:
ret_rews = [cast((self.rew_buffs[i][inds] -
self.rew_buffs[i][:self.filled_i].mean()) /
self.rew_buffs[i][:self.filled_i].std())
for i in range(self.num_agents)]
else:
ret_rews = [cast(self.rew_buffs[i][inds]) for i in range(self.num_agents)]
return ([cast(self.obs_buffs[i][inds]) for i in range(self.num_agents)],
[cast(self.ac_buffs[i][inds]) for i in range(self.num_agents)],
ret_rews,
[cast(self.next_obs_buffs[i][inds]) for i in range(self.num_agents)],
[cast(self.done_buffs[i][inds]) for i in range(self.num_agents)])
def get_average_rewards(self, N):
if self.filled_i == self.max_steps:
inds = np.arange(self.curr_i - N, self.curr_i) # allow for negative indexing
else:
inds = np.arange(max(0, self.curr_i - N), self.curr_i)
return [self.rew_buffs[i][inds].mean() for i in range(self.num_agents)]
class ReplayBuffer_pre(object):
"""
Replay Buffer for multi-agent RL with parallel rollouts
"""
def __init__(self, max_steps, num_agents, obs_dims, ac_dims, device):
"""
Inputs:
max_steps (int): Maximum number of timepoints to store in buffer
num_agents (int): Number of agents in environment
obs_dims (list of ints): number of obervation dimensions for each
agent
ac_dims (list of ints): number of action dimensions for each agent
"""
self.device = device
self.max_steps = max_steps
self.num_agents = num_agents
self.ac_pre_buffs = []
self.obs_buffs = []
self.ac_buffs = []
self.rew_buffs = []
self.next_obs_buffs = []
self.done_buffs = []
for odim, adim in zip(obs_dims, ac_dims):
self.ac_pre_buffs.append(np.zeros((max_steps, 5)))
self.obs_buffs.append(np.zeros((max_steps, odim)))
self.ac_buffs.append(np.zeros((max_steps, adim)))
self.rew_buffs.append(np.zeros(max_steps))
self.next_obs_buffs.append(np.zeros((max_steps, odim)))
self.done_buffs.append(np.zeros(max_steps))
self.filled_i = 0 # index of first empty location in buffer (last index when full)
self.curr_i = 0 # current index to write to (ovewrite oldest data)
def __len__(self):
return self.filled_i
def push(self, actions_pre, observations, actions, rewards, next_observations, dones):
nentries = observations.shape[0] # handle multiple parallel environments
if self.curr_i + nentries > self.max_steps:
rollover = self.max_steps - self.curr_i # num of indices to roll over
for agent_i in range(self.num_agents):
self.ac_pre_buffs[agent_i] = np.roll(self.ac_pre_buffs[agent_i][:,:5],
rollover, axis=0)
self.obs_buffs[agent_i] = np.roll(self.obs_buffs[agent_i],
rollover, axis=0)
self.ac_buffs[agent_i] = np.roll(self.ac_buffs[agent_i],
rollover, axis=0)
self.rew_buffs[agent_i] = np.roll(self.rew_buffs[agent_i],
rollover)
self.next_obs_buffs[agent_i] = np.roll(
self.next_obs_buffs[agent_i], rollover, axis=0)
self.done_buffs[agent_i] = np.roll(self.done_buffs[agent_i],
rollover)
self.curr_i = 0
self.filled_i = self.max_steps
for agent_i in range(self.num_agents):
self.ac_pre_buffs[agent_i][self.curr_i:self.curr_i + nentries] = actions_pre[agent_i][:,:5]
self.obs_buffs[agent_i][self.curr_i:self.curr_i + nentries] = np.vstack(
observations[:, agent_i])
# actions are already batched by agent, so they are indexed differently
self.ac_buffs[agent_i][self.curr_i:self.curr_i + nentries] = actions[agent_i]
self.rew_buffs[agent_i][self.curr_i:self.curr_i + nentries] = rewards[:, agent_i]
self.next_obs_buffs[agent_i][self.curr_i:self.curr_i + nentries] = np.vstack(
next_observations[:, agent_i])
self.done_buffs[agent_i][self.curr_i:self.curr_i + nentries] = dones[:, agent_i]
self.curr_i += nentries
if self.filled_i < self.max_steps:
self.filled_i += nentries
if self.curr_i == self.max_steps:
self.curr_i = 0
def sample(self, N, to_gpu=False, norm_rews=True):
inds = np.random.choice(np.arange(self.filled_i), size=N,
replace=False)
if to_gpu:
cast = lambda x: Variable(Tensor(x), requires_grad=False).to(torch.device(self.device))
else:
cast = lambda x: Variable(Tensor(x), requires_grad=False)
if norm_rews:
ret_rews = [cast((self.rew_buffs[i][inds] -
self.rew_buffs[i][:self.filled_i].mean()) /
self.rew_buffs[i][:self.filled_i].std())
for i in range(self.num_agents)]
else:
ret_rews = [cast(self.rew_buffs[i][inds]) for i in range(self.num_agents)]
return ([cast(self.ac_pre_buffs[i][inds]) for i in range(self.num_agents)],
[cast(self.obs_buffs[i][inds]) for i in range(self.num_agents)],
[cast(self.ac_buffs[i][inds]) for i in range(self.num_agents)],
ret_rews,
[cast(self.next_obs_buffs[i][inds]) for i in range(self.num_agents)],
[cast(self.done_buffs[i][inds]) for i in range(self.num_agents)])
def get_average_rewards(self, N):
if self.filled_i == self.max_steps:
inds = np.arange(self.curr_i - N, self.curr_i) # allow for negative indexing
else:
inds = np.arange(max(0, self.curr_i - N), self.curr_i)
return [self.rew_buffs[i][inds].mean() for i in range(self.num_agents)]
class ReplayBuffer_RNN(object):
"""
Replay Buffer for multi-agent RL with parallel rollouts
"""
def __init__(self, max_steps, num_agents, obs_dims, ac_dims, ep_dims):
"""
Inputs:
max_steps (int): Maximum number of timepoints to store in buffer
num_agents (int): Number of agents in environment
obs_dims (list of ints): number of obervation dimensions for each
agent
ac_dims (list of ints): number of action dimensions for each agent
ep_dims (int): Number of steps in each episode
"""
self.max_steps = max_steps
self.num_agents = num_agents
self.obs_buffs = []
self.ac_buffs = []
self.rew_buffs = []
self.next_obs_buffs = []
self.done_buffs = []
for odim, adim in zip(obs_dims, ac_dims):
self.obs_buffs.append(np.zeros((max_steps, ep_dims, odim)))
self.ac_buffs.append(np.zeros((max_steps, ep_dims, adim)))
self.rew_buffs.append(np.zeros((max_steps, ep_dims)))
self.next_obs_buffs.append(np.zeros((max_steps, ep_dims, odim)))
self.done_buffs.append(np.zeros((max_steps, ep_dims)))
self.filled_i = 0 # index of first empty location in buffer (last index when full)
self.curr_i = 0 # current index to write to (ovewrite oldest data)
def __len__(self):
return self.filled_i
def push(self, observations_ep, actions_ep, rewards_ep, next_observations_ep, dones_ep):
nentries = observations_ep[0].shape[0] # handle multiple parallel environments
observations_ep, actions_ep, rewards_ep, next_observations_ep, dones_ep = \
np.array(observations_ep), np.array(actions_ep), np.array(rewards_ep),\
np.array(next_observations_ep), np.array(dones_ep)
if self.curr_i + nentries > self.max_steps:
rollover = self.max_steps - self.curr_i # num of indices to roll over
for agent_i in range(self.num_agents):
self.obs_buffs[agent_i] = np.roll(self.obs_buffs[agent_i],
rollover, axis=0)
self.ac_buffs[agent_i] = np.roll(self.ac_buffs[agent_i],
rollover, axis=0)
self.rew_buffs[agent_i] = np.roll(self.rew_buffs[agent_i],
rollover)
self.next_obs_buffs[agent_i] = np.roll(
self.next_obs_buffs[agent_i], rollover, axis=0)
self.done_buffs[agent_i] = np.roll(self.done_buffs[agent_i],
rollover)
self.curr_i = 0
self.filled_i = self.max_steps
for agent_i in range(self.num_agents):
for i in range(observations_ep[:,:,agent_i].shape[0]):
if i == 0:
ob_ep = np.expand_dims(np.vstack(observations_ep[:,:,agent_i][i]), 0)
ob_next_ep = np.expand_dims(np.vstack(next_observations_ep[:,:,agent_i][i]), 0)
else:
ob_ep = np.vstack((ob_ep, np.expand_dims(np.vstack(observations_ep[:,:,agent_i][i]), 0)))
ob_next_ep = np.vstack((ob_next_ep, np.expand_dims(np.vstack(next_observations_ep[:,:,agent_i][i]), 0)))
self.obs_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = ob_ep.transpose(1, 0, 2)
# actions are already batched by agent, so they are indexed differently
self.ac_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = actions_ep[:,:,0,:].transpose(1, 0, 2)
self.rew_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = rewards_ep[:, :, agent_i].transpose(1, 0)
self.next_obs_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = ob_next_ep.transpose(1, 0, 2)
self.done_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = dones_ep[:, :, agent_i].transpose(1, 0)
self.curr_i += nentries
if self.filled_i < self.max_steps:
self.filled_i += nentries
if self.curr_i == self.max_steps:
self.curr_i = 0
def sample(self, N, to_gpu=False, norm_rews=True):
inds = np.random.choice(np.arange(self.filled_i), size=N,
replace=False)
if to_gpu:
cast = lambda x: Variable(Tensor(x), requires_grad=False).to(torch.device(self.device))
else:
cast = lambda x: Variable(Tensor(x), requires_grad=False)
if norm_rews:
ret_rews = [cast((self.rew_buffs[i][inds] -
self.rew_buffs[i][:self.filled_i].mean()) /
self.rew_buffs[i][:self.filled_i].std())
for i in range(self.num_agents)]
else:
ret_rews = [cast(self.rew_buffs[i][inds]) for i in range(self.num_agents)]
return ([cast(self.obs_buffs[i][inds]) for i in range(self.num_agents)],
[cast(self.ac_buffs[i][inds]) for i in range(self.num_agents)],
ret_rews,
[cast(self.next_obs_buffs[i][inds]) for i in range(self.num_agents)],
[cast(self.done_buffs[i][inds]) for i in range(self.num_agents)])
def get_average_rewards(self, N):
if self.filled_i == self.max_steps:
inds = np.arange(self.curr_i - N, self.curr_i) # allow for negative indexing
else:
inds = np.arange(max(0, self.curr_i - N), self.curr_i)
return [self.rew_buffs[i][inds].mean() for i in range(self.num_agents)]
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/utils/env_wrappers.py
================================================
"""
Modified from OpenAI Baselines code to work with multi-agent envs
"""
import numpy as np
from multiprocessing import Process, Pipe
from common.vec_env.vec_env import VecEnv, CloudpickleWrapper
def worker(remote, parent_remote, env_fn_wrapper):
parent_remote.close()
env = env_fn_wrapper.x()
while True:
cmd, data = remote.recv()
if cmd == 'step':
ob, reward, done, info = env.step(data)
if all(done):
ob = env.reset()
remote.send((ob, reward, done, info))
elif cmd == 'reset':
ob = env.reset()
remote.send(ob)
elif cmd == 'reset_task':
ob = env.reset_task()
remote.send(ob)
elif cmd == 'close':
remote.close()
break
elif cmd == 'get_spaces':
remote.send((env.observation_space, env.action_space))
elif cmd == 'get_agent_types':
if all([hasattr(a, 'adversary') for a in env.agents]):
remote.send(['adversary' if a.adversary else 'agent' for a in
env.agents])
else:
remote.send(['agent' for _ in env.agents])
else:
raise NotImplementedError
class SubprocVecEnv(VecEnv):
def __init__(self, env_fns, spaces=None):
"""
envs: list of gym environments to run in subprocesses
"""
self.waiting = False
self.closed = False
nenvs = len(env_fns)
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
for p in self.ps:
p.daemon = True # if the main process crashes, we should not cause things to hang
p.start()
for remote in self.work_remotes:
remote.close()
self.remotes[0].send(('get_spaces', None))
observation_space, action_space = self.remotes[0].recv()
self.remotes[0].send(('get_agent_types', None))
self.agent_types = self.remotes[0].recv()
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
def step_async(self, actions):
for remote, action in zip(self.remotes, actions):
remote.send(('step', action))
self.waiting = True
def step_wait(self):
results = [remote.recv() for remote in self.remotes]
self.waiting = False
obs, rews, dones, infos = zip(*results)
return np.stack(obs), np.stack(rews), np.stack(dones), infos
def reset(self):
for remote in self.remotes:
remote.send(('reset', None))
return np.stack([remote.recv() for remote in self.remotes])
def reset_task(self):
for remote in self.remotes:
remote.send(('reset_task', None))
return np.stack([remote.recv() for remote in self.remotes])
def close(self):
if self.closed:
return
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(('close', None))
for p in self.ps:
p.join()
self.closed = True
class DummyVecEnv(VecEnv):
def __init__(self, env_fns):
self.envs = [fn() for fn in env_fns]
env = self.envs[0]
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
if all([hasattr(a, 'adversary') for a in env.agents]):
self.agent_types = ['adversary' if a.adversary else 'agent' for a in
env.agents]
else:
self.agent_types = ['agent' for _ in env.agents]
self.ts = np.zeros(len(self.envs), dtype='int')
self.actions = None
def step_async(self, actions):
self.actions = actions
def step_wait(self):
results = [env.step(a) for (a,env) in zip(self.actions, self.envs)]
obs, rews, dones, infos = map(np.array, zip(*results))
self.ts += 1
for (i, done) in enumerate(dones):
if all(done):
obs[i] = self.envs[i].reset()
self.ts[i] = 0
self.actions = None
return np.array(obs), np.array(rews), np.array(dones), infos
def reset(self):
results = [env.reset() for env in self.envs]
return np.array(results)
def close(self):
return
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/utils/make_env.py
================================================
"""
Code for creating a multiagent environment with one of the scenarios listed
in ./scenarios/.
Can be called by using, for example:
env = make_env('simple_speaker_listener')
After producing the env object, can be used similarly to an OpenAI gym
environment.
A policy using this environment must output actions in the form of a list
for all agents. Each element of the list should be a numpy array,
of size (env.world.dim_p + env.world.dim_c, 1). Physical actions precede
communication actions in this array. See environment.py for more details.
"""
def make_env(scenario_name, benchmark=False, discrete_action=False):
'''
Creates a MultiAgentEnv object as env. This can be used similar to a gym
environment by calling env.reset() and env.step().
Use env.render() to view the environment on the screen.
Input:
scenario_name : name of the scenario from ./scenarios/ to be Returns
(without the .py extension)
benchmark : whether you want to produce benchmarking data
(usually only done during evaluation)
Some useful env properties (see environment.py):
.observation_space : Returns the observation space for each agent
.action_space : Returns the action space for each agent
.n : Returns the number of Agents
'''
from multiagent.environment import MultiAgentEnv
import multiagent.scenarios as scenarios
# load scenario from script
scenario = scenarios.load(scenario_name + ".py").Scenario()
# create world
world = scenario.make_world()
# create multiagent environment
if benchmark:
env = MultiAgentEnv(world, scenario.reset_world, scenario.reward,
scenario.observation, scenario.benchmark_data)
else:
env = MultiAgentEnv(world, scenario.reset_world, scenario.reward,
scenario.observation)
return env
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/utils/misc.py
================================================
import os
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.autograd import Variable
import numpy as np
# https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py#L11
def soft_update(target, source, tau):
"""
Perform DDPG soft update (move target params toward source based on weight
factor tau)
Inputs:
target (torch.nn.Module): Net to copy parameters to
source (torch.nn.Module): Net whose parameters to copy
tau (float, 0 < x < 1): Weight factor for update
"""
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
# https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py#L15
def hard_update(target, source):
"""
Copy network parameters from source to target
Inputs:
target (torch.nn.Module): Net to copy parameters to
source (torch.nn.Module): Net whose parameters to copy
"""
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)
# https://github.com/seba-1511/dist_tuto.pth/blob/gh-pages/train_dist.py
def average_gradients(model):
""" Gradient averaging. """
size = float(dist.get_world_size())
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, group=0)
param.grad.data /= size
# https://github.com/seba-1511/dist_tuto.pth/blob/gh-pages/train_dist.py
def init_processes(rank, size, fn, backend='gloo'):
""" Initialize the distributed environment. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size)
def onehot_from_logits(logits, eps=0.0):
"""
Given batch of logits, return one-hot sample using epsilon greedy strategy
(based on given epsilon)
"""
# get best (according to current policy) actions in one-hot form
argmax_acs = (logits == logits.max(1, keepdim=True)[0]).float()
if eps == 0.0:
return argmax_acs
# get random actions in one-hot form
rand_acs = Variable(torch.eye(logits.shape[1])[[np.random.choice(
range(logits.shape[1]), size=logits.shape[0])]], requires_grad=False)
# chooses between best and random actions using epsilon greedy
return torch.stack([argmax_acs[i] if r > eps else rand_acs[i] for i, r in
enumerate(torch.rand(logits.shape[0]))])
# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb
def sample_gumbel(shape, eps=1e-20, tens_type=torch.FloatTensor):
"""Sample from Gumbel(0, 1)"""
U = Variable(tens_type(*shape).uniform_(), requires_grad=False)
return -torch.log(-torch.log(U + eps) + eps)
# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb
def gumbel_softmax_sample(logits, temperature):
""" Draw a sample from the Gumbel-Softmax distribution"""
y = logits + sample_gumbel(logits.shape, tens_type=type(logits.data)).to(logits.device)
return F.softmax(y / temperature, dim=1)
# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb
def gumbel_softmax(logits, temperature=1.0, hard=False):
"""Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits: [batch_size, n_class] unnormalized log-probs
temperature: non-negative scalar
hard: if True, take argmax, but differentiate w.r.t. soft sample y
Returns:
[batch_size, n_class] sample from the Gumbel-Softmax distribution.
If hard=True, then the returned sample will be one-hot, otherwise it will
be a probabilitiy distribution that sums to 1 across classes
"""
y = gumbel_softmax_sample(logits, temperature)
if hard:
y_hard = onehot_from_logits(y)
y = (y_hard - y).detach() + y
return y
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/utils/multiprocessing.py
================================================
# This code is from openai baseline
# https://github.com/openai/baselines/tree/master/baselines/common/vec_env
import time
import matplotlib.pyplot as plt
import numpy as np
from multiprocessing import Process, Pipe
def _flatten_list(l):
assert isinstance(l, (list, tuple))
assert len(l) > 0
assert all([len(l_) > 0 for l_ in l])
return [l__ for l_ in l for l__ in l_]
def worker(remote, parent_remote, env_fn_wrapper):
parent_remote.close()
env = env_fn_wrapper.x()
while True:
cmd, data = remote.recv()
if cmd == 'step':
ob, reward, done, info = env.step(data)
if done:
ob = env.reset()
remote.send((ob, reward, done, info))
elif cmd == 'reset':
ob = env.reset()
remote.send(ob)
elif cmd == 'reset_task':
ob = env.reset_task()
remote.send(ob)
elif cmd == 'render':
ob = env.render(mode='rgb_array')
# print(len(ob), 'len(frames)')
# print(len(ob[0]), 'len(frames[0])')
# print(len(ob[0][0]), 'len(frames[0][0])')
remote.send(ob) # rgb_array
elif cmd == 'observe':
ob = env.observe(data)
remote.send(ob)
elif cmd == 'agents':
remote.send(env.agents)
elif cmd == 'spec':
remote.send(env.spec)
elif cmd == 'get_spaces':
remote.send((env.observation_space, env.action_space))
elif cmd == 'close':
remote.close()
break
else:
raise NotImplementedError
class VecEnv(object):
"""
An abstract asynchronous, vectorized environment.
"""
closed = False
viewer = None
metadata = {
'render.modes': ['human', 'rgb_array']
}
def __init__(self, num_envs, observation_space, action_space):
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
def observe(self, agent):
pass
def reset(self):
"""
Reset all the environments and return an array of
observations, or a tuple of observation arrays.
If step_async is still doing work, that work will
be cancelled and step_wait() should not be called
until step_async() is invoked again.
"""
pass
def step_async(self, actions):
"""
Tell all the environments to start taking a step
with the given actions.
Call step_wait() to get the results of the step.
You should not call this if a step_async run is
already pending.
"""
pass
def step_wait(self):
"""
Wait for the step taken with step_async().
Returns (obs, rews, dones, infos):
- obs: an array of observations, or a tuple of
arrays of observations.
- rews: an array of rewards
- dones: an array of "episode done" booleans
- infos: a sequence of info objects
"""
pass
def close(self):
"""
Clean up the environments' resources.
"""
pass
def step(self, actions):
self.step_async(actions)
return self.step_wait()
def render(self, mode='human'):
imgs = self.get_images()
bigimg = self.tile_images(imgs)
if mode == 'human':
self.get_viewer().imshow(bigimg) #
return self.get_viewer().isopen
elif mode == 'rgb_array':
return bigimg
else:
raise NotImplementedError
def get_images(self):
"""
Return RGB images from each environment
"""
raise NotImplementedError
def get_viewer(self):
if self.viewer is None:
from common import rendering
self.viewer = rendering.SimpleImageViewer()
return self.viewer
def tile_images(self, img_nhwc):
"""
Tile N images into one big PxQ image
(P,Q) are chosen to be as close as possible, and if N
is square, then P=Q.
input: img_nhwc, list or array of images, ndim=4 once turned into array
n = batch index, h = height, w = width, c = channel
returns:
bigim_HWc, ndarray with ndim=3
"""
img_nhwc = np.asarray(img_nhwc)
N, h, w, c = img_nhwc.shape
H = int(np.ceil(np.sqrt(N)))
W = int(np.ceil(float(N) / H))
img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(N, H * W)])
img_HWhwc = img_nhwc.reshape(H, W, h, w, c)
img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)
img_Hh_Ww_c = img_HhWwc.reshape(H * h, W * w, c)
return img_Hh_Ww_c
class CloudpickleWrapper(object):
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
"""
def __init__(self, x):
self.x = x
def __getstate__(self):
import cloudpickle
return cloudpickle.dumps(self.x)
def __setstate__(self, ob):
import pickle
self.x = pickle.loads(ob)
class SubprocVecEnv(VecEnv):
def __init__(self, env_fns, spaces=None):
"""
envs_sc: list of gym environments to run in subprocesses
"""
# self.venv = venv
self.waiting = False
self.closed = False
nenvs = len(env_fns)
self.nenvs = nenvs
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
for p in self.ps:
p.daemon = True # if the main process crashes, we should not cause things to hang
p.start()
for remote in self.work_remotes:
remote.close()
self.remotes[0].send(('get_spaces', None))
observation_space, action_space = self.remotes[0].recv()
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
def step_async(self, actions):
for remote, action in zip(self.remotes, actions): # the input of step() : action
remote.send(('step', action))
self.waiting = True
def step_wait(self):
results = [remote.recv() for remote in self.remotes] # the output of step() : zip(*results)
self.waiting = False
obs, rews, dones, infos = zip(*results)
return np.stack(obs), np.stack(rews), np.stack(dones), infos
def step_wait_2(self):
results = [remote.recv() for remote in self.remotes]
self.waiting = False
reward, done, _cumulative_rewards = zip(*results)
return reward, done, _cumulative_rewards
def step_wait_3(self):
results = [remote.recv() for remote in self.remotes] # the output of step() : zip(*results)
self.waiting = False
obs, rews, dones, infos = zip(*results)
return np.stack(obs), np.stack(rews), np.stack(dones), infos
def reset(self):
for remote in self.remotes:
remote.send(('reset', None))
return np.stack([remote.recv() for remote in self.remotes])
def agents(self):
for remote in self.remotes:
remote.send(('agents', None))
return np.stack([remote.recv() for remote in self.remotes])
def reset_task(self):
for remote in self.remotes:
remote.send(('reset_task', None))
return np.stack([remote.recv() for remote in self.remotes])
def spec(self):
for remote in self.remotes:
remote.send(('spec', None))
return np.stack([remote.recv() for remote in self.remotes])
def get_images(self):
# self._assert_not_closed()
for pipe in self.remotes:
pipe.send(('render', None))
imgs = [pipe.recv() for pipe in self.remotes]
# imgs = _flatten_list(imgs)
return imgs
def observe(self, agent):
for remote, agent in zip(self.remotes, agent):
remote.send(('observe', agent))
return np.stack([remote.recv() for remote in self.remotes])
# def render(self, mode='human'):
# return self.venv.render(mode=mode)
def close(self):
if self.closed:
return
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(('close', None))
for p in self.ps:
p.join()
self.closed = True
def __len__(self):
return self.nenvs
def _flatten_list(l):
assert isinstance(l, (list, tuple))
assert len(l) > 0
assert all([len(l_) > 0 for l_ in l])
return [l__ for l_ in l for l__ in l_]
class DummyVecEnv(VecEnv):
def __init__(self, env_fns):
self.envs = [fn() for fn in env_fns]
env = self.envs[0]
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
if all([hasattr(a, 'adversary') for a in env.agents]):
self.agent_types = ['adversary' if a.adversary else 'agent' for a in
env.agents]
else:
self.agent_types = ['agent' for _ in env.agents]
self.ts = np.zeros(len(self.envs), dtype='int')
self.actions = None
def step_async(self, actions):
self.actions = actions
def step_wait(self):
results = [env.step(a) for (a,env) in zip(self.actions, self.envs)]
obs, rews, dones, infos = map(np.array, zip(*results))
self.ts += 1
for (i, done) in enumerate(dones):
if all(done):
obs[i] = self.envs[i].reset()
self.ts[i] = 0
self.actions = None
return np.array(obs), np.array(rews), np.array(dones), infos
def reset(self):
results = [env.reset() for env in self.envs]
return np.array(results)
def close(self):
return
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/utils/networks.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from braincog.base.node.node import LIFNode
class MLPNetwork(nn.Module):
"""
MLP network (can be used as value or policy)
"""
def __init__(self, input_dim, out_dim, hidden_dim=64, nonlin=F.relu,
constrain_out=False, norm_in=True, discrete_action=True):
"""
Inputs:
input_dim (int): Number of dimensions in input
out_dim (int): Number of dimensions in output
hidden_dim (int): Number of hidden dimensions
nonlin (PyTorch function): Nonlinearity to apply to hidden layers
"""
super(MLPNetwork, self).__init__()
if norm_in: # normalize inputs
self.in_fn = nn.BatchNorm1d(input_dim) #train
# self.in_fn = input_dim #test
self.in_fn.weight.data.fill_(1)
self.in_fn.bias.data.fill_(0)
else:
self.in_fn = lambda x: x
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, out_dim)
self.nonlin = nonlin
if constrain_out and not discrete_action:
# initialize small to prevent saturation
self.fc3.weight.data.uniform_(-3e-3, 3e-3)
self.out_fn = F.tanh
else: # logits for discrete action (will softmax later)
self.out_fn = lambda x: x
def forward(self, X):
"""
Inputs:
X (PyTorch Matrix): Batch of observations
Outputs:
out (PyTorch Matrix): Output of network (actions, values, etc)
"""
h1 = self.nonlin(self.fc1(self.in_fn(X)))
h2 = self.nonlin(self.fc2(h1))
out = self.out_fn(self.fc3(h2))
return out
class BCNoSpikingLIFNode(LIFNode):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, dv: torch.Tensor):
self.integral(dv)
return self.mem
class SNNNetwork(nn.Module):
"""
SNN network (can be used as value or policy or MLE)
"""
def __init__(self, input_dim, out_dim, hidden_dim=64, node=LIFNode, time_window=16,
norm_in=True, output_style='sum'):
"""
Inputs:
input_dim (int): Number of dimensions in input
out_dim (int): Number of dimensions in output
hidden_dim (int): Number of hidden dimensions
nonlin (PyTorch function): Nonlinearity to apply to hidden layers
"""
super(SNNNetwork, self).__init__()
self._threshold = 0.5
self.v_reset = 0.0
self._time_window = time_window
self.output_style = output_style
self._node1 = node(threshold=self._threshold, v_reset=self.v_reset)
self._node2 = node(threshold=self._threshold, v_reset=self.v_reset)
if norm_in: # normalize inputs
self.in_fn = nn.BatchNorm1d(input_dim) #train
self.in_fn.weight.data.fill_(1)
self.in_fn.bias.data.fill_(0)
else:
self.in_fn = lambda x: x
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, out_dim)
if self.output_style == 'sum':
self._out_node = lambda x: x
elif self.output_style == 'voltage':
self._out_node = BCNoSpikingLIFNode()
def reset(self):
for mod in self.modules():
if hasattr(mod, 'n_reset'):
mod.n_reset()
def forward(self, X):
qs = []
self.reset()
for t in range(self._time_window):
x = self.fc1((self.in_fn(X)+0.5)) #train
# x = self.fc1((X + 0.5)) #test
x = self._node1(x)
x = self.fc2(x)
x = self._node2(x)
x = self.fc3(x)
x = self._out_node(x)
qs.append(x)
if self.output_style == 'sum':
outputs = sum(qs) / self._time_window
return outputs
elif self.output_style == 'voltage':
outputs = x
return outputs
================================================
FILE: examples/Social_Cognition/MAToM-SNN/MPE/utils/noise.py
================================================
import numpy as np
# from https://github.com/songrotek/DDPG/blob/master/ou_noise.py
class OUNoise:
def __init__(self, action_dimension, scale=0.1, mu=0, theta=0.15, sigma=0.2):
self.action_dimension = action_dimension
self.scale = scale
self.mu = mu
self.theta = theta
self.sigma = sigma
self.state = np.ones(self.action_dimension) * self.mu
self.reset()
def reset(self):
self.state = np.ones(self.action_dimension) * self.mu
def noise(self):
x = self.state
dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(len(x))
self.state = x + dx
return self.state * self.scale
================================================
FILE: examples/Social_Cognition/MAToM-SNN/README.md
================================================
# MAToM-SNN
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/agents/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/agents/sagent.py
================================================
import numpy as np
import torch
from torch.distributions import Categorical
from braincog.base.encoder.population_coding import PEncoder
from spikingjelly.activation_based import functional
TIMESTEPS = 15
M = 5
# Agent
class Agents:
def __init__(self, args):
self.n_actions = args.n_actions
self.n_agents = args.n_agents
self.obs_shape = args.obs_shape
# encoder
self.pencoder = PEncoder(TIMESTEPS, 'population_voltage')
if args.alg == 'ppo':
from policy.ppo import PPO
self.policy = PPO(args)
if args.alg == 'iql':
from policy.iql import IQL
self.policy = IQL(args)
if args.mode == 'test':
self.policy.load_model(395000)
if args.alg == 'svdn':
from policy.svdn import SVDN
self.policy = SVDN(args)
if args.alg == 'scovdn':
from policy.scovdn import SCOVDN
self.policy = SCOVDN(args)
if args.alg == 'stomvdn':
from policy.stomvdn import SToMVDN
self.policy = SToMVDN(args)
if args.alg == 'scovdn_weight':
from policy.scovdn_weight import SCOVDN_W
self.policy = SCOVDN_W(args)
if args.alg == 'siql':
from policy.siql import SIQL
self.policy = SIQL(args)
if args.alg == 'scoiql':
from policy.scoiql import SCOIQL
self.policy = SCOIQL(args)
if args.alg == 'siql_e':
from policy.siql_encoder import SIQL_E
self.policy = SIQL_E(args)
if args.alg == 'siql_e2':
from policy.siql_encoder2 import SIQL_EE
self.policy = SIQL_EE(args)
if args.alg == 'siql_no_rnn':
from policy.siql_no_rnn import SIQLUR
self.policy = SIQLUR(args)
if args.alg == 'siql_no_rnn2':
from policy.siql_no_rnn2 import SIQLUR2
self.policy = SIQLUR2(args)
self.args = args
def choose_action(self, num_env, obs, last_action, agent_num, avail_actions, epsilon, maven_z=None, evaluate=False):
inputs = obs.copy()
avail_actions_ind = np.nonzero(avail_actions)[0] # index of actions which can be choose
# transform agent_num to onehot vector
agent_id = np.zeros((num_env, self.n_agents))
agent_id[:, agent_num] = 1.
if self.args.last_action:
inputs = np.hstack((inputs, last_action))
if self.args.reuse_network:
inputs = np.hstack((inputs, agent_id))
# transform the shape of inputs from (42,) to (1,42)
inputs = torch.tensor(inputs, dtype=torch.float32).unsqueeze(0) # torch.Size([1, 17])
# init hidden tensor
if self.args.alg == 'siql_e' or self.args.alg == 'siql_e2':
h1_mem = self.policy.eval_h1_mem[:, agent_num, :, :, :, :]
h1_spike = self.policy.eval_h1_spike[:, agent_num, :, :, :, :]
h2_mem = self.policy.eval_h2_mem[:, agent_num, :, :, :, :]
h2_spike = self.policy.eval_h2_spike[:, agent_num, :, :, :, :]
inputs_, _ = self.pencoder(inputs=inputs, num_popneurons=M, VTH=0.99) ###########################################################
inputs = torch.transpose(inputs_, 0, 3)
inputs = inputs.squeeze().unsqueeze(0)
else:
h1_mem = self.policy.eval_h1_mem[:, agent_num, :, :] #
h1_spike = self.policy.eval_h1_spike[:, agent_num, :, :]
h2_mem = self.policy.eval_h2_mem[:, agent_num, :, :]
h2_spike = self.policy.eval_h2_spike[:, agent_num, :, :]
avail_actions = torch.tensor(avail_actions, dtype=torch.float32).unsqueeze(0)
if self.args.cuda:
inputs = inputs.cuda(self.args.device)
h1_mem = h1_mem.cuda(self.args.device)
h1_spike = h1_spike.cuda(self.args.device)
h2_mem = h2_mem.cuda(self.args.device)
h2_spike = h2_spike.cuda(self.args.device)
# get q value
if self.args.alg == 'siql_no_rnn' or self.args.alg == 'siql_no_rnn2':
self.policy.eval_snn.reset()
q_value = self.policy.eval_snn(inputs)
# functional.reset_net(self.policy_sc.eval_snn)
else:
q_value, self.policy.eval_h1_mem[:, agent_num, :], self.policy.eval_h1_spike[:, agent_num, :],\
self.policy.eval_h2_mem[:, agent_num, :], self.policy.eval_h2_spike[:, agent_num, :]= \
self.policy.eval_snn(inputs, h1_mem, h1_spike, h2_mem, h2_spike)
# choose action from q value
# q_value[avail_actions == 0.0] = - float("inf")
if self.args.alg == 'siql_e' or self.args.alg == 'siql_e2':
q_value = q_value.sum(dim=2)
q_value = q_value.sum(dim=2)
if np.random.uniform() < epsilon:
# action = np.random.choice(avail_actions_ind) # action是一个整数
action = torch.tensor([[np.random.choice(avail_actions_ind) for i in range(num_env)]])
else:
action = torch.argmax(q_value, 2)
return action
def _choose_action_from_softmax(self, inputs, avail_actions, epsilon, evaluate=False):
"""
:param_sc inputs: # q_value of all actions
"""
action_num = avail_actions.sum(dim=1, keepdim=True).float().repeat(1, avail_actions.shape[-1]) # num of avail_actions
# 先将Actor网络的输出通过softmax转换成概率分布
prob = torch.nn.functional.softmax(inputs, dim=-1)
# add noise of epsilon
prob = ((1 - epsilon) * prob + torch.ones_like(prob) * epsilon / action_num)
prob[avail_actions == 0] = 0.0 # 不能执行的动作概率为0
"""
不能执行的动作概率为0之后,prob中的概率和不为1,这里不需要进行正则化,因为torch.distributions.Categorical
会将其进行正则化。要注意在训练的过程中没有用到Categorical,所以训练时取执行的动作对应的概率需要再正则化。
"""
if epsilon == 0 and evaluate:
action = torch.argmax(prob)
else:
action = Categorical(prob).sample().long()
return action
def _get_max_episode_len(self, batch):
terminated = batch['TERMINATE']
episode_num = terminated.shape[0]
max_episode_len = 0
for episode_idx in range(episode_num):
for transition_idx in range(self.args.episode_limit):
if terminated[episode_idx, transition_idx, 0] == 1:
if transition_idx + 1 >= max_episode_len:
max_episode_len = transition_idx + 1
break
if max_episode_len == 0: # 防止所有的episode都没有结束,导致terminated中没有1
max_episode_len = self.args.episode_limit
return max_episode_len
def train(self, batch, train_step, epsilon=None): # coma needs epsilon for training
# different episode has different length, so we need to get max length of the batch
max_episode_len = self._get_max_episode_len(batch)
for key in batch.keys():
if key != 'z':
batch[key] = batch[key][:, :max_episode_len]
self.policy.learn(batch, max_episode_len, train_step, epsilon)
if train_step > 0 and train_step % self.args.save_cycle == 0:
self.policy.save_model(train_step)
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/common_sr/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/common_sr/arguments.py
================================================
import argparse
def get_common_args():
parser = argparse.ArgumentParser()
## multiprocessing
parser.add_argument('--process', type=int, default=5, help='multiprocessing')
## the environment setting 'CLASSIC', 'HUNT', 'HARVEST', 'ESCALATION'
parser.add_argument('--ENV', type=str, default='HUNT', help='the version of the game, choose from ["CLASSIC", "HUNT", "HARVEST", "ESCALATION"]')
parser.add_argument('--env_name', type=str, default='stag_stay', help='the version of the game, choose from ["CLASSIC", "HUNT", "HARVEST", "ESCALATION"]')
parser.add_argument('--obs_type', type=str, default='coords', help='Can be "image" for pixel-array based observations, or "coords" for just the entity coordinates')
parser.add_argument('--forage_quantity', type=int, default=2, help='the number of trees')
parser.add_argument('--opponent_policy', type=str, default='random', help='the poliocy of opponent')
parser.add_argument('--replay_dir', type=str, default='', help='absolute path to save the replay')
## The alternative policy ################################################
parser.add_argument('--num_run', type=int, default='4', help='the number of run')
## 'svdn', 'stomvdn'
parser.add_argument('--alg', type=str, default='svdn', help='the algorithm to train the agent')
parser.add_argument('--mode', type=str, default='train', help='the mode')
parser.add_argument('--n_steps', type=int, default=1000000, help='total time steps')#2000000
parser.add_argument('--n_episodes', type=int, default=2, help='the number of episodes before once training')
parser.add_argument('--last_action', type=bool, default=True, help='whether to use the last action to choose action')
parser.add_argument('--reuse_network', type=bool, default=True, help='whether to use one network_sc for all agents_sc')
parser.add_argument('--gamma', type=float, default=0.99, help='discount factor')
parser.add_argument('--epsilon', type=float, default=1.0, help='epsilon factor')
############# "Adam"
parser.add_argument('--optimizer', type=str, default="RMS", help='optimizer')
parser.add_argument('--evaluate_cycle', type=int, default=10, help='how often to evaluate the model')#5000
parser.add_argument('--evaluate_epoch', type=int, default=6, help='number of the epoch to evaluate the agent')#32
## save weights->model/args->log/reward->result/plot
parser.add_argument('--model_dir', type=str, default='./model', help='model directory of the policy_base')
parser.add_argument('--result_dir', type=str, default='./result', help='result directory of the policy_base')#./result#/home/zhaozhuoya/exp2/ToM2_test/result
parser.add_argument('--log_dir', type=str, default='./log', help='args directory')
parser.add_argument('--plot_dir', type=str, default='./plot', help='args directory')
parser.add_argument('--exp_dir', type=str, default='/exp_vdn', help='result directory of the policy_base')
parser.add_argument('--save_model_dir', type=str, default='/199_rnn_net_params_hunt1.pkl', help='load weights and bias')
parser.add_argument('--load_model', type=bool, default=False, help='whether to load the pretrained model')
parser.add_argument('--evaluate', type=bool, default=False, help='whether to evaluate the model')
parser.add_argument('--cuda', type=bool, default=True, help='whether to use the GPU') #True
parser.add_argument('--mini_batch_size', type=int, default=250, help='whether to use the GPU')
args = parser.parse_args()
parser.add_argument('--device', type=str, default='cuda:{}'.format(args.num_run), help='whether to use the GPU') #'cuda:1'
args = parser.parse_args()
return args
# arguments of coma
def get_coma_args(args):
# network_sc
args.rnn_hidden_dim = 64
args.critic_dim = 128
args.lr_actor = 1e-4
args.lr_critic = 1e-3
# epsilon-greedy
# args.epsilon = 0.5
args.anneal_epsilon = 0.00064
args.min_epsilon = 0.02
args.epsilon_anneal_scale = 'episode'
# lambda of td-lambda return
args.td_lambda = 0.8
# how often to save the model
args.save_cycle = 5000
# how often to update the target_net
args.target_update_cycle = 200
# prevent gradient explosion
args.grad_norm_clip = 10
return args
# arguments of vnd、 qmix、 qtran
def get_mixer_args(args):
# network_sc
args.rnn_hidden_dim = 64
args.qmix_hidden_dim = 32
args.two_hyper_layers = False
args.hyper_hidden_dim = 64
args.qtran_hidden_dim = 64
args.ppo_hidden_size = 64
args.lr = 5e-4
# epsilon greedy
# args.epsilon = 1
args.min_epsilon = 0.05
anneal_steps = 50000
args.anneal_epsilon = (args.epsilon - args.min_epsilon) / anneal_steps
args.epsilon_anneal_scale = 'step'
# the number of the train steps in one epoch
args.train_steps = 1
# experience replay
args.batch_size = 32
args.buffer_size = int(5e3)
# how often to save the model
args.save_cycle = 5000
# how often to update the target_net
args.target_update_cycle = 200
# QTRAN lambda
args.lambda_opt = 1
args.lambda_nopt = 1
# prevent gradient explosion
args.grad_norm_clip = 10
# MAVEN
args.noise_dim = 16
args.lambda_mi = 0.001
args.lambda_ql = 1
args.entropy_coefficient = 0.001
return args
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/common_sr/dummy_vec_env.py
================================================
import numpy as np
from .vec_env import VecEnv
from .util import copy_obs_dict, dict_to_obs, obs_space_info
class DummyVecEnv(VecEnv):
"""
VecEnv that does runs multiple environments sequentially, that is,
the step and reset commands are send to one environment at a time.
Useful when debugging and when num_env == 1 (in the latter case,
avoids communication overhead)
"""
def __init__(self, env_fns):
"""
Arguments:
env_fns: iterable of callables functions that build environments
"""
self.envs = [fn() for fn in env_fns]
env = self.envs[0]
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
obs_space = env.observation_space
self.keys, shapes, dtypes = obs_space_info(obs_space)
self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys }
self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
self.buf_infos = [{} for _ in range(self.num_envs)]
self.actions = None
self.spec = self.envs[0].spec
def step_async(self, actions):
listify = True
try:
if len(actions) == self.num_envs:
listify = False
except TypeError:
pass
if not listify:
self.actions = actions
else:
assert self.num_envs == 1, "actions {} is either not a list or has a wrong size - cannot match to {} environments".format(actions, self.num_envs)
self.actions = [actions]
def step_wait(self):
for e in range(self.num_envs):
action = self.actions[e]
# if isinstance(self.envs_sc[e].action_space, spaces.Discrete):
# action = int(action)
obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(action)
if self.buf_dones[e]:
obs = self.envs[e].reset()
self._save_obs(e, obs)
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones),
self.buf_infos.copy())
def reset(self):
for e in range(self.num_envs):
obs = self.envs[e].reset()
self._save_obs(e, obs)
return self._obs_from_buf()
def _save_obs(self, e, obs):
for k in self.keys:
if k is None:
self.buf_obs[k][e] = obs
else:
self.buf_obs[k][e] = obs[k]
def _obs_from_buf(self):
return dict_to_obs(copy_obs_dict(self.buf_obs))
def get_images(self):
return [env.render(mode='rgb_array') for env in self.envs]
def render(self, mode='human'):
if self.num_envs == 1:
return self.envs[0].render(mode=mode)
else:
return super().render(mode=mode)
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/common_sr/multiprocessing_env.py
================================================
# This code is from openai baseline
# https://github.com/openai/baselines/tree/master/baselines/common/vec_env
import numpy as np
from multiprocessing import Process, Pipe
def _flatten_list(l):
assert isinstance(l, (list, tuple))
assert len(l) > 0
assert all([len(l_) > 0 for l_ in l])
return [l__ for l_ in l for l__ in l_]
def worker(remote, parent_remote, env_fn_wrapper):
parent_remote.close()
env = env_fn_wrapper.x()
while True:
cmd, data = remote.recv()
if cmd == 'step':
ob, reward, done, info = env.step(data)
if np.array(done).all():
# if done:
ob = env.reset()
remote.send((ob, reward, done, info))
elif cmd == 'reset':
ob = env.reset()
remote.send(ob)
elif cmd == 'reset_task':
ob = env.reset_task()
remote.send(ob)
elif cmd == 'render':
ob = env.render()
remote.send(ob) #rgb_array
elif cmd == 'close':
remote.close()
break
elif cmd == 'get_spaces':
remote.send((env.observation_space, env.action_space))
else:
raise NotImplementedError
class VecEnv(object):
"""
An abstract asynchronous, vectorized environment.
"""
def __init__(self, num_envs, observation_space, action_space):
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
def reset(self):
"""
Reset all the environments and return an array of
observations, or a tuple of observation arrays.
If step_async is still doing work, that work will
be cancelled and step_wait() should not be called
until step_async() is invoked again.
"""
pass
def step_async(self, actions):
"""
Tell all the environments to start taking a step
with the given actions.
Call step_wait() to get the results of the step.
You should not call this if a step_async run is
already pending.
"""
pass
def step_wait(self):
"""
Wait for the step taken with step_async().
Returns (obs, rews, dones, infos):
- obs: an array of observations, or a tuple of
arrays of observations.
- rews: an array of rewards
- dones: an array of "episode done" booleans
- infos: a sequence of info objects
"""
pass
def close(self):
"""
Clean up the environments' resources.
"""
pass
def step(self, actions):
self.step_async(actions)
return self.step_wait()
def render(self, mode='human'):
imgs = self.get_images()
# bigimg = tile_images(imgs)
# if mode == 'human':
# self.get_viewer().imshow(bigimg) #
# return self.get_viewer().isopen
# elif mode == 'rgb_array':
# return bigimg
# else:
# raise NotImplementedError
def get_images(self):
"""
Return RGB images from each environment
"""
raise NotImplementedError
class CloudpickleWrapper(object):
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
"""
def __init__(self, x):
self.x = x
def __getstate__(self):
import cloudpickle
return cloudpickle.dumps(self.x)
def __setstate__(self, ob):
import pickle
self.x = pickle.loads(ob)
class SubprocVecEnv(VecEnv):
def __init__(self, env_fns, spaces=None):
"""
envs_sc: list of gym environments to run in subprocesses
"""
self.waiting = False
self.closed = False
nenvs = len(env_fns)
self.nenvs = nenvs
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
for p in self.ps:
p.daemon = True # if the main process crashes, we should not cause things to hang
p.start()
for remote in self.work_remotes:
remote.close()
self.remotes[0].send(('get_spaces', None))
observation_space, action_space = self.remotes[0].recv()
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
def step_async(self, actions):
for remote, action in zip(self.remotes, actions):
remote.send(('step', action))
self.waiting = True
def step_wait(self):
results = [remote.recv() for remote in self.remotes]
self.waiting = False
obs, rews, dones, infos = zip(*results)
return np.stack(obs), np.stack(rews), np.stack(dones), infos
def reset(self):
for remote in self.remotes:
remote.send(('reset', None))
return np.stack([remote.recv() for remote in self.remotes])
def reset_task(self):
for remote in self.remotes:
remote.send(('reset_task', None))
return np.stack([remote.recv() for remote in self.remotes])
def get_images(self):
# self._assert_not_closed()
for pipe in self.remotes:
pipe.send(('render', None))
imgs = [pipe.recv() for pipe in self.remotes]
# imgs = _flatten_list(imgs)
return imgs
def close(self):
if self.closed:
return
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(('close', None))
for p in self.ps:
p.join()
self.closed = True
def __len__(self):
return self.nenvs
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/common_sr/replay_buffer.py
================================================
import numpy as np
import threading
class ReplayBuffer:
def __init__(self, args):
self.args = args
self.n_actions = self.args.n_actions
self.n_agents = self.args.n_agents
# self.state_shape = self.args.state_shape
self.obs_shape = self.args.obs_shape
self.size = self.args.buffer_size
self.episode_limit = self.args.episode_limit
# memory management
self.current_idx = 0
self.current_size = 0
# create the buffer to store info
self.buffers = {'O': np.empty([self.size, self.episode_limit, self.n_agents, self.obs_shape]),
'U': np.empty([self.size, self.episode_limit, self.n_agents, 1]),
# 's': np.empty([self.size, self.episode_limit, self.state_shape]),
'R': np.empty([self.size, self.episode_limit, self.n_agents, 1]),
'O_NEXT': np.empty([self.size, self.episode_limit, self.n_agents, self.obs_shape]),
# 's_next': np.empty([self.size, self.episode_limit, self.state_shape]),
'AVAIL_U': np.empty([self.size, self.episode_limit, self.n_agents, self.n_actions]),
'AVAIL_U_NEXT': np.empty([self.size, self.episode_limit, self.n_agents, self.n_actions]),
'U_ONEHOT': np.empty([self.size, self.episode_limit, self.n_agents, self.n_actions]),
'PADDED': np.empty([self.size, self.episode_limit, 1]),
'TERMINATE': np.empty([self.size, self.episode_limit, 1])
}
# thread lock
self.lock = threading.Lock()
# store the episode
def store_episode(self, episode_batch):
batch_size = episode_batch['O'].shape[0] # episode_number
with self.lock:
idxs = self._get_storage_idx(inc=batch_size)
# store the informations
self.buffers['O'][idxs] = episode_batch['O']
self.buffers['U'][idxs] = episode_batch['U']
# self.buffers['s'][idxs] = episode_batch['s']
self.buffers['R'][idxs] = episode_batch['R']
self.buffers['O_NEXT'][idxs] = episode_batch['O_NEXT']
# self.buffers['s_next'][idxs] = episode_batch['s_next']
self.buffers['AVAIL_U'][idxs] = episode_batch['AVAIL_U']
self.buffers['AVAIL_U_NEXT'][idxs] = episode_batch['AVAIL_U_NEXT']
self.buffers['U_ONEHOT'][idxs] = episode_batch['U_ONEHOT']
self.buffers['PADDED'][idxs] = episode_batch['PADDED']
self.buffers['TERMINATE'][idxs] = episode_batch['TERMINATE']
if self.args.alg == 'maven':
self.buffers['z'][idxs] = episode_batch['z']
def sample(self, batch_size):
temp_buffer = {}
idx = np.random.randint(0, self.current_size, batch_size)
for key in self.buffers.keys():
temp_buffer[key] = self.buffers[key][idx]
return temp_buffer
def _get_storage_idx(self, inc=None):
inc = inc or 1
if self.current_idx + inc <= self.size:
idx = np.arange(self.current_idx, self.current_idx + inc)
self.current_idx += inc
elif self.current_idx < self.size:
overflow = inc - (self.size - self.current_idx)
idx_a = np.arange(self.current_idx, self.size)
idx_b = np.arange(0, overflow)
idx = np.concatenate([idx_a, idx_b])
self.current_idx = overflow
else:
idx = np.arange(0, inc)
self.current_idx = inc
self.current_size = min(self.size, self.current_size + inc)
if inc == 1:
idx = idx[0]
return idx
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/common_sr/srollout.py
================================================
import numpy as np
import torch
class RolloutWorker:
def __init__(self, env, agents, args):
self.env = env
self.agents = agents
self.episode_limit = args.episode_limit
self.n_actions = args.n_actions
self.n_agents = args.n_agents
self.obs_shape = args.obs_shape
self.args = args
self.epsilon = args.epsilon
self.anneal_epsilon = args.anneal_epsilon
self.min_epsilon = args.min_epsilon
print('Init RolloutWorker')
def generate_episode(self, episode_num=None, evaluate=False):
# if self.args.alg == 'siql_no_rnn':
# from policy_sc.siql_no_rnn import SIQLUR
# self.policy_sc = SIQLUR(self.args)
# self.policy_sc.eval_snn.reset()
if self.args.replay_dir != '' and evaluate and episode_num == 0: # prepare for save replay of evaluation
self.env.close()
# Store all data
EPISODE = dict(
O = [],
U = [],
R = [],
O_NEXT = [],
U_ONEHOT = [],
AVAIL_U = [],
AVAIL_U_NEXT = [],
PADDED = [],
TERMINATE = [],
)
NUM_EPISODES = self.args.n_episodes if evaluate==False else self.args.evaluate_epoch
episode_num = 0 if evaluate == False else self.args.evaluate_epoch
episode_reward = np.zeros((self.args.process, self.n_agents))
for episode_idx in range(NUM_EPISODES):
# Store one multiprocessing data
o, u, r, avail_u, u_onehot, terminate, padded = [], [], [], [], [], [], []
obs = self.env.reset()
obs1 = obs.copy()
obs1[:, 0], obs1[:, 1], obs1[:, 2], obs1[:, 3] = \
obs[:, 2], obs[:, 3], obs[:, 0], obs[:, 1]
obs_ = (obs, obs1)
obs_ = np.stack((obs, obs1), axis=0).transpose(1, 0, 2)
num_env = obs.shape[0]
last_action = np.zeros((self.args.n_agents, num_env, self.args.n_actions))
self.agents.policy.init_hidden(1, num_env)
terminated = False
win_tag = False
step = 0
# epsilon
epsilon = 0 if evaluate else self.epsilon
if self.args.epsilon_anneal_scale == 'episode':
epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon
# for each episode (include 50 steps and num_env multiprocessing)
while not terminated and step < self.episode_limit:
# time.sleep(0.2)
obs = np.array(obs_) #A perspective, B perspective
avail_action = [1] * self.args.n_actions
actions, avail_actions, actions_onehot = [], [], []
for agent_id in range(self.n_agents):
action = self.agents.choose_action(num_env, obs[:, agent_id, :], last_action[agent_id],
agent_id, avail_action, epsilon, evaluate)
# generate onehot vector of th action
action_onehot = np.zeros((num_env, self.args.n_actions))
for i in range(num_env): action_onehot[i, action[0, i]] = 1
actions.append(action[0].cpu().numpy().tolist()) #np.int(action)
actions_onehot.append(action_onehot)
avail_actions.append(avail_action)
last_action[agent_id] = action_onehot
actions = np.array(actions).transpose(1,0) #[num_env, num_agent](4, 2)
obs_, reward, done, info = self.env.step(actions=actions) #[num_env,num_agent,num_state],[num_env, num_agent],[num_env] (4, 2, 10) (4, 2)
if self.args.load_model == True:
self.env.render(mode="human")
print(reward)
o.append(obs)
u.append(np.expand_dims(actions, self.n_agents))
u_onehot.append(actions_onehot)
avail_u.append(avail_actions)
r.append(np.expand_dims(reward, 2))
terminate.append(np.expand_dims(np.array([terminated]*num_env), 1))
padded.append(np.expand_dims(np.array([0.]*num_env), 1))
episode_reward = episode_reward + reward
step += 1
if self.args.epsilon_anneal_scale == 'step':
epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon
# last obs
obs = np.array(obs_)
o.append(obs)
o_next = o[1:]
o = o[:-1]
# get avail_action for last obs,because target_q needs avail_action in training
avail_actions = []
for agent_id in range(self.n_agents):
avail_action = [1] * self.args.n_actions
avail_actions.append(avail_action)
avail_u.append(avail_actions)
avail_u_next = avail_u[1:]
avail_u = avail_u[:-1]
# if step < self.episode_limit,padding (if termined before the max steps, add data to max steps)
for i in range(step, self.episode_limit):
o.append(np.zeros((self.n_agents, self.obs_shape)))
u.append(np.zeros([self.n_agents, 1]))
r.append(np.zeros([self.n_agents, 1]))
o_next.append(np.zeros((self.n_agents, self.obs_shape)))
u_onehot.append(np.zeros((self.n_agents, self.n_actions)))
avail_u.append(np.zeros((self.n_agents, self.n_actions)))
avail_u_next.append(np.zeros((self.n_agents, self.n_actions)))
padded.append([1.]*num_env)
terminate.append([1.]*num_env)
# Processing data for each episode
EPISODE['O'].append(np.stack(o, axis=0).transpose(1, 0, 2, 3))
EPISODE['U'].append(np.stack(u, axis=0).transpose(1, 0, 2, 3).astype(int))
EPISODE['R'].append(np.stack(r, axis=0).transpose(1, 0, 2, 3))
EPISODE['O_NEXT'].append(np.stack(o_next, axis=0).transpose(1, 0, 2, 3))
EPISODE['U_ONEHOT'].append(np.stack(u_onehot, axis=0).transpose(2, 0, 1, 3))
EPISODE['AVAIL_U'].append(np.ones(EPISODE['U_ONEHOT'][0].shape))
EPISODE['AVAIL_U_NEXT'].append(np.ones(EPISODE['U_ONEHOT'][0].shape))
EPISODE['PADDED'].append(np.stack(padded, axis=0).transpose(1, 0, 2))
EPISODE['TERMINATE'].append(np.stack(terminate, axis=0).transpose(1, 0, 2))
episode_reward = episode_reward.sum(0)
for i in EPISODE.keys():
EPISODE[i] = np.concatenate(EPISODE[i], axis=0)
step = step * self.args.n_episodes * num_env
if not evaluate:
self.epsilon = epsilon
if evaluate and episode_num == self.args.evaluate_epoch and self.args.replay_dir != '':
self.env.save_replay()
self.env.close()
return EPISODE, episode_reward, win_tag, step
def generate_episode_sample(self, episodes, steps, episode_num=None, evaluate=False):
if self.args.replay_dir != '' and evaluate and episode_num == 0: # prepare for save replay of evaluation
self.env.close()
o, u, r, avail_u, u_onehot, terminate, padded = [], [], [], [], [], [], []
obs = self.env.reset()
obs_ = (obs, self.env.game._flip_coord_observation_perspective(obs)) # A perspective, B perspective
terminated = False
win_tag = False
step = 0
episode_reward = (0, 0) # cumulative rewards
# ###
# for param_sc in self.agents_sc.policy_base.parameters():
# param_sc.requires_grad = False
# self.agents_sc.policy_base.eval()
last_action = np.zeros((self.args.n_agents, self.args.n_actions))
self.agents.policy.init_hidden(1)
# epsilon
epsilon = 0 if evaluate else self.epsilon
if self.args.epsilon_anneal_scale == 'episode':
epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon
while not terminated and step < self.episode_limit:
# time.sleep(0.2)
obs = np.array(obs_) #A perspective, B perspective
avail_action = [1] * self.args.n_actions
actions, avail_actions, actions_onehot = [], [], []
for agent_id in range(self.n_agents):
action = self.agents.choose_action(obs[agent_id], last_action[agent_id], agent_id,
avail_action, epsilon, evaluate)
# generate onehot vector of th action
action_onehot = np.zeros(self.args.n_actions)
action_onehot[action] = 1
actions.append(np.int(action))
actions_onehot.append(action_onehot)
avail_actions.append(avail_action)
last_action[agent_id] = action_onehot
obs_, reward, done, info = self.env.step(actions=actions)
# print(actions,reward)
win_tag = True if terminated else False
# save obs, actions, avail_actions, reward at time t
o.append(obs)
u.append(np.reshape(actions, [self.n_agents, 1]))
u_onehot.append(actions_onehot)
avail_u.append(avail_actions)
r.append(np.reshape(reward, [self.n_agents, 1])) #reward
terminate.append([terminated])
padded.append([0.])
# episode_reward += reward
episode_reward = [episode_reward[i] + reward[i] for i in range(min(len(episode_reward), len(reward)))]
step += 1
if self.args.epsilon_anneal_scale == 'step':
epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon
if self.args.load_model == True:
self.env.render(mode="human")
# last obs
obs = np.array(obs_)
o.append(obs)
o_next = o[1:]
o = o[:-1]
# get avail_action for last obs,because target_q needs avail_action in training
avail_actions = []
for agent_id in range(self.n_agents):
avail_action = [1] * self.args.n_actions
avail_actions.append(avail_action)
avail_u.append(avail_actions)
avail_u_next = avail_u[1:]
avail_u = avail_u[:-1]
# if step < self.episode_limit,padding
for i in range(step, self.episode_limit):
o.append(np.zeros((self.n_agents, self.obs_shape)))
u.append(np.zeros([self.n_agents, 1]))
r.append(np.zeros([self.n_agents, 1]))
o_next.append(np.zeros((self.n_agents, self.obs_shape)))
u_onehot.append(np.zeros((self.n_agents, self.n_actions)))
avail_u.append(np.zeros((self.n_agents, self.n_actions)))
avail_u_next.append(np.zeros((self.n_agents, self.n_actions)))
padded.append([1.])
terminate.append([1.])
episode = dict(o=o.copy(),
u=u.copy(),
r=r.copy(),
avail_u=avail_u.copy(),
o_next=o_next.copy(),
avail_u_next=avail_u_next.copy(),
u_onehot=u_onehot.copy(),
padded=padded.copy(),
terminated=terminate.copy()
)
episodes[episode_num] = episode
steps[episode_num] = step
# add episode dim
for key in episode.keys():
episode[key] = np.array([episode[key]])
if not evaluate:
self.epsilon = epsilon
if evaluate and episode_num == self.args.evaluate_epoch - 1 and self.args.replay_dir != '':
self.env.save_replay()
self.env.close()
# return episode, episode_reward, win_tag, step
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/envs/Stag_Hunt_env.py
================================================
import gym
import gym_stag_hunt
from ray import tune
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
from gym_stag_hunt.envs.pettingzoo.hunt import raw_env
if __name__ == "__main__":
def env_creator(args):
return PettingZooEnv(raw_env(**args))
tune.register_env("StagHunt-Hunt-PZ-v0", env_creator)
model = tune.run(
"DQN",
name="stag_hunt",
stop={"episodes_total": 10000},
checkpoint_freq=100,
checkpoint_at_end=True,
config={
"horizon": 100,
"framework": "tf2",
# Environment specific
"env": "StagHunt-Hunt-PZ-v0",
# General
"num_workers": 2,
# Method specific
"multiagent": {
"policies": {"player_0", "player_1"},
"policy_mapping_fn": (lambda agent_id, episode, **kwargs: agent_id),
"policies_to_train": ["player_0", "player_1"]
},
# Env Specific
"env_config": {
"obs_type": "coords",
"forage_reward": 1.0,
"stag_reward": 5.0,
"stag_follows": True,
"mauling_punishment": -.5,
"enable_multiagent": True,
}
}
)
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/envs/__init__.py
================================================
from ToM2.envs.grid_env1 import *
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/envs/abstract.py
================================================
"""
Implements abstract class for meta-reinforcement learning environments.
"""
from typing import Generic, TypeVar, Tuple
import abc
ObsType = TypeVar('ObsType')
class MetaEpisodicEnv(abc.ABC, Generic[ObsType]):
@property
@abc.abstractmethod
def max_episode_len(self) -> int:
"""
Return the maximum episode length.
"""
pass
@abc.abstractmethod
def new_env(self) -> None:
"""
Reset the environment's structure by resampling
the state transition probabilities and/or reward function
from a prior distribution.
Returns:
None
"""
pass
@abc.abstractmethod
def reset(self) -> ObsType:
"""
Resets the environment's state to some designated initial state.
This is distinct from resetting the environment's structure
via self.new_env().
Returns:
initial observation.
"""
pass
@abc.abstractmethod
def step(
self,
action: int,
auto_reset: bool = True
) -> Tuple[ObsType, float, bool, dict]:
"""
Step the env.
Args:
action: integer action indicating which action to take
auto_reset: whether or not to automatically reset the environment
on done. if true, next observation will be given by self.reset()
Returns:
next observation, reward, and done flat
"""
pass
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/envs/constants.py
================================================
"""
Global constants for the gridworld env.
"""
# =============================================================================
# set the value of interface
# =============================================================================
FPS = 25
WinWidth = 340 #window width
WinHeight = 260 #window width
BoxSize = 20 #the size of one grid
GridWidth = 7 #the number of lattices are there in the x-axis
GridHeight = 7 #the number of lattices are there in the y-axis
XMargin = int((WinWidth - GridWidth * BoxSize)/2)
TopMargin = int((WinHeight - GridHeight * BoxSize))/2-5
# =============================================================================
# set color
# =============================================================================
White = (255, 255, 255)
Gray = (185, 185, 185)
Black = (0, 0, 0)
Red = (255, 0, 0)
Green = (0, 128, 0)
SpringGreen = (60, 179, 113)
DarkOrange = (255, 140, 0)
RoyalBlue = (65, 105, 225)
DarkVoilet = (148, 0, 211)
HotPink = (255, 105, 180)
BoardColor = White
BGColor = White
TextColor = White
Test = []
# =============================================================================
# set maps
# =============================================================================
# BlankBox = 1
# shadow = 0
# Wall = 5
# Obstacle = 5
# observer = 8
# button = 7
# obeservation_1 = 11
# obeservation_2 = 22
# obeservation_3 = 33
"""
'S' : starting point
'F' or '.': free space
'W' or 'x': wall
'H' or 'o': hole (terminates episode)
'G' : goal
"""
Start = 'S'
Free_space = 'F'
Wall = 'W'
Danger = 'H'
Goal = 'G'
Shadow = 'Sh'
MAPs = {
0: [
"FFFFF",
"FHFWF",
"FFFFF",
"WFFFF",
"FFFGF"
],
1: [
"FFFFF",
"FHWFF",
"FFFFF",
"WFGFF",
"FFFFF"
],
2: [
"FFFF",
"FWFW",
"FFFW",
"WFFG"
],
3: [
"FFFF",
"FHFW",
"FFFW",
"GFFF"
],
4: [
"FFFF",
"FHFW",
"FFFW",
"WGFF"
],
}
# 2: [
# "SFFFFFFF",
# "FFFFFFFF",
# "FFFHFFFF",
# "FFFFFWFF",
# "FFFHFFFF",
# "FWHFFFWF",
# "FWFFHFWF",
# "FFFWFFFG"
# ],
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/main_spiking.py
================================================
from common_sr.arguments import get_common_args, get_coma_args, get_mixer_args
from common_sr.multiprocessing_env import SubprocVecEnv
from runner import Runner
from time import sleep
from gym_stag_hunt.envs.gym.escalation import EscalationEnv
from gym_stag_hunt.envs.gym.harvest import HarvestEnv
from gym_stag_hunt.envs.gym.hunt import HuntEnv
from gym_stag_hunt.envs.gym.simple import SimpleEnv
from gym_stag_hunt.src.games.abstract_grid_game import UP, LEFT, DOWN, RIGHT, STAND
import json
import os
os.environ["SDL_VIDEODRIVER"] = "dummy"
if __name__ == '__main__':
ENVS = {
"CLASSIC": SimpleEnv,
"HUNT": HuntEnv,
"HARVEST": HarvestEnv,
"ESCALATION": EscalationEnv,
}
args = get_common_args()
args = get_mixer_args(args)
if args.ENV == 'HUNT':
args.n_actions = 5 # [5] # up, down, left, right or stand
args.n_agents = 2 # [2]
args.obs_shape = 6 + args.forage_quantity * 2
elif args.ENV == 'ESCALATION':
args.n_actions = 5 # [5] # up, down, left, right or stand
args.n_agents = 2 # [2]
args.obs_shape = 6
elif args.ENV == 'HARVEST':
args.n_actions = 5 # [5] # up, down, left, right or stand
args.n_agents = 2 # [2]
args.obs_shape = 6 + args.forage_quantity * 5
args.episode_limit = 50
args.train_steps = 100
save_path = args.log_dir + '/' + args.alg + args.exp_dir
print(os.path.exists(save_path))
if not os.path.exists(save_path):
os.makedirs(save_path)
# save args
argsDict = args.__dict__
with open(save_path + '/args_{}'.format(args.num_run), 'w') as f:
f.writelines('------------------ start ------------------' + '\n')
for eachArg, value in argsDict.items():
f.writelines(eachArg + ' : ' + str(value) + '\n')
f.writelines('------------------- end -------------------')
def make_env():
def _thunk():
if args.ENV == 'HUNT':
env = ENVS[args.ENV](obs_type="coords", enable_multiagent=True, opponent_policy="random", \
forage_quantity=args.forage_quantity, run_away_after_maul=True)
elif args.ENV == 'ESCALATION':
env = ENVS[args.ENV](obs_type="coords", enable_multiagent=True)
elif args.ENV == 'HARVEST':
env = ENVS[args.ENV](obs_type="coords", enable_multiagent=True)
return env
return _thunk
# for i in range(args.num_run):
envs = [make_env() for i in range(args.process)]
envs = SubprocVecEnv(envs)
runner = Runner(envs, args)
if not args.evaluate:
runner.run(args.num_run)
else:
win_rate, _ = runner.evaluate()
print('The win rate of {} is {}'.format(args.alg, win_rate))
envs.close()
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/network/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/network/spiking_net.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as f
from torch.distributions import Normal
from braincog.base.node.node import IFNode, LIFNode
from braincog.base.strategy.surrogate import AtanGrad
thresh = 0.3
lens = 0.25
decay = 0.3
TIMESTEPS = 15
M = 5
# BrainCog
class BCNoSpikingLIFNode(LIFNode):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, dv: torch.Tensor):
self.integral(dv)
return self.mem
class BCNoSpikingIFNode(IFNode):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, dv: torch.Tensor):
self.integral(dv)
return self.mem
# Sug
class ActFun(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.gt(thresh).float()
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
temp = abs(input - thresh) < lens
return grad_input * temp.float() / (2 * lens)
#act_fun = ActFun.apply
act_fun = AtanGrad(alpha=2.,requires_grad=False)
def mem_update(fc, x, mem, spike):
mem = mem * decay * (1 - spike) + fc(x)
#spike = act_fun(mem)
spike = act_fun(x=mem-1)
return mem, spike
class Critic(nn.Module):
def __init__(self, input_shape, args):
super(Critic, self).__init__()
self.args = args
self.fc1 = nn.Linear(input_shape, args.ppo_hidden_size)
self.fc2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim, bias = True)
self.fc3 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim, bias = True)
self.fc4 = nn.Linear(args.rnn_hidden_dim, args.n_actions)#
self.req_grad = False
def forward(self, inputs, h1_mem, h1_spike, h2_mem, h2_spike):
# if self.req_grad == False:
# [1, 17] -> [1, process, 64]
x = self.fc1(inputs)
# x = IFNode()(x)
x = LIFNode()(x)
if self.args.alg == 'siql_e':
h1_mem = h1_mem.reshape(-1, M,TIMESTEPS, self.args.rnn_hidden_dim)
h1_spike = h1_spike.reshape(-1, M,TIMESTEPS, self.args.rnn_hidden_dim)
h2_mem = h2_mem.reshape(-1, M, TIMESTEPS, self.args.rnn_hidden_dim)
h2_spike = h2_spike.reshape(-1, M, TIMESTEPS, self.args.rnn_hidden_dim)
else:
# [1, 64] -> [process, 64]
h1_mem = h1_mem.reshape(-1, self.args.rnn_hidden_dim)
h1_spike = h1_spike.reshape(-1, self.args.rnn_hidden_dim)
h2_mem = h2_mem.reshape(-1, self.args.rnn_hidden_dim)
h2_spike = h2_spike.reshape(-1, self.args.rnn_hidden_dim)
h1_mem, h1_spike = mem_update(self.fc2, x, h1_mem, h1_spike)
h2_mem, h2_spike = mem_update(self.fc3, h1_spike, h2_mem, h2_spike)
# [1, 5]
value = BCNoSpikingLIFNode(tau=2.0)(self.fc4(h2_mem))
return value, h1_mem, h1_spike, h2_mem, h2_spike
class VDNNet(nn.Module):
def __init__(self):
super(VDNNet, self).__init__()
def forward(self, q_values):
return torch.sum(q_values, dim=2, keepdim=True)
class Linear_weight(nn.Module):
def __init__(self, input_shape, out_shape, args):
super(Linear_weight,self).__init__()
self.args = args
# self.fc = nn.Linear(input_shape, out_shape)
self.alpha = nn.Parameter(torch.Tensor(out_shape))
def forward(self, x):
# return self.fc(x)
if self.args.alg == 'scovdn_weight':
x = x[:,:,:,0] * self.alpha + x[:,:,:,1] * (1 - self.alpha)
return x.unsqueeze(3)
elif self.args.alg == 'stomvdn':
x = x[:, :, :, :, 0] * self.alpha + x[:, :, :, :, 1] * (1 - self.alpha)
return x.unsqueeze(4)
class BiasNet(nn.Module):
def __init__(self, args):
super(BiasNet, self).__init__()
self.args = args
input_shape = self.args.obs_shape + self.args.rnn_hidden_dim
#
# self.h1_mem = self.h1_spike = torch.zeros(self.args.n_episodes * self.args.process,
# self.args.episode_limit, self.args.rnn_hidden_dim)
# if self.args.cuda:
# self.h1_mem = self.h1_mem.cuda(self.args.device)
# self.h1_spike = self.h1_spike.cuda(self.args.device)
self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)#neuron.IFNode()
self.fc2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim, bias = True)
self.fc3 = nn.Linear(args.rnn_hidden_dim, 1)#
def reset(self, episode_num):
self.h1_mem = self.h1_spike = torch.zeros(episode_num,
self.args.episode_limit, self.args.rnn_hidden_dim)
if self.args.cuda:
self.h1_mem = self.h1_mem.cuda(self.args.device)
self.h1_spike = self.h1_spike.cuda(self.args.device)
def forward(self, state, hidden):
episode_num, max_episode_len, n_agents, _ = hidden.shape
state = state.reshape(episode_num * max_episode_len, -1)
state = state * 0.2
hidden = \
hidden.reshape(episode_num * max_episode_len, n_agents, -1).sum(dim=-2)
inputs = torch.cat([state, hidden], dim=-1)
x = self.fc1(inputs)
x = neuron.IFNode()(x)
# x = IFNode()(x)
# x = LIFNode()(x) #bad
# [1, 64] -> [process, 64]
self.h1_mem = self.h1_mem.reshape(-1, self.args.rnn_hidden_dim)
self.h1_spike = self.h1_spike.reshape(-1, self.args.rnn_hidden_dim)
self.h1_mem, self.h1_spike = mem_update(self.fc2, x, self.h1_mem, self.h1_spike)
# [1, 5]
# value = NonSpikingLIFNode(tau=2.0)(self.fc4(h2_mem))
# value = BCNoSpikingLIFNode(tau=2.0)(self.fc4(h2_mem))
value = BCNoSpikingIFNode(tau=2.0)(self.fc3(self.h1_mem))
return value
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/policy/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/policy/dqn.py
================================================
import torch
import os
from network.base_net import RNN
class DQN:
def __init__(self, args, model_eval, model_target, agent_id):
self.n_actions = args.n_actions
self.n_agents = args.n_agents
self.obs_shape = args.obs_shape
self.agent_id = agent_id
input_shape = self.obs_shape
# 根据参数决定RNN的输入维度
if args.last_action:
input_shape += self.n_actions
# if args.reuse_network:
# input_shape += self.n_agents
# 神经网络
self.eval_rnn = model_eval
self.target_rnn = model_target
self.args = args
if self.args.cuda:
self.eval_rnn.cuda(self.args.device)
self.target_rnn.cuda(self.args.device)
# self.model_dir = args.model_dir + '/' + args.alg
self.model_dir = '/home/zhaozhuoya/exp2/MARL_test_exp/model' + '/' + args.alg + args.exp_dir#+ args.model_dir
# 如果存在模型则加载模型
if self.args.load_model:
if os.path.exists(self.model_dir + args.save_model_dir):
# path_snn = '/home/zhaozhuoya/exp2/ToM2/model/iql/199_rnn_net_params.pkl'
map_location = self.args.device if self.args.cuda else 'cpu'
self.eval_rnn.load_state_dict(torch.load(self.model_dir + args.save_model_dir, map_location=map_location))
# self.eval_rnn.load_state_dict(torch.load(self.model_dir))
print('Successfully load the model: {}'.format(self.model_dir + args.save_model_dir))
else:
raise Exception("No model!")
# 让target_net和eval_net的网络参数相同
self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
self.eval_parameters = list(self.eval_rnn.parameters())
if args.optimizer == "RMS":
self.optimizer = torch.optim.RMSprop(self.eval_parameters, lr=args.lr)
# 执行过程中,要为每个agent都维护一个eval_hidden
# 学习过程中,要为每个episode的每个agent都维护一个eval_hidden、target_hidden
self.eval_hidden = None
self.target_hidden = None
print('Init alg DQN')
def learn(self, batch, max_episode_len, train_step, epsilon=None): # train_step表示是第几次学习,用来控制更新target_net网络的参数
'''
在learn的时候,抽取到的数据是四维的,四个维度分别为 1——第几个episode 2——episode中第几个transition
3——第几个agent的数据 4——具体obs维度。因为在选动作时不仅需要输入当前的inputs,还要给神经网络输入hidden_state,
hidden_state和之前的经验相关,因此就不能随机抽取经验进行学习。所以这里一次抽取多个episode,然后一次给神经网络
传入每个episode的同一个位置的transition
'''
episode_num = batch['O'].shape[0]
self.init_hidden_learn(episode_num)
# hidden_state = self.policy.eval_hidden[:, self.agent_id, :, :]
eval_hidden, target_hidden = \
self.eval_hidden[:, self.agent_id, :], self.target_hidden[:, self.agent_id, :]
# for key in batch.keys(): # 把batch里的数据转化成tensor
# if key == 'u':
# batch[key] = torch.tensor(batch[key], dtype=torch.long)
# else:
# batch[key] = torch.tensor(batch[key], dtype=torch.float32)
# u, r, avail_u, avail_u_next, terminated = batch['u'], batch['r'].squeeze(-1), batch['avail_u'], \
# batch['avail_u_next'], batch['terminated'].repeat(1, 1, self.n_agents)
# mask = (1 - batch["padded"].float()).repeat(1, 1, self.n_agents) # 用来把那些填充的经验的TD-error置0,从而不让它们影响到学习
for key in batch.keys(): # 把batch里的数据转化成tensor
if key == 'O':
batch[key] = torch.tensor(batch[key], dtype=torch.long)
else:
batch[key] = torch.tensor(batch[key], dtype=torch.float32)
u, r, avail_u, avail_u_next, terminated = batch['U'], batch['R'].squeeze(-1), batch['AVAIL_U'], \
batch['AVAIL_U_NEXT'], batch['TERMINATE'].repeat(1, 1, self.n_agents)
mask = (1 - batch["PADDED"].float()).repeat(1, 1, self.n_agents) # 用来把那些填充的经验的TD-error置0,从而不让它们影响到学习
# 得到每个agent对应的Q值,维度为(episode个数, max_episode_len, n_agents, n_actions)
q_evals, q_targets = self.get_q_values(batch, max_episode_len, eval_hidden, target_hidden)
if self.args.cuda:
u = u.cuda(self.args.device)
r = r.cuda(self.args.device)
terminated = terminated.cuda(self.args.device)
mask = mask.cuda(self.args.device)
# 取每个agent动作对应的Q值,并且把最后不需要的一维去掉,因为最后一维只有一个值了
u = u.to(torch.int64)
q_evals = torch.gather(q_evals, dim=3, index=u[:, :, self.agent_id, :].unsqueeze(3)).squeeze(3)
# 得到target_q
q_targets[avail_u_next[:, :, self.agent_id, :].unsqueeze(2) == 0.0] = - 9999999
q_targets = q_targets.max(dim=3)[0]
targets = r[:, :, self.agent_id].unsqueeze(2) + self.args.gamma * q_targets * (1 - terminated[:, :, self.agent_id].unsqueeze(2))
td_error = (q_evals - targets.detach())
masked_td_error = mask[:, :, self.agent_id].unsqueeze(2) * td_error # 抹掉填充的经验的td_error
# 不能直接用mean,因为还有许多经验是没用的,所以要求和再比真实的经验数,才是真正的均值
loss = (masked_td_error ** 2).sum() / mask.sum()
# print('loss is ', loss)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.eval_parameters, self.args.grad_norm_clip)
self.optimizer.step()
if train_step > 0 and train_step % self.args.target_update_cycle == 0:
self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
return loss
def _get_inputs(self, batch, transition_idx):
# 取出所有episode上该transition_idx的经验,u_onehot要取出所有,因为要用到上一条
# obs, obs_next, u_onehot = batch['o'][:, transition_idx], \
# batch['o_next'][:, transition_idx], batch['u_onehot'][:]
# 取出所有episode上该transition_idx的经验,u_onehot要取出所有,因为要用到上一条
obs, obs_next, u_onehot = batch['O'][:, transition_idx], \
batch['O_NEXT'][:, transition_idx], batch['U_ONEHOT'][:]
episode_num = obs.shape[0]
inputs, inputs_next = [], []
inputs.append(obs[:, self.agent_id, :])
inputs_next.append(obs_next[:, self.agent_id, :])
# 给obs添加上一个动作、agent编号
if self.args.last_action:
if transition_idx == 0: # 如果是第一条经验,就让前一个动作为0向量
inputs.append(torch.zeros_like(u_onehot[:, :, self.agent_id, :][:, transition_idx]))
else:
inputs.append(u_onehot[:, :, self.agent_id, :][:, transition_idx - 1])
inputs_next.append(u_onehot[:, :, self.agent_id, :][:, transition_idx])
# if self.args.reuse_network:
# 因为当前的obs三维的数据,每一维分别代表(episode编号,agent编号,obs维度),直接在dim_1上添加对应的向量
# 即可,比如给agent_0后面加(1, 0, 0, 0, 0),表示5个agent中的0号。而agent_0的数据正好在第0行,那么需要加的
# agent编号恰好就是一个单位矩阵,即对角线为1,其余为0
# inputs.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))
# inputs_next.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))
# inputs.append(torch.zeros((self.args.n_agents, self.args.n_agents)).unsqueeze(0).expand(episode_num, -1, -1))
# inputs_next.append(torch.zeros((self.args.n_agents, self.args.n_agents)).unsqueeze(0).expand(episode_num, -1, -1))
# 要把obs中的三个拼起来,并且要把episode_num个episode、self.args.n_agents个agent的数据拼成40条(40,96)的数据,
# 因为这里所有agent共享一个神经网络,每条数据中带上了自己的编号,所以还是自己的数据
inputs = torch.cat([x for x in inputs], dim=1)
inputs_next = torch.cat([x for x in inputs_next], dim=1)
return inputs, inputs_next
def get_q_values(self, batch, max_episode_len, eval_hidden, target_hidden):
# episode_num = batch['o'].shape[0]
episode_num = batch['O'].shape[0]
q_evals, q_targets = [], []
for transition_idx in range(max_episode_len):
inputs, inputs_next = self._get_inputs(batch, transition_idx) # 给obs加last_action、agent_id
if self.args.cuda:
inputs = inputs.cuda(self.args.device)
inputs_next = inputs_next.cuda(self.args.device)
eval_hidden = eval_hidden.cuda(self.args.device)
target_hidden = target_hidden.cuda(self.args.device)
q_eval, self.eval_hidden = self.eval_rnn(inputs, eval_hidden) # inputs维度为(40,96),得到的q_eval维度为(40,n_actions)
q_target, self.target_hidden = self.target_rnn(inputs_next, target_hidden)
# 把q_eval维度重新变回(8, 5,n_actions)
q_eval = q_eval.view(episode_num, 1, -1)
q_target = q_target.view(episode_num, 1, -1)
q_evals.append(q_eval)
q_targets.append(q_target)
# 得的q_eval和q_target是一个列表,列表里装着max_episode_len个数组,数组的的维度是(episode个数, n_agents,n_actions)
# 把该列表转化成(episode个数, max_episode_len, n_agents,n_actions)的数组
q_evals = torch.stack(q_evals, dim=1)
q_targets = torch.stack(q_targets, dim=1)
return q_evals, q_targets
def init_hidden(self, episode_num, num_env):
# 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden
self.eval_hidden = torch.zeros((episode_num, self.n_agents, num_env,self.args.rnn_hidden_dim))
self.target_hidden = torch.zeros((episode_num, self.n_agents, num_env,self.args.rnn_hidden_dim))
def init_hidden_learn(self, episode_num):
# 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden
self.eval_hidden = torch.zeros((episode_num, self.n_agents, self.args.rnn_hidden_dim))
self.target_hidden = torch.zeros((episode_num, self.n_agents, self.args.rnn_hidden_dim))
def save_model(self, train_step):
num = str(train_step // self.args.save_cycle)
if not os.path.exists(self.model_dir):
os.makedirs(self.model_dir)
torch.save(self.eval_rnn.state_dict(), self.model_dir + '/' + num + '_rnn_net_params.pkl')
def load_model(self, train_step):
num = str(train_step // self.args.save_cycle)
path = torch.load(self.model_dir + '/' + num + '_rnn_net_params.pkl')
self.eval_rnn.load_state_dict(path)
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/policy/stomvdn.py
================================================
import torch
import os
from network.spiking_net import Critic, VDNNet, Linear_weight, BiasNet
import copy
class SToMVDN:
def __init__(self, args):
self.n_actions = args.n_actions
self.n_agents = args.n_agents
self.obs_shape = args.obs_shape
input_shape = self.obs_shape
# 根据参数决定RNN的输入维度
if args.last_action:
input_shape += self.n_actions
if args.reuse_network:
input_shape += self.n_agents
self.loss_trade_off_target = 0
self.loss_trade_off_eval = 0
# 神经网络
self.eval_snn = Critic(input_shape, args)
self.target_snn = Critic(input_shape, args)
self.eval_vdn_snn = VDNNet() # 把agentsQ值加起来的网络
self.target_vdn_snn = VDNNet()
self.bias_net = BiasNet(args)
self.trade_off_net = Linear_weight(2, 1, args)
self.args = args
if self.args.cuda:
self.eval_snn.cuda(self.args.device)
self.target_snn.cuda(self.args.device)
self.eval_vdn_snn.cuda(self.args.device)
self.target_vdn_snn.cuda(self.args.device)
self.trade_off_net.cuda(self.args.device)
self.bias_net.cuda(self.args.device)
self.model_dir = args.model_dir + '/' + args.alg + args.exp_dir + args.save_model_dir
# 如果存在模型则加载模型
if self.args.load_model:
if os.path.exists(self.model_dir):
# path_snn = '/home/zhaozhuoya/exp2/ToM2_test/model/siql/199_snn_net_params.pkl'
map_location = self.args.device if self.args.cuda else 'cpu'
self.eval_snn.load_state_dict(torch.load(self.model_dir, map_location=map_location))
print('Successfully load the model: {}'.format(self.model_dir))
else:
print(self.model_dir)
raise Exception("No model!")
# 让target_net和eval_net的网络参数相同
self.target_snn.load_state_dict(self.eval_snn.state_dict())
self.target_vdn_snn.load_state_dict(self.eval_vdn_snn.state_dict())
self.eval_parameters = list(self.eval_snn.parameters()) + \
list(self.eval_vdn_snn.parameters()) + \
list(self.trade_off_net.parameters()) + \
list(self.bias_net.parameters())
self.trade_off_parameters = list(self.trade_off_net.parameters())
if args.optimizer == "RMS":
self.optimizer = torch.optim.RMSprop(self.eval_parameters, lr=args.lr)
self.optimizer_T = torch.optim.RMSprop(self.trade_off_parameters, lr=args.lr)
# 执行过程中,要为每个agent都维护一个eval_hidden
# 学习过程中,要为每个episode的每个agent都维护一个eval_hidden、target_hidden
self.eval_h1_mem, self.eval_h1_spike = None, None
self.target_h1_mem, self.target_h1_spike = None, None
self.eval_h2_mem, self.eval_h2_spike = None, None
self.target_h2_mem, self.target_h2_spike = None, None
print('Init alg SCOVDN_ToM')
def learn(self, batch, max_episode_len, train_step, epsilon=None): # train_step表示是第几次学习,用来控制更新target_net网络的参数
'''
在learn的时候,抽取到的数据是四维的,四个维度分别为 1——第几个episode 2——episode中第几个transition
3——第几个agent的数据 4——具体obs维度。因为在选动作时不仅需要输入当前的inputs,还要给神经网络输入hidden_state,
hidden_state和之前的经验相关,因此就不能随机抽取经验进行学习。所以这里一次抽取多个episode,然后一次给神经网络
传入每个episode的同一个位置的transition
'''
episode_num = batch['O'].shape[0]
self.init_hidden_learn(episode_num)
self.bias_net.reset(episode_num)
for key in batch.keys(): # 把batch里的数据转化成tensor
if key == 'U': # 'O'
batch[key] = torch.tensor(batch[key], dtype=torch.long)
else:
batch[key] = torch.tensor(batch[key], dtype=torch.float32)
u, r, avail_u, avail_u_next, terminated = batch['U'], batch['R'].squeeze(-1), batch['AVAIL_U'], \
batch['AVAIL_U_NEXT'], batch['TERMINATE'].repeat(1, 1, self.n_agents)
mask = (1 - batch["PADDED"].float()).repeat(1, 1, self.n_agents) # 用来把那些填充的经验的TD-error置0,从而不让它们影响到学习
if self.args.cuda:
u = u.cuda(self.args.device)
r = r.cuda(self.args.device)
terminated = terminated.cuda(self.args.device)
mask = mask.cuda(self.args.device)
# self.bias_net.cuda(self.args.device)
u = u.to(torch.int64)
# ---------------------------------------independent_Q_net------------------------------------------------------
# 得到每个agent对应的Q值,维度为(episode个数, max_episode_len, n_agents, n_actions)
q_evals, q_targets, hidden_evals, hidden_targets = self.get_q_values(batch, max_episode_len)
# ---------------------------------------independent_Q_net------------------------------------------------------
# --------------------------------------------bias_net----------------------------------------------------------
# 得到每个agent对应的Q值,维度为(episode个数, max_episode_len, n_agents, n_actions)
v = self.get_bias(batch, hidden_evals, hidden_targets, episode_num)
# --------------------------------------------bias_net----------------------------------------------------------
# ---------------------------------------_self+Q_other_net------------------------------------------------------
q_other_evals, q_other_targets = q_evals[:, :, [1, 0], :].unsqueeze(4), q_targets[:, :, [1, 0], :].unsqueeze(4)
q_evals_, q_targets_ = q_evals.unsqueeze(4), q_targets.unsqueeze(4)
q_total_evals = torch.cat((q_evals_, q_other_evals), 4)
q_total_targets = torch.cat((q_targets_, q_other_targets), 4)
# q_total_evals = self.trade_off_net(q_total_evals) #([10, 50, 2, 5, 1])
# q_total_targets = self.trade_off_net(q_total_targets) # ([10, 50, 2, 5, 1])
q_total_evals = q_evals_ + q_other_targets
q_total_targets = q_targets_ + q_other_targets
# --------------------------------------------_self+Q_other_net-------------------------------------------------
# --------------------------------------------L_self/other------------------------------------------------------
q_total_targets[avail_u_next == 0.0] = - 9999999
q_total_targets = q_total_targets.max(dim=3)[0].squeeze()
q_total_evals = torch.gather(q_total_evals.squeeze(4), dim=3, index=u).squeeze(3)
y = r + self.args.gamma * q_total_targets * (1 - terminated)
td_error = q_total_evals - y.detach()
l_so = ((td_error * mask) ** 2).sum() / mask.sum()
# --------------------------------------------L_self/other------------------------------------------------------
# --------------------------------------------action_prob_Q-----------------------------------------------------
# probablity of action
action_prob = self._get_action_prob(batch, max_episode_len, 0.4) # 每个agent的所有动作的概率self.args.epsilon
pi_taken = torch.gather(action_prob, dim=3, index=u).squeeze(3) # 每个agent的选择的动作对应的概率
pi_taken[mask == 0] = 1.0 # 因为要取对数,对于那些填充的经验,所有概率都为0,取了log就是负无穷了,所以让它们变成1
log_pi_taken = torch.log(pi_taken)
# --------------------------------------------action_prob_Q-----------------------------------------------------
# ----------------------------------------------L_coma----------------------------------------------------------
# q_evals = torch.gather(q_evals * action_prob, dim=3, index=u).squeeze(3)
q_evals_coma = (q_evals * action_prob).sum(dim=3, keepdim=True).squeeze(3)
coma_error = q_evals_coma.sum(dim=-1) - q_total_targets.detach().sum(dim=-1) + v
l_coma = ((coma_error * mask[:,:,0]) ** 2).sum() / mask[:,:,0].sum()
# ----------------------------------------------L_coma----------------------------------------------------------
# -----------------------------------------------L_sum----------------------------------------------------------
q_evals_sum = self.eval_vdn_snn(q_evals)
sum_error = q_evals_sum.sum(dim=-1).squeeze(2) - q_total_targets.detach().sum(dim=-1) + v
l_sum = ((sum_error * mask[:,:,0]) ** 2).sum() / mask[:,:,0].sum()
# -----------------------------------------------L_sum----------------------------------------------------------
LOSS = l_so + l_coma + l_sum
self.optimizer.zero_grad()
LOSS.backward()
if train_step > 0 and train_step % self.args.target_update_cycle == 0:
self.target_snn.load_state_dict(self.eval_snn.state_dict())
self.target_vdn_snn.load_state_dict(self.eval_vdn_snn.state_dict())
return LOSS
def _get_inputs(self, batch, transition_idx):
# 取出所有episode上该transition_idx的经验,u_onehot要取出所有,因为要用到上一条
obs, obs_next, u_onehot = batch['O'][:, transition_idx], \
batch['O_NEXT'][:, transition_idx], batch['U_ONEHOT'][:]
episode_num = obs.shape[0]
inputs, inputs_next = [], []
inputs.append(obs)
inputs_next.append(obs_next)
# 给obs添加上一个动作、agent编号
if self.args.last_action:
if transition_idx == 0: # 如果是第一条经验,就让前一个动作为0向量
inputs.append(torch.zeros_like(u_onehot[:, transition_idx]))
else:
inputs.append(u_onehot[:, transition_idx - 1])
inputs_next.append(u_onehot[:, transition_idx])
if self.args.reuse_network:
# 因为当前的obs三维的数据,每一维分别代表(episode编号,agent编号,obs维度),直接在dim_1上添加对应的向量
# 即可,比如给agent_0后面加(1, 0, 0, 0, 0),表示5个agent中的0号。而agent_0的数据正好在第0行,那么需要加的
# agent编号恰好就是一个单位矩阵,即对角线为1,其余为0
inputs.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))
inputs_next.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))
# 要把obs中的三个拼起来,并且要把episode_num个episode、self.args.n_agents个agent的数据拼成40条(40,96)的数据,
# 因为这里所有agent共享一个神经网络,每条数据中带上了自己的编号,所以还是自己的数据
inputs = torch.cat([x.reshape(episode_num * self.args.n_agents, -1) for x in inputs], dim=1)
inputs_next = torch.cat([x.reshape(episode_num * self.args.n_agents, -1) for x in inputs_next], dim=1)
return inputs, inputs_next
def get_q_values(self, batch, max_episode_len):
episode_num = batch['O'].shape[0]
q_evals, q_targets, eval_h2_mems, target_h2_mems = [], [], [], []
for transition_idx in range(max_episode_len):
inputs, inputs_next = self._get_inputs(batch, transition_idx) # 给obs加last_action、agent_id
if self.args.cuda:
inputs = inputs.cuda(self.args.device)
inputs_next = inputs_next.cuda(self.args.device)
self.eval_h1_mem, self.eval_h1_spike, self.eval_h2_mem, self.eval_h2_spike = \
self.eval_h1_mem.cuda(self.args.device), self.eval_h1_spike.cuda(
self.args.device), self.eval_h2_mem.cuda(self.args.device), self.eval_h2_spike.cuda(
self.args.device)
self.target_h1_mem, self.target_h1_spike, self.target_h2_mem, self.target_h2_spike = \
self.target_h1_mem.cuda(self.args.device), self.target_h1_spike.cuda(
self.args.device), self.target_h2_mem.cuda(self.args.device), self.target_h2_spike.cuda(
self.args.device)
q_eval, self.eval_h1_mem, self.eval_h1_spike, self.eval_h2_mem, self.eval_h2_spike = \
self.eval_snn(inputs, self.eval_h1_mem, self.eval_h1_spike, self.eval_h2_mem,
self.eval_h2_spike) # inputs维度为(40,96),得到的q_eval维度为(40,n_actions)
q_target, self.target_h1_mem, self.target_h1_spike, self.target_h2_mem, self.target_h2_spike = \
self.target_snn(inputs_next, self.target_h1_mem, self.target_h1_spike, self.target_h2_mem,
self.target_h2_spike)
# 把q_eval维度重新变回(8, 5,n_actions)
q_eval = q_eval.view(episode_num, self.n_agents, -1)
q_target = q_target.view(episode_num, self.n_agents, -1)
eval_h2_mem = self.eval_h2_mem.view(episode_num, self.n_agents, -1)
target_h2_mem = self.target_h2_mem.view(episode_num, self.n_agents, -1)
q_evals.append(q_eval)
q_targets.append(q_target)
eval_h2_mems.append(eval_h2_mem)
target_h2_mems.append(target_h2_mem)
# 得的q_eval和q_target是一个列表,列表里装着max_episode_len个数组,数组的的维度是(episode个数, n_agents,n_actions)
# 把该列表转化成(episode个数, max_episode_len, n_agents,n_actions)的数组
q_evals = torch.stack(q_evals, dim=1)
q_targets = torch.stack(q_targets, dim=1)
hidden_evals = torch.stack(eval_h2_mems, dim=1)
hidden_targets = torch.stack(target_h2_mems, dim=1)
return q_evals, q_targets, hidden_evals, hidden_targets
def get_bias(self, batch, hidden_evals, hidden_targets, episode_num, hat=False):
# episode_num, max_episode_len, _, _ = hidden_targets.shape
max_episode_len = self.args.episode_limit
states = batch['O'][:, :max_episode_len]
states_next = batch['O_NEXT'][:, :max_episode_len]
u_onehot = batch['U_ONEHOT'][:, :max_episode_len]
if self.args.cuda:
states = states.cuda(self.args.device)[:,:,0,:]
states_next = states_next.cuda(self.args.device)[:,:,0,:]
u_onehot = u_onehot.cuda(self.args.device)
hidden_evals = hidden_evals.cuda(self.args.device)
hidden_targets = hidden_targets.cuda(self.args.device)
if hat:
v = None
else:
v = self.bias_net(states, hidden_evals)
# 把q_eval、q_target、v维度变回(episode_num, max_episode_len)
v = v.view(episode_num, -1, 1).squeeze(-1)
return v
def _get_actor_inputs(self, batch, transition_idx):
# 取出所有episode上该transition_idx的经验,u_onehot要取出所有,因为要用到上一条
obs, u_onehot = batch['O'][:, transition_idx], batch['U_ONEHOT'][:]
episode_num = obs.shape[0]
inputs = []
inputs.append(obs)
# 给inputs添加上一个动作、agent编号
if self.args.last_action:
if transition_idx == 0: # 如果是第一条经验,就让前一个动作为0向量
inputs.append(torch.zeros_like(u_onehot[:, transition_idx]))
else:
inputs.append(u_onehot[:, transition_idx - 1])
if self.args.reuse_network:
# 因为当前的inputs三维的数据,每一维分别代表(episode编号,agent编号,inputs维度),直接在dim_1上添加对应的向量
# 即可,比如给agent_0后面加(1, 0, 0, 0, 0),表示5个agent中的0号。而agent_0的数据正好在第0行,那么需要加的
# agent编号恰好就是一个单位矩阵,即对角线为1,其余为0
inputs.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))
# 要把inputs中的三个拼起来,并且要把episode_num个episode、self.args.n_agents个agent的数据拼成40条(40,96)的数据,
# 因为这里所有agent共享一个神经网络,每条数据中带上了自己的编号,所以还是自己的数据
inputs = torch.cat([x.reshape(episode_num * self.args.n_agents, -1) for x in inputs], dim=1)
return inputs
def _get_action_prob(self, batch, max_episode_len, epsilon):
episode_num = batch['O'].shape[0]
avail_actions = batch['AVAIL_U']
action_prob = []
for transition_idx in range(max_episode_len):
inputs = self._get_actor_inputs(batch, transition_idx) # 给obs加last_action、agent_id
if self.args.cuda:
inputs = inputs.cuda(self.args.device)
# self.eval_hidden = self.eval_hidden.cuda(self.args.device)
self.eval_h1_mem, self.eval_h1_spike, self.eval_h2_mem, self.eval_h2_spike = \
self.eval_h1_mem.cuda(self.args.device), self.eval_h1_spike.cuda(
self.args.device), self.eval_h2_mem.cuda(self.args.device), self.eval_h2_spike.cuda(
self.args.device)
self.target_h1_mem, self.target_h1_spike, self.target_h2_mem, self.target_h2_spike = \
self.target_h1_mem.cuda(self.args.device), self.target_h1_spike.cuda(
self.args.device), self.target_h2_mem.cuda(self.args.device), self.target_h2_spike.cuda(
self.args.device)
# outputs, self.eval_hidden = self.eval_snn(inputs, self.eval_hidden) # inputs维度为(40,96),得到的q_eval维度为(40,n_actions)
outputs, self.eval_h1_mem, self.eval_h1_spike, self.eval_h2_mem, self.eval_h2_spike = \
self.eval_snn(inputs, self.eval_h1_mem, self.eval_h1_spike, self.eval_h2_mem,
self.eval_h2_spike) # inputs维度为(40,96),得到的q_eval维度为(40,n_actions)
# 把q_eval维度重新变回(8, 5,n_actions)
outputs = outputs.view(episode_num, self.n_agents, -1)
prob = torch.nn.functional.softmax(outputs, dim=-1)
action_prob.append(prob)
# 得的action_prob是一个列表,列表里装着max_episode_len个数组,数组的的维度是(episode个数, n_agents,n_actions)
# 把该列表转化成(episode个数, max_episode_len, n_agents,n_actions)的数组
action_prob = torch.stack(action_prob, dim=1).cpu()
action_num = avail_actions.sum(dim=-1, keepdim=True).float().repeat(1, 1, 1,
avail_actions.shape[-1]) # 可以选择的动作的个数
action_prob = ((1 - epsilon) * action_prob + torch.ones_like(action_prob) * epsilon / action_num)
action_prob[avail_actions == 0] = 0.0 # 不能执行的动作概率为0
# 因为上面把不能执行的动作概率置为0,所以概率和不为1了,这里要重新正则化一下。执行过程中Categorical会自己正则化。
action_prob = action_prob / action_prob.sum(dim=-1, keepdim=True)
# 因为有许多经验是填充的,它们的avail_actions都填充的是0,所以该经验上所有动作的概率都为0,在正则化的时候会得到nan。
# 因此需要再一次将该经验对应的概率置为0
action_prob[avail_actions == 0] = 0.0
if self.args.cuda:
action_prob = action_prob.cuda(self.args.device)
return action_prob
def init_hidden(self, episode_num, num_env):
# 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden
self.eval_h1_mem = self.eval_h1_spike = torch.zeros(episode_num, self.n_agents, num_env,
self.args.rnn_hidden_dim)
self.target_h1_mem = self.target_h1_spike = torch.zeros(episode_num, self.n_agents, num_env,
self.args.rnn_hidden_dim)
self.eval_h2_mem = self.eval_h2_spike = torch.zeros(episode_num, self.n_agents, num_env,
self.args.rnn_hidden_dim)
self.target_h2_mem = self.target_h2_spike = torch.zeros(episode_num, self.n_agents, num_env,
self.args.rnn_hidden_dim)
def init_hidden_learn(self, episode_num):
# 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden
self.eval_h1_mem = self.eval_h1_spike = torch.zeros(episode_num, self.n_agents,
self.args.rnn_hidden_dim)
self.target_h1_mem = self.target_h1_spike = torch.zeros(episode_num, self.n_agents,
self.args.rnn_hidden_dim)
self.eval_h2_mem = self.eval_h2_spike = torch.zeros(episode_num, self.n_agents,
self.args.rnn_hidden_dim)
self.target_h2_mem = self.target_h2_spike = torch.zeros(episode_num, self.n_agents,
self.args.rnn_hidden_dim)
def save_model(self, train_step):
num = str(train_step // self.args.save_cycle)
if not os.path.exists(self.model_dir):
os.makedirs(self.model_dir)
torch.save(self.eval_snn.state_dict(),
self.model_dir + '/' + num + '_snn_net_params_{}.pkl'.format(self.args.num_run))
def load_model(self, train_step):
num = str(train_step // self.args.save_cycle)
path = torch.load(self.model_dir + '/' + num + '_snn_net_params.pkl'.format(self.args.num_run))
self.eval_snn.load_state_dict(path)
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/policy/svdn.py
================================================
import torch
import os
from network.spiking_net import Critic, VDNNet
class SVDN:
def __init__(self, args):
self.n_actions = args.n_actions
self.n_agents = args.n_agents
self.obs_shape = args.obs_shape
input_shape = self.obs_shape
# 根据参数决定RNN的输入维度
if args.last_action:
input_shape += self.n_actions
if args.reuse_network:
input_shape += self.n_agents
# 神经网络
self.eval_snn = Critic(input_shape, args)
self.target_snn = Critic(input_shape, args)
self.eval_vdn_snn = VDNNet() # 把agentsQ值加起来的网络
self.target_vdn_snn = VDNNet()
self.args = args
if self.args.cuda:
self.eval_snn.cuda(self.args.device)
self.target_snn.cuda(self.args.device)
self.eval_vdn_snn.cuda(self.args.device)
self.target_vdn_snn.cuda(self.args.device)
self.model_dir = args.model_dir + '/' + args.alg + args.exp_dir + args.save_model_dir
# 如果存在模型则加载模型
if self.args.load_model:
if os.path.exists(self.model_dir):
# path_snn = '/home/zhaozhuoya/exp2/ToM2_test/model/siql/199_snn_net_params.pkl'
map_location = self.args.device if self.args.cuda else 'cpu'
self.eval_snn.load_state_dict(torch.load(self.model_dir, map_location=map_location))
print('Successfully load the model: {}'.format(self.model_dir))
else:
print(self.model_dir)
raise Exception("No model!")
# 让target_net和eval_net的网络参数相同
self.target_snn.load_state_dict(self.eval_snn.state_dict())
self.target_vdn_snn.load_state_dict(self.eval_vdn_snn.state_dict())
self.eval_parameters = list(self.eval_snn.parameters()) + list(self.eval_vdn_snn.parameters())
if args.optimizer == "RMS":
self.optimizer = torch.optim.RMSprop(self.eval_parameters, lr=args.lr)
# 执行过程中,要为每个agent都维护一个eval_hidden
# 学习过程中,要为每个episode的每个agent都维护一个eval_hidden、target_hidden
self.eval_h1_mem, self.eval_h1_spike = None, None
self.target_h1_mem, self.target_h1_spike = None, None
self.eval_h2_mem, self.eval_h2_spike = None, None
self.target_h2_mem, self.target_h2_spike = None, None
print('Init alg SVDN')
def learn(self, batch, max_episode_len, train_step, epsilon=None): # train_step表示是第几次学习,用来控制更新target_net网络的参数
'''
在learn的时候,抽取到的数据是四维的,四个维度分别为 1——第几个episode 2——episode中第几个transition
3——第几个agent的数据 4——具体obs维度。因为在选动作时不仅需要输入当前的inputs,还要给神经网络输入hidden_state,
hidden_state和之前的经验相关,因此就不能随机抽取经验进行学习。所以这里一次抽取多个episode,然后一次给神经网络
传入每个episode的同一个位置的transition
'''
episode_num = batch['O'].shape[0]
self.init_hidden_learn(episode_num)
for key in batch.keys(): # 把batch里的数据转化成tensor
if key == 'U':
batch[key] = torch.tensor(batch[key], dtype=torch.long)
else:
batch[key] = torch.tensor(batch[key], dtype=torch.float32)
u, r, avail_u, avail_u_next, terminated = batch['U'], batch['R'].squeeze(-1), batch['AVAIL_U'], \
batch['AVAIL_U_NEXT'], batch['TERMINATE'].repeat(1, 1, self.n_agents)
mask = (1 - batch["PADDED"].float()).repeat(1, 1, self.n_agents) # 用来把那些填充的经验的TD-error置0,从而不让它们影响到学习
# 得到每个agent对应的Q值,维度为(episode个数, max_episode_len, n_agents, n_actions)
q_evals, q_targets = self.get_q_values(batch, max_episode_len)
if self.args.cuda:
u = u.cuda(self.args.device)
r = r.cuda(self.args.device)
terminated = terminated.cuda(self.args.device)
mask = mask.cuda(self.args.device)
# 取每个agent动作对应的Q值,并且把最后不需要的一维去掉,因为最后一维只有一个值了
u = u.to(torch.int64)
q_evals = torch.gather(q_evals, dim=3, index=u).squeeze(3)
# 得到target_q
q_targets[avail_u_next == 0.0] = - 9999999
q_targets = q_targets.max(dim=3)[0]
q_total_eval = self.eval_vdn_snn(q_evals)
q_total_target = self.target_vdn_snn(q_targets)
targets = r + self.args.gamma * q_total_target * (1 - terminated)
td_error = targets.detach() - q_total_eval
masked_td_error = mask * td_error # 抹掉填充的经验的td_error
# 不能直接用mean,因为还有许多经验是没用的,所以要求和再比真实的经验数,才是真正的均值
loss = (masked_td_error ** 2).sum() / mask.sum()
# print('loss is ', loss)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.eval_parameters, self.args.grad_norm_clip)
self.optimizer.step()
if train_step > 0 and train_step % self.args.target_update_cycle == 0:
self.target_snn.load_state_dict(self.eval_snn.state_dict())
self.target_vdn_snn.load_state_dict(self.eval_vdn_snn.state_dict())
return loss
def _get_inputs(self, batch, transition_idx):
# 取出所有episode上该transition_idx的经验,u_onehot要取出所有,因为要用到上一条
obs, obs_next, u_onehot = batch['O'][:, transition_idx], \
batch['O_NEXT'][:, transition_idx], batch['U_ONEHOT'][:]
episode_num = obs.shape[0]
inputs, inputs_next = [], []
inputs.append(obs)
inputs_next.append(obs_next)
# 给obs添加上一个动作、agent编号
if self.args.last_action:
if transition_idx == 0: # 如果是第一条经验,就让前一个动作为0向量
inputs.append(torch.zeros_like(u_onehot[:, transition_idx]))
else:
inputs.append(u_onehot[:, transition_idx - 1])
inputs_next.append(u_onehot[:, transition_idx])
if self.args.reuse_network:
# 因为当前的obs三维的数据,每一维分别代表(episode编号,agent编号,obs维度),直接在dim_1上添加对应的向量
# 即可,比如给agent_0后面加(1, 0, 0, 0, 0),表示5个agent中的0号。而agent_0的数据正好在第0行,那么需要加的
# agent编号恰好就是一个单位矩阵,即对角线为1,其余为0
inputs.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))
inputs_next.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))
# 要把obs中的三个拼起来,并且要把episode_num个episode、self.args.n_agents个agent的数据拼成40条(40,96)的数据,
# 因为这里所有agent共享一个神经网络,每条数据中带上了自己的编号,所以还是自己的数据
inputs = torch.cat([x.reshape(episode_num * self.args.n_agents, -1) for x in inputs], dim=1)
inputs_next = torch.cat([x.reshape(episode_num * self.args.n_agents, -1) for x in inputs_next], dim=1)
return inputs, inputs_next
def get_q_values(self, batch, max_episode_len):
episode_num = batch['O'].shape[0]
q_evals, q_targets = [], []
for transition_idx in range(max_episode_len):
inputs, inputs_next = self._get_inputs(batch, transition_idx) # 给obs加last_action、agent_id
if self.args.cuda:
inputs = inputs.cuda(self.args.device)
inputs_next = inputs_next.cuda(self.args.device)
self.eval_h1_mem, self.eval_h1_spike, self.eval_h2_mem, self.eval_h2_spike = \
self.eval_h1_mem.cuda(self.args.device), self.eval_h1_spike.cuda(self.args.device), self.eval_h2_mem.cuda(self.args.device), self.eval_h2_spike.cuda(self.args.device)
self.target_h1_mem, self.target_h1_spike, self.target_h2_mem, self.target_h2_spike = \
self.target_h1_mem.cuda(self.args.device), self.target_h1_spike.cuda(self.args.device), self.target_h2_mem.cuda(self.args.device), self.target_h2_spike.cuda(self.args.device)
q_eval, self.eval_h1_mem, self.eval_h1_spike, self.eval_h2_mem, self.eval_h2_spike = \
self.eval_snn(inputs, self.eval_h1_mem, self.eval_h1_spike, self.eval_h2_mem, self.eval_h2_spike) # inputs维度为(40,96),得到的q_eval维度为(40,n_actions)
q_target, self.target_h1_mem, self.target_h1_spike, self.target_h2_mem, self.target_h2_spike = \
self.target_snn(inputs_next, self.target_h1_mem, self.target_h1_spike, self.target_h2_mem, self.target_h2_spike)
# 把q_eval维度重新变回(8, 5,n_actions)
q_eval = q_eval.view(episode_num, self.n_agents, -1)
q_target = q_target.view(episode_num, self.n_agents, -1)
q_evals.append(q_eval)
q_targets.append(q_target)
# 得的q_eval和q_target是一个列表,列表里装着max_episode_len个数组,数组的的维度是(episode个数, n_agents,n_actions)
# 把该列表转化成(episode个数, max_episode_len, n_agents,n_actions)的数组
q_evals = torch.stack(q_evals, dim=1)
q_targets = torch.stack(q_targets, dim=1)
return q_evals, q_targets
def init_hidden(self, episode_num, num_env):
# 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden
self.eval_h1_mem = self.eval_h1_spike = torch.zeros(episode_num, self.n_agents, num_env,
self.args.rnn_hidden_dim)
self.target_h1_mem = self.target_h1_spike = torch.zeros(episode_num, self.n_agents, num_env,
self.args.rnn_hidden_dim)
self.eval_h2_mem = self.eval_h2_spike = torch.zeros(episode_num, self.n_agents, num_env,
self.args.rnn_hidden_dim)
self.target_h2_mem = self.target_h2_spike = torch.zeros(episode_num, self.n_agents, num_env,
self.args.rnn_hidden_dim)
def init_hidden_learn(self, episode_num):
# 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden
self.eval_h1_mem = self.eval_h1_spike = torch.zeros(episode_num, self.n_agents,
self.args.rnn_hidden_dim)
self.target_h1_mem = self.target_h1_spike = torch.zeros(episode_num, self.n_agents,
self.args.rnn_hidden_dim)
self.eval_h2_mem = self.eval_h2_spike = torch.zeros(episode_num, self.n_agents,
self.args.rnn_hidden_dim)
self.target_h2_mem = self.target_h2_spike = torch.zeros(episode_num, self.n_agents,
self.args.rnn_hidden_dim)
def save_model(self, train_step):
num = str(train_step // self.args.save_cycle)
if not os.path.exists(self.model_dir):
os.makedirs(self.model_dir)
torch.save(self.eval_snn.state_dict(), self.model_dir + '/' + num + '_snn_net_params_{}.pkl'.format(self.args.num_run))
def load_model(self, train_step):
num = str(train_step // self.args.save_cycle)
path = torch.load(self.model_dir + '/' + num + '_snn_net_params.pkl'.format(self.args.num_run))
self.eval_snn.load_state_dict(path)
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/preprocessoing/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/preprocessoing/common.py
================================================
"""
preprocess
"""
from typing import Union
import abc
import torch as tc
import numpy as np
class Preprocessing(abc.ABC, tc.nn.Module):
def forward(
self,
curr_obs: Union[tc.LongTensor, tc.FloatTensor],
prev_action: tc.LongTensor,
prev_reward: tc.FloatTensor,
prev_done: tc.FloatTensor
) -> tc.FloatTensor:
"""
Creates an input vector for a meta-learning agent.
Args:
curr_obs: either tc.LongTensor or tc.FloatTensor of shape [B, ...].
prev_action: tc.LongTensor of shape [B, ...]
prev_reward: tc.FloatTensor of shape [B, ...]
prev_done: tc.FloatTensor of shape [B, ...]
Returns:
tc.FloatTensor of shape [B, ..., ?]
"""
pass
def one_hot_torch(ys: tc.LongTensor, depth: int, device) -> tc.FloatTensor:
"""
Applies one-hot encoding to a batch of vectors.
Args:
ys: tc.LongTensor of shape [B].
depth: int specifying the number of possible y values.
Returns:
the one-hot encodings of tensor ys.
"""
vecs_shape = list(ys.shape) + [depth]
vecs = tc.zeros(dtype=tc.float32, size=vecs_shape).to(device)
vecs.scatter_(dim=-1, index=ys.unsqueeze(-1),
src=tc.ones(dtype=tc.float32, size=vecs_shape).to(device))
return vecs.float()
def one_hot(ys: int, depth: int) -> list:
"""
Applies one-hot encoding to a batch of vectors.
Args:
ys: tc.LongTensor of shape [B].
depth: int specifying the number of possible y values.
Returns:
the one-hot encodings of tensor ys.
"""
letter = [0 for _ in range(depth)]
letter[ys-1] = 1
letter = np.array(letter)
# print(letter)
return letter
================================================
FILE: examples/Social_Cognition/MAToM-SNN/STAG/runner.py
================================================
import os
import matplotlib.pyplot as plt
import numpy as np
import torch.multiprocessing as mp
from common_sr.srollout import RolloutWorker
from agents.sagent import Agents
from common_sr.replay_buffer import ReplayBuffer
import time
from tqdm import tqdm
class Runner:
def __init__(self, env, args):
self.env = env
self.agents = Agents(args)
self.rolloutWorker = RolloutWorker(env, self.agents, args)
if not args.evaluate:
self.buffer = ReplayBuffer(args)
self.args = args
self.win_rates = []
self.episode_rewards = []
# 用来保存plt和pkl
self.save_path = self.args.result_dir + '/' + args.alg + args.exp_dir
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
def run(self, num):
time_steps, train_steps, evaluate_steps = 0, 0, -1
pbar = tqdm(self.args.n_steps)
if self.args.load_model == False:
while time_steps < self.args.n_steps:
# print('Run {}, time_steps {}'.format(num, time_steps))
if time_steps // self.args.evaluate_cycle > evaluate_steps:
win_rate, episode_reward = self.evaluate()
# episode_reward = [i for i in [2, 3]]
self.episode_rewards.append(episode_reward)
# self.plt(time_steps // self.args.evaluate_cycle)
# print(time_steps // self.args.evaluate_cycle)
evaluate_steps += self.args.evaluate_epoch
# 收集self.args.n_episodes个episodes
episodes = []
start = time.time()
episode_batch, _, _, steps = self.rolloutWorker.generate_episode()
end = time.time()
# print(end - start, 'sample with multiprocessing:', self.args.process)
time_steps += steps
pbar.update(steps)
self.buffer.store_episode(episode_batch)
start = time.time()
for train_step in range(self.args.train_steps):
mini_batch = self.buffer.sample(min(self.buffer.current_size, self.args.batch_size))
if self.args.alg.find('o') > -1:
self.agents.train(mini_batch, train_steps, self.args.epsilon)
else:
self.agents.train(mini_batch, train_steps)
train_steps += 1
end = time.time()
# print(end - start, 'training')
pbar.close()
win_rate, episode_reward = self.evaluate()
# print('win_rate is ', win_rate)
self.win_rates.append(win_rate)
self.episode_rewards.append(episode_reward)
if self.args.load_model == False:
self.plt(num)
def evaluate(self):
win_number = 0
episode_rewards = (0, 0) # cumulative rewards
_, episode_rewards, win_tag, _ = self.rolloutWorker.generate_episode(evaluate=True)
episode_rewards = [episode_rewards[i] / self.args.evaluate_epoch / self.args.process for i in range(len(episode_rewards))]
return win_number / self.args.evaluate_epoch, episode_rewards
def plt(self, num):
# plt.figure()
# plt.ylim([0, 105])
# plt.cla()
# plt.subplot(2, 1, 1)
# plt.plot(range(len(self.win_rates)), self.win_rates)
# plt.xlabel('step*{}'.format(self.args.evaluate_cycle))
# plt.ylabel('win_rates')
#
# plt.subplot(2, 1, 2)
# plt.plot(range(len(self.episode_rewards)), self.episode_rewards)
# plt.xlabel('step*{}'.format(self.args.evaluate_cycle))
# plt.ylabel('episode_rewards')
#
# plt.savefig(self.save_path + '/plt_{}.png'.format(num), format='png')
# np.save(self.save_path + '/win_rates_{}'.format(num), self.win_rates)
# np.save(self.save_path + '/episode_rewards_{}'.format(num), self.episode_rewards)
# plt.close()
# plt.figure()
# plt.ylim([0, 105])
# plt.cla()
# plt.plot(2, 1, 1)
# plt.plot(range(len(self.episode_rewards)), self.episode_rewards[0][0])
# plt.xlabel('step*{}'.format(self.args.evaluate_cycle))
# plt.ylabel('episode_rewards_A')
#
# plt.plot(2, 1, 2)
# plt.plot(range(len(self.episode_rewards)), self.episode_rewards[0][1])
# plt.xlabel('step*{}'.format(self.args.evaluate_cycle))
# plt.ylabel('episode_rewards_B')
# plt.savefig(self.save_path + '/plt_{}.png'.format(num), format='png')
# np.save(self.save_path + '/win_rates_{}'.format(num), self.win_rates)
np.save(self.save_path + '/episode_rewards_{}'.format(num), self.episode_rewards) #
# print(self.episode_rewards)
# plt.close()
================================================
FILE: examples/Social_Cognition/ReadMe.md
================================================
================================================
FILE: examples/Social_Cognition/SmashVat/dqn.py
================================================
import os
import time
import random
from itertools import count
from collections import namedtuple, deque
import numpy as np
import pandas as pd
import torch
import imageio
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from side_effect_eval import *
from qnets import *
from environment import *
Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))
class ReplayBuffer:
def __init__(self, capacity):
self.capacity = capacity
self.memory = deque(maxlen=capacity)
def push(self, *args):
self.memory.append(Transition(*args))
if len(self.memory) > self.capacity:
self.memory.popleft()
def sample(self, batch_size):
batch = random.sample(self.memory, batch_size)
return batch
class AnseEmpDQN:
def __init__(self, env, net_type='SNN',
init_buffer_size=10000, replay_buffer_size=100000, batch_size=100, target_update_interval=1000,
weight_sep=20, weight_emp_impact=None):
self.env = env
input_dim = 3
output_dim = env.action_space.n
assert net_type in ['SNN', 'ANN']
self.net_type = net_type
if self.net_type == 'SNN':
self.policy_net = SNNQnet(input_dim, output_dim).cuda()
self.target_net = SNNQnet(input_dim, output_dim).cuda()
elif self.net_type == 'ANN':
self.policy_net = CNNQnet(input_dim, output_dim).cuda()
self.target_net = CNNQnet(input_dim, output_dim).cuda()
self.init_buffer_size = init_buffer_size
self.replay_buffer_size = replay_buffer_size
self.replay_buffer = ReplayBuffer(replay_buffer_size)
self.batch_size = batch_size
self.target_update_interval = target_update_interval
# empathy
self.num_others = env.num_humans
self.baseline = StepwiseInactionModel(noop_action=env.actions.noop)
self.baselines_others = [StepwiseInactionModel(noop_action=env.actions.noop)] * self.num_others
self.deviation = AttainableUtilityMeasure(uf_num=30, uf_discount=0.99)
self.weight_sep = weight_sep
self.weight_emp_impact = weight_emp_impact if weight_emp_impact is not None else weight_sep
def state2tensor(self, state):
state_arr = self.env._decode(state)
array = np.zeros([3] + list(self.env.p.map_shape))
array[0] = self.env.map # [0]为环境信息
for i, pos in enumerate(self.env.p.vat_pos):
array[0][pos] = self.env.cells.vat if state_arr[i] else self.env.cells.empty # 根据state还原各个缸的状态,因为replay_buffer中存的状态对应的map与当前的可能不一样
agent_pos = tuple(state_arr[self.env.num_vats:self.env.num_vats + 2])
array[1][agent_pos] = 1 # [1]为agent位置信息
for pos in self.env.human_pos:
array[2][tuple(pos)] = 1 # [2]为human位置信息
return torch.Tensor(array).float().cuda()
def epsilon_greedy(self, net, state, epsilon):
num_actions = self.env.action_space.n
p = np.ones(num_actions) * epsilon / num_actions
state_tensor = self.state2tensor(state).unsqueeze(0)
best_action = torch.argmax(net(state_tensor)).item()
p[best_action] += 1 - epsilon
action = np.random.choice(self.env.action_space.n, p=p)
return action
def train(self, lr=1e-3, num_episodes=10000, gamma=0.99,
epsilon_start=1, epsilon_end=0.01, decay_start=0.05, decay_end=0.95,
checkpoint_interval=1000,
checkpoint_dir='./models/',
log_dir='./log'):
policy_net_opt = optim.Adam(self.policy_net.parameters(), lr=lr)
epsilons = [epsilon_end] * num_episodes
decay_start_episode = int(num_episodes * decay_start)
decay_end_episode = int(num_episodes * decay_end)
epsilons[0:decay_start_episode] = np.full(decay_start_episode, epsilon_start)
epsilons[decay_start_episode:decay_end_episode] = np.linspace(epsilon_start, epsilon_end, decay_end_episode - decay_start_episode)
tb_logger = SummaryWriter(log_dir=log_dir)
pd_timestep_logger = pd.DataFrame(columns=['loss', 'reward', 'impact', 'aup_impact', 'empathy_impact'])
pd_episode_logger = pd.DataFrame(columns=['step', 'ep_reward', 'ep_reward_mean',
'ep_impact', 'ep_aup_impact', 'ep_empathy_impact',
'num_vat_broken', 'num_human_saved'])
def optimize_net():
transitions = self.replay_buffer.sample(self.batch_size)
batch = Transition(*zip(*transitions))
state_batch = torch.stack([self.state2tensor(state[0]) for state in batch.state]).float().cuda()
action_batch = torch.tensor(batch.action).unsqueeze(-1).cuda()
reward_batch = torch.tensor(batch.reward, dtype=torch.float).cuda()
next_state_batch = torch.stack([self.state2tensor(state[0]) for state in batch.next_state]).float().cuda()
done_batch = torch.tensor(batch.done).cuda()
best_actions = self.policy_net(next_state_batch).max(1)[1].detach()
next_state_values = self.target_net(next_state_batch).gather(1, best_actions.unsqueeze(1)).squeeze(1).detach()
expected_state_action_values = reward_batch + gamma * next_state_values * torch.logical_not(done_batch)
state_action_values = self.policy_net(state_batch).gather(1, action_batch)
loss = F.mse_loss(state_action_values, expected_state_action_values.unsqueeze(1))
policy_net_opt.zero_grad()
loss.backward()
policy_net_opt.step()
tb_logger.add_scalar('loss', loss.item(), total_step)
pd_timestep_logger.loc[total_step, 'loss'] = loss.item()
def calculate_impact(prev_states, prev_action, current_states):
if self.weight_sep==0 and self.weight_emp_impact==0:
return 0, 0, 0
prev_state_agent = prev_states[0]
current_state_agent = current_states[0]
prev_states_others = prev_states[1:]
current_states_others = current_states[1:]
baseline_state_agent = self.baseline.calculate(prev_state_agent, prev_action, current_state_agent)
self.deviation.update(prev_state_agent, prev_action, current_state_agent)
dev_self = self.deviation.calculate(current_state_agent, baseline_state_agent, lambda x: abs(np.minimum(0, x)))
weighted_dev_self = -self.weight_sep * dev_self
dev_others = []
for prev_state, current_state, baseline in zip(prev_states_others, current_states_others, self.baselines_others):
baseline_state = baseline.calculate(prev_state, prev_action, current_state)
dev_others.append(self.deviation.calculate(current_state, baseline_state, lambda x: x))
dev_others_mean = sum(dev_others) / len(dev_others) if len(dev_others) > 0 else 0
weighted_dev_others = self.weight_emp_impact * dev_others_mean
total_impact = weighted_dev_self + weighted_dev_others
return total_impact, weighted_dev_self, weighted_dev_others
# 初始化replay buffer
state = self.env.reset()
for i in range(self.init_buffer_size):
# action = self.epsilon_greedy(self.empathy_net, state[0], epsilon_start)
action = self.epsilon_greedy(self.policy_net, state[0], epsilon_start)
next_state, reward, done, info = self.env.step(action)
if self.weight_sep != 0 or self.weight_emp_impact != 0: # 如果都等于0则退化为标准DQN
impact, _, _ = calculate_impact(state, action, next_state)
reward += impact
reward /= (self.weight_sep + self.weight_emp_impact) / 2 # 正则化操作,防止reward绝对值过大使得网络发散
self.replay_buffer.push(state, action, reward, next_state, done)
if done:
state = self.env.reset()
else:
state = next_state
# 开始训练
total_step = 0
for episode in range(num_episodes):
if episode % checkpoint_interval == 0 and episode != 0:
torch.save(self.policy_net.state_dict(), checkpoint_dir + f"/checkpoint_{episode}.pth")
tb_logger.add_scalar('epsilon', epsilons[episode], episode + 1)
state = self.env.reset()
if episode % 100 == 0:
print('Episode {} of {}'.format(episode + 1, num_episodes))
episode_reward = 0
episode_impact = 0
ep_aup_impact = 0
ep_empathy_impact = 0
for step in count():
if total_step % self.target_update_interval == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
action = self.epsilon_greedy(self.policy_net, state[0], epsilons[episode])
next_state, reward, done, info = self.env.step(action)
impact, aup_impact, empathy_impact = calculate_impact(state, action, next_state)
if self.weight_sep != 0 or self.weight_emp_impact != 0: # 如果都等于0则退化为标准DQN
reward += impact
reward /= (self.weight_sep + self.weight_emp_impact) / 2
episode_reward += reward
episode_impact += impact
ep_aup_impact += aup_impact
ep_empathy_impact += empathy_impact
self.replay_buffer.push(state, action, reward, next_state, done)
optimize_net()
pd_timestep_logger.loc[total_step, 'reward'] = reward
pd_timestep_logger.loc[total_step, 'impact'] = impact
pd_timestep_logger.loc[total_step, 'aup_impact'] = aup_impact
pd_timestep_logger.loc[total_step, 'empathy_impact'] = empathy_impact
if done:
# print(f'step: {step}, reward: {episode_reward:.2f}')
tb_logger.add_scalar('step', step + 1, episode + 1)
tb_logger.add_scalar('reward', episode_reward, episode + 1)
tb_logger.add_scalar('ep-reward-mean', episode_reward / (step + 1), episode + 1)
tb_logger.add_scalar('impact', episode_impact, episode + 1)
pd_episode_logger.loc[episode, 'step'] = step + 1
pd_episode_logger.loc[episode, 'ep_reward'] = episode_reward
pd_episode_logger.loc[episode, 'ep_reward_mean'] = episode_reward / (step + 1)
pd_episode_logger.loc[episode, 'ep_impact'] = episode_impact
pd_episode_logger.loc[episode, 'ep_aup_impact'] = ep_aup_impact
pd_episode_logger.loc[episode, 'ep_empathy_impact'] = ep_empathy_impact
env_map = self.env.map
num_vat_broken = 0
num_human_saved = 0
for pos in self.env.p.vat_pos:
if env_map[pos] != self.env.cells.vat:
num_vat_broken += 1
if pos in self.env.p.human_pos:
num_human_saved += 1
pd_episode_logger.loc[episode, 'num_vat_broken'] = num_vat_broken
pd_episode_logger.loc[episode, 'num_human_saved'] = num_human_saved
break
else:
state = next_state
total_step += 1
tb_logger.close()
pd_timestep_logger.to_csv(log_dir + '/timestep_logger.csv', index=False)
pd_episode_logger.to_csv(log_dir + '/episode_logger.csv', index=False)
def save(self, path):
if not os.path.exists(path):
os.makedirs(path)
torch.save(self.policy_net.state_dict(), path + "/policy_net.pth")
def load(self, path):
self.policy_net.load_state_dict(torch.load(path + "/policy_net.pth"))
def run(self, gif_name=None):
self.policy_net.eval()
obs = self.env.reset()
images = []
for step in count():
# self.env.render()
images.append(self.env.render(mode='rgb_array'))
obs_tensor = self.state2tensor(obs[0]).unsqueeze(0)
action_p = self.policy_net(obs_tensor)
action = torch.argmax(self.policy_net(obs_tensor)).item()
print(self.env.actions(action))
next_state, reward, done, _ = self.env.step(action)
time.sleep(1)
if done:
images.append(self.env.render(mode='rgb_array'))
images.append(self.env.render(mode='rgb_array'))
break
else:
obs = next_state
if gif_name is not None:
imageio.mimsave(gif_name, images, 'GIF', duration=0.5)
def set_seed(seed=114514):
random.seed(seed) # replay buffer 中使用了random.sample
np.random.seed(seed) # e-greedy 中使用了np.random_choice
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def main():
set_seed(1919810)
env = BasicVatGoalEnv()
env.render()
model = AnseEmpDQN(env, net_type='ANN', init_buffer_size=10000, replay_buffer_size=100000, batch_size=100, target_update_interval=1000)
model.train(lr=1e-3, num_episodes=10000, gamma=0.99,
epsilon_start=1, epsilon_end=0.01, decay_start=0.05, decay_end=0.95,
checkpoint_interval=1000,
checkpoint_dir='./models/ANN-BasicVatGoalEnv-test',
log_dir='./log/ANN-BasicVatGoalEnv-test')
model.save("./models/ANN-BasicVatGoalEnv-test")
model.load("./models/ANN-BasicVatGoalEnv-test")
model.run("ANN-BasicVatGoalEnv-test.gif")
if __name__ == '__main__':
main()
================================================
FILE: examples/Social_Cognition/SmashVat/environment.py
================================================
import copy
from enum import IntEnum
import numpy as np
import gymnasium as gym
import imageio
from window import Window
class HumanVatGoalEnv(gym.Env):
"""General HumanVatGoalEnv Class"""
class Actions(IntEnum):
noop = 0
left = 1
right = 2
up = 3
down = 4
smash = 5 # destroying all surrounding vat(s) at same time
pass
class Cells(IntEnum):
empty = 0
wall = 1
goal = 2
vat = 3
pass
class CellsRender(object):
# empty = np.full(shape=(64, 64, 4), fill_value=255)
wall = imageio.imread('./materials/wall.png')
goal = imageio.imread('./materials/goal.png')
vat = imageio.imread('./materials/vat.png')
agent = imageio.imread('./materials/agent.png')
human = imageio.imread('./materials/adult.png')
class Params(object):
"""Params for the Environment"""
def __init__(
self,
map_shape=(7, 5),
agent_pos=(1, 2),
human_pos=((2, 3),),
vat_pos=((3, 2), (2, 3),),
goal_pos=((-2, 2),),
wall_pos=(), # user-defined walls (default surrounding walls are NOT included here)
max_steps=50,
env_name=None,
# human_policy = "noop"
):
self.map_shape = map_shape
self.agent_pos = agent_pos
self.human_pos = human_pos
self.vat_pos = vat_pos
self.goal_pos = goal_pos
self.wall_pos = wall_pos
self.max_steps = max_steps
self.env_name = env_name
def __init__(self, env_params=Params()):
super(HumanVatGoalEnv, self).__init__()
self.p = env_params
self.actions = HumanVatGoalEnv.Actions
self.cells = HumanVatGoalEnv.Cells
self.action_space = gym.spaces.Discrete(len(self.actions))
# observation_dim: dimension of observation space
# With fixed human_pos, currently we only consider: each vat state (broken or not) + agent_pos (x, y)
# e.g. [2,2,7,5] means vat1*vat2*agent_y*agent_x
self.observation_dim = [2] * len(self.p.vat_pos) + list(self.p.map_shape)
self.observation_space = gym.spaces.Discrete(np.prod(self.observation_dim))
self.window = None
if self.p.env_name is not None:
self.descr = self.p.env_name
else:
n_human = len(self.p.human_pos)
n_vat = len(self.p.vat_pos)
n_goal = len(self.p.goal_pos)
self.descr = self.__class__.__name__.lower()
self.descr = self.descr.replace("human", "human" + str(n_human) + "-")
self.descr = self.descr.replace("vat", "vat" + str(n_vat) + "-")
self.descr = self.descr.replace("goal", "goal" + str(n_goal) + "-")
self.num_vats = len(self.p.vat_pos)
self.num_humans = len(self.p.human_pos) # for empathy qlearning
self.reset()
return
def reset(self):
# reset env state
self._gen_map()
# reset agent & human state
self.agent_pos = np.array(self.p.agent_pos)
self.human_pos = [np.array(pos) for pos in self.p.human_pos]
# reset episode statistics
self.step_count = 0
self.total_reward = 0
# self.total_hidden_reward = 0 ##TODO
# self.total_human_rewards = [0]*len(self.p.human_pos) ##TODO
# generate observation from state
obs = self._gen_obs()
return obs
def _gen_map(self):
self.map = np.full(shape=self.p.map_shape, fill_value=self.cells.empty)
# place default surrounding walls
self.map[0] = self.cells.wall
self.map[-1] = self.cells.wall
self.map[:, 0] = self.cells.wall
self.map[:, -1] = self.cells.wall
# place user-defined walls
for pos in self.p.wall_pos:
self.map[pos] = self.cells.wall
# place goals
for pos in self.p.goal_pos:
self.map[pos] = self.cells.goal
# place vats
for pos in self.p.vat_pos:
self.map[pos] = self.cells.vat
return
def _gen_obs(self):
# internal state
s_env = [(self.map[pos] == self.cells.vat) for pos in self.p.vat_pos]
s_agent = list(self.agent_pos)
s_human = [list(pos) for pos in self.human_pos]
# external observation
obs_agent = self._encode(s_env + s_agent)
obs_human = [self._encode(s_env + s_h) for s_h in s_human]
return [obs_agent, *obs_human]
def _encode(self, obs):
i = 0
for idx, dim in enumerate(self.observation_dim):
i *= dim
i += obs[idx]
assert 0 <= i <= self.observation_space.n
return i
def _decode(self, i):
out = []
for dim in reversed(self.observation_dim):
out.append(i % dim)
i = i // dim
assert i == 0
return list(reversed(out))
def render(self, mode="window", cell_size=64, style="realistic"):
if mode == "rgb_array":
return self._gen_img(cell_size, style)
elif mode == "window":
if not isinstance(self.window, Window):
self.window = Window(self.descr)
if self.window.is_open():
img = self._gen_img(cell_size, style)
self.window.show_img(img)
self.window.show(block=False)
return
def _gen_img(self, cell_size, style):
if style == "abstract":
h, w = self.map.shape
img = np.full(shape=(h * cell_size, w * cell_size, 3), fill_value=255)
def draw_cell(cell_type, cell_pos, cell_size):
if cell_type == self.cells.empty:
pass
elif cell_type == self.cells.wall:
x, y = np.array(cell_pos) * cell_size
img[x: (x + cell_size), y: (y + cell_size), :] = np.array(
[128, 128, 128]
)
elif cell_type == self.cells.goal:
x, y = np.array(cell_pos) * cell_size
img[x: (x + cell_size), y: (y + cell_size), :] = np.array([0, 255, 0])
elif cell_type == self.cells.vat:
x, y = np.array(cell_pos) * cell_size
img[x: (x + cell_size), y: (y + cell_size), :] = np.array([255, 0, 0])
else:
pass
def draw_agent(pos, cell_size):
# draw rectangle
x, y = np.array(pos) * cell_size
img[
int(x + 0.2 * cell_size): int(x + 0.8 * cell_size + 1),
int(y + 0.2 * cell_size): int(y + 0.8 * cell_size + 1),
:,
] = np.array([0, 0, 0])
# # draw cicle
# def fill_circle(img, cx, cy, r, color):
# h, w = img.shape[0:2]
# X, Y = np.ogrid[:h, :w]
# mask = (X-cx)**2+(Y-cy)**2 <= r**2
# img[mask] = color
# # return img
# x0, y0 = np.array(pos) * cell_size
# sub_img = img[int(x0):int(x0+cell_size), int(y0):int(y0+cell_size), :]
# cx, cy, r = np.array([0.5, 0.5, 0.3]) * cell_size
# color = np.array([0,0,0])
# fill_circle(sub_img, cx, cy, r, color)
pass
def draw_human(pos, cell_size):
# # draw rectangle
# x, y = np.array(pos) * cell_size
# img[int(x+0.2*cell_size):int(x+0.8*cell_size+1),
# int(y+0.2*cell_size):int(y+0.8*cell_size+1),:] = np.array([255,255,0])
# draw cicle
def fill_circle(img, cx, cy, r, color):
h, w = img.shape[0:2]
X, Y = np.ogrid[:h, :w]
mask = (X - cx) ** 2 + (Y - cy) ** 2 <= r ** 2
img[mask] = color
# return img
x0, y0 = np.array(pos) * cell_size
sub_img = img[
int(x0): int(x0 + cell_size), int(y0): int(y0 + cell_size), :
]
cx, cy, r = np.array([0.5, 0.5, 0.36]) * cell_size
color = np.array([255, 255, 0])
fill_circle(sub_img, cx, cy, r, color)
pass
def draw_gridline(cell_size):
img[::cell_size, :] = np.array([255, 255, 255])
img[-1::-cell_size, :] = np.array([255, 255, 255])
img[:, ::cell_size] = np.array([255, 255, 255])
img[:, -1::-cell_size] = np.array([255, 255, 255])
pass
for i in range(h):
for j in range(w):
draw_cell(self.map[i, j], (i, j), cell_size)
for pos in self.human_pos:
draw_human(list(pos), cell_size)
draw_agent(self.agent_pos, cell_size)
draw_gridline(cell_size)
return img.astype(np.uint8)
elif style == "realistic":
cell_size = 64 ##In realistic mode, we fix cell_size to 64 to avoid resize of image
h, w = self.map.shape
img = np.full(shape=(h * cell_size, w * cell_size, 4), fill_value=255)
def draw_cell_realistic(cell_type, cell_pos, cell_size):
if cell_type == self.cells.empty:
# img_paste(cell_pos, cell_size, HumanVatGoalEnv.CellsRender.empty)
pass
elif cell_type == self.cells.wall:
img_paste(cell_pos, cell_size, HumanVatGoalEnv.CellsRender.wall)
elif cell_type == self.cells.goal:
img_paste(cell_pos, cell_size, HumanVatGoalEnv.CellsRender.goal)
elif cell_type == self.cells.vat:
img_paste(cell_pos, cell_size, HumanVatGoalEnv.CellsRender.vat)
else:
pass
def draw_agent_realistic(pos, cell_size):
img_paste(pos, cell_size, HumanVatGoalEnv.CellsRender.agent)
pass
def draw_human_realistic(pos, cell_size):
img_paste(pos, cell_size, HumanVatGoalEnv.CellsRender.human)
pass
##paste a png(RGBA) image on to existing img depending on the alpha channel of img_in
def img_paste(pos, cell_size, img_in):
x, y = np.array(pos) * cell_size
img[x:(x + cell_size), y:(y + cell_size), 0:3][img_in[:, :, 3] > 128] = img_in[:, :, 0:3][img_in[:, :, 3] > 128]
def draw_gridline_realistic(cell_size):
img[::cell_size, :] = np.array([255, 255, 255, 0])
img[-1::-cell_size, :] = np.array([255, 255, 255, 0])
img[:, ::cell_size] = np.array([255, 255, 255, 0])
img[:, -1::-cell_size] = np.array([255, 255, 255, 0])
pass
for i in range(h):
for j in range(w):
draw_cell_realistic(self.map[i, j], (i, j), cell_size)
for pos in self.human_pos:
draw_human_realistic(list(pos), cell_size)
draw_agent_realistic(self.agent_pos, cell_size)
draw_gridline_realistic(cell_size)
return img.astype(np.uint8)
def close(self):
if isinstance(self.window, Window):
self.window.close()
self.window = None
pass
def step(self, action):
def apply_env_dynamics(cur_pos, action):
assert type(cur_pos) == np.ndarray
reward = -0.01
cur_cell = self.map[tuple(cur_pos)]
if cur_cell == self.cells.vat: # got trapped in vat
next_pos = cur_pos
# reward += -0.05 ##TODO: should we give explicit penalty here?
else:
next_pos = copy.deepcopy(cur_pos)
if action == self.actions.left:
next_pos += [0, -1]
elif action == self.actions.right:
next_pos += [0, +1]
elif action == self.actions.up:
next_pos += [-1, 0]
elif action == self.actions.down:
next_pos += [+1, 0]
elif action == self.actions.noop:
next_pos = next_pos
elif action == self.actions.smash:
next_pos = next_pos
smash_pos = next_pos + [[0, -1], [0, +1], [-1, 0], [+1, 0]]
for sp in smash_pos:
smash_cell = self.map[tuple(sp)]
if smash_cell == self.cells.vat:
# agent will smash all surrounding vat(s)
self.map[tuple(sp)] = self.cells.empty
# self.total_hidden_reward -= 0.2 ##TODO
pass
pass
pass
else:
raise
next_cell = self.map[tuple(next_pos)]
if next_cell == self.cells.empty:
next_pos = next_pos
elif next_cell == self.cells.wall:
next_pos = cur_pos
elif next_cell == self.cells.goal:
next_pos = next_pos
reward += 1.0 ##TODO: differnet rewards for reaching different goals
elif next_cell == self.cells.vat:
next_pos = next_pos
else:
raise
return next_pos, reward
self.agent_pos, reward = apply_env_dynamics(self.agent_pos, action)
##TODO: add more human dynamics here
for i, _ in enumerate(self.human_pos):
self.human_pos[i], _ = apply_env_dynamics(
self.human_pos[i], self.actions.noop
) ##TODO: add different human policy
##TODO: human reward may be different with that of agent
self.step_count += 1
done = (self.step_count >= self.p.max_steps) or (
self.map[tuple(self.agent_pos)] == self.cells.goal
)
obs = self._gen_obs()
self.total_reward += reward
info = {"total_reward": round(self.total_reward, 2)}
return obs, reward, done, info
class BasicGoalEnv(HumanVatGoalEnv):
def __init__(self):
super().__init__(
env_params=HumanVatGoalEnv.Params(
map_shape=(7, 5),
agent_pos=(1, 2),
human_pos=(),
vat_pos=(),
goal_pos=((-2, 2),),
wall_pos=(), # user-defined walls (default surrounding walls are NOT included here)
max_steps=50,
env_name="basic-1goal-env",
)
)
class BasicVatGoalEnv(HumanVatGoalEnv):
def __init__(self):
super().__init__(
env_params=HumanVatGoalEnv.Params(
map_shape=(7, 5),
agent_pos=(1, 2),
human_pos=(),
vat_pos=((3, 2),),
goal_pos=((-2, 2),),
wall_pos=(), # user-defined walls (default surrounding walls are NOT included here)
max_steps=50,
env_name="basic-1vat-1goal-env",
)
)
class BasicHumanVatGoalEnv(HumanVatGoalEnv):
def __init__(self):
super().__init__(
env_params=HumanVatGoalEnv.Params(
map_shape=(7, 5),
agent_pos=(1, 2),
human_pos=((3, 2),),
vat_pos=((3, 2),),
goal_pos=((-2, 2),),
wall_pos=(), # user-defined walls (default surrounding walls are NOT included here)
max_steps=50,
env_name="basic-1human-1vat-1goal-env",
)
)
class CShapeVatGoalEnv(HumanVatGoalEnv):
def __init__(self):
super().__init__(
env_params=HumanVatGoalEnv.Params(
map_shape=(7, 5),
agent_pos=(1, 3),
human_pos=(),
vat_pos=((3, 2), (3, 3)),
goal_pos=((-2, 3),),
wall_pos=(), # user-defined walls (default surrounding walls are NOT included here)
max_steps=50,
env_name="C-shape-2vat-1goal-env",
)
)
class CShapeHumanVatGoalEnv(HumanVatGoalEnv):
def __init__(self):
super().__init__(
env_params=HumanVatGoalEnv.Params(
map_shape=(7, 5),
agent_pos=(1, 3),
human_pos=((3, 2),),
vat_pos=((3, 2), (3, 3)),
goal_pos=((-2, 3),),
wall_pos=(), # user-defined walls (default surrounding walls are NOT included here)
max_steps=50,
env_name="C-shape-1human-2vat-1goal-env",
)
)
class SShapeVatGoalEnv(HumanVatGoalEnv):
def __init__(self):
super().__init__(
env_params=HumanVatGoalEnv.Params(
map_shape=(10, 7),
agent_pos=(1, 1),
human_pos=(),
vat_pos=((3, 1), (3, 2), (3, 3), (6, 3), (6, 4), (6, 5)),
goal_pos=((-2, -2),),
wall_pos=(),
max_steps=100,
env_name="S-shape-6vat-1goal-env",
)
)
class SideHumanVatGoalEnv(HumanVatGoalEnv):
def __init__(self):
super().__init__(
env_params=HumanVatGoalEnv.Params(
map_shape=(7, 5),
agent_pos=(1, 1),
human_pos=((3, 3),),
vat_pos=((3, 3),),
goal_pos=((5, 1),),
wall_pos=(), # user-defined walls (default surrounding walls are NOT included here)
max_steps=50,
env_name="side-1human-1vat-1goal-env",
)
)
class SmashAndDetourEnv(HumanVatGoalEnv):
def __init__(self):
super().__init__(
env_params=HumanVatGoalEnv.Params(
map_shape=(7, 5),
agent_pos=(1, 1),
human_pos=((2, 3),),
vat_pos=((2, 3), (3, 2), (3, 3),),
goal_pos=((5, 3),),
wall_pos=(), # user-defined walls (default surrounding walls are NOT included here)
max_steps=50,
env_name="side-1human-1vat-1goal-env",
)
)
class CmpxHumanVatGoalEnv(HumanVatGoalEnv):
def __init__(self):
super().__init__(
env_params=HumanVatGoalEnv.Params(
map_shape=(10, 7),
agent_pos=(1, 3),
human_pos=((4, 2), (5, 5), (7, 2)),
vat_pos=((2, 3), (3, 1), (4, 1), (4, 2), (5, 5), (6, 4), (6, 5)),
goal_pos=((-2, -2),),
wall_pos=((1, 1), (8, 1), (8, 2)),
max_steps=100,
env_name="complex-3human-7vat-1goal-env",
)
)
env_list = ['BasicGoalEnv',
'BasicVatGoalEnv', 'BasicHumanVatGoalEnv', 'SideHumanVatGoalEnv',
'CShapeVatGoalEnv', 'CShapeHumanVatGoalEnv',
'SShapeVatGoalEnv', 'SmashAndDetourEnv',
'CmpxHumanVatGoalEnv']
if __name__ == "__main__":
import time
params = HumanVatGoalEnv.Params()
params.map_shape = (9, 7)
params.agent_pos = (1, 4)
params.human_pos = ((3, 2), (2, 3), (2, 1))
params.vat_pos = ((3, 2), (2, 3), (5, 2))
params.goal_pos = ((-2, 2), (6, 4))
params.wall_pos = ((5, 5), (4, 5), (4, 4))
params.max_steps = 30
params.env_name = "example-env"
# params.human_policy = "noop"
env = HumanVatGoalEnv(env_params=params)
# env = BasicVatGoalEnv()
print("observation_dim: ", env.observation_dim)
print("action_space: ", env.action_space)
print("observation_space: ", env.observation_space)
print("descr: ", env.descr)
# env.render()
# time.sleep(8.0)
# env.close()
for i_episode in range(5):
obs = env.reset()
for t in range(100):
env.render(mode="window")
time.sleep(0.1)
action = env.action_space.sample()
print(
"step=%2d\t" % (env.step_count),
env._decode(obs[0]),
"->",
env.actions(action),
end="",
)
obs, reward, done, info = env.step(action)
print("\treward=%.2f" % (reward))
if done:
env.render(mode="window")
print("done!")
print(info)
print("-" * 20)
time.sleep(0.2)
break
env.close()
================================================
FILE: examples/Social_Cognition/SmashVat/main.py
================================================
import argparse
import os
import random
import time
import numpy as np
import torch
from environment import *
from dqn import AnseEmpDQN
parser = argparse.ArgumentParser()
parser.add_argument('--cuda', type=int, default=0)
parser.add_argument('--seed', type=int, default=1919810)
parser.add_argument('--env', type=str, default='BasicHumanVatGoalEnv')
parser.add_argument('--net-type', type=str, default='ANN')
parser.add_argument('--init-buffer-size', type=int, default=10000)
parser.add_argument('--replay-buffer-size', type=int, default=100000)
parser.add_argument('--batch-size', type=int, default=100)
parser.add_argument('--target-update-interval', type=int, default=1000)
parser.add_argument('--weight-sep', type=float, default=20)
parser.add_argument('--weight-emp-impact', type=float, default=None)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--num-episodes', type=int, default=10000)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epsilon-start', type=float, default=1.0)
parser.add_argument('--epsilon-end', type=float, default=0.01)
parser.add_argument('--decay-start', type=float, default=0.05)
parser.add_argument('--decay-end', type=float, default=0.95)
parser.add_argument('--checkpoint-interval', type=int, default=1000)
parser.add_argument('--log-dir', type=str, default=None)
parser.add_argument('--model-save-dir', type=str, default=None)
parser.add_argument('--gif-dir', type=str, default=None)
def set_seed(seed=114514):
random.seed(seed) # replay buffer 中使用了random.sample
np.random.seed(seed) # e-greedy 中使用了np.random_choice
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def save_args(args, log_dir):
filename = os.path.join(log_dir, 'args.txt')
with open(filename, 'w') as file:
for arg in vars(args):
file.write('{}: {}\n'.format(arg, getattr(args, arg)))
def make_dirs(args, timestamp):
if args.log_dir is not None:
log_dir = args.log_dir
else:
log_dir = os.path.join('./logs', args.net_type + '-' + args.env, timestamp)
os.makedirs(log_dir, exist_ok=True)
if args.model_save_dir is not None:
model_save_dir = args.model_save_dir
else:
model_save_dir = os.path.join('./models', args.net_type + '-' + args.env, timestamp)
os.makedirs(model_save_dir, exist_ok=True)
if args.gif_dir is not None:
gif_dir = args.gif_dir
else:
gif_dir = os.path.join('./gifs', args.net_type + '-' + args.env)
os.makedirs(gif_dir, exist_ok=True)
return log_dir, model_save_dir, gif_dir
def main():
args = parser.parse_args()
set_seed(args.seed)
timestamp = time.strftime("%Y%m%d-%H%M%S")
log_dir, model_save_dir, gif_dir = make_dirs(args, timestamp)
save_args(args, log_dir)
assert args.env in env_list
env = eval(args.env)()
with torch.cuda.device(args.cuda):
model = AnseEmpDQN(env,
net_type=args.net_type,
init_buffer_size=args.init_buffer_size,
replay_buffer_size=args.replay_buffer_size,
batch_size=args.batch_size,
target_update_interval=args.target_update_interval,
weight_sep=args.weight_sep,
weight_emp_impact=args.weight_emp_impact)
model.train(lr=args.lr, num_episodes=args.num_episodes, gamma=args.gamma,
epsilon_start=args.epsilon_start, epsilon_end=args.epsilon_end,
decay_start=args.decay_start, decay_end=args.decay_end,
checkpoint_interval=args.checkpoint_interval, checkpoint_dir=model_save_dir,
log_dir=log_dir)
model.save(model_save_dir)
gif_name = os.path.join(gif_dir, '{}.gif'.format(timestamp))
model.run(gif_name=gif_name)
model.env.close()
if __name__ == '__main__':
main()
================================================
FILE: examples/Social_Cognition/SmashVat/manual_control.py
================================================
import time
from window import Window
from environment import *
class ManualControl(object):
"""ManualControl of HumanVatGoalEnv Class"""
def __init__(self, env=HumanVatGoalEnv()):
self.env = env
self.window = Window(env.descr + "[manual]")
self.window.reg_key_press_handler(self._key_handler)
def display(self):
self._reset()
# Blocking event loop
self.window.show(block=True)
return
def _redraw(self):
img = self.env.render("rgb_array", cell_size=64)
self.window.show_img(img)
return
def _reset(self):
self.env.reset()
print("-" * 20)
print(
"step=%2d " % (self.env.step_count),
"obs=",
[self.env._decode(o) for o in self.env._gen_obs()],
end=" -> ",
flush=True,
)
self._redraw()
return
def _step(self, action):
obs, reward, done, info = self.env.step(action)
print(self.env.actions(action), "\treward=%.2f" % (reward))
print(
"step=%2d " % (self.env.step_count),
"obs=",
[self.env._decode(o) for o in self.env._gen_obs()],
end=" -> ",
flush=True,
)
self._redraw()
if done:
print("done!")
print(info)
time.sleep(0.2)
self._reset()
return
def _key_handler(self, event):
# print('pressed', event.key)
if event.key == "escape" or event.key == "q":
self.window.close()
elif event.key == "backspace":
self._reset()
elif event.key == "left":
self._step(self.env.actions.left)
elif event.key == "right":
self._step(self.env.actions.right)
elif event.key == "up":
self._step(self.env.actions.up)
elif event.key == "down":
self._step(self.env.actions.down)
elif event.key == " ": # Spacebar
self._step(self.env.actions.noop)
elif event.key == "enter": # Smash
self._step(self.env.actions.smash)
return
if __name__ == "__main__":
mc = ManualControl(env=HumanVatGoalEnv())
mc.display()
================================================
FILE: examples/Social_Cognition/SmashVat/qnets.py
================================================
import torch
from torch import nn
from braincog.base.encoder import encoder
from braincog.base.node import LIFNode
class CNNQnet(nn.Module):
def __init__(self, input_dim, output_dim):
super(CNNQnet, self).__init__()
self.cnn = nn.Sequential(
nn.Conv2d(in_channels=input_dim, out_channels=16, kernel_size=3, padding=1, padding_mode='replicate'),
nn.ReLU(),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3),
nn.ReLU(),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten()
)
self.l = nn.Sequential(
nn.Linear(in_features=64, out_features=128),
nn.ReLU(),
nn.Linear(in_features=128, out_features=output_dim)
)
def forward(self, x):
x = self.cnn(x)
x = self.l(x)
return x
class SNNQnet(nn.Module):
def __init__(self, input_dim, output_dim,
step=4, node=LIFNode, encode_type='direct'):
super(SNNQnet, self).__init__()
self.step = step
self.encoder = encoder.Encoder(step=step, encode_type=encode_type)
self.cnn = nn.Sequential(
nn.Conv2d(in_channels=input_dim, out_channels=16, kernel_size=3, padding=1, padding_mode='replicate'),
node(),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3),
node(),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
node(),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten()
)
self.l = nn.Sequential(
nn.Linear(in_features=64, out_features=128),
node(),
nn.Linear(in_features=128, out_features=output_dim)
)
def forward(self, input):
inputs = self.encoder(input)
outputs = []
self.reset()
for t in range(self.step):
x = inputs[t]
x = self.cnn(x)
x = self.l(x)
outputs.append(x)
return sum(outputs) / len(outputs)
def reset(self):
for mod in self.modules():
if hasattr(mod, 'n_reset'):
mod.n_reset()
================================================
FILE: examples/Social_Cognition/SmashVat/side_effect_eval.py
================================================
# Code Reference:
# https://github.com/deepmind/deepmind-research/blob/master/side_effects_penalties/side_effects_penalty.py
# https://github.com/alexander-turner/attainable-utility-preservation/blob/master/agents/model_free_aup.py
import numpy as np
from collections import defaultdict
class StepwiseInactionModel(object):
"""Calculate the next state after one noop action from current state"""
def __init__(self, noop_action=None):
self._noop_action = noop_action
self._baseline_state = None
self._inaction_model = defaultdict(lambda: defaultdict(lambda: 0)) # init _inaction_model[state][next_state]=0
return
def reset(self, baseline_state):
self._baseline_state = baseline_state
return
def _sample(self, state):
"""Sample next_state based on its history frequency"""
d = self._inaction_model[state]
counts = np.array(list(d.values()))
assert len(counts) > 0 and sum(counts) > 0
index = np.random.choice(a=len(counts), p=counts / sum(counts))
return list(d.keys())[index]
def calculate(self, prev_state, prev_action, current_state):
"""Update inaction transition model, and predict the noop baseline state """
# update
if prev_action == self._noop_action:
self._inaction_model[prev_state][current_state] += 1
# predict
if prev_state in self._inaction_model:
self._baseline_state = self._sample(prev_state)
else:
self._baseline_state = prev_state
return self._baseline_state
class AttainableUtilityMeasure(object):
def __init__(self, uf_num=10, uf_discount=0.99):
# initialize a group of auxiliary utility functions
self._uf_values = [defaultdict(lambda: 0.0) for _ in range(uf_num)]
# initialize random rewards for auxiliary tasks
self._uf_rewards = [defaultdict(lambda: np.random.random()) for _ in range(uf_num)]
assert 0 <= uf_discount < 1.0, "uf_discount should be between [0, 1)"
self._uf_discount = uf_discount
# initialize update counts and confidence for the value estimation of each state
self._uf_update_cnts = [defaultdict(lambda: 0) for _ in range(uf_num)]
self._confid_func = lambda x: 1.0 if x > 0 else 0.0 # confident if state value has been updated
# record predecessors of states for backward value iteration
self._predecessors = defaultdict(set)
return
def update(self, prev_state, prev_action, current_state):
"""Update estimations of Auxiliary Utility Functions with new transitions"""
del prev_action # unused in value iteration
# update transitions
self._predecessors[current_state].add(prev_state)
# iterative update values
for reward, u_value, update_cnt in zip(self._uf_rewards, self._uf_values, self._uf_update_cnts):
seen = set()
queue = [current_state]
while queue:
s_to = queue.pop(0)
seen.add(s_to)
for s_from in self._predecessors[s_to]:
v = reward[s_from] + self._uf_discount * u_value[s_to]
if u_value[s_from] < v:
u_value[s_from] = v
if s_from not in seen:
queue.append(s_from)
update_cnt[s_from] += 1 # update counts for the value estimation of each state
return
def calculate(self, current_state, baseline_state, dev_fun=lambda diff: abs(np.minimum(0, diff))):
"""Calculate the deviation between two states, with given deviation_function"""
cs_values = [u_value[current_state] for u_value in self._uf_values]
bs_values = [u_value[baseline_state] for u_value in self._uf_values]
diff_values = [(cs_value - bs_value) for cs_value, bs_value in zip(cs_values, bs_values)]
cs_confids = [self._confid_func(update_cnt[current_state]) for update_cnt in self._uf_update_cnts]
bs_confids = [self._confid_func(update_cnt[baseline_state]) for update_cnt in self._uf_update_cnts]
diff_confids = [(cs_confid * bs_confid) for cs_confid, bs_confid in zip(cs_confids, bs_confids)]
deviations = [diff_confid * dev_fun(diff_value) * (1. - self._uf_discount)
for diff_confid, diff_value in zip(diff_confids, diff_values)]
return sum(deviations) / len(deviations)
def _get_aup_value(self, state):
"""For debugging purpose,
The Attainable Utility Preservation (aup) value are based on the estimation
towards an imaginary baseline_state of u_value=0.0 and confidence=1.0"""
dev_fun = lambda diff: abs(diff)
cs_values = [u_value[state] for u_value in self._uf_values]
bs_values = [0.0] * len(self._uf_values)
diff_values = [(cs_value - bs_value) for cs_value, bs_value in zip(cs_values, bs_values)]
cs_confids = [self._confid_func(update_cnt[state]) for update_cnt in self._uf_update_cnts]
bs_confids = [1.0] * len(self._uf_update_cnts)
diff_confids = [(cs_confid * bs_confid) for cs_confid, bs_confid in zip(cs_confids, bs_confids)]
deviations = [diff_confid * dev_fun(diff_value) * (1. - self._uf_discount)
for diff_confid, diff_value in zip(diff_confids, diff_values)]
return sum(deviations) / len(deviations)
def _get_avgd_confid(self, state):
"""For debugging purpose"""
s_confids = [self._confid_func(update_cnt[state]) for update_cnt in self._uf_update_cnts]
return sum(s_confids) / len(s_confids)
def _get_u_values(self, state):
"""For debugging purpose"""
return [u_value[state] for u_value in self._uf_values]
================================================
FILE: examples/Social_Cognition/SmashVat/window.py
================================================
# Code modified from:
# https://github.com/maximecb/gym-minigrid/blob/master/gym_minigrid/window.py
import sys
import numpy as np
import matplotlib.pyplot as plt
class Window(object):
"""Interactive Window for Image Display, using matplotlib"""
def __init__(self, title):
self.fig, self.ax = plt.subplots()
self.ax.axis("off") # clear x-axis and y-axis
self.title = title
self.set_window_title(self.title)
self.key_press_handler = self._default_key_press_handler
self.reg_key_press_handler(self.key_press_handler)
self.img_shown = None
return
def set_window_title(self, title):
self.title = title
# https://stackoverflow.com/questions/5812960/change-figure-window-title-in-pylab
self.fig.canvas.manager.set_window_title(self.title)
return
def reg_key_press_handler(self, key_press_handler):
self.key_press_handler = key_press_handler
self.fig.canvas.mpl_connect("key_press_event", self.key_press_handler)
return
def _default_key_press_handler(self, event):
print("press", event.key)
sys.stdout.flush()
if event.key == "escape":
self.close()
def show(self, block=True):
if not self.is_open():
# https://stackoverflow.com/questions/31729948/matplotlib-how-to-show-a-figure-that-has-been-closed
# if window has been closed by plt.close()
# create a dummy figure and use its manager to display "fig"
dummy = plt.figure()
new_manager = dummy.canvas.manager
new_manager.canvas.figure = self.fig
self.fig.set_canvas(new_manager.canvas)
self.set_window_title(self.title)
self.reg_key_press_handler(self.key_press_handler)
if not block:
plt.ion()
else:
plt.ioff()
plt.show()
# https://stackoverflow.com/questions/28269157/plotting-in-a-non-blocking-way-with-matplotlib
# https://stackoverflow.com/questions/53758472/why-is-plt-pause-not-described-in-any-tutorials-if-it-is-so-essential-or-am-i
plt.pause(0.001)
return
def show_img(self, img):
if self.img_shown == None:
self.img_shown = self.ax.imshow(img)
else:
self.img_shown.set_data(img)
self.fig.canvas.draw()
# https://stackoverflow.com/questions/28269157/plotting-in-a-non-blocking-way-with-matplotlib
# https://stackoverflow.com/questions/53758472/why-is-plt-pause-not-described-in-any-tutorials-if-it-is-so-essential-or-am-i
plt.pause(0.001)
return
def close(self):
plt.close(self.fig)
return
def is_open(self):
# https://stackoverflow.com/questions/7557098/matplotlib-interactive-mode-determine-if-figure-window-is-still-displayed
return bool(plt.get_fignums())
if __name__ == "__main__":
window = Window("TestWindow")
def on_press(event):
print("press", event.key)
sys.stdout.flush()
if event.key == "escape":
window.close()
elif event.key == "x":
img = np.full(shape=(7 * 32, 5 * 32, 3), fill_value=55).astype(np.uint8)
window.show_img(img)
elif event.key == "c":
img = np.full(shape=(7 * 32, 5 * 32, 3), fill_value=155).astype(np.uint8)
window.show_img(img)
window.reg_key_press_handler(on_press)
print(window.is_open()) # True
window.show(block=True)
print(window.is_open()) # False
window.show(block=False)
print(window.is_open()) # True
plt.pause(2.0)
window.close()
print(window.is_open()) # False
window.show(block=True)
================================================
FILE: examples/Social_Cognition/ToCM/README.md
================================================
# ToCM
This code accompanies the paper "A Brain-inspired Theory of Collective Mind Model for Efficient Social Cooperation".
================================================
FILE: examples/Social_Cognition/ToCM/agent/controllers/ToCMController.py
================================================
from collections import defaultdict
from copy import deepcopy
import numpy as np
import torch
from torch.distributions import OneHotCategorical
from environments import Env
from agent.models.ToCMModel import ToCMModel
from networks.ToCM.action import Actor, AttentionActor
class ToCMController:
def __init__(self, config):
self.model = ToCMModel(config).to(config.DEVICE).eval()
# 17 7 256 2
# TODO TODO TODO!!!!
self.env_type = config.ENV_TYPE
self.actor = Actor(config.IN_DIM+2*(config.num_agents-1), config.ACTION_SIZE, config.ACTION_HIDDEN, config.ACTION_LAYERS).to(config.DEVICE) # TODO FEAT
self.expl_decay = config.EXPL_DECAY
self.expl_noise = config.EXPL_NOISE
self.expl_min = config.EXPL_MIN
self.init_rnns()
self.init_buffer()
self.device = config.DEVICE
self.config = config
def receive_params(self, params):
self.model.load_state_dict(params['model'])
self.actor.load_state_dict(params['actor'])
def init_buffer(self):
self.buffer = defaultdict(list)
def init_rnns(self):
self.prev_rnn_state = None
self.prev_actions = None
def dispatch_buffer(self):
total_buffer = {k: np.asarray(v, dtype=np.float32) for k, v in self.buffer.items()}
last = np.zeros_like(total_buffer['done'])
last[-1] = 1.0
total_buffer['last'] = last
self.init_rnns()
self.init_buffer()
return total_buffer
def update_buffer(self, items):
for k, v in items.items(): # TODO TODO TODO
if v is not None:
self.buffer[k].append(v.squeeze(0).cpu().detach().clone().numpy())
@torch.no_grad()
def step(self, observations, avail_actions, nn_mask):
""""
Compute policy's action distribution from inputs, and sample an
action. Calls the model to produce mean, log_std, value estimate, and
next recurrent state. Moves inputs to device and returns outputs back
to CPU, for the sampler. Advances the recurrent state of the agent.
(no grad)
"""
state = self.model(observations, self.prev_actions, self.prev_rnn_state, nn_mask)
if self.prev_actions == None:
# self.prev_actions = torch.zeros((1, 2, 7)).to(self.config.DEVICE)
self.prev_actions = torch.zeros((observations.shape[0], observations.shape[1], 5)).to(self.config.DEVICE)
next_state = self.model.transition(self.prev_actions, state) # TODO
next_feat = next_state.get_features().detach() # TODO
observations_next_other, _ = self.model.observation_decoder(next_feat) # TODO
if nn_mask is not None:
nn_mask = nn_mask.to(self.device)
action, pi = self.actor(torch.cat((observations, observations_next_other[:, :, -(self.config.num_agents-1)*4:-(self.config.num_agents-1)*2]),
-1))
# print(action, pi)
# print("aviail_action:", avail_actions)
if avail_actions is not None:
pi[avail_actions == 0] = -1e10
action_dist = OneHotCategorical(logits=pi)
action = action_dist.sample()
self.advance_rnns(state)
self.prev_actions = action.clone() # no use
return action.squeeze(0).clone().to(self.device)
def advance_rnns(self, state):
self.prev_rnn_state = deepcopy(state)
def exploration(self, action):
"""
:param action: action to take, shape (1,)
:return: action of the same shape passed in, augmented with some noise
"""
for i in range(action.shape[0]):
if np.random.uniform(0, 1) < self.expl_noise:
index = torch.randint(0, action.shape[-1], (1, ), device=action.device)
transformed = torch.zeros(action.shape[-1])
transformed[index] = 1.
action[i] = transformed
self.expl_noise *= self.expl_decay
self.expl_noise = max(self.expl_noise, self.expl_min)
return action.to(self.device)
================================================
FILE: examples/Social_Cognition/ToCM/agent/learners/ToCMLearner.py
================================================
import sys
from copy import deepcopy
from pathlib import Path
import numpy as np
import torch
from agent.memory.ToCMMemory import ToCMMemory
from agent.models.ToCMModel import ToCMModel
from agent.optim.loss import model_loss, actor_loss, value_loss, actor_rollout
from agent.optim.utils import advantage
from environments import Env
from networks.ToCM.action import Actor, AttentionActor
from networks.ToCM.critic import MADDPGCritic
torch.autograd.set_detect_anomaly = True
def orthogonal_init(tensor, gain=1):
if tensor.ndimension() < 2:
raise ValueError("Only tensors with 2 or more dimensions are supported")
rows = tensor.size(0)
cols = tensor[0].numel()
flattened = tensor.new(rows, cols).normal_(0, 1)
if rows < cols:
flattened.t_()
# Compute the qr factorization
u, s, v = torch.svd(flattened, some=True)
if rows < cols:
u.t_()
q = u if tuple(u.shape) == (rows, cols) else v
with torch.no_grad():
tensor.view_as(q).copy_(q)
tensor.mul_(gain)
return tensor
def initialize_weights(mod, scale=1.0, mode='ortho'):
for p in mod.parameters():
if mode == 'ortho':
if len(p.data.shape) >= 2:
orthogonal_init(p.data, gain=scale)
elif mode == 'xavier':
if len(p.data.shape) >= 2:
torch.nn.init.xavier_uniform_(p.data)
class ToCMLearner: # 通过ToCMLearnerConfig来构建
def __init__(self, config):
self.config = config
self.pretrain_model = False
self.shared_model = False # shared pretrain_model
self.pretrain_actor = False
self.pretrain_critic = False
# 根据ToCMLearnerConfig的参数包括:DEVICE, CAPACITY, SEQ_LENGTH, ACTION_SIZE, IN_DIM, FEAT, HIDDEN......
self.model = ToCMModel(config).to(config.DEVICE).eval() # wsw TODO 这里已经有了device,为什么挂钩子
# ToCM Model
self.actor = Actor(config.IN_DIM+2*(config.num_agents-1), config.ACTION_SIZE, config.ACTION_HIDDEN, config.ACTION_LAYERS).to(
config.DEVICE) # IN_DIM / FEAT # TODO
self.critic = MADDPGCritic(config.FEAT, config.HIDDEN).to(config.DEVICE)
# 关键点是把model actor critic都放到了device上
if self.pretrain_model:
if not self.shared_model:
self.model.load_state_dict(torch.load(self.load_dir + '28_model.pth'))
else:
initialize_weights(self.model, mode='xavier') # 先全部初始化
# 加载部分预训练权重
shared_state_dict = torch.load(self.load_dir + '28_model.pth')
ignored_layer_keys = ['observation_encoder.fc1.weight', 'observation_decoder.fc2.weight',
'observation_decoder.fc2.bias', 'transition._rnn_input_model.0.weight',
'representation._transition_model._rnn_input_model.0.weight',
'av_action.model.4.weight', 'av_action.model.4.bias', 'q_action.weight',
'q_action.bias']
for k in ignored_layer_keys:
del shared_state_dict[k]
self.model.load_state_dict(shared_state_dict, strict=False)
print("Load ToCM Model.")
else:
initialize_weights(self.model, mode='xavier')
if self.pretrain_actor:
self.actor.load_state_dict(torch.load(self.load_dir + '10_actor.pth'), strict=False)
else:
initialize_weights(self.actor, mode='xavier')
if self.pretrain_critic:
self.critic.load_state_dict(torch.load(self.load_dir + '10_critic.pth'), strict=False)
else:
initialize_weights(self.critic, mode='xavier')
self.old_critic = deepcopy(self.critic)
self.replay_buffer = ToCMMemory(config.CAPACITY, config.SEQ_LENGTH, config.ACTION_SIZE, config.IN_DIM, 2,
config.DEVICE, config.ENV_TYPE)
self.entropy = config.ENTROPY
self.step_count = -1
self.cur_update = 1
self.accum_samples = 0
self.total_samples = 0
self.init_optimizers()
self.n_agents = 2
Path(config.LOG_FOLDER).mkdir(parents=True, exist_ok=True)
global wandb
import wandb
wandb.init(dir=config.LOG_FOLDER,
name=str(config.env_name) + '_' + str(2) +
"_seed" + str(config.random_seed) + '131',
project=str('mpesnn').upper(),
group=str(config.env_name) ) # TODO
def init_optimizers(self):
self.model_optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.MODEL_LR)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.config.ACTOR_LR, weight_decay=0.00001)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.config.VALUE_LR) # TODO
self.critic_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.critic_optimizer, mode='min', verbose=True)
def params(self):
return {'model': {k: v.cpu() for k, v in self.model.state_dict().items()},
'actor': {k: v.cpu() for k, v in self.actor.state_dict().items()},
'critic': {k: v.cpu() for k, v in self.critic.state_dict().items()}}
def step(self, rollout):
if self.n_agents != rollout['action'].shape[-2]:
self.n_agents = rollout['action'].shape[-2]
self.accum_samples += len(rollout['action']) # 5
self.total_samples += len(rollout['action']) # 5
self.replay_buffer.append(rollout['observation'], rollout['action'], rollout['reward'], rollout['done'],
rollout['fake'], rollout['last'], rollout.get('avail_action'))
self.step_count += 1
if self.accum_samples < self.config.N_SAMPLES:
return
if len(self.replay_buffer) < self.config.MIN_BUFFER_SIZE:
return
self.accum_samples = 0
sys.stdout.flush()
if 20000 > self.step_count >= 10000:
self.config.MODEL_EPOCHS = 10
if self.step_count >= 20000:
self.config.MODEL_EPOCHS = 5
for i in range(self.config.MODEL_EPOCHS):
samples = self.replay_buffer.sample(self.config.MODEL_BATCH_SIZE)
self.train_model(samples)
for i in range(self.config.EPOCHS):
samples = self.replay_buffer.sample(self.config.BATCH_SIZE)
# for key, sample in samples.items():
# print("key: ", key)
# print("sample.shape: ", sample.shape)
# print("samples.shape: ", samples.shape)
self.train_agent(samples)
def train_model(self, samples): # world model
# print("Start train")
self.model.train()
loss = model_loss(self.config, self.model, samples['observation'], samples['action'], samples['av_action'],
samples['reward'], samples['done'], samples['fake'], samples['last'])
# print("loss: ", loss)
self.apply_optimizer(self.model_optimizer, self.model, loss, self.config.GRAD_CLIP, name='model')
# print("backward by model")
self.model.eval()
def train_agent(self, samples):
actions, av_actions, old_policy, imag_feat, imag_state, obs_pred, returns = actor_rollout(samples['observation'],
samples['action'],
samples['last'], self.model,
self.actor,
self.critic if self.config.ENV_TYPE == Env.STARCRAFT # TODO
else self.old_critic,
self.config)
adv = returns.detach() - self.critic(imag_feat, actions).detach()
if self.config.ENV_TYPE == Env.STARCRAFT or self.config.ENV_TYPE == Env.MPE:
adv = advantage(adv) # TODO what adv
# wandb.log({'Agent/adv': adv.mean()})
wandb.log({'Agent/Returns': returns.mean()}) # discount algorithm
# wandb.log({'Agent/Returns max': returns.max()})
# wandb.log({'Agent/Returns min': returns.min()})
# wandb.log({'Agent/Returns std': returns.std()})
for epoch in range(self.config.PPO_EPOCHS):
inds = np.random.permutation(actions.shape[0])
step = 2000
for i in range(0, len(inds), step): # 15
self.cur_update += 1
idx = inds[i:i + step]
loss = actor_loss(self.model, imag_state.map(lambda x: x[idx]) ,
obs_pred[idx], actions[idx], av_actions[idx] if av_actions is not None else None,
old_policy[idx], adv[idx], self.actor, self.entropy, self.config) # TODO
self.apply_optimizer(self.actor_optimizer, self.actor, loss, self.config.GRAD_CLIP_POLICY, name='actor')
self.entropy *= self.config.ENTROPY_ANNEALING # 0.001 0.998
val_loss = value_loss(self.critic, actions[idx], imag_feat[idx], returns[idx])
# print("val_loss: ", val_loss)
if np.random.randint(20) == 9:
wandb.log({'Agent/val_loss': val_loss, 'Agent/actor_loss': loss})
self.apply_optimizer(self.critic_optimizer, self.critic, val_loss, self.config.GRAD_CLIP_POLICY, name='critic')
# print("backward by agent")
if self.config.ENV_TYPE == Env.MPE and self.cur_update % self.config.TARGET_UPDATE == 0:
self.old_critic = deepcopy(self.critic)
def apply_optimizer(self, opt, model, loss, grad_clip, name=None): # type of model
opt.zero_grad()
loss.backward() # only here
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) # 100
if name is not None and np.random.randint(20) == 9:
wandb.log({'Grad of '+name: grad_norm})
opt.step()
def apply_optimizer_scheduler(self, opt, sch, model, loss, grad_clip, name=None):
opt.zero_grad()
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) # 100
# if name is not None:
# wandb.log({'Grad of ' + name: grad_norm})
opt.step()
sch.step(loss)
================================================
FILE: examples/Social_Cognition/ToCM/agent/memory/ToCMMemory.py
================================================
import numpy as np
import torch
from environments import Env
# 由ToCMLearner的函数创建
class ToCMMemory:
def __init__(self, capacity, sequence_length, action_size, obs_size, n_agents, device, env_type):
self.capacity = capacity
self.sequence_length = sequence_length
self.action_size = action_size
self.obs_size = obs_size
self.device = device # 加入了device
self.env_type = env_type
self.init_buffer(n_agents, env_type) # TODO
def init_buffer(self, n_agents, env_type): # 初始对环境进行采样,观察和动作都是np.array # TODO init_buffer specially?
self.observations = np.empty((self.capacity, n_agents, self.obs_size), dtype=np.float32)
self.actions = np.empty((self.capacity, n_agents, self.action_size), dtype=np.float32)
self.av_actions = np.empty((self.capacity, n_agents, self.action_size), # 3, 5
dtype=np.float32) if env_type == Env.STARCRAFT or env_type == Env.MPE else None # TODO need mask?
self.rewards = np.empty((self.capacity, n_agents, 1), dtype=np.float32)
self.dones = np.empty((self.capacity, n_agents, 1), dtype=np.float32)
self.fake = np.empty((self.capacity, n_agents, 1), dtype=np.float32)
self.last = np.empty((self.capacity, n_agents, 1), dtype=np.float32)
self.next_idx = 0
self.n_agents = n_agents
self.full = False
def append(self, obs, action, reward, done, fake, last, av_action):
if self.actions.shape[-2] != action.shape[-2]:
self.init_buffer(action.shape[-2], self.env_type)
for i in range(len(obs)):
self.observations[self.next_idx] = obs[i]
self.actions[self.next_idx] = action[i]
if av_action is not None:
self.av_actions[self.next_idx] = av_action[i]
self.rewards[self.next_idx] = reward[i]
self.dones[self.next_idx] = done[i]
self.fake[self.next_idx] = fake[i]
self.last[self.next_idx] = last[i]
self.next_idx = (self.next_idx + 1) % self.capacity
self.full = self.full or self.next_idx == 0
def tenzorify(self, nparray):
return torch.from_numpy(nparray).float()
def sample(self, batch_size):
return self.get_transitions(self.sample_positions(batch_size))
def process_batch(self, val, idxs, batch_size): # 这里全部传到了cuda上
return torch.as_tensor(val[idxs].reshape(self.sequence_length, batch_size, self.n_agents, -1)).to(self.device)
def get_transitions(self, idxs):
batch_size = len(idxs)
vec_idxs = idxs.transpose().reshape(-1)
observation = self.process_batch(self.observations, vec_idxs, batch_size)[1:]
reward = self.process_batch(self.rewards, vec_idxs, batch_size)[:-1]
action = self.process_batch(self.actions, vec_idxs, batch_size)[:-1]
av_action = self.process_batch(self.av_actions, vec_idxs, batch_size)[1:] if self.env_type == Env.STARCRAFT else None
done = self.process_batch(self.dones, vec_idxs, batch_size)[:-1]
fake = self.process_batch(self.fake, vec_idxs, batch_size)[1:]
last = self.process_batch(self.last, vec_idxs, batch_size)[1:]
return {'observation': observation, 'reward': reward, 'action': action, 'done': done,
'fake': fake, 'last': last, 'av_action': av_action}
def sample_position(self):
valid_idx = False
while not valid_idx:
idx = np.random.randint(0, self.capacity if self.full else self.next_idx - self.sequence_length)
idxs = np.arange(idx, idx + self.sequence_length) % self.capacity
valid_idx = self.next_idx not in idxs[1:] # Make sure data does not cross the memory index
return idxs
def sample_positions(self, batch_size):
return np.asarray([self.sample_position() for _ in range(batch_size)])
def __len__(self):
return self.capacity if self.full else self.next_idx
def clean(self):
self.memory = list()
self.position = 0
================================================
FILE: examples/Social_Cognition/ToCM/agent/models/ToCMModel.py
================================================
import torch
import torch.nn as nn
from environments import Env
from networks.ToCM.dense import DenseBinaryModel, DenseModel
from networks.ToCM.vae import Encoder, Decoder
from networks.ToCM.rnns import RSSMRepresentation, RSSMTransition
from thop import profile
from thop import clever_format
class ToCMModel(nn.Module):
def __init__(self, config):
super().__init__()
self.action_size = config.ACTION_SIZE
self.observation_encoder = Encoder(in_dim=config.IN_DIM, hidden=config.HIDDEN, embed=config.EMBED) # in_dim:
self.observation_decoder = Decoder(embed=config.FEAT, hidden=config.HIDDEN, out_dim=config.IN_DIM)
self.transition = RSSMTransition(config, config.MODEL_HIDDEN)
self.representation = RSSMRepresentation(config, self.transition) # ann
self.reward_model = DenseModel(config.FEAT, 1, config.REWARD_LAYERS, config.REWARD_HIDDEN) # ann
self.pcont = DenseBinaryModel(config.FEAT, 1, config.PCONT_LAYERS, config.PCONT_HIDDEN)
if config.ENV_TYPE == Env.STARCRAFT:
# print("config.FEAT, config.ACTION_SIZE, config.PCONT_LAYERS, config.PCONT_HIDDEN:", config.FEAT,
# config.ACTION_SIZE, config.PCONT_LAYERS, config.PCONT_HIDDEN) # 1280 7 2 256
self.av_action = DenseBinaryModel(config.FEAT, config.ACTION_SIZE, config.PCONT_LAYERS, config.PCONT_HIDDEN)
else:
self.av_action = None
self.q_features = DenseModel(config.HIDDEN, config.PCONT_HIDDEN, 1, config.PCONT_HIDDEN)
self.q_action = nn.Linear(config.PCONT_HIDDEN, config.ACTION_SIZE)
# input_encoder = torch.randn(1, 10, config.IN_DIM)
# macs, params = profile(self.observation_encoder, inputs=(input,))
def forward(self, observations, prev_actions=None, prev_states=None, mask=None):
if prev_actions is None:
prev_actions = torch.zeros(observations.size(0), observations.size(1), self.action_size,
device=observations.device)
if prev_states is None:
prev_states = self.representation.initial_state(prev_actions.size(0), observations.size(1),
device=observations.device)
return self.get_state_representation(observations, prev_actions, prev_states, mask)
def get_state_representation(self, observations, prev_actions, prev_states, mask):
"""
:param observations: size(batch, n_agents, in_dim)
:param prev_actions: size(batch, n_agents, action_size)
:param prev_states: size(batch, n_agents, state_size)
:return: RSSMState
"""
# print("mask = ", mask)
obs_embeds = self.observation_encoder(observations)
# print("obs_embeds=", obs_embeds)
_, states = self.representation(obs_embeds, prev_actions, prev_states, mask)
# print("state = ", states)
return states
================================================
FILE: examples/Social_Cognition/ToCM/agent/optim/loss.py
================================================
import numpy as np
import torch
import wandb
import torch.nn.functional as F
from agent.optim.utils import rec_loss, compute_return, state_divergence_loss, calculate_ppo_loss, \
batch_multi_agent, log_prob_loss, info_loss
from agent.utils.params import FreezeParameters
from networks.ToCM.rnns import rollout_representation, rollout_policy
def model_loss(config, model, obs, action, av_action, reward, done, fake, last):
time_steps = obs.shape[0]
batch_size = obs.shape[1]
n_agents = obs.shape[2]
embed = model.observation_encoder(obs.reshape(-1, n_agents, obs.shape[-1]))
embed = embed.reshape(time_steps, batch_size, n_agents, -1)
prev_state = model.representation.initial_state(batch_size, n_agents, device=obs.device)
prior, post, deters = rollout_representation(model.representation, time_steps, embed, action, prev_state, last)
feat = torch.cat([post.stoch, deters], -1)
feat_dec = post.get_features()
# decoder inputs:reshape obs[:-1] to self.SEQ_LENGTH * self.MODEL_BATCH_SIZE, n_agents, dim
reconstruction_loss, i_feat = rec_loss(model.observation_decoder, # decoder
feat_dec.reshape(-1, n_agents, feat_dec.shape[-1]), # input of decoder
obs[:-1].reshape(-1, n_agents, obs.shape[-1]), # label real obs
1. - fake[:-1].reshape(-1, n_agents, 1)) # fake
reward_loss = F.smooth_l1_loss(model.reward_model(feat), reward[1:]) # reward
# print("pcont_loss")
pcont_loss = log_prob_loss(model.pcont, feat, (1. - done[1:]))
av_action_loss = log_prob_loss(model.av_action, feat_dec, av_action[:-1]) if av_action is not None else 0.
i_feat = i_feat.reshape(time_steps - 1, batch_size, n_agents, -1)
dis_loss = info_loss(i_feat[1:], model, action[1:-1], 1. - fake[1:-1].reshape(-1))
div = state_divergence_loss(prior, post, config) #kl
model_loss = div + reward_loss + dis_loss + reconstruction_loss + pcont_loss + av_action_loss
if np.random.randint(20) == 4:
wandb.log({'Model/reward_loss': reward_loss, 'Model/div': div, 'Model/av_action_loss': av_action_loss,
'Model/reconstruction_loss': reconstruction_loss, 'Model/info_loss': dis_loss,
'Model/pcont_loss': pcont_loss})
return model_loss
def actor_rollout(obs, action, last, model, actor, critic, config): # model=ToCMLearnerModel
n_agents = obs.shape[2] # 2
with FreezeParameters([model]):
embed = model.observation_encoder(obs.reshape(-1, n_agents, obs.shape[-1]))
embed = embed.reshape(obs.shape[0], obs.shape[1], n_agents, -1)
prev_state = model.representation.initial_state(obs.shape[1], obs.shape[2], device=obs.device)
prior, post, _ = rollout_representation(model.representation, obs.shape[0], embed, action,
prev_state, last)
post = post.map(lambda x: x.reshape((obs.shape[0] - 1) * obs.shape[1], n_agents, -1))
items = rollout_policy(model, model.av_action, config.HORIZON, actor, post, action, config) # horizon is 15 TODO av_action: 49 40 2 7
#
imag_feat = items["imag_states"].get_features()
obs_pred = items["obs_preds"] # TODO
# old_policy = items['old_policy']
imag_rew_feat = torch.cat([items["imag_states"].stoch[:-1], items["imag_states"].deter[1:]], -1)
# obs_pred_rew = items["obs_preds"][:-1] # TODO
returns = critic_rollout(model, critic, imag_feat, imag_rew_feat, items["actions"],
items["imag_states"].map(lambda x: x.reshape(-1, n_agents, x.shape[-1])), config)
output = [items["actions"][:-1].detach(),
items["av_actions"][:-1].detach() if items["av_actions"] is not None else None,
items["old_policy"][:-1].detach(), # TODO pi
imag_feat[:-1].detach(),
items["imag_states"].map(lambda x: x[:-1]),
obs_pred[:-1].detach(),
returns.detach()]
return [batch_multi_agent(v, n_agents) for v in output]
def critic_rollout(model, critic, states, rew_states, actions, raw_states, config):
with FreezeParameters([model, critic]):
imag_reward = calculate_next_reward(model, actions, raw_states)
imag_reward = imag_reward.reshape(actions.shape[:-1]).unsqueeze(-1).mean(-2, keepdim=True)[:-1]
# print("states:", states.shape)
# print("actions: ", actions.shape)
value = critic(states, actions)
# print("discount_arr")
discount_arr = model.pcont(rew_states).mean
wandb.log({'Value/Max reward': imag_reward.max(), 'Value/Min reward': imag_reward.min(),
'Value/Reward': imag_reward.mean(), 'Value/Discount': discount_arr.mean(),
'Value/Value': value.mean()})
returns = compute_return(imag_reward, value[:-1], discount_arr, bootstrap=value[-1], lmbda=config.DISCOUNT_LAMBDA,
gamma=config.GAMMA)
return returns
def calculate_reward(model, states, mask=None):
imag_reward = model.reward_model(states)
if mask is not None:
imag_reward *= mask
return imag_reward
def calculate_next_reward(model, actions, states):
actions = actions.reshape(-1, actions.shape[-2], actions.shape[-1])
next_state = model.transition(actions, states)
imag_rew_feat = torch.cat([states.stoch, next_state.deter], -1)
return calculate_reward(model, imag_rew_feat)
def actor_loss(model, imag_state, obs_pred, actions, av_actions, old_policy, advantage, actor, ent_weight, config):
next_state = model.transition(actions, imag_state) # TODO
next_feat = next_state.get_features().detach() # TODO
observations_next_other, _ = model.observation_decoder(next_feat) # TODO
_, new_policy = actor(torch.cat((obs_pred, observations_next_other[:, :, -(config.num_agents-1)*4:-(config.num_agents-1)*2]), -1)) # TODO
if av_actions is not None:
new_policy[av_actions == 0] = -1e10
actions = actions.argmax(-1, keepdim=True)
rho = (F.log_softmax(new_policy, dim=-1).gather(2, actions) - # new policy is PPO pi
F.log_softmax(old_policy, dim=-1).gather(2, actions)).exp() # old policy is actor_rollout pi
ppo_loss, ent_loss = calculate_ppo_loss(new_policy, rho, advantage) # normalized
if np.random.randint(10) == 9:
wandb.log({'Policy/Entropy': ent_loss.mean(), 'Policy/Mean action': actions.float().mean()})
return (ppo_loss + ent_loss.unsqueeze(-1) * ent_weight).mean()
def value_loss(critic, actions, imag_feat, targets):
value_pred = critic(imag_feat, actions)
mse_loss = (targets - value_pred) ** 2 / 2.0
return torch.mean(mse_loss)
================================================
FILE: examples/Social_Cognition/ToCM/agent/optim/utils.py
================================================
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
def rec_loss(decoder, z, x, fake):
x_pred, feat = decoder(z)
batch_size = np.prod(list(x.shape[:-1]))
gen_loss1 = (F.smooth_l1_loss(x_pred, x, reduction='none') * fake).sum() / batch_size
return gen_loss1, feat
def ppo_loss(A, rho, eps=0.2):
return -torch.min(rho * A, rho.clamp(1 - eps, 1 + eps) * A)
def mse(model, x, target):
pred = model(x)
return ((pred - target) ** 2 / 2).mean()
def entropy_loss(prob, logProb):
return (prob * logProb).sum(-1)
def advantage(A):
std = 1e-4 + A.std() if len(A) > 0 else 1
adv = (A - A.mean()) / std
adv = adv.detach()
adv[adv != adv] = 0 # TODO what?
return adv
def calculate_ppo_loss(logits, rho, A): # pi rho adv
prob = F.softmax(logits, dim=-1)
logProb = F.log_softmax(logits, dim=-1)
polLoss = ppo_loss(A, rho)
entLoss = entropy_loss(prob, logProb)
return polLoss, entLoss
def batch_multi_agent(tensor, n_agents):
if tensor is not None:
return tensor.map(lambda x: x.view(-1, n_agents, x.shape[-1])) if tensor.type() == None \
else tensor.view(-1, n_agents, tensor.shape[-1])
else :
return None
def compute_return(reward, value, discount, bootstrap, lmbda, gamma):
next_values = torch.cat([value[1:], bootstrap[None]], 0)
target = reward + gamma * discount * next_values * (1 - lmbda)
outputs = []
accumulated_reward = bootstrap
for t in reversed(range(reward.shape[0])):
discount_factor = discount[t]
accumulated_reward = target[t] + gamma * discount_factor * accumulated_reward * lmbda
outputs.append(accumulated_reward)
returns = torch.flip(torch.stack(outputs), [0])
return returns
def info_loss(feat, model, actions, fake):
q_feat = F.relu(model.q_features(feat))
action_logits = model.q_action(q_feat)
return (fake * action_information_loss(action_logits, actions)).mean()
def action_information_loss(logits, target):
criterion = nn.CrossEntropyLoss(reduction='none')
return criterion(logits.view(-1, logits.shape[-1]), target.argmax(-1).view(-1))
def log_prob_loss(model, x, target):
pred = model(x)
return -torch.mean(pred.log_prob(target))
def kl_div_categorical(p, q):
eps = 1e-7
return (p * (torch.log(p + eps) - torch.log(q + eps))).sum(-1)
def reshape_dist(dist, config):
return dist.get_dist(dist.deter.shape[:-1], config.N_CATEGORICALS, config.N_CLASSES)
def state_divergence_loss(prior, posterior, config, reduce=True, balance=0.2):
prior_dist = reshape_dist(prior, config)
post_dist = reshape_dist(posterior, config)
post = kl_div_categorical(post_dist, prior_dist.detach())
pri = kl_div_categorical(post_dist.detach(), prior_dist)
kl_div = balance * post.mean(-1) + (1 - balance) * pri.mean(-1)
if reduce:
return torch.mean(kl_div)
else:
return kl_div
================================================
FILE: examples/Social_Cognition/ToCM/agent/runners/ToCMRunner.py
================================================
import ray
import wandb
import os
import torch
import numpy as np
from agent.workers.ToCMWorker import ToCMWorker
from environments import Env
class ToCMServer:
def __init__(self, n_workers, env_config, controller_config, model):
# ray.init(local_mode=True) #ray.init()
ray.init(dashboard_port=8625, object_store_memory=8*1024*1024*1024, _memory=8*1024*1024*1024, _temp_dir='~/temp/')
# ray.init()
self.workers = [ToCMWorker.remote(i, env_config, controller_config) for i in range(n_workers)]
self.tasks = [worker.run.remote(model) for worker in self.workers]
def append(self, idx, update):
self.tasks.append(self.workers[idx].run.remote(update))
def run(self):
done_id, tasks = ray.wait(self.tasks)
self.tasks[:] = tasks
del tasks
recvs = ray.get(done_id)[0]
return recvs
class ToCMRunner:
def __init__(self, env_config, learner_config, controller_config, n_workers):
self.env_config = env_config
self.env_type = env_config.ENV_TYPE
self.n_workers = n_workers
self.learner = learner_config.create_learner()
self.server = ToCMServer(n_workers, env_config, controller_config, self.learner.params()) # share weight
self.save_dir = '~/ToCM/weights/seed'\
+ str(controller_config.random_seed) + 'num_agent_2' + '/'+ learner_config.env_name + '/'
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.pretrain = True
def run(self, max_steps=10 ** 10, max_episodes=10 ** 10): # 10**10 50000
print("Start ToCM Runner!")
cur_steps, cur_episode = 0, 0
wandb.define_metric("steps")
wandb.define_metric("reward", step_metric="steps")
episode_rewards = []
while True:
rollout, info = self.server.run() # control_config -> worker
episode_rewards.append(info["reward"])
self.learner.step(rollout)
cur_steps += info["steps_done"]
cur_episode += 1
if self.env_type == Env.MPE:
if cur_steps % 1000 == 0:
episode_average_rewards = np.mean(episode_rewards)
episode_rewards = []
wandb.log({'reward': episode_average_rewards, 'steps': cur_steps})
print('cur_steps:', cur_steps, 'total_samples:',
self.learner.total_samples, 'reward', episode_average_rewards)
else:
wandb.log({'reward': info[""
"reward"], 'steps': cur_steps})
print('cur_episode:', cur_episode, 'total_samples:',
self.learner.total_samples, 'reward', info["reward"])
if cur_episode >= max_episodes or cur_steps >= max_steps:
break
if cur_episode % 100 == 1 and self.pretrain:
path = self.save_dir + str( cur_episode // 100)
torch.save(self.learner.params()['model'], path + '_model.pth')
torch.save(self.learner.params()['actor'], path + '_actor.pth')
torch.save(self.learner.params()['critic'], path + '_critic.pth')
print("Save model_" + str( cur_episode // 100))
self.server.append(info['idx'], self.learner.params())
================================================
FILE: examples/Social_Cognition/ToCM/agent/utils/params.py
================================================
from typing import Iterable
from torch.nn import Module
def get_parameters(modules: Iterable[Module]):
"""
Given a list of torch modules, returns a list of their parameters.
:param modules: iterable of modules
:returns: a list of parameters
"""
model_parameters = []
for module in modules:
model_parameters += list(module.parameters())
return model_parameters
class FreezeParameters:
def __init__(self, modules: Iterable[Module]):
"""
Context manager to locally freeze gradients.
In some cases with can speed up computation because gradients aren't calculated for these listed modules.
example:
```
with FreezeParameters([module]):
output_tensor = module(input_tensor)
```
:param modules: iterable of modules. used to call .parameters() to freeze gradients.
"""
self.modules = modules
self.param_states = [p.requires_grad for p in get_parameters(self.modules)]
def __enter__(self):
for param in get_parameters(self.modules):
param.requires_grad = False
def __exit__(self, exc_type, exc_val, exc_tb):
for i, param in enumerate(get_parameters(self.modules)):
param.requires_grad = self.param_states[i]
================================================
FILE: examples/Social_Cognition/ToCM/agent/workers/ToCMWorker.py
================================================
from copy import deepcopy
import numpy as np
import ray
import torch
from collections import defaultdict
from environments import Env
@ray.remote(num_gpus=1) # TODO
class ToCMWorker:
def __init__(self, idx, env_config, controller_config):
self.runner_handle = idx
self.env = env_config.create_env()
self.controller = controller_config.create_controller() # controller
self.in_dim = controller_config.IN_DIM
self.env_type = env_config.ENV_TYPE
self.controller_config = controller_config
self.device = env_config.device
def _check_handle(self, handle):
if self.env_type == Env.STARCRAFT:
return self.done[handle] == 0
else: # TODO
return self.env.agents[handle].movable
def _select_actions(self, state):
avail_actions = []
observations = []
fakes = []
nn_mask = None
for handle in range(self.env.n_agents):
if self.env_type == Env.STARCRAFT:
avail_actions.append(torch.tensor(self.env.get_avail_agent_actions(handle)))
if self._check_handle(handle) and handle in state:
fakes.append(torch.zeros(1, 1))
observations.append(state[handle].unsqueeze(0))
elif self.done[handle] == 1: # handle is not in state
fakes.append(torch.ones(1, 1)) # fake move
observations.append((self.get_absorbing_state()).to(self.device))
else:
fakes.append(torch.zeros(1, 1))
obs = (torch.tensor(self.env.obs_builder._get_internal(handle)).float().unsqueeze(0)).to(self.device)
observations.append(obs)
# print("observations:", observations)
observations = torch.cat(observations).unsqueeze(0) # TODO
# print("observations:", observations)
av_action = torch.stack(avail_actions).unsqueeze(0).to(self.device) if len(avail_actions) > 0 else None
# print("av_actions:", av_action)
nn_mask = nn_mask.unsqueeze(0).repeat(8, 1, 1).to(self.device) if nn_mask is not None else None
# print("nn_mask:", nn_mask)
actions = self.controller.step(observations, av_action, nn_mask).to(self.device)
# print("actions:", actions)
return actions, observations, torch.cat(fakes).unsqueeze(0), av_action # TODO use controller to model and pred
def _wrap(self, d):
for key, value in d.items():
d[key] = torch.tensor(value).to(self.controller_config.DEVICE).float()
return d
def get_absorbing_state(self):
state = torch.zeros(1, self.in_dim).to(self.device) # TODO
return state
def augment(self, data, inverse=False):
aug = []
default = list(data.values())[0].reshape(1, -1)
for handle in range(self.env.n_agents):
if handle in data.keys():
aug.append(data[handle].reshape(1, -1))
else:
aug.append(torch.ones_like(default) if inverse else torch.zeros_like(default))
return torch.cat(aug).unsqueeze(0).to(self.device) # TODO
def _check_termination(self, info, steps_done):
if self.env_type == Env.STARCRAFT or self.env_type == Env.MPE:
return "episode_limit" not in info
else:
return steps_done < self.env.max_time_steps # can not chao shi
def run(self, ToCM_params):
f"""
interact with environment
:param ToCM_params:
:return: rollout: dict reward steps_done
"""
self.controller.receive_params(ToCM_params)
# Share the parameters learned by the learner with the controller.
# freeze the parameters
state = self._wrap(self.env.reset()) # to device
steps_done = 0
self.done = defaultdict(lambda: False)
episode_rewards = []
while True:
steps_done += 1
# print("state=", state)
actions, obs, fakes, av_actions = self._select_actions(state) # use controller to select action
if self.env_type == Env.MPE:
next_state, reward, done, info = self.env.step(actions) # use env to update, with cpu
rewards = []
for key, value in reward.items():
rewards.append(value)
episode_rewards.append(rewards)
else:
next_state, reward, done, info = self.env.step([action.argmax() for i, action in enumerate(actions)])
next_state, reward, done = self._wrap(deepcopy(next_state)), self._wrap(deepcopy(reward)), \
self._wrap(deepcopy(done)) # to device
self.done = done
self.controller.update_buffer({"action": actions,
"observation": obs,
"reward": self.augment(reward),
"done": self.augment(done),
"fake": fakes,
"avail_action": av_actions})
state = next_state
if all([done[key] == 1 for key in range(self.env.n_agents)]):
# print("Done")
if self._check_termination(info, steps_done):
# print("Done!")
obs = torch.cat([self.get_absorbing_state() for i in range(self.env.n_agents)]).unsqueeze(0)
actions = torch.zeros(1, self.env.n_agents, actions.shape[-1])
index = torch.randint(0, actions.shape[-1], actions.shape[:-1], device=actions.device)
actions.scatter_(2, index.unsqueeze(-1), 1.)
items = {"observation": obs,
"action": actions,
"reward": torch.zeros(1, self.env.n_agents, 1),
"fake": torch.ones(1, self.env.n_agents, 1),
"done": torch.ones(1, self.env.n_agents, 1),
"avail_action": torch.ones_like(actions) if self.env_type == Env.STARCRAFT else None}
self.controller.update_buffer(items)
self.controller.update_buffer(items) # why two
break
if self.env_type == Env.MPE:
reward = np.mean(np.sum(episode_rewards, axis=0)) # TODO
else:
reward = 1. if 'battle_won' in info and info['battle_won'] else 0.
return self.controller.dispatch_buffer(), {"idx": self.runner_handle,
"reward": reward, # a num
"steps_done": steps_done}
================================================
FILE: examples/Social_Cognition/ToCM/configs/Config.py
================================================
from collections.abc import Iterable
# train->agent_configs = [ToCMControllerConfig(ToCMConfig),] -> class ToCMConfig(Config) -> Config
class Config:
def __init__(self):
pass
def to_dict(self, prefix=""):
res_dict = dict()
for key, value in self.__dict__.items():
if isinstance(value, Config):
res_dict.update(value.to_dict(prefix + str(key) + "_"))
elif isinstance(value, Iterable):
if value and isinstance(value[0], Config):
for i, v in enumerate(value):
res_dict.update(v.to_dict(prefix + str(key) + str(i) + "_"))
else:
res_dict[prefix + str(key)] = value
else:
res_dict[prefix + str(key)] = value
return res_dict
================================================
FILE: examples/Social_Cognition/ToCM/configs/EnvConfigs.py
================================================
from configs.Config import Config
from env.starcraft.StarCraft import StarCraft
from env.mpe.MPE import MPE
class EnvConfig(Config):
def __init__(self):
pass
def create_env(self):
pass
# TODO
class MPEConfig(EnvConfig):
def __init__(self, args):
self.args = args
def create_env(self):
return MPE(self.args) # an env object with base class MultiAgentEnv(gym.Env)
class StarCraftConfig(EnvConfig):
def __init__(self, env_name, random_seed):
self.env_name = env_name
self.random_seed = random_seed # TODO
def create_env(self):
return StarCraft(self.env_name, self.random_seed)
class EnvCurriculumConfig(EnvConfig):
def __init__(self, env_configs, env_episodes, env_type, device, obs_builder_config=None, reward_config=None):
self.env_configs = env_configs
self.env_episodes = env_episodes # (100,)
self.ENV_TYPE = env_type #
self.device = device # TODO
if obs_builder_config is not None:
self.set_obs_builder_config(obs_builder_config)
if reward_config is not None:
self.set_reward_config(reward_config)
def update_random_seed(self):
for conf in self.env_configs:
conf.update_random_seed()
def set_obs_builder_config(self, obs_builder_config):
for conf in self.env_configs:
conf.set_obs_builder_config(obs_builder_config)
def set_reward_config(self, reward_config):
for conf in self.env_configs:
conf.set_reward_config(reward_config)
def create_env(self):
return EnvCurriculum(self.env_configs, self.env_episodes)
class EnvCurriculumSampleConfig(EnvConfig):
def __init__(self, env_configs, env_probs, obs_builder_config=None, reward_config=None):
self.env_configs = env_configs
self.env_probs = env_probs
if obs_builder_config is not None:
self.set_obs_builder_config(obs_builder_config)
if reward_config is not None:
self.set_reward_config(reward_config)
def update_random_seed(self):
for conf in self.env_configs:
conf.update_random_seed()
def set_obs_builder_config(self, obs_builder_config):
for conf in self.env_configs:
conf.set_obs_builder_config(obs_builder_config)
def set_reward_config(self, reward_config):
for conf in self.env_configs:
conf.set_reward_config(reward_config)
def create_env(self):
return EnvCurriculumSample(self.env_configs, self.env_probs)
class EnvCurriculumPrioritizedSampleConfig(EnvConfig):
def __init__(self, env_configs, repeat_random_seed, obs_builder_config=None, reward_config=None):
self.env_configs = env_configs
self.repeat_random_seed = repeat_random_seed
if obs_builder_config is not None:
self.set_obs_builder_config(obs_builder_config)
if reward_config is not None:
self.set_reward_config(reward_config)
def update_random_seed(self):
for conf in self.env_configs:
conf.update_random_seed()
def set_obs_builder_config(self, obs_builder_config):
for conf in self.env_configs:
conf.set_obs_builder_config(obs_builder_config)
def set_reward_config(self, reward_config):
for conf in self.env_configs:
conf.set_reward_config(reward_config)
def create_env(self):
return EnvCurriculumPrioritizedSample(self.env_configs, self.repeat_random_seed)
================================================
FILE: examples/Social_Cognition/ToCM/configs/Experiment.py
================================================
from configs.Config import Config
class Experiment(Config): # 这个还没改且里面没有
def __init__(self, steps, episodes, random_seed, env_config, controller_config, learner_config):
super(Experiment, self).__init__() # TODO device 在env里面加入了
self.steps = steps
self.episodes = episodes
self.random_seed = random_seed
self.env_config = env_config
self.controller_config = controller_config
self.learner_config = learner_config
================================================
FILE: examples/Social_Cognition/ToCM/configs/ToCM/ToCMAgentConfig.py
================================================
from dataclasses import dataclass
import torch
import torch.distributions as td
import torch.nn.functional as F
from configs.Config import Config
RSSM_STATE_MODE = 'discrete'
#
class ToCMConfig(Config): # 从Config继承
def __init__(self):
super().__init__()
self.HIDDEN = 64 # 隐藏层神经元个数
self.MODEL_HIDDEN = 64 # 模型隐藏层神经元个数
self.EMBED = 64 # 编码器神经元个数
self.N_CATEGORICALS = 32 # 分类数
self.N_CLASSES = 32 # 类别数
self.STOCHASTIC = self.N_CATEGORICALS * self.N_CLASSES # stochastic:随机的
self.DETERMINISTIC = 64 # deterministic:确定的
self.FEAT = self.STOCHASTIC + self.DETERMINISTIC # feat:特征
self.GLOBAL_FEAT = self.FEAT + self.EMBED # global_feat:全局特征
self.VALUE_LAYERS = 2 # value_layers:值层
self.VALUE_HIDDEN = 64 # value_hidden:值隐藏层
self.PCONT_LAYERS = 2 # pcont_layers:概率层
self.PCONT_HIDDEN = 64 # pcont_hidden:概率隐藏层
self.ACTION_SIZE = 9 # action_size:动作大小
self.ACTION_LAYERS = 2 # action_layers:动作层
self.ACTION_HIDDEN = 64 # action_hidden:动作隐藏层
self.REWARD_LAYERS = 2 # reward_layers:奖励层
self.REWARD_HIDDEN = 64 # reward_hidden:奖励隐藏层
self.GAMMA = 0.99 # gamma:折扣因子
self.DISCOUNT = 0.99 # discount:折扣
self.DISCOUNT_LAMBDA = 0.95 # discount_lambda:折扣lambda
self.IN_DIM = 30 # in_dim:输入维度
self.LOG_FOLDER = 'wandb/' # log_folder:日志文件夹
self.num_agents = 2
@dataclass
class RSSMStateBase:
stoch: torch.Tensor
deter: torch.Tensor
def map(self, func):
return RSSMState(**{key: func(val) for key, val in self.__dict__.items()})
def get_features(self):
return torch.cat((self.stoch, self.deter), dim=-1)
def get_dist(self, *input):
pass
def type(self):
return None
@dataclass
class RSSMStateDiscrete(RSSMStateBase):
logits: torch.Tensor
def get_dist(self, batch_shape, n_categoricals, n_classes):
return F.softmax(self.logits.reshape(*batch_shape, n_categoricals, n_classes), -1)
@dataclass
class RSSMStateCont(RSSMStateBase):
mean: torch.Tensor
std: torch.Tensor
def get_dist(self, *input):
return td.independent.Independent(td.Normal(self.mean, self.std), 1)
RSSMState = {'discrete': RSSMStateDiscrete,
'cont': RSSMStateCont}[RSSM_STATE_MODE]
================================================
FILE: examples/Social_Cognition/ToCM/configs/ToCM/ToCMControllerConfig.py
================================================
from agent.controllers.ToCMController import ToCMController
from configs.ToCM.ToCMAgentConfig import ToCMConfig
# train->agent_configs = [ToCMControllerConfig(ToCMConfig),] -> class ToCMConfig(Config) -> Config
class ToCMControllerConfig(ToCMConfig):
def __init__(self, env_name, RANDOM_SEED, device): # RANDOM_SEED:23 device:'cuda:6' env_name:'3s5z_vs_3s6z'
super().__init__()
self.EXPL_DECAY = 0.9999 # exploration decay rate:探索衰减率
self.EXPL_NOISE = 0. # exploration noise:探索噪声
self.EXPL_MIN = 0. # minimum exploration:最小探索
self.DEVICE = device # TODO
self.env_name = env_name # TODO
self.random_seed = RANDOM_SEED # TODO
def create_controller(self):
return ToCMController(self)
================================================
FILE: examples/Social_Cognition/ToCM/configs/ToCM/ToCMLearnerConfig.py
================================================
from agent.learners.ToCMLearner import ToCMLearner
from configs.ToCM.ToCMAgentConfig import ToCMConfig
# train->agent_configs = [ToCMLearnerConfig(ToCMConfig),] -> class ToCMConfig(Config) -> Config
class ToCMLearnerConfig(ToCMConfig): # 从ToCMConfig继承,有输入维度、输出维度、隐层维度、隐层层数、动作维度、动作隐层维度、动作隐层层数、
def __init__(self, env_name, RANDOM_SEED, device):
super().__init__()
self.MODEL_LR = 2e-4
self.ACTOR_LR = 7e-4 # TODO
self.VALUE_LR = 7e-4 # TODO
self.CAPACITY = 500000
self.MIN_BUFFER_SIZE = 100
self.MODEL_EPOCHS = 20 # TODO
self.EPOCHS = 4 # TODO
self.PPO_EPOCHS = 10 # TODO
self.MODEL_BATCH_SIZE = 30#40
self.BATCH_SIZE = 40
self.SEQ_LENGTH = 50
self.N_SAMPLES = 1
self.TARGET_UPDATE = 128
self.GRAD_CLIP = 100.0
self.HORIZON = 15
self.ENTROPY = 0.001
self.ENTROPY_ANNEALING = 0.99998
self.GRAD_CLIP_POLICY = 100.
self.DEVICE = device # TODO
self.env_name = env_name # TODO
self.random_seed = RANDOM_SEED # TODO
self.num_agents = 2
def create_learner(self): # 通过config创建learner
return ToCMLearner(self)
================================================
FILE: examples/Social_Cognition/ToCM/configs/ToCM/optimal/starcraft/AgentConfig.py
================================================
from configs.Config import Config
class ToCMConfig(Config):
def __init__(self):
super().__init__()
self.HIDDEN = 256
self.MODEL_HIDDEN = 256
self.EMBED = 256
self.N_CATEGORICALS = 32
self.N_CLASSES = 32
self.STOCHASTIC = self.N_CATEGORICALS * self.N_CLASSES
self.DETERMINISTIC = 256
self.FEAT = self.STOCHASTIC + self.DETERMINISTIC
self.GLOBAL_FEAT = self.FEAT + self.EMBED
self.VALUE_LAYERS = 2
self.VALUE_HIDDEN = 256
self.PCONT_LAYERS = 2
self.PCONT_HIDDEN = 256
self.ACTION_SIZE = 9
self.ACTION_LAYERS = 2
self.ACTION_HIDDEN = 256
self.REWARD_LAYERS = 2
self.REWARD_HIDDEN = 256
self.GAMMA = 0.99
self.DISCOUNT = 0.99
self.DISCOUNT_LAMBDA = 0.95
self.IN_DIM = 30
================================================
FILE: examples/Social_Cognition/ToCM/configs/ToCM/optimal/starcraft/LearnerConfig.py
================================================
from agent.learners.ToCMLearner import ToCMLearner
from configs.ToCM.ToCMAgentConfig import ToCMConfig
class ToCMLearnerConfig(ToCMConfig):
def __init__(self):
super().__init__()
self.MODEL_LR = 2e-4
self.ACTOR_LR = 5e-4
self.VALUE_LR = 5e-4
self.CAPACITY = 250000
self.MIN_BUFFER_SIZE = 500
self.MODEL_EPOCHS = 40
self.EPOCHS = 4
self.PPO_EPOCHS = 10
self.MODEL_BATCH_SIZE = 40
self.BATCH_SIZE = 40
self.SEQ_LENGTH = 20
self.N_SAMPLES = 1
self.TARGET_UPDATE = 1
self.DEVICE = 'cuda:8'
self.GRAD_CLIP = 100.0
self.HORIZON = 15
self.ENTROPY = 0.001
self.ENTROPY_ANNEALING = 0.99998
self.GRAD_CLIP_POLICY = 100.0
def create_learner(self):
return ToCMLearner(self)
================================================
FILE: examples/Social_Cognition/ToCM/configs/__init__.py
================================================
from .Experiment import Experiment
================================================
FILE: examples/Social_Cognition/ToCM/env/mpe/MPE.py
================================================
from mpe.MPE_Env import MPEEnv
class MPE:
def __init__(self, args):
self.env = MPEEnv(args) # TODO args name and random seed
# scenario_name=args.scenario_name, benchmark=args.benchmark, num_agents=args.num_agents,
# num_adversaries, num_landmarks, episode_length
self.env.seed(args.seed)
self.n_agents = self.env.num_agents
self.agents = self.env.agents
def to_dict(self, l):
return {i: e for i, e in enumerate(l)}
def step(self, action_dict): # action dict for each agent
# print("action_dist", action_dict)
obs, reward, done, info = self.env.step(action_dict) # TODO return four list
return {i: obs[i] for i in range(self.n_agents)}, {i: reward[i] for i in range(self.n_agents)}, \
{i: done[i] for i in range(self.n_agents)}, {i: info[i] for i in range(self.n_agents)}
def reset(self):
obs = self.env.reset()
return self.to_dict(obs)
def close(self):
self.env.close()
# no mask and no this usage
def get_avail_agent_actions(self, handle): # available handle is the i th agent, add mask
return self.env._get_done(handle)
================================================
FILE: examples/Social_Cognition/ToCM/env/starcraft/StarCraft.py
================================================
from smac.env import StarCraft2Env # import a package smac
class StarCraft:
def __init__(self, env_name, random_seed):
# map_name ->
self.env = StarCraft2Env(map_name=env_name, seed=random_seed, continuing_episode=True, difficulty="7") # TODO
env_info = self.env.get_env_info()
self.n_obs = env_info["obs_shape"]
self.n_actions = env_info["n_actions"]
self.n_agents = env_info["n_agents"]
def to_dict(self, l):
return {i: e for i, e in enumerate(l)}
def step(self, action_dict):
reward, done, info = self.env.step(action_dict)
return self.to_dict(self.env.get_obs()), {i: reward for i in range(self.n_agents)}, \
{i: done for i in range(self.n_agents)}, info
def reset(self):
self.env.reset()
return {i: obs for i, obs in enumerate(self.env.get_obs())}
def render(self):
self.env.render()
def close(self):
self.env.close()
def get_avail_agent_actions(self, handle):
return self.env.get_avail_agent_actions(handle)
================================================
FILE: examples/Social_Cognition/ToCM/environments.py
================================================
from enum import Enum
class Env(str, Enum):
STARCRAFT = "starcraft"
MPE = "mpe"
# RANDOM_SEED = 23
# ENV_NAME = "5_agents"
================================================
FILE: examples/Social_Cognition/ToCM/mpe/MPE_Env.py
================================================
"""
Code for creating a multiagent environment with one of the scenarios listed
in ./scenarios/.
Can be called by using, for example:
env = make_env('simple_speaker_listener')
After producing the env object, can be used similarly to an OpenAI gym
environment.
A policy using this environment must output actions in the form of a list
for all agents. Each element of the list should be a numpy array,
of size (env.world.dim_p + env.world.dim_c, 1). Physical actions precede
communication actions in this array. See environment.py for more details.
"""
from .environment import MultiAgentEnv
from .scenarios import load
def MPEEnv(args):
"""
Creates a MultiAgentEnv object as env. This can be used similar to a gym
environment by calling env.reset() and env.step().
Use env.render() to view the environment on the screen.
Input:
scenario_name : name of the scenario from ./scenarios/ to be Returns
(without the .py extension)
benchmark : whether you want to produce benchmarking data
(usually only done during evaluation)
Some useful env properties (see environment.py):
.observation_space : Returns the observation space for each agent
.action_space : Returns the action space for each agent
.n : Returns the number of Agents
"""
# load scenario from script
scenario = load(args.env_name + ".py").Scenario()
# create world
world = scenario.make_world(args) # py file and others parameters, use the train parse?
# create multi agent environment
env = MultiAgentEnv(world, scenario.reset_world,
scenario.reward, scenario.observation)
return env
================================================
FILE: examples/Social_Cognition/ToCM/mpe/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/ToCM/mpe/core.py
================================================
import numpy as np
# import seaborn as sns
# physical/external base state of all entites
class EntityState(object):
def __init__(self):
# physical position
self.p_pos = None
# physical velocity
self.p_vel = None
# state of agents (including communication and internal/mental state)
class AgentState(EntityState):
def __init__(self):
super(AgentState, self).__init__()
# communication utterance
self.c = None
# action of the agent
class Action(object):
def __init__(self):
# physical action
self.u = None
# communication action
self.c = None
# properties of wall entities
class Wall(object):
def __init__(self, orient='H', axis_pos=0.0, endpoints=(-1, 1), width=0.1,
hard=True):
# orientation: 'H'orizontal or 'V'ertical
self.orient = orient
# position along axis which wall lays on (y-axis for H, x-axis for V)
self.axis_pos = axis_pos
# endpoints of wall (x-coords for H, y-coords for V)
self.endpoints = np.array(endpoints)
# width of wall
self.width = width
# whether wall is impassable to all agents
self.hard = hard
# color of wall
self.color = np.array([0.0, 0.0, 0.0])
# properties and state of physical world entity
class Entity(object):
def __init__(self):
# index among all entities (important to set for distance caching)
self.i = 0
# name
self.name = ''
# properties:
self.size = 0.050
# entity can move / be pushed
self.movable = False
# entity collides with others
self.collide = True
# entity can pass through non-hard walls
self.ghost = False
# material density (affects mass)
self.density = 25.0
# color
self.color = None
# max speed and accel
self.max_speed = None
self.accel = None
# state: including internal/mental state p_pos, p_vel
self.state = EntityState()
# mass
self.initial_mass = 1.0
# commu channel
self.channel = None
@property
def mass(self):
return self.initial_mass
# properties of landmark entities
class Landmark(Entity):
def __init__(self):
super(Landmark, self).__init__()
# properties of agent entities
class Agent(Entity):
def __init__(self):
super(Agent, self).__init__()
# agent are adversary
self.adversary = False
# agent are dummy
self.dummy = False
# agents are movable by default
self.movable = True
# cannot send communication signals
self.silent = False
# cannot observe the world
self.blind = False
# physical motor noise amount
self.u_noise = None
# communication noise amount
self.c_noise = None
# control range
self.u_range = 1.0
# state: including communication state(communication utterance) c and internal/mental state p_pos, p_vel
self.state = AgentState()
# action: physical action u & communication action c
self.action = Action()
# script behavior to execute
self.action_callback = None
# zoe 20200420
self.goal = None
# multi-agent world
class World(object):
def __init__(self):
# list of agents and entities (can change at execution-time!)
self.agents = []
self.landmarks = []
self.walls = []
# communication channel dimensionality
self.dim_c = 0
# position dimensionality
self.dim_p = 2
# color dimensionality
self.dim_color = 3
# simulation timestep
self.dt = 0.1
# physical damping
self.damping = 0.25
# contact response parameters
self.contact_force = 1e+2
self.contact_margin = 1e-3
# cache distances between all agents (not calculated by default)
self.cache_dists = False
self.cached_dist_vect = None
self.cached_dist_mag = None
# zoe 20200420
self.world_length = 25
self.world_step = 0
self.num_agents = 0
self.num_landmarks = 0
# return all entities in the world
@property
def entities(self):
return self.agents + self.landmarks
# return all agents controllable by external policies
@property
def policy_agents(self):
return [agent for agent in self.agents if agent.action_callback is None]
# return all agents controlled by world scripts
@property
def scripted_agents(self):
return [agent for agent in self.agents if agent.action_callback is not None]
def calculate_distances(self):
if self.cached_dist_vect is None:
# initialize distance data structure
self.cached_dist_vect = np.zeros((len(self.entities),
len(self.entities),
self.dim_p))
# calculate minimum distance for a collision between all entities
self.min_dists = np.zeros((len(self.entities), len(self.entities)))
for ia, entity_a in enumerate(self.entities):
for ib in range(ia + 1, len(self.entities)):
entity_b = self.entities[ib]
min_dist = entity_a.size + entity_b.size
self.min_dists[ia, ib] = min_dist
self.min_dists[ib, ia] = min_dist
for ia, entity_a in enumerate(self.entities):
for ib in range(ia + 1, len(self.entities)):
entity_b = self.entities[ib]
delta_pos = entity_a.state.p_pos - entity_b.state.p_pos
self.cached_dist_vect[ia, ib, :] = delta_pos
self.cached_dist_vect[ib, ia, :] = -delta_pos
self.cached_dist_mag = np.linalg.norm(self.cached_dist_vect, axis=2)
self.cached_collisions = (self.cached_dist_mag <= self.min_dists)
def assign_agent_colors(self):
n_dummies = 0
if hasattr(self.agents[0], 'dummy'):
n_dummies = len([a for a in self.agents if a.dummy])
n_adversaries = 0
if hasattr(self.agents[0], 'adversary'):
n_adversaries = len([a for a in self.agents if a.adversary])
n_good_agents = len(self.agents) - n_adversaries - n_dummies
# r g b
dummy_colors = [(0.25, 0.75, 0.25)] * n_dummies
# sns.color_palette("OrRd_d", n_adversaries)
adv_colors = [(0.75, 0.25, 0.25)] * n_adversaries
# sns.color_palette("GnBu_d", n_good_agents)
good_colors = [(0.25, 0.25, 0.75)] * n_good_agents
colors = dummy_colors + adv_colors + good_colors
for color, agent in zip(colors, self.agents):
agent.color = color
# landmark color
def assign_landmark_colors(self):
for landmark in self.landmarks:
landmark.color = np.array([0.25, 0.25, 0.25])
# update state of the world
def step(self):
self.world_step += 1
# set actions for scripted agents
for agent in self.scripted_agents:
agent.action = agent.action_callback(agent, self)
# gather forces applied to entities
p_force = [None] * len(self.entities)
# apply agent physical controls
p_force = self.apply_action_force(p_force)
# apply environment forces
p_force = self.apply_environment_force(p_force)
# integrate physical state
self.integrate_state(p_force)
# update agent state
for agent in self.agents:
self.update_agent_state(agent)
# calculate and store distances between all entities
if self.cache_dists:
self.calculate_distances()
# gather agent action forces
def apply_action_force(self, p_force):
# set applied forces
for i, agent in enumerate(self.agents):
if agent.movable:
noise = np.random.randn(
*agent.action.u.shape) * agent.u_noise if agent.u_noise else 0.0
# force = mass * a * action + n
p_force[i] = (
agent.mass * agent.accel if agent.accel is not None else agent.mass) * agent.action.u + noise
return p_force
# gather physical forces acting on entities
def apply_environment_force(self, p_force):
# simple (but inefficient) collision response
for a, entity_a in enumerate(self.entities):
for b, entity_b in enumerate(self.entities):
if(b <= a):
continue
[f_a, f_b] = self.get_entity_collision_force(a, b)
if(f_a is not None):
if(p_force[a] is None):
p_force[a] = 0.0
p_force[a] = f_a + p_force[a]
if(f_b is not None):
if(p_force[b] is None):
p_force[b] = 0.0
p_force[b] = f_b + p_force[b]
if entity_a.movable:
for wall in self.walls:
wf = self.get_wall_collision_force(entity_a, wall)
if wf is not None:
if p_force[a] is None:
p_force[a] = 0.0
p_force[a] = p_force[a] + wf
return p_force
# integrate physical state
def integrate_state(self, p_force):
for i, entity in enumerate(self.entities):
if not entity.movable:
continue
entity.state.p_vel = entity.state.p_vel * (1 - self.damping)
if (p_force[i] is not None):
entity.state.p_vel += (p_force[i] / entity.mass) * self.dt
if entity.max_speed is not None:
speed = np.sqrt(
np.square(entity.state.p_vel[0]) + np.square(entity.state.p_vel[1]))
if speed > entity.max_speed:
entity.state.p_vel = entity.state.p_vel / np.sqrt(np.square(entity.state.p_vel[0]) +
np.square(entity.state.p_vel[1])) * entity.max_speed
entity.state.p_pos += entity.state.p_vel * self.dt
def update_agent_state(self, agent):
# set communication state (directly for now)
if agent.silent:
agent.state.c = np.zeros(self.dim_c)
else:
noise = np.random.randn(*agent.action.c.shape) * \
agent.c_noise if agent.c_noise else 0.0
agent.state.c = agent.action.c + noise
# get collision forces for any contact between two entities
def get_entity_collision_force(self, ia, ib):
entity_a = self.entities[ia]
entity_b = self.entities[ib]
if (not entity_a.collide) or (not entity_b.collide):
return [None, None] # not a collider
if (not entity_a.movable) and (not entity_b.movable):
return [None, None] # neither entity moves
if (entity_a is entity_b):
return [None, None] # don't collide against itself
if self.cache_dists:
delta_pos = self.cached_dist_vect[ia, ib]
dist = self.cached_dist_mag[ia, ib]
dist_min = self.min_dists[ia, ib]
else:
# compute actual distance between entities
delta_pos = entity_a.state.p_pos - entity_b.state.p_pos
dist = np.sqrt(np.sum(np.square(delta_pos)))
# minimum allowable distance
dist_min = entity_a.size + entity_b.size
# softmax penetration
k = self.contact_margin
penetration = np.logaddexp(0, -(dist - dist_min)/k)*k
force = self.contact_force * delta_pos / dist * penetration
if entity_a.movable and entity_b.movable:
# consider mass in collisions
force_ratio = entity_b.mass / entity_a.mass
force_a = force_ratio * force
force_b = -(1 / force_ratio) * force
else:
force_a = +force if entity_a.movable else None
force_b = -force if entity_b.movable else None
return [force_a, force_b]
# get collision forces for contact between an entity and a wall
def get_wall_collision_force(self, entity, wall):
if entity.ghost and not wall.hard:
return None # ghost passes through soft walls
if wall.orient == 'H':
prll_dim = 0
perp_dim = 1
else:
prll_dim = 1
perp_dim = 0
ent_pos = entity.state.p_pos
if (ent_pos[prll_dim] < wall.endpoints[0] - entity.size or
ent_pos[prll_dim] > wall.endpoints[1] + entity.size):
return None # entity is beyond endpoints of wall
elif (ent_pos[prll_dim] < wall.endpoints[0] or
ent_pos[prll_dim] > wall.endpoints[1]):
# part of entity is beyond wall
if ent_pos[prll_dim] < wall.endpoints[0]:
dist_past_end = ent_pos[prll_dim] - wall.endpoints[0]
else:
dist_past_end = ent_pos[prll_dim] - wall.endpoints[1]
theta = np.arcsin(dist_past_end / entity.size)
dist_min = np.cos(theta) * entity.size + 0.5 * wall.width
else: # entire entity lies within bounds of wall
theta = 0
dist_past_end = 0
dist_min = entity.size + 0.5 * wall.width
# only need to calculate distance in relevant dim
delta_pos = ent_pos[perp_dim] - wall.axis_pos
dist = np.abs(delta_pos)
# softmax penetration
k = self.contact_margin
penetration = np.logaddexp(0, -(dist - dist_min)/k)*k
force_mag = self.contact_force * delta_pos / dist * penetration
force = np.zeros(2)
force[perp_dim] = np.cos(theta) * force_mag
force[prll_dim] = np.sin(theta) * np.abs(force_mag)
return force
================================================
FILE: examples/Social_Cognition/ToCM/mpe/environment.py
================================================
import gym
from gym import spaces
from gym.envs.registration import EnvSpec
import numpy as np
from .multi_discrete import MultiDiscrete
# update bounds to center around agent
cam_range = 2
# environment for all agents in the multi agent world
# currently code assumes that no agents will be created/destroyed at runtime!
class MultiAgentEnv(gym.Env):
metadata = {
'render.modes': ['human', 'rgb_array']
}
def __init__(self, world, reset_callback=None, reward_callback=None,
observation_callback=None, info_callback=None,
done_callback=None, post_step_callback=None,
shared_viewer=True, discrete_action=True):
self.world = world
self.world_length = self.world.world_length # obs TODO 25
self.current_step = 0
self.agents = self.world.policy_agents
# set required vectorized gym env property
self.num_agents = len(world.policy_agents)
# scenario callbacks
self.reset_callback = reset_callback
self.reward_callback = reward_callback
self.observation_callback = observation_callback
self.info_callback = info_callback
self.done_callback = done_callback
self.post_step_callback = post_step_callback
# environment parameters
# self.discrete_action_space = True
self.discrete_action_space = discrete_action # actions dim TODO
# if true, action is a number 0...N, otherwise action is a one-hot N-dimensional vector
self.discrete_action_input = False
# if true, even the action is continuous, action will be performed discretely
self.force_discrete_action = world.discrete_action if hasattr(
world, 'discrete_action') else False
# in this env, force_discrete_action == False��because world do not have discrete_action
# if true, every agent has the same reward
self.shared_reward = world.collaborative if hasattr(
world, 'collaborative') else False
# self.shared_reward = False
self.time = 0
# configure spaces
self.action_space = []
self.observation_space = []
self.share_observation_space = []
share_obs_dim = 0
for agent in self.agents:
total_action_space = []
# physical action space
if self.discrete_action_space:
u_action_space = spaces.Discrete(world.dim_p * 2 + 1)
else:
u_action_space = spaces.Box(
low=-agent.u_range, high=+agent.u_range, shape=(world.dim_p,), dtype=np.float32) # [-1,1]
if agent.movable:
total_action_space.append(u_action_space)
# communication action space
if self.discrete_action_space:
c_action_space = spaces.Discrete(world.dim_c)
else:
c_action_space = spaces.Box(low=0.0, high=1.0, shape=(
world.dim_c,), dtype=np.float32) # [0,1]
# c_action_space = spaces.Discrete(world.dim_c)
if not agent.silent:
total_action_space.append(c_action_space)
# total action space
if len(total_action_space) > 1:
# all action spaces are discrete, so simplify to MultiDiscrete action space
if all([isinstance(act_space, spaces.Discrete) for act_space in total_action_space]):
act_space = MultiDiscrete(
[[0, act_space.n - 1] for act_space in total_action_space])
else:
act_space = spaces.Tuple(total_action_space)
self.action_space.append(act_space)
else:
self.action_space.append(total_action_space[0])
# observation space
obs_dim = len(observation_callback(agent, self.world))
share_obs_dim += obs_dim
self.observation_space.append(spaces.Box(
low=-np.inf, high=+np.inf, shape=(obs_dim,), dtype=np.float32)) # [-inf,inf]
agent.action.c = np.zeros(self.world.dim_c)
self.share_observation_space = [spaces.Box(
low=-np.inf, high=+np.inf, shape=(share_obs_dim,), dtype=np.float32)] * self.num_agents
# rendering
self.shared_viewer = shared_viewer
if self.shared_viewer:
self.viewers = [None]
else:
self.viewers = [None] * self.num_agents
self._reset_render()
def seed(self, seed=None):
if seed is None:
np.random.seed(1)
else:
np.random.seed(seed)
# step this is env.step()
def step(self, action_n):
self.current_step += 1
obs_n = []
reward_n = []
done_n = []
info_n = []
self.agents = self.world.policy_agents
# set action for each agent
for i, agent in enumerate(self.agents):
self._set_action(action_n[i], agent, self.action_space[i])
# advance world state
self.world.step() # core.step()
# record observation for each agent
for i, agent in enumerate(self.agents):
obs_n.append(self._get_obs(agent))
reward_n.append([self._get_reward(agent)])
done_n.append([self._get_done(agent)])
info = {'individual_reward': self._get_reward(agent)}
info_n.append(info)
# all agents get total reward in cooperative case, if shared reward, all agents have the same reward,
# and reward is sum
reward = np.sum(reward_n)
if self.shared_reward:
reward_n = [[reward]] * self.num_agents
if self.post_step_callback is not None:
self.post_step_callback(self.world)
return obs_n, reward_n, done_n, info_n
def reset(self):
self.current_step = 0
# reset world
self.reset_callback(self.world)
# reset renderer
self._reset_render()
# record observations for each agent
obs_n = []
self.agents = self.world.policy_agents
for agent in self.agents:
obs_n.append(self._get_obs(agent))
return obs_n
# get info used for benchmarking
def _get_info(self, agent):
if self.info_callback is None:
return {}
return self.info_callback(agent, self.world)
# get observation for a particular agent
def _get_obs(self, agent):
if isinstance(agent, int):
agent = self.agents[agent]
if self.observation_callback is None:
print("Unavailable:", np.zeros(0))
return np.zeros(0)
return self.observation_callback(agent, self.world)
# get dones for a particular agent
# unused right now -- agents are allowed to go beyond the viewing screen
def _get_done(self, agent):
if isinstance(agent, int):
agent = self.agents[agent]
if self.done_callback is None:
if self.current_step >= self.world_length:
return True
else:
return False
return self.done_callback(agent, self.world)
# get reward for a particular agent
def _get_reward(self, agent):
if self.reward_callback is None:
return 0.0
return self.reward_callback(agent, self.world)
# set env action for a particular agent
def _set_action(self, action, agent, action_space, time=None):
agent.action.u = np.zeros(self.world.dim_p)
agent.action.c = np.zeros(self.world.dim_c)
# process action
if isinstance(action_space, MultiDiscrete):
act = []
size = action_space.high - action_space.low + 1
index = 0
for s in size:
act.append(action[index:(index + s)])
index += s
action = act
else:
action = [action]
if agent.movable:
# physical action
if self.discrete_action_input:
agent.action.u = np.zeros(self.world.dim_p)
# process discrete action
if action[0] == 1:
agent.action.u[0] = -1.0
if action[0] == 2:
agent.action.u[0] = +1.0
if action[0] == 3:
agent.action.u[1] = -1.0
if action[0] == 4:
agent.action.u[1] = +1.0
d = self.world.dim_p
else:
if self.discrete_action_space:
agent.action.u[0] += action[0][1] - action[0][2]
agent.action.u[1] += action[0][3] - action[0][4]
d = 5
else:
if self.force_discrete_action:
p = np.argmax(action[0][0:self.world.dim_p])
action[0][:] = 0.0
action[0][p] = 1.0
agent.action.u = action[0][0:self.world.dim_p]
d = self.world.dim_p
sensitivity = 5.0
if agent.accel is not None:
sensitivity = agent.accel
agent.action.u *= sensitivity
if (not agent.silent) and (not isinstance(action_space, MultiDiscrete)):
action[0] = action[0][d:]
else:
action = action[1:]
if not agent.silent:
# communication action
if self.discrete_action_input:
agent.action.c = np.zeros(self.world.dim_c)
agent.action.c[action[0]] = 1.0
else:
agent.action.c = action[0]
action = action[1:]
# make sure we used all elements of action
assert len(action) == 0
def _get_avail_action(self, handle):
agent = self.agents[handle] # TODO
# reset rendering assets
def _reset_render(self):
self.render_geoms = None
self.render_geoms_xform = None
def render(self, mode='human', close=True):
if close:
# close any existic renderers
for i, viewer in enumerate(self.viewers):
if viewer is not None:
viewer.close()
self.viewers[i] = None
return []
if mode == 'human':
alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
message = ''
for agent in self.world.agents:
comm = []
for other in self.world.agents:
if other is agent:
continue
if np.all(other.state.c == 0):
word = '_'
else:
word = alphabet[np.argmax(other.state.c)]
message += (other.name + ' to ' +
agent.name + ': ' + word + ' ')
print(message)
for i in range(len(self.viewers)):
# create viewers (if necessary)
if self.viewers[i] is None:
# import rendering only if we need it (and don't import for headless machines)
# from gym.envs.classic_control import rendering
from . import rendering
self.viewers[i] = rendering.Viewer(700, 700)
# create rendering geometry
if self.render_geoms is None:
# import rendering only if we need it (and don't import for headless machines)
# from gym.envs.classic_control import rendering
from . import rendering
self.render_geoms = []
self.render_geoms_xform = []
self.comm_geoms = []
for entity in self.world.entities:
geom = rendering.make_circle(entity.size)
xform = rendering.Transform()
entity_comm_geoms = []
if 'agent' in entity.name:
geom.set_color(*entity.color, alpha=0.5)
if not entity.silent:
dim_c = self.world.dim_c
# make circles to represent communication
for ci in range(dim_c):
comm = rendering.make_circle(entity.size / dim_c)
comm.set_color(1, 1, 1)
comm.add_attr(xform)
offset = rendering.Transform()
comm_size = (entity.size / dim_c)
offset.set_translation(ci * comm_size * 2 -
entity.size + comm_size, 0)
comm.add_attr(offset)
entity_comm_geoms.append(comm)
else:
geom.set_color(*entity.color)
if entity.channel is not None:
dim_c = self.world.dim_c
# make circles to represent communication
for ci in range(dim_c):
comm = rendering.make_circle(entity.size / dim_c)
comm.set_color(1, 1, 1)
comm.add_attr(xform)
offset = rendering.Transform()
comm_size = (entity.size / dim_c)
offset.set_translation(ci * comm_size * 2 -
entity.size + comm_size, 0)
comm.add_attr(offset)
entity_comm_geoms.append(comm)
geom.add_attr(xform)
self.render_geoms.append(geom)
self.render_geoms_xform.append(xform)
self.comm_geoms.append(entity_comm_geoms)
for wall in self.world.walls:
corners = ((wall.axis_pos - 0.5 * wall.width, wall.endpoints[0]),
(wall.axis_pos - 0.5 *
wall.width, wall.endpoints[1]),
(wall.axis_pos + 0.5 *
wall.width, wall.endpoints[1]),
(wall.axis_pos + 0.5 * wall.width, wall.endpoints[0]))
if wall.orient == 'H':
corners = tuple(c[::-1] for c in corners)
geom = rendering.make_polygon(corners)
if wall.hard:
geom.set_color(*wall.color)
else:
geom.set_color(*wall.color, alpha=0.5)
self.render_geoms.append(geom)
# add geoms to viewer
# for viewer in self.viewers:
# viewer.geoms = []
# for geom in self.render_geoms:
# viewer.add_geom(geom)
for viewer in self.viewers:
viewer.geoms = []
for geom in self.render_geoms:
viewer.add_geom(geom)
for entity_comm_geoms in self.comm_geoms:
for geom in entity_comm_geoms:
viewer.add_geom(geom)
results = []
for i in range(len(self.viewers)):
from . import rendering
if self.shared_viewer:
pos = np.zeros(self.world.dim_p)
else:
pos = self.agents[i].state.p_pos
self.viewers[i].set_bounds(
pos[0] - cam_range, pos[0] + cam_range, pos[1] - cam_range, pos[1] + cam_range)
# update geometry positions
for e, entity in enumerate(self.world.entities):
self.render_geoms_xform[e].set_translation(*entity.state.p_pos)
if 'agent' in entity.name:
self.render_geoms[e].set_color(*entity.color, alpha=0.5)
if not entity.silent:
for ci in range(self.world.dim_c):
color = 1 - entity.state.c[ci]
self.comm_geoms[e][ci].set_color(
color, color, color)
else:
self.render_geoms[e].set_color(*entity.color)
if entity.channel is not None:
for ci in range(self.world.dim_c):
color = 1 - entity.channel[ci]
self.comm_geoms[e][ci].set_color(
color, color, color)
# render to display or array
results.append(self.viewers[i].render(
return_rgb_array=mode == 'rgb_array'))
return results
# create receptor field locations in local coordinate frame
def _make_receptor_locations(self, agent):
receptor_type = 'polar'
range_min = 0.05 * 2.0
range_max = 1.00
dx = []
# circular receptive field
if receptor_type == 'polar':
for angle in np.linspace(-np.pi, +np.pi, 8, endpoint=False):
for distance in np.linspace(range_min, range_max, 3):
dx.append(
distance * np.array([np.cos(angle), np.sin(angle)]))
# add origin
dx.append(np.array([0.0, 0.0]))
# grid receptive field
if receptor_type == 'grid':
for x in np.linspace(-range_max, +range_max, 5):
for y in np.linspace(-range_max, +range_max, 5):
dx.append(np.array([x, y]))
return dx
================================================
FILE: examples/Social_Cognition/ToCM/mpe/multi_discrete.py
================================================
# An old version of OpenAI Gym's multi_discrete.py. (Was getting affected by Gym updates)
# (https://github.com/openai/gym/blob/1fb81d4e3fb780ccf77fec731287ba07da35eb84/gym/spaces/multi_discrete.py)
import numpy as np
import gym
class MultiDiscrete(gym.Space):
"""
- The multi-discrete action space consists of a series of discrete action spaces with different parameters
- It can be adapted to both a Discrete action space or a continuous (Box) action space
- It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space
- It is parametrized by passing an array of arrays containing [min, max] for each discrete action space
where the discrete action space can take any integers from `min` to `max` (both inclusive)
Note: A value of 0 always need to represent the NOOP action.
e.g. Nintendo Game Controller
- Can be conceptualized as 3 discrete action spaces:
1) Arrow Keys: Discrete 5 - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4] - params: min: 0, max: 4
2) Button A: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1
3) Button B: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1
- Can be initialized as
MultiDiscrete([ [0,4], [0,1], [0,1] ])
"""
def __init__(self, array_of_param_array):
self.low = np.array([x[0] for x in array_of_param_array])
self.high = np.array([x[1] for x in array_of_param_array])
self.num_discrete_space = self.low.shape[0]
def sample(self):
""" Returns a array with one sample from each discrete action space """
# For each row: round(random .* (max - min) + min, 0)
#random_array = prng.np_random.rand(self.num_discrete_space)
random_array = np.random.rand(self.num_discrete_space)
return [int(x) for x in np.floor(np.multiply((self.high - self.low + 1.), random_array) + self.low)]
def contains(self, x):
return len(x) == self.num_discrete_space and (np.array(x) >= self.low).all() and (np.array(x) <= self.high).all()
@property
def shape(self):
return self.num_discrete_space
def __repr__(self):
return "MultiDiscrete" + str(self.num_discrete_space)
def __eq__(self, other):
return np.array_equal(self.low, other.low) and np.array_equal(self.high, other.high)
================================================
FILE: examples/Social_Cognition/ToCM/mpe/rendering.py
================================================
"""
2D rendering framework
"""
from __future__ import division
import os
import six # TODO
import sys
if "Apple" in sys.version:
if 'DYLD_FALLBACK_LIBRARY_PATH' in os.environ:
os.environ['DYLD_FALLBACK_LIBRARY_PATH'] += ':/usr/lib'
# (JDS 2016/04/15): avoid bug on Anaconda 2.3.0 / Yosemite
from gym.utils import reraise
from gym import error
try:
import pyglet
except ImportError as e:
reraise(
suffix="HINT: you can install pyglet directly via 'pip install pyglet'. But if you really just want to install all Gym dependencies and not have to think about it, 'pip install -e .[all]' or 'pip install gym[all]' will do it.")
try:
from pyglet.gl import *
except ImportError as e:
reraise(prefix="Error occured while running `from pyglet.gl import *`",
suffix="HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get install python-opengl'. If you're running on a server, you may need a virtual frame buffer; something like this should work: 'xvfb-run -s \"-screen 0 1400x900x24\" python '")
import math
import numpy as np
RAD2DEG = 57.29577951308232
def get_display(spec):
"""Convert a display specification (such as :0) into an actual Display
object.
Pyglet only supports multiple Displays on Linux.
"""
if spec is None:
return None
elif isinstance(spec, six.string_types):
return pyglet.canvas.Display(spec)
else:
raise error.Error(
'Invalid display specification: {}. (Must be a string like :0 or None.)'.format(spec))
class Viewer(object):
def __init__(self, width, height, display=None):
display = get_display(display)
self.width = width
self.height = height
self.window = pyglet.window.Window(
width=width, height=height, display=display)
self.window.on_close = self.window_closed_by_user
self.geoms = []
self.onetime_geoms = []
self.transform = Transform()
glEnable(GL_BLEND)
# glEnable(GL_MULTISAMPLE)
glEnable(GL_LINE_SMOOTH)
# glHint(GL_LINE_SMOOTH_HINT, GL_DONT_CARE)
glHint(GL_LINE_SMOOTH_HINT, GL_NICEST)
glLineWidth(2.0)
glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)
def close(self):
self.window.close()
def window_closed_by_user(self):
self.close()
def set_bounds(self, left, right, bottom, top):
assert right > left and top > bottom
scalex = self.width/(right-left)
scaley = self.height/(top-bottom)
self.transform = Transform(
translation=(-left*scalex, -bottom*scaley),
scale=(scalex, scaley))
def add_geom(self, geom):
self.geoms.append(geom)
def add_onetime(self, geom):
self.onetime_geoms.append(geom)
def render(self, return_rgb_array=False):
glClearColor(1, 1, 1, 1)
self.window.clear()
self.window.switch_to()
self.window.dispatch_events()
self.transform.enable()
for geom in self.geoms:
geom.render()
for geom in self.onetime_geoms:
geom.render()
self.transform.disable()
arr = None
if return_rgb_array:
buffer = pyglet.image.get_buffer_manager().get_color_buffer()
image_data = buffer.get_image_data()
arr = np.fromstring(image_data.data, dtype=np.uint8, sep='')
# In https://github.com/openai/gym-http-api/issues/2, we
# discovered that someone using Xmonad on Arch was having
# a window of size 598 x 398, though a 600 x 400 window
# was requested. (Guess Xmonad was preserving a pixel for
# the boundary.) So we use the buffer height/width rather
# than the requested one.
arr = arr.reshape(buffer.height, buffer.width, 4)
arr = arr[::-1, :, 0:3]
self.window.flip()
self.onetime_geoms = []
return arr
# Convenience
def draw_circle(self, radius=10, res=30, filled=True, **attrs):
geom = make_circle(radius=radius, res=res, filled=filled)
_add_attrs(geom, attrs)
self.add_onetime(geom)
return geom
def draw_polygon(self, v, filled=True, **attrs):
geom = make_polygon(v=v, filled=filled)
_add_attrs(geom, attrs)
self.add_onetime(geom)
return geom
def draw_polyline(self, v, **attrs):
geom = make_polyline(v=v)
_add_attrs(geom, attrs)
self.add_onetime(geom)
return geom
def draw_line(self, start, end, **attrs):
geom = Line(start, end)
_add_attrs(geom, attrs)
self.add_onetime(geom)
return geom
def get_array(self):
self.window.flip()
image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
self.window.flip()
arr = np.fromstring(image_data.data, dtype=np.uint8, sep='')
arr = arr.reshape(self.height, self.width, 4)
return arr[::-1, :, 0:3]
def _add_attrs(geom, attrs):
if "color" in attrs:
geom.set_color(*attrs["color"])
if "linewidth" in attrs:
geom.set_linewidth(attrs["linewidth"])
class Geom(object):
def __init__(self):
self._color = Color((0, 0, 0, 1.0))
self.attrs = [self._color]
def render(self):
for attr in reversed(self.attrs):
attr.enable()
self.render1()
for attr in self.attrs:
attr.disable()
def render1(self):
raise NotImplementedError
def add_attr(self, attr):
self.attrs.append(attr)
def set_color(self, r, g, b, alpha=1):
self._color.vec4 = (r, g, b, alpha)
class Attr(object):
def enable(self):
raise NotImplementedError
def disable(self):
pass
class Transform(Attr):
def __init__(self, translation=(0.0, 0.0), rotation=0.0, scale=(1, 1)):
self.set_translation(*translation)
self.set_rotation(rotation)
self.set_scale(*scale)
def enable(self):
glPushMatrix()
# translate to GL loc ppint
glTranslatef(self.translation[0], self.translation[1], 0)
glRotatef(RAD2DEG * self.rotation, 0, 0, 1.0)
glScalef(self.scale[0], self.scale[1], 1)
def disable(self):
glPopMatrix()
def set_translation(self, newx, newy):
self.translation = (float(newx), float(newy))
def set_rotation(self, new):
self.rotation = float(new)
def set_scale(self, newx, newy):
self.scale = (float(newx), float(newy))
class Color(Attr):
def __init__(self, vec4):
self.vec4 = vec4
def enable(self):
glColor4f(*self.vec4)
class LineStyle(Attr):
def __init__(self, style):
self.style = style
def enable(self):
glEnable(GL_LINE_STIPPLE)
glLineStipple(1, self.style)
def disable(self):
glDisable(GL_LINE_STIPPLE)
class LineWidth(Attr):
def __init__(self, stroke):
self.stroke = stroke
def enable(self):
glLineWidth(self.stroke)
class Point(Geom):
def __init__(self):
Geom.__init__(self)
def render1(self):
glBegin(GL_POINTS) # draw point
glVertex3f(0.0, 0.0, 0.0)
glEnd()
class FilledPolygon(Geom):
def __init__(self, v):
Geom.__init__(self)
self.v = v
def render1(self):
if len(self.v) == 4:
glBegin(GL_QUADS)
elif len(self.v) > 4:
glBegin(GL_POLYGON)
else:
glBegin(GL_TRIANGLES)
for p in self.v:
glVertex3f(p[0], p[1], 0) # draw each vertex
glEnd()
color = (self._color.vec4[0] * 0.5, self._color.vec4[1] *
0.5, self._color.vec4[2] * 0.5, self._color.vec4[3] * 0.5)
glColor4f(*color)
glBegin(GL_LINE_LOOP)
for p in self.v:
glVertex3f(p[0], p[1], 0) # draw each vertex
glEnd()
def make_circle(radius=10, res=30, filled=True):
points = []
for i in range(res):
ang = 2*math.pi*i / res
points.append((math.cos(ang)*radius, math.sin(ang)*radius))
if filled:
return FilledPolygon(points)
else:
return PolyLine(points, True)
def make_polygon(v, filled=True):
if filled:
return FilledPolygon(v)
else:
return PolyLine(v, True)
def make_polyline(v):
return PolyLine(v, False)
def make_capsule(length, width):
l, r, t, b = 0, length, width/2, -width/2
box = make_polygon([(l, b), (l, t), (r, t), (r, b)])
circ0 = make_circle(width/2)
circ1 = make_circle(width/2)
circ1.add_attr(Transform(translation=(length, 0)))
geom = Compound([box, circ0, circ1])
return geom
class Compound(Geom):
def __init__(self, gs):
Geom.__init__(self)
self.gs = gs
for g in self.gs:
g.attrs = [a for a in g.attrs if not isinstance(a, Color)]
def render1(self):
for g in self.gs:
g.render()
class PolyLine(Geom):
def __init__(self, v, close):
Geom.__init__(self)
self.v = v
self.close = close
self.linewidth = LineWidth(1)
self.add_attr(self.linewidth)
def render1(self):
glBegin(GL_LINE_LOOP if self.close else GL_LINE_STRIP)
for p in self.v:
glVertex3f(p[0], p[1], 0) # draw each vertex
glEnd()
def set_linewidth(self, x):
self.linewidth.stroke = x
class Line(Geom):
def __init__(self, start=(0.0, 0.0), end=(0.0, 0.0)):
Geom.__init__(self)
self.start = start
self.end = end
self.linewidth = LineWidth(1)
self.add_attr(self.linewidth)
def render1(self):
glBegin(GL_LINES)
glVertex2f(*self.start)
glVertex2f(*self.end)
glEnd()
class Image(Geom):
def __init__(self, fname, width, height):
Geom.__init__(self)
self.width = width
self.height = height
img = pyglet.image.load(fname)
self.img = img
self.flip = False
def render1(self):
self.img.blit(-self.width/2, -self.height/2,
width=self.width, height=self.height)
# ================================================================
class SimpleImageViewer(object):
def __init__(self, display=None):
self.window = None
self.isopen = False
self.display = display
def imshow(self, arr):
if self.window is None:
height, width, channels = arr.shape
self.window = pyglet.window.Window(
width=width, height=height, display=self.display)
self.width = width
self.height = height
self.isopen = True
assert arr.shape == (
self.height, self.width, 3), "You passed in an image with the wrong number shape"
image = pyglet.image.ImageData(
self.width, self.height, 'RGB', arr.tobytes(), pitch=self.width * -3)
self.window.clear()
self.window.switch_to()
self.window.dispatch_events()
image.blit(0, 0)
self.window.flip()
def close(self):
if self.isopen:
self.window.close()
self.isopen = False
def __del__(self):
self.close()
================================================
FILE: examples/Social_Cognition/ToCM/mpe/scenario.py
================================================
import numpy as np
# defines scenario upon which the world is built
class BaseScenario(object):
# create elements of the world
def make_world(self):
raise NotImplementedError()
# create initial conditions of the world
def reset_world(self, world):
raise NotImplementedError()
================================================
FILE: examples/Social_Cognition/ToCM/mpe/scenarios/__init__.py
================================================
import imp
import os.path as osp
def load(name):
pathname = osp.join(osp.dirname(__file__), name)
return imp.load_source('', pathname)
================================================
FILE: examples/Social_Cognition/ToCM/mpe/scenarios/hetero_spread.py
================================================
import numpy as np
from mpe.core import World, Agent, Landmark
from mpe.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self, args):
world = World()
# set any world properties first
world.dim_c = 2
world.max_steps = 25
num_agents = args.num_agents
self.n_agent_a = num_agents // 2 # 2
self.n_agent_b = num_agents // 2 # 2
num_landmarks = args.num_agents
world.collaborative = True
self.agent_size = 0.10
self.n_others = 3
self.n_group = 2
# add agents
world.agents = [Agent() for i in range(num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = True
agent.silent = True
agent.id = i
if i < self.n_agent_a:
agent.size = self.agent_size
agent.accel = 3.0
agent.max_speed = 1.0
else:
agent.size = self.agent_size / 2
agent.accel = 4.0
agent.max_speed = 1.3
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
world.num_steps = 0
self.end_steps = world.max_steps
# random properties for agents
for i, agent in enumerate(world.agents):
if i < self.n_agent_a:
agent.color = np.array([0.35, 0.35, 0.85])
else:
agent.color = np.array([0.35, 0.85, 0.35])
# random properties for landmarks
for i, landmark in enumerate(world.landmarks):
landmark.color = np.array([0.25, 0.25, 0.25])
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
rew = 0
collisions = 0
occupied_landmarks = 0
min_dists = 0
for l in world.landmarks:
dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for a in world.agents]
min_dists += min(dists)
rew -= min(dists)
if min(dists) < 0.1:
occupied_landmarks += 1
if agent.collide:
for a in world.agents:
if self.is_collision(a, agent):
rew -= 1
collisions += 1
return (rew, collisions, min_dists, occupied_landmarks)
def is_collision(self, agent1, agent2):
delta_pos = agent1.state.p_pos - agent2.state.p_pos
dist = np.sqrt(np.sum(np.square(delta_pos)))
dist_min = agent1.size + agent2.size
return True if dist < dist_min else False
def reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark, penalized for collisions
rew = 0
shaped_reward = False
if shaped_reward: # distance-based reward
for l in world.landmarks:
dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for a in world.agents]
rew -= min(dists)
if agent.collide:
for a in world.agents:
if self.is_collision(a, agent):
rew -= 1
return rew
else:
win_agents = []
for land in world.landmarks:
for a in world.agents:
if self.is_collision(a, land):
win_agents.append(a)
break
rew += 2 * len(set(win_agents))
def bound(x):
if x > 1.0:
return min(np.exp(2 * x - 2), 10)
else:
return 0.0
bound_rew = 0.0
for p in range(world.dim_p):
x = abs(agent.state.p_pos[p])
bound_rew -= bound(x)
rew += bound_rew
return rew
def observation(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks: # world.entities:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# entity colors
entity_color = []
other_vel = []
for entity in world.landmarks: # world.entities:
entity_color.append(entity.color)
# communication of all other agents
comm = []
other_pos = []
for other in world.agents:
if other is agent:
other_vel.append([0, 0])
continue
comm.append(other.state.c)
other_pos.append(other.state.p_pos - agent.state.p_pos)
return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + comm +
other_vel + entity_pos + other_pos)
================================================
FILE: examples/Social_Cognition/ToCM/mpe/scenarios/simple_adversary.py
================================================
import numpy as np
from mpe.core import World, Agent, Landmark
from mpe.scenario import BaseScenario # TODO
import random
class Scenario(BaseScenario):
def make_world(self, args):
world = World() # from core
# set any world properties first
world.dim_c = 2
num_agents = args.num_agents # 3
world.num_agents = num_agents
num_adversaries = 1
num_landmarks = num_agents - 1
# add agents
world.agents = [Agent() for i in range(num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = False
agent.silent = True
agent.adversary = True if i < num_adversaries else False
agent.size = 0.15
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
landmark.size = 0.08
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# random properties for agents
world.assign_agent_colors()
# random properties for landmarks
world.assign_landmark_colors()
# set goal landmark
goal = np.random.choice(world.landmarks)
goal.color = np.array([0.15, 0.65, 0.15])
for agent in world.agents:
agent.goal_a = goal
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
# returns data for benchmarking purposes
if agent.adversary:
return np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos))
else:
dists = []
for l in world.landmarks:
dists.append(np.sum(np.square(agent.state.p_pos - l.state.p_pos)))
dists.append(np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos)))
return tuple(dists)
# return all agents that are not adversaries
def good_agents(self, world):
return [agent for agent in world.agents if not agent.adversary]
# return all adversarial agents
def adversaries(self, world):
return [agent for agent in world.agents if agent.adversary]
def reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark
return self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)
def agent_reward(self, agent, world):
# Rewarded based on how close any good agent is to the goal landmark, and how far the adversary is from it
shaped_reward = True
shaped_adv_reward = True
# Calculate negative reward for adversary
adversary_agents = self.adversaries(world)
if shaped_adv_reward: # distance-based adversary reward
adv_rew = sum([np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in adversary_agents])
else: # proximity-based adversary reward (binary)
adv_rew = 0
for a in adversary_agents:
if np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) < 2 * a.goal_a.size:
adv_rew -= 5
# Calculate positive reward for agents
good_agents = self.good_agents(world)
if shaped_reward: # distance-based agent reward
pos_rew = -min(
[np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in good_agents])
else: # proximity-based agent reward (binary)
pos_rew = 0
if min([np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in good_agents]) \
< 2 * agent.goal_a.size:
pos_rew += 5
pos_rew -= min(
[np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in good_agents])
return pos_rew + adv_rew
def adversary_reward(self, agent, world):
# Rewarded based on proximity to the goal landmark
shaped_reward = True
if shaped_reward: # distance-based reward
return -np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos))
else: # proximity-based reward (binary)
adv_rew = 0
if np.sqrt(np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos))) < 2 * agent.goal_a.size:
adv_rew += 5
return adv_rew
def observation(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# entity colors
entity_color = []
for entity in world.landmarks:
entity_color.append(entity.color)
# communication of all other agents
other_pos = []
for other in world.agents:
if other is agent: continue
other_pos.append(other.state.p_pos - agent.state.p_pos)
if not agent.adversary:
return np.concatenate([agent.goal_a.state.p_pos - agent.state.p_pos] + entity_pos + other_pos)
else:
return np.concatenate(entity_pos + other_pos)
================================================
FILE: examples/Social_Cognition/ToCM/mpe/scenarios/simple_crypto.py
================================================
"""
Scenario:
1 speaker, 2 listeners (one of which is an adversary). Good agents rewarded for proximity to goal, and distance from
adversary to goal. Adversary is rewarded for its distance to the goal.
"""
import numpy as np
from mpe.core import World, Agent, Landmark
from mpe.scenario import BaseScenario
import random
class CryptoAgent(Agent):
def __init__(self):
super(CryptoAgent, self).__init__()
self.key = None
class Scenario(BaseScenario):
def make_world(self, args):
world = World()
# set any world properties first
num_agents = args.num_agents # 3
num_adversaries = 1
num_landmarks = args.num_landmarks # 2
world.dim_c = 4
# add agents
world.agents = [CryptoAgent() for i in range(num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = False
agent.adversary = True if i < num_adversaries else False
agent.speaker = True if i == 2 else False
agent.movable = False
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# random properties for agents
for agent in world.agents:
agent.color = np.array([0.25, 0.25, 0.25])
if agent.adversary:
agent.color = np.array([0.75, 0.25, 0.25])
agent.key = None
# random properties for landmarks
color_list = [np.zeros(world.dim_c) for i in world.landmarks]
for i, color in enumerate(color_list):
color[i] += 1
for color, landmark in zip(color_list, world.landmarks):
landmark.color = color
# set goal landmark
goal = np.random.choice(world.landmarks)
world.agents[1].color = goal.color
world.agents[2].key = np.random.choice(world.landmarks).color
for agent in world.agents:
agent.goal_a = goal
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
# returns data for benchmarking purposes
return (agent.state.c, agent.goal_a.color)
# return all agents that are not adversaries
def good_listeners(self, world):
return [agent for agent in world.agents if not agent.adversary and not agent.speaker]
# return all agents that are not adversaries
def good_agents(self, world):
return [agent for agent in world.agents if not agent.adversary]
# return all adversarial agents
def adversaries(self, world):
return [agent for agent in world.agents if agent.adversary]
def reward(self, agent, world):
return self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)
def agent_reward(self, agent, world):
# Agents rewarded if Bob can reconstruct message, but adversary (Eve) cannot
good_listeners = self.good_listeners(world)
adversaries = self.adversaries(world)
good_rew = 0
adv_rew = 0
for a in good_listeners:
if (a.state.c == np.zeros(world.dim_c)).all():
continue
else:
good_rew -= np.sum(np.square(a.state.c - agent.goal_a.color))
for a in adversaries:
if (a.state.c == np.zeros(world.dim_c)).all():
continue
else:
adv_l1 = np.sum(np.square(a.state.c - agent.goal_a.color))
adv_rew += adv_l1
return adv_rew + good_rew
def adversary_reward(self, agent, world):
# Adversary (Eve) is rewarded if it can reconstruct original goal
rew = 0
if not (agent.state.c == np.zeros(world.dim_c)).all():
rew -= np.sum(np.square(agent.state.c - agent.goal_a.color))
return rew
def observation(self, agent, world):
# goal color
goal_color = np.zeros(world.dim_color)
if agent.goal_a is not None:
goal_color = agent.goal_a.color
# print('goal color in obs is {}'.format(goal_color))
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# communication of all other agents
comm = []
for other in world.agents:
if other is agent or (other.state.c is None) or not other.speaker: continue
comm.append(other.state.c)
confer = np.array([0])
if world.agents[2].key is None:
confer = np.array([1])
key = np.zeros(world.dim_c)
goal_color = np.zeros(world.dim_c)
else:
key = world.agents[2].key
prnt = False # True if train use False
# speaker
if agent.speaker:
if prnt:
print('speaker')
print(agent.state.c)
print(np.concatenate([goal_color] + [key] + [confer] + [np.random.randn(1)]))
return np.concatenate([goal_color] + [key])
# listener
if not agent.speaker and not agent.adversary:
if prnt:
print('listener')
print(agent.state.c)
print(np.concatenate([key] + comm + [confer]))
return np.concatenate([key] + comm)
if not agent.speaker and agent.adversary:
if prnt:
print('adversary')
print(agent.state.c)
print(np.concatenate(comm + [confer]))
return np.concatenate(comm)
================================================
FILE: examples/Social_Cognition/ToCM/mpe/scenarios/simple_crypto_display.py
================================================
"""
Scenario:
1 speaker, 2 listeners (one of which is an adversary). Good agents rewarded for proximity to goal, and distance from
adversary to goal. Adversary is rewarded for its distance to the goal.
"""
import numpy as np
from mpe.core import World, Agent, Landmark
from mpe.scenario import BaseScenario
import random
class CryptoAgent(Agent):
def __init__(self):
super(CryptoAgent, self).__init__()
self.key = None
class Scenario(BaseScenario):
def make_world(self, args):
world = World()
# set any world properties first
num_agents = args.num_agents # 3
num_adversaries = 1
num_landmarks = args.num_landmarks # 2
world.dim_c = 4
# add agents
world.agents = [CryptoAgent() for i in range(num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = False
agent.adversary = True if i < num_adversaries else False
agent.speaker = True if i == 2 else False
agent.movable = False
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# random properties for agents
world.assign_agent_colors()
for agent in world.agents:
if agent.speaker:
agent.color = np.array([0.25, 0.75, 0.25])
agent.key = None
# random properties for landmarks
world.assign_landmark_colors()
# random properties for landmarks
channel_list = [np.zeros(world.dim_c) for i in world.landmarks]
for i, channel in enumerate(channel_list):
channel[i] += 1
for channel, landmark in zip(channel_list, world.landmarks):
landmark.channel = channel
# set goal landmark
goal = np.random.choice(world.landmarks)
world.agents[1].channel = goal.channel
world.agents[2].key = np.random.choice(world.landmarks).channel
for agent in world.agents:
agent.goal_a = goal
# set random initial states
for i, agent in enumerate(world.agents):
# agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_pos = np.array([0.0, -0.5 + 1.0 / (len(world.agents) - 1) * i])
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
if landmark is goal:
landmark.color = np.array([0.15, 0.15, 0.75])
# landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_pos = np.array([0.5, 0.5 - 0.5 / (len(world.landmarks) - 1) * i])
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
# returns data for benchmarking purposes
return (agent.state.c, agent.goal_a.channel)
# return all agents that are not adversaries
def good_listeners(self, world):
return [agent for agent in world.agents if not agent.adversary and not agent.speaker]
# return all agents that are not adversaries
def good_agents(self, world):
return [agent for agent in world.agents if not agent.adversary]
# return all adversarial agents
def adversaries(self, world):
return [agent for agent in world.agents if agent.adversary]
def reward(self, agent, world):
return self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)
def agent_reward(self, agent, world):
# Agents rewarded if Bob can reconstruct message, but adversary (Eve) cannot
good_listeners = self.good_listeners(world)
adversaries = self.adversaries(world)
good_rew = 0
adv_rew = 0
for a in good_listeners:
if (a.state.c == np.zeros(world.dim_c)).all():
continue
else:
good_rew -= np.sum(np.square(a.state.c - agent.goal_a.channel))
for a in adversaries:
if (a.state.c == np.zeros(world.dim_c)).all():
continue
else:
adv_l1 = np.sum(np.square(a.state.c - agent.goal_a.channel))
adv_rew += adv_l1
return adv_rew + good_rew
def adversary_reward(self, agent, world):
# Adversary (Eve) is rewarded if it can reconstruct original goal
rew = 0
if not (agent.state.c == np.zeros(world.dim_c)).all():
rew -= np.sum(np.square(agent.state.c - agent.goal_a.channel))
return rew
def observation(self, agent, world):
# goal channel
goal_channel = np.zeros(world.dim_color)
if agent.goal_a is not None:
goal_channel = agent.goal_a.channel
print('goal channel in obs is {}'.format(goal_channel))
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# communication of all other agents
comm = []
for other in world.agents:
if other is agent or (other.state.c is None) or not other.speaker: continue
comm.append(other.state.c)
confer = np.array([0])
if world.agents[2].key is None:
confer = np.array([1])
key = np.zeros(world.dim_c)
goal_channel = np.zeros(world.dim_c)
else:
key = world.agents[2].key
prnt = True # if train use False
# speaker
if agent.speaker:
if prnt:
print('speaker')
print(agent.state.c)
# print(np.concatenate([goal_channel] + [key] + [confer] + [np.random.randn(1)]))
return np.concatenate([goal_channel] + [key])
# listener
if not agent.speaker and not agent.adversary:
if prnt:
print('listener')
print(agent.state.c)
# print(np.concatenate([key] + comm + [confer]))
return np.concatenate([key] + comm)
if not agent.speaker and agent.adversary:
if prnt:
print('adversary')
print(agent.state.c)
# print(np.concatenate(comm + [confer]))
return np.concatenate(comm)
================================================
FILE: examples/Social_Cognition/ToCM/mpe/scenarios/simple_push.py
================================================
import numpy as np
from mpe.core import World, Agent, Landmark
from mpe.scenario import BaseScenario
import random
#
# # the non-ensemble version of
#
#
class Scenario(BaseScenario):
def make_world(self, args):
world = World()
# set any world properties first
world.dim_c = 2
num_agents = args.num_agents # 2
num_adversaries = 1
num_landmarks = args.num_landmarks # 2
# add agents
world.agents = [Agent() for i in range(num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = True
agent.silent = True
if i < num_adversaries:
agent.adversary = True
else:
agent.adversary = False
# agent.u_noise = 1e-1
# agent.c_noise = 1e-1
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# random properties for landmarks
for i, landmark in enumerate(world.landmarks):
landmark.color = np.array([0.1, 0.1, 0.1])
landmark.color[i + 1] += 0.8
landmark.index = i
# set goal landmark
goal = np.random.choice(world.landmarks)
for i, agent in enumerate(world.agents):
agent.goal_a = goal
agent.color = np.array([0.25, 0.25, 0.25])
if agent.adversary:
agent.color = np.array([0.75, 0.25, 0.25])
else:
j = goal.index
agent.color[j + 1] += 0.5
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = 0.8 * np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark
return self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)
def agent_reward(self, agent, world):
# the distance to the goal
return -np.sqrt(np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos)))
def adversary_reward(self, agent, world):
# keep the nearest good agents away from the goal
agent_dist = [np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in world.agents if
not a.adversary]
pos_rew = min(agent_dist)
# nearest_agent = world.good_agents[np.argmin(agent_dist)]
# neg_rew = np.sqrt(np.sum(np.square(nearest_agent.state.p_pos - agent.state.p_pos)))
neg_rew = np.sqrt(np.sum(np.square(agent.goal_a.state.p_pos - agent.state.p_pos)))
# neg_rew = sum([np.sqrt(np.sum(np.square(a.state.p_pos - agent.state.p_pos))) for a in world.good_agents])
return pos_rew - neg_rew
def observation(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks: # world.entities:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# entity colors
entity_color = []
for entity in world.landmarks: # world.entities:
entity_color.append(entity.color)
# communication of all other agents
comm = []
other_pos = []
for other in world.agents:
if other is agent: continue
comm.append(other.state.c)
other_pos.append(other.state.p_pos - agent.state.p_pos)
if not agent.adversary:
return np.concatenate([agent.state.p_vel] + [agent.goal_a.state.p_pos - agent.state.p_pos] + [
agent.color] + entity_pos + entity_color + other_pos)
else:
# other_pos = list(reversed(other_pos)) if random.uniform(0,1) > 0.5 else other_pos # randomize position of other agents in adversary network
return np.concatenate([agent.state.p_vel] + entity_pos + other_pos)
================================================
FILE: examples/Social_Cognition/ToCM/mpe/scenarios/simple_reference.py
================================================
import numpy as np
from mpe.core import World, Agent, Landmark
from mpe.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self, args):
world = World()
# set any world properties first
# world.world_length = args.episode_length
world.dim_c = 10
world.collaborative = True # whether agents share rewards
# add agents
world.num_agents = args.num_agents # 2
assert world.num_agents == 2, (
"only 2 agents is supported, check the config.py.")
world.agents = [Agent() for i in range(world.num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = False
# agent.u_noise = 1e-1
# agent.c_noise = 1e-1
# add landmarks
world.num_landmarks = args.num_landmarks # 3
world.landmarks = [Landmark() for i in range(world.num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# assign goals to agents
for agent in world.agents:
agent.goal_a = None
agent.goal_b = None
# want other agent to go to the goal landmark
world.agents[0].goal_a = world.agents[1]
world.agents[0].goal_b = np.random.choice(world.landmarks)
world.agents[1].goal_a = world.agents[0]
world.agents[1].goal_b = np.random.choice(world.landmarks)
# random properties for agents
world.assign_agent_colors()
# random properties for landmarks
world.landmarks[0].color = np.array([0.75, 0.25, 0.25])
world.landmarks[1].color = np.array([0.25, 0.75, 0.25])
world.landmarks[2].color = np.array([0.25, 0.25, 0.75])
# special colors for goals
world.agents[0].goal_a.color = world.agents[0].goal_b.color
world.agents[1].goal_a.color = world.agents[1].goal_b.color
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = 0.8 * np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def reward(self, agent, world):
if agent.goal_a is None or agent.goal_b is None:
return 0.0
dist2 = np.sum(
np.square(agent.goal_a.state.p_pos - agent.goal_b.state.p_pos))
return -dist2 # np.exp(-dist2)
def observation(self, agent, world):
# goal positions
# goal_pos = [np.zeros(world.dim_p), np.zeros(world.dim_p)]
# if agent.goal_a is not None:
# goal_pos[0] = agent.goal_a.state.p_pos - agent.state.p_pos
# if agent.goal_b is not None:
# goal_pos[1] = agent.goal_b.state.p_pos - agent.state.p_pos
# goal color
goal_color = [np.zeros(world.dim_color), np.zeros(world.dim_color)]
# if agent.goal_a is not None:
# goal_color[0] = agent.goal_a.color
if agent.goal_b is not None:
goal_color[1] = agent.goal_b.color
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks: # world.entities:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# entity colors
entity_color = []
for entity in world.landmarks: # world.entities:
entity_color.append(entity.color)
# communication of all other agents
comm = []
for other in world.agents:
if other is agent:
continue
comm.append(other.state.c)
return np.concatenate([agent.state.p_vel] + entity_pos + [goal_color[1]] + comm)
================================================
FILE: examples/Social_Cognition/ToCM/mpe/scenarios/simple_speaker_listener.py
================================================
import numpy as np
from mpe.core import World, Agent, Landmark
from mpe.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self, args):
world = World()
world.world_length = args.episode_length
# set any world properties first
world.dim_c = 3
world.num_landmarks = args.num_landmarks # 3
world.collaborative = True
# add agents
world.num_agents = args.num_agents # 2
assert world.num_agents == 2, (
"only 2 agents is supported, check the config.py.")
world.agents = [Agent() for i in range(world.num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = False
agent.size = 0.075
# speaker
world.agents[0].movable = False
# listener
world.agents[1].silent = True
# add landmarks
world.landmarks = [Landmark() for i in range(world.num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
landmark.size = 0.04
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# assign goals to agents
for agent in world.agents:
agent.goal_a = None
agent.goal_b = None
# want listener to go to the goal landmark
world.agents[0].goal_a = world.agents[1]
world.agents[0].goal_b = np.random.choice(world.landmarks)
# random properties for agents
for i, agent in enumerate(world.agents):
agent.color = np.array([0.25, 0.25, 0.25])
# random properties for landmarks
world.landmarks[0].color = np.array([0.65, 0.15, 0.15])
world.landmarks[1].color = np.array([0.15, 0.65, 0.15])
world.landmarks[2].color = np.array([0.15, 0.15, 0.65])
# special colors for goals
world.agents[0].goal_a.color = world.agents[0].goal_b.color + \
np.array([0.45, 0.45, 0.45])
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
# returns data for benchmarking purposes
return reward(agent, reward)
def reward(self, agent, world):
# squared distance from listener to landmark
a = world.agents[0]
dist2 = np.sum(np.square(a.goal_a.state.p_pos - a.goal_b.state.p_pos))
return -dist2
def observation(self, agent, world):
# goal color
goal_color = np.zeros(world.dim_color)
if agent.goal_b is not None:
goal_color = agent.goal_b.color
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# communication of all other agents
comm = []
for other in world.agents:
if other is agent or (other.state.c is None):
continue
comm.append(other.state.c)
# speaker
if not agent.movable:
return np.concatenate([goal_color])
# listener
if agent.silent:
return np.concatenate([agent.state.p_vel] + entity_pos + comm)
================================================
FILE: examples/Social_Cognition/ToCM/mpe/scenarios/simple_spread.py
================================================
import numpy as np
from mpe.core import World, Agent, Landmark
from mpe.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self, args):
world = World()
# world.world_length = args.episode_length
# set any world properties first
world.dim_c = 2
world.num_agents = args.num_agents
world.num_landmarks = args.num_landmarks # 3
world.collaborative = True
# add agents
world.agents = [Agent() for i in range(world.num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = True
agent.silent = True
agent.size = 0.15
# add landmarks
world.landmarks = [Landmark() for i in range(world.num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = False
landmark.movable = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# random properties for agents
world.assign_agent_colors()
world.assign_landmark_colors()
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = 0.8 * np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
rew = 0
collisions = 0
occupied_landmarks = 0
min_dists = 0
for l in world.landmarks:
dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos)))
for a in world.agents]
min_dists += min(dists)
rew -= min(dists)
if min(dists) < 0.1:
occupied_landmarks += 1
if agent.collide:
for a in world.agents:
if self.is_collision(a, agent):
rew -= 1
collisions += 1
return (rew, collisions, min_dists, occupied_landmarks)
def is_collision(self, agent1, agent2):
delta_pos = agent1.state.p_pos - agent2.state.p_pos
dist = np.sqrt(np.sum(np.square(delta_pos)))
dist_min = agent1.size + agent2.size
return True if dist < dist_min else False
def reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark, penalized for collisions
rew = 0
for l in world.landmarks:
dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos)))
for a in world.agents]
rew -= min(dists)
if agent.collide:
for a in world.agents:
if self.is_collision(a, agent):
rew -= 1
return rew
def observation(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks: # world.entities:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# entity colors
entity_color = []
for entity in world.landmarks: # world.entities:
entity_color.append(entity.color)
# communication of all other agents
comm = []
other_pos = []
for other in world.agents:
if other is agent:
continue
comm.append(other.state.c)
other_pos.append(other.state.p_pos - agent.state.p_pos)
return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + comm)
================================================
FILE: examples/Social_Cognition/ToCM/mpe/scenarios/simple_tag.py
================================================
import numpy as np
from mpe.core import World, Agent, Landmark
from mpe.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self, args):
world = World()
# set any world properties first
world.dim_c = 2
num_good_agents = args.num_good_agents # 1
num_adversaries = args.num_adversaries # 3
num_agents = num_adversaries + num_good_agents
num_landmarks = args.num_landmarks # 2
# add agents
world.agents = [Agent() for i in range(num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = True
agent.silent = True
agent.adversary = True if i < num_adversaries else False
agent.size = 0.075 if agent.adversary else 0.05
agent.accel = 3.0 if agent.adversary else 4.0
# agent.accel = 20.0 if agent.adversary else 25.0
agent.max_speed = 1.0 if agent.adversary else 1.3
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = True
landmark.movable = False
landmark.size = 0.2
landmark.boundary = False
# make initial conditions
self.reset_world(world)
return world
def reset_world(self, world):
# random properties for agents
world.assign_agent_colors()
# random properties for landmarks
world.assign_landmark_colors()
# random properties for landmarks
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
if not landmark.boundary:
landmark.state.p_pos = 0.8 * np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
# returns data for benchmarking purposes
if agent.adversary:
collisions = 0
for a in self.good_agents(world):
if self.is_collision(a, agent):
collisions += 1
return collisions
else:
return 0
def is_collision(self, agent1, agent2):
delta_pos = agent1.state.p_pos - agent2.state.p_pos
dist = np.sqrt(np.sum(np.square(delta_pos)))
dist_min = agent1.size + agent2.size
return True if dist < dist_min else False
# return all agents that are not adversaries
def good_agents(self, world):
return [agent for agent in world.agents if not agent.adversary]
# return all adversarial agents
def adversaries(self, world):
return [agent for agent in world.agents if agent.adversary]
def reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark
main_reward = self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)
return main_reward
def agent_reward(self, agent, world):
# Agents are negatively rewarded if caught by adversaries
rew = 0
shape = False # different from openai
adversaries = self.adversaries(world)
if shape: # reward can optionally be shaped (increased reward for increased distance from adversary)
for adv in adversaries:
rew += 0.1 * np.sqrt(np.sum(np.square(agent.state.p_pos - adv.state.p_pos)))
if agent.collide:
for a in adversaries:
if self.is_collision(a, agent):
rew -= 10
# agents are penalized for exiting the screen, so that they can be caught by the adversaries
def bound(x):
if x < 0.9:
return 0
if x < 1.0:
return (x - 0.9) * 10
return min(np.exp(2 * x - 2), 10)
for p in range(world.dim_p):
x = abs(agent.state.p_pos[p])
rew -= bound(x)
return rew
def adversary_reward(self, agent, world):
# Adversaries are rewarded for collisions with agents
rew = 0
shape = False # different from openai
agents = self.good_agents(world)
adversaries = self.adversaries(world)
if shape: # reward can optionally be shaped (decreased reward for increased distance from agents)
for adv in adversaries:
rew -= 0.1 * min([np.sqrt(np.sum(np.square(a.state.p_pos - adv.state.p_pos))) for a in agents])
if agent.collide:
for ag in agents:
for adv in adversaries:
if self.is_collision(ag, adv):
rew += 10
return rew
def observation(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
if not entity.boundary:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
# communication of all other agents
comm = []
other_pos = []
other_vel = []
for other in world.agents:
if other is agent: continue
comm.append(other.state.c)
other_pos.append(other.state.p_pos - agent.state.p_pos)
if not other.adversary:
other_vel.append(other.state.p_vel)
return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel)
================================================
FILE: examples/Social_Cognition/ToCM/mpe/scenarios/simple_world_comm.py
================================================
import numpy as np
from mpe.core import World, Agent, Landmark
from mpe.scenario import BaseScenario
class Scenario(BaseScenario):
def make_world(self, args):
world = World()
# set any world properties first
world.dim_c = 4
# world.damping = 1
num_good_agents = args.num_good_agents # 2
num_adversaries = args.num_adversaries # 4
num_agents = num_adversaries + num_good_agents
num_landmarks = args.num_landmarks # 1
num_food = 2
num_forests = 2
# add agents
world.agents = [Agent() for i in range(num_agents)]
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.collide = True
agent.leader = True if i == 0 else False
agent.silent = True if i > 0 else False
agent.adversary = True if i < num_adversaries else False
agent.size = 0.075 if agent.adversary else 0.045
agent.accel = 3.0 if agent.adversary else 4.0
# agent.accel = 20.0 if agent.adversary else 25.0
agent.max_speed = 1.0 if agent.adversary else 1.3
# add landmarks
world.landmarks = [Landmark() for i in range(num_landmarks)]
for i, landmark in enumerate(world.landmarks):
landmark.name = 'landmark %d' % i
landmark.collide = True
landmark.movable = False
landmark.size = 0.2
landmark.boundary = False
world.food = [Landmark() for i in range(num_food)]
for i, landmark in enumerate(world.food):
landmark.name = 'food %d' % i
landmark.collide = False
landmark.movable = False
landmark.size = 0.03
landmark.boundary = False
world.forests = [Landmark() for i in range(num_forests)]
for i, landmark in enumerate(world.forests):
landmark.name = 'forest %d' % i
landmark.collide = False
landmark.movable = False
landmark.size = 0.3
landmark.boundary = False
world.landmarks += world.food
world.landmarks += world.forests
# world.landmarks += self.set_boundaries(world) # world boundaries now penalized with negative reward
# make initial conditions
self.reset_world(world)
return world
def set_boundaries(self, world):
boundary_list = []
landmark_size = 1
edge = 1 + landmark_size
num_landmarks = int(edge * 2 / landmark_size)
for x_pos in [-edge, edge]:
for i in range(num_landmarks):
l = Landmark()
l.state.p_pos = np.array([x_pos, -1 + i * landmark_size])
boundary_list.append(l)
for y_pos in [-edge, edge]:
for i in range(num_landmarks):
l = Landmark()
l.state.p_pos = np.array([-1 + i * landmark_size, y_pos])
boundary_list.append(l)
for i, l in enumerate(boundary_list):
l.name = 'boundary %d' % i
l.collide == True
l.movable = False
l.boundary = True
l.color = np.array([0.75, 0.75, 0.75])
l.size = landmark_size
l.state.p_vel = np.zeros(world.dim_p)
return boundary_list
def reset_world(self, world):
# random properties for agents
for i, agent in enumerate(world.agents):
agent.color = np.array([0.45, 0.95, 0.45]) if not agent.adversary else np.array([0.95, 0.45, 0.45])
agent.color -= np.array([0.3, 0.3, 0.3]) if agent.leader else np.array([0, 0, 0])
# random properties for landmarks
for i, landmark in enumerate(world.landmarks):
landmark.color = np.array([0.25, 0.25, 0.25])
for i, landmark in enumerate(world.food):
landmark.color = np.array([0.15, 0.15, 0.65])
for i, landmark in enumerate(world.forests):
landmark.color = np.array([0.6, 0.9, 0.6])
# set random initial states
for agent in world.agents:
agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
agent.state.p_vel = np.zeros(world.dim_p)
agent.state.c = np.zeros(world.dim_c)
for i, landmark in enumerate(world.landmarks):
landmark.state.p_pos = 0.8 * np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
for i, landmark in enumerate(world.food):
landmark.state.p_pos = 0.8 * np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
for i, landmark in enumerate(world.forests):
landmark.state.p_pos = 0.8 * np.random.uniform(-1, +1, world.dim_p)
landmark.state.p_vel = np.zeros(world.dim_p)
def benchmark_data(self, agent, world):
if agent.adversary:
collisions = 0
for a in self.good_agents(world):
if self.is_collision(a, agent):
collisions += 1
return collisions
else:
return 0
def is_collision(self, agent1, agent2):
delta_pos = agent1.state.p_pos - agent2.state.p_pos
dist = np.sqrt(np.sum(np.square(delta_pos)))
dist_min = agent1.size + agent2.size
return True if dist < dist_min else False
# return all agents that are not adversaries
def good_agents(self, world):
return [agent for agent in world.agents if not agent.adversary]
# return all adversarial agents
def adversaries(self, world):
return [agent for agent in world.agents if agent.adversary]
def reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark
# boundary_reward = -10 if self.outside_boundary(agent) else 0
main_reward = self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)
return main_reward
def outside_boundary(self, agent):
if agent.state.p_pos[0] > 1 or agent.state.p_pos[0] < -1 or agent.state.p_pos[1] > 1 or agent.state.p_pos[
1] < -1:
return True
else:
return False
def agent_reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark
rew = 0
shape = False
adversaries = self.adversaries(world)
if shape:
for adv in adversaries:
rew += 0.1 * np.sqrt(np.sum(np.square(agent.state.p_pos - adv.state.p_pos)))
if agent.collide:
for a in adversaries:
if self.is_collision(a, agent):
rew -= 5
def bound(x):
if x < 0.9:
return 0
if x < 1.0:
return (x - 0.9) * 10
return min(np.exp(2 * x - 2), 10) # 1 + (x - 1) * (x - 1)
for p in range(world.dim_p):
x = abs(agent.state.p_pos[p])
rew -= 2 * bound(x)
for food in world.food:
if self.is_collision(agent, food):
rew += 2
rew += 0.05 * min([np.sqrt(np.sum(np.square(food.state.p_pos - agent.state.p_pos))) for food in world.food])
return rew
def adversary_reward(self, agent, world):
# Agents are rewarded based on minimum agent distance to each landmark
rew = 0
shape = True
agents = self.good_agents(world)
adversaries = self.adversaries(world)
if shape:
rew -= 0.1 * min([np.sqrt(np.sum(np.square(a.state.p_pos - agent.state.p_pos))) for a in agents])
# for adv in adversaries:
# rew -= 0.1 * min([np.sqrt(np.sum(np.square(a.state.p_pos - adv.state.p_pos))) for a in agents])
if agent.collide:
for ag in agents:
for adv in adversaries:
if self.is_collision(ag, adv):
rew += 5
return rew
def observation2(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks: # world.entities:
if not entity.boundary:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
food_pos = []
for entity in world.food: # world.entities:
if not entity.boundary:
food_pos.append(entity.state.p_pos - agent.state.p_pos)
# communication of all other agents
comm = []
other_pos = []
other_vel = []
for other in world.agents:
if other is agent: continue
comm.append(other.state.c)
other_pos.append(other.state.p_pos - agent.state.p_pos)
if not other.adversary:
other_vel.append(other.state.p_vel)
return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel)
def observation(self, agent, world):
# get positions of all entities in this agent's reference frame
entity_pos = []
for entity in world.landmarks:
if not entity.boundary:
entity_pos.append(entity.state.p_pos - agent.state.p_pos)
in_forest = [np.array([-1]), np.array([-1])]
inf1 = False
inf2 = False
if self.is_collision(agent, world.forests[0]):
in_forest[0] = np.array([1])
inf1 = True
if self.is_collision(agent, world.forests[1]):
in_forest[1] = np.array([1])
inf2 = True
food_pos = []
for entity in world.food:
if not entity.boundary:
food_pos.append(entity.state.p_pos - agent.state.p_pos)
# communication of all other agents
comm = []
other_pos = []
other_vel = []
for other in world.agents:
if other is agent: continue
comm.append(other.state.c)
oth_f1 = self.is_collision(other, world.forests[0])
oth_f2 = self.is_collision(other, world.forests[1])
if (inf1 and oth_f1) or (inf2 and oth_f2) or (
not inf1 and not oth_f1 and not inf2 and not oth_f2) or agent.leader: # without forest vis
other_pos.append(other.state.p_pos - agent.state.p_pos)
if not other.adversary:
other_vel.append(other.state.p_vel)
else:
other_pos.append([0, 0])
if not other.adversary:
other_vel.append([0, 0])
# to tell the pred when the prey are in the forest
prey_forest = []
ga = self.good_agents(world)
for a in ga:
if any([self.is_collision(a, f) for f in world.forests]):
prey_forest.append(np.array([1]))
else:
prey_forest.append(np.array([-1]))
# to tell leader when pred are in forest
prey_forest_lead = []
for f in world.forests:
if any([self.is_collision(a, f) for a in ga]):
prey_forest_lead.append(np.array([1]))
else:
prey_forest_lead.append(np.array([-1]))
comm = [world.agents[0].state.c]
if agent.adversary and not agent.leader:
return np.concatenate(
[agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + in_forest + comm)
if agent.leader:
return np.concatenate(
[agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + in_forest + comm)
else:
return np.concatenate(
[agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + in_forest + other_vel)
================================================
FILE: examples/Social_Cognition/ToCM/networks/ToCM/action.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
# if '/home/zhaofeifei/.local/lib/python3.8/site-packages' in sys.path:
# sys.path.remove('/home/zhaofeifei/.local/lib/python3.8/site-packages')
# sys.path.append('/home/zhaofeifei/mambaSNN_Mpe/networks/ToCM/')
# sys.path.append("/home/zhaofeifei/mambaSNN_Mpe/")
from torch.distributions import OneHotCategorical
from networks.transformer.layers import AttentionEncoder, AttentionActorEncoder
from networks.ToCM.utils import build_model_snn, build_model
from braincog.base.node.node import LIFNode, BaseNode, PLIFNode, DoubleSidePLIFNode
from braincog.base.strategy.surrogate import AtanGrad
class BCNoSpikingLIFNode(LIFNode):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, dv: torch.Tensor):
# print("dv: ", dv)
# print("dv.shape: ", dv.shape)
self.integral(dv)
return self.mem
#SNN
# class Actor(nn.Module):
# def __init__(self, in_dim, out_dim, hidden_size, layers, node='LIFNode', time_window=8,
# norm_in=True, output_style='voltage'): # 1.激活函数需要改成node # voltage
# super().__init__()
# # 1.加入SNN的脉冲参数
# self._threshold = 0.5
# self.v_reset = 0.0
# self.tau = 0.5
# self._time_window = time_window
# # 2.设置输出格式
# self.output_style = output_style
# # 3.ffn是否归一化
# self.norm = norm_in
# self.activation = node
# self.feedforward_model = build_model_snn(in_dim, out_dim, layers, hidden_size, # kkkk TODO!!!
# th=self._threshold, re=self.v_reset, tau=self.tau,
# activation=self.activation, normalize=lambda x: x) # TODO
# if self.output_style == 'ann':
# self.out_node = lambda x: x
# elif self.output_style == 'voltage':
# self.out_node = BCNoSpikingLIFNode(tau=1.0)
#
# def forward(self, state_features):
# # 5.加入脉冲仿真步长
# # print("state.shape", state_features.shape)
# self.reset() # why
# for t in range(self._time_window):
# x = self.feedforward_model(state_features)
# x = self.out_node(x)
# # print("x", x.shape)
# action_dist = OneHotCategorical(logits=x)
# action = action_dist.sample() # 长度为x,一行默认 tensor([0., 1., 0., 0.])
# return action, x
#
# # 调用modules里面node的n_reset
# def reset(self):
# for mod in self.modules():
# if hasattr(mod, 'n_reset'):
# mod.n_reset()
#ANN
class Actor(nn.Module):
def __init__(self, in_dim, out_dim, hidden_size, layers, activation=nn.ReLU):
super().__init__()
self.feedforward_model = build_model(in_dim, out_dim, layers, hidden_size, activation)
def forward(self, state_features):
x = self.feedforward_model(state_features)
action_dist = OneHotCategorical(logits=x)
action = action_dist.sample()
return action, x
class AttentionActor(nn.Module):
def __init__(self, in_dim, out_dim, hidden_size, layers, node='LIFNode', time_window=16,
norm_in=True, output_style='voltage'): # 2.激活层
super().__init__()
# 1.加入SNN的脉冲参数
self._threshold = 0.5
self.v_reset = 0.0
self._time_window = time_window
# 2.设置输出格式
self.output_style = output_style
# 3.ffn是否归一化
self.norm = norm_in
# 4.改变linear层的激活函数为LIFNode
self.activation = node
# hint: hidden_size = 其他网络的in_dim
self.feedforward_model = build_model_snn(hidden_size, out_dim, 2, hidden_size,
th=self._threshold, re=self.v_reset,
activation=self.activation, normalize=lambda x: x) # TODO
# build_model_snn(in_dim, out_dim, layers, hidden, activation, normalize=lambda x: x)
self._attention_stack = AttentionActorEncoder(1, hidden_size, hidden_size)
# no pos_embedding
# self._attention_stack = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=in_dim, nhead=1,
# dim_feedforward=hidden_size,
# dropout=.0), num_layers=1) # TODO
# n_layers, in_dim, hidden
# 使用transformer的编码器,加入位置编码,加入隐藏单元d_hid,返回一个序列,其中第0维度应该是观测变量?
self.embed = nn.Linear(in_dim, hidden_size)
self.node1 = LIFNode(threshold=self._threshold, v_reset=self.v_reset)
self.node2 = LIFNode(threshold=self._threshold, v_reset=self.v_reset)
# 5. 定义一个处理linear层的node
if self.activation == 'LIFNode':
if self.output_style == 'voltage':
self.out_node = BCNoSpikingLIFNode(tau=2.0)
def forward(self, state_features): # 状态值tensor
# print("state_feat:", state_features[0])
# attn_embeds = self._attention_stack(state_features)
# n_agents = state_features.shape[-2] # 推测state的维度为[batch_size(m,n), n_agents, in_dim]
# batch_size = state_features.shape[:-2] # 除去最后2维度的维度
qs = []
self.reset() # why
# print("attn_embeds", attn_embeds[0])
# print("state.shape", state_features.shape)
attn_embeds = self.embed(state_features) # Linear
for t in range(self._time_window):
embeds = self.node1(attn_embeds) # Node
# attn_embeds = embeds.view(-1, n_agents, embeds.shape[-1])
# embeds = self.node2(self._attention_stack(embeds).view(*batch_size, n_agents, embeds.shape[-1]))
x = self.feedforward_model(embeds)
x = self.out_node(x)
qs.append(x)
p = torch.zeros(qs[0].shape)
if self.output_style == "sum":
p = sum(qs) / self._time_window
elif self.output_style == "voltage":
p = qs[-1] # TODO
# p = F.softmax(p)
# print("pi:", p[0])
action_dist = OneHotCategorical(logits=p) # 编码器,长度为p
action = action_dist.sample()
# print("actions", action[0])
# 对输出进行采样
return action, p # 返回一个行动序列action为每个位置符合p = x[i]的0,1序列
# 调用modules里面node的n_reset
def reset(self):
for mod in self.modules():
if hasattr(mod, 'n_reset'):
mod.n_reset()
# aa = AttentionActor(16, 8, 64, 3) # in_dim, out_dim, hidden_size, layers,
# state_feature = torch.randn([8, 8, 2, 16]) # 输入变量维度
# out, x = aa(state_feature)
# print(aa)
# print(out)
# print(out.shape)
================================================
FILE: examples/Social_Cognition/ToCM/networks/ToCM/critic.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append('/home/zhaofeifei/mambaSNN/networks/ToCM/')
from networks.ToCM.utils import build_model_snn, build_model
from networks.transformer.layers import AttentionEncoder
from braincog.base.node.node import LIFNode
from braincog.base.strategy.surrogate import AtanGrad
decay = 0.3
thresh = 0.3
lens = 0.25
# print("File Critic Here")
# 0.定义一个返回膜电势的 LIFNode
class BCNoSpikingLIFNode(LIFNode):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, dv: torch.Tensor):
# print("dv: ", dv)
# print("dv.shape: ", dv.shape)
self.integral(dv)
return self.mem
act_fun = AtanGrad(alpha=2., requires_grad=False)
def mem_update(fc, x, mem, spike):
mem = mem * decay * (1 - spike) + fc(x)
# spike = act_fun(mem)
spike = act_fun(x=mem-1)
return mem, spike
class Critic(nn.Module):
def __init__(self, in_dim, hidden_size, layers=2, node='LIFNode', time_window=16,
norm_in=True, output_style='voltage'):
# hint是critic没有输出维度,action的输出维度是action的数量
super().__init__()
# 1.加入SNN的脉冲参数
self._threshold = 0.5
self.v_reset = 0.0
self._time_window = time_window
# 2.设置输出格式
self.output_style = output_style
# 3.ffn是否归一化
self.norm = norm_in
# if self.norm:
# self.in_norm = nn.BatchNorm1d(in_dim)
# self.in_norm.weight.data.fill_(1)
# self.in_norm.bias.data.zero_()
# else:
# self.in_norm = lambda x: x
self.in_norm = lambda x: x
# 4.改变linear层的激活函数为LIFNode
self.activation = node
self.hidden_size = hidden_size
self.layers = layers
self.feedforward_model = build_model_snn(in_dim, 1, layers, hidden_size,
th=self._threshold, re=self.v_reset,
activation=self.activation, normalize=lambda x: x)
# 这里feedforward的输出维度为1,其余一样 in_dim, out_dim, layers, hidden
# 5. 定义输出神经元node
if self.output_style == "sum":
self.out_node = lambda x: x
elif self.output_style == "voltage":
self.out_node = BCNoSpikingLIFNode(tau=2.0)
def forward(self, state_features, actions):
# 6.加入脉冲步长模拟
qs = []
self.reset() # why
# 7.加入第一次输入的归一化,对最前面的输入进行norm
state_features = self.in_norm(state_features)
for t in range(self._time_window):
x = self.feedforward_model(state_features)
# 8.linear层之后还得有个node接住。否则如果对于ann来说,linear之后的浮点数就能作为最后的分值了,对于snn不行
x = self.out_node(x)
qs.append(x)
if self.output_style == 'sum':
value = sum(qs) / self._time_window
return value
elif self.output_style == 'voltage':
value = qs[-1]
return value
# 调用modules里面node的n_reset
def reset(self):
for mod in self.modules():
if hasattr(mod, 'n_reset'):
mod.n_reset()
#SNN
# class MADDPGCritic(nn.Module):
# def __init__(self, in_dim, hidden_size, node='nn.Tanh', time_window=1, # time_window=16,
# norm_in=True, output_style='ann'): # in_dim 1280 hidden_size 256
# super().__init__()
#
# # 1.加入SNN的脉冲参数
# self._threshold = 0.5
# self.v_reset = 0.0
# self._time_window = time_window
# # 2.设置输出格式
# self.output_style = output_style
# # 3. ffn是否归一化
# self.norm = norm_in
#
# # TODO no normalize
# self.in_norm = lambda x: x
# # 4.改变linear层的激活函数为LIFNode
# self.activation = node # TODO!!!!!!!!!!!
#
# self.feedforward_model = build_model_snn(hidden_size, 1, 1, hidden_size,
# th=self._threshold, re=self.v_reset,
# activation=self.activation, normalize=lambda x: x)
# # in_dim, out_dim, layers, hidden
# # (in_dim = hidden)->hidden->hidden......-> (out_dim = 1)
#
# self._attention_stack = AttentionEncoder(1, hidden_size, hidden_size)
# self.embed = nn.Linear(in_dim, hidden_size) # 1280 256
# self.prior = build_model_snn(in_dim, 1, 3, hidden_size, # 1280 256
# th=self._threshold, re=self._threshold,
# activation=self.activation, normalize=lambda x: x)
# # also in_dim, out_dim, layers, hidden
# # (in_dim = hidden)->hidden->hidden......-> (out_dim = 1)
# # 可能是个决策函数,决策优先选择哪个action
#
# # 5. 定义输出神经元node
# if self.output_style == "sum":
# self.out_node = lambda x: x
# elif self.output_style == "voltage":
# self.out_node = BCNoSpikingLIFNode(tau=2.0)
# elif self.output_style == 'ann':
# self.out_node = lambda x: x
#
# def forward(self, state_features, actions):
# self.reset() # reset函数得看看怎么加
# n_agents = state_features.shape[-2]
# batch_size = state_features.shape[:-2]
# # 6.加入第一次输入的归一化
# state_features = self.in_norm(state_features)
# # 7.暂时不把编码加入模拟时长
# embeds = F.relu(self.embed(state_features))
# embeds = embeds.view(-1, n_agents, embeds.shape[-1])
# attn_embeds = F.relu(self._attention_stack(embeds).view(*batch_size, n_agents, embeds.shape[-1]))
#
# # 7.设置脉冲发放时长模拟,只在ffn层
# qs = []
# for t in range(self._time_window):
# x = self.feedforward_model(attn_embeds)
# x = self.out_node(x)
# qs.append(x)
#
# value = qs[-1] # after 16 mem
# # x = self.feedforward_model(attn_embeds)
# # value = self.out_node(x) # only mem once
# return value
#
# # 调用modules里面node的n_reset
# def reset(self):
# for mod in self.modules():
# if hasattr(mod, 'n_reset'):
# mod.n_reset()
#ANN
class MADDPGCritic(nn.Module):
def __init__(self, in_dim, hidden_size, layers=2, activation=nn.ELU):
super().__init__()
self.hidden_size = hidden_size
self.layers = layers
self.activation = activation
self.feedforward_model = build_model(in_dim, 1, layers, hidden_size, activation)
def forward(self, state_features, actions):
return self.feedforward_model(state_features)
# critic_net = Critic(in_dim=2, hidden_size=32, layers=2, node='LIFNode', time_window=16, norm_in=True,
# output_style='voltage')
# print(critic_net)
# maddpg_critic_net = MADDPGCritic(in_dim=2, hidden_size=32, node='LIFNode', time_window=16, norm_in=True,
# output_style='voltage')
# print(maddpg_critic_net)
================================================
FILE: examples/Social_Cognition/ToCM/networks/ToCM/dense.py
================================================
import torch
import torch.distributions as td
import torch.nn as nn
from networks.ToCM.utils import build_model_snn
class DenseModel(nn.Module):
def __init__(self, in_dim, out_dim, layers, hidden, activation="nn.ELU"): # TODO activation=nn.ELU
super().__init__()
self.model = build_model_snn(in_dim, out_dim, layers, hidden, activation=activation) # no use activation
def forward(self, features):
return self.model(features)
class DenseBinaryModel(DenseModel):
def __init__(self, in_dim, out_dim, layers, hidden, activation="nn.ELU"): # 1280 7 2 256
super().__init__(in_dim, out_dim, layers, hidden, activation=activation)
def forward(self, features):
# for name, p in self.model.named_parameters():
# print("name", name)
# print("p", p.shape)
# if features.shape[1] != 40:
# print("features.shape[0] / 40: ", features.shape[0] / 40)
# features = torch.as_tensor(torch.split(features, int(features.shape[0] / 40), dim=0))
# print("Dense features: ", features.shape)
dist_inputs = self.model(features) # features.shape 48 40 2 1280
# print("dist_inputs:", dist_inputs.shape)
return td.independent.Independent(td.Bernoulli(logits=dist_inputs), 1)
================================================
FILE: examples/Social_Cognition/ToCM/networks/ToCM/rnns.py
================================================
import torch
import torch.nn as nn
from torch.distributions import OneHotCategorical
from configs.ToCM.ToCMAgentConfig import RSSMState
from networks.transformer.layers import AttentionEncoder
def stack_states(rssm_states: list, dim):
return reduce_states(rssm_states, dim, torch.stack)
def cat_states(rssm_states: list, dim):
return reduce_states(rssm_states, dim, torch.cat)
def reduce_states(rssm_states: list, dim, func):
return RSSMState(*[func([getattr(state, key) for state in rssm_states], dim=dim)
for key in rssm_states[0].__dict__.keys()])
class DiscreteLatentDist(nn.Module):
def __init__(self, in_dim, n_categoricals, n_classes, hidden_size):
super().__init__()
self.n_categoricals = n_categoricals
self.n_classes = n_classes
self.dists = nn.Sequential(nn.Linear(in_dim, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, n_classes * n_categoricals))
def forward(self, x):
logits = self.dists(x).view(x.shape[:-1] + (self.n_categoricals, self.n_classes))
class_dist = OneHotCategorical(logits=logits)
one_hot = class_dist.sample()
latents = one_hot + class_dist.probs - class_dist.probs.detach()
return logits.view(x.shape[:-1] + (-1,)), latents.view(x.shape[:-1] + (-1,))
class RSSMTransition(nn.Module):
def __init__(self, config, hidden_size=200, activation=nn.ReLU):
super().__init__()
self._stoch_size = config.STOCHASTIC
self._deter_size = config.DETERMINISTIC
self._hidden_size = hidden_size
self._activation = activation
self._cell = nn.GRU(hidden_size, self._deter_size)
self._attention_stack = AttentionEncoder(3, hidden_size, hidden_size, dropout=0.1)
self._rnn_input_model = self._build_rnn_input_model(config.ACTION_SIZE + self._stoch_size)
self._stochastic_prior_model = DiscreteLatentDist(self._deter_size, config.N_CATEGORICALS, config.N_CLASSES,
self._hidden_size)
def _build_rnn_input_model(self, in_dim):
rnn_input_model = [nn.Linear(in_dim, self._hidden_size)]
rnn_input_model += [self._activation()]
return nn.Sequential(*rnn_input_model)
def forward(self, prev_actions, prev_states, mask=None):
batch_size = prev_actions.shape[0]
n_agents = prev_actions.shape[1]
stoch_input = self._rnn_input_model(torch.cat([prev_actions, prev_states.stoch], dim=-1))
attn = self._attention_stack(stoch_input, mask=mask)
deter_state = self._cell(attn.reshape(1, batch_size * n_agents, -1),
prev_states.deter.reshape(1, batch_size * n_agents, -1))[0].reshape(batch_size, n_agents, -1)
logits, stoch_state = self._stochastic_prior_model(deter_state)
return RSSMState(logits=logits, stoch=stoch_state, deter=deter_state)
class RSSMRepresentation(nn.Module):
def __init__(self, config, transition_model: RSSMTransition):
super().__init__()
self._transition_model = transition_model
self._stoch_size = config.STOCHASTIC
self._deter_size = config.DETERMINISTIC
self._stochastic_posterior_model = DiscreteLatentDist(self._deter_size + config.EMBED, config.N_CATEGORICALS,
config.N_CLASSES, config.HIDDEN)
def initial_state(self, batch_size, n_agents, **kwargs):
return RSSMState(stoch=torch.zeros(batch_size, n_agents, self._stoch_size, **kwargs),
logits=torch.zeros(batch_size, n_agents, self._stoch_size, **kwargs),
deter=torch.zeros(batch_size, n_agents, self._deter_size, **kwargs))
def forward(self, obs_embed, prev_actions, prev_states, mask=None):
"""
:param obs_embed: size(batch, n_agents, obs_size)
:param prev_actions: size(batch, n_agents, action_size)
:param prev_states: size(batch, n_agents, state_size)
:return: RSSMState, global_state: size(batch, 1, global_state_size)
"""
prior_states = self._transition_model(prev_actions, prev_states, mask)
x = torch.cat([prior_states.deter, obs_embed], dim=-1)
logits, stoch_state = self._stochastic_posterior_model(x)
posterior_states = RSSMState(logits=logits, stoch=stoch_state, deter=prior_states.deter)
return prior_states, posterior_states
def rollout_representation(representation_model, steps, obs_embed, action, prev_states, done):
"""
Roll out the model with actions and observations from data.
:param steps: number of steps to roll out
:param obs_embed: size(time_steps, batch_size, n_agents, embedding_size)
:param action: size(time_steps, batch_size, n_agents, action_size)
:param prev_states: RSSM state, size(batch_size, n_agents, state_size)
:return: prior, posterior states. size(time_steps, batch_size, n_agents, state_size)
"""
priors = []
posteriors = []
for t in range(steps):
prior_states, posterior_states = representation_model(obs_embed[t], action[t], prev_states)
prev_states = posterior_states.map(lambda x: x * (1.0 - done[t]))
priors.append(prior_states)
posteriors.append(posterior_states)
prior = stack_states(priors, dim=0)
post = stack_states(posteriors, dim=0)
return prior.map(lambda x: x[:-1]), post.map(lambda x: x[:-1]), post.deter[1:]
def rollout_policy(transition_model, av_action, steps, policy, prev_state, prev_action, config): # av_action.shape=[49 40 2 7] policy=actor
"""
Roll out the model with a policy function.
:param steps: number of steps to roll out
:param policy: RSSMState -> action
:param prev_state: RSSM state, size(batch_size, state_size)
:return: next states size(time_steps, batch_size, state_size),
actions size(time_steps, batch_size, action_size)
"""
state = prev_state
action = prev_action[:-1].reshape((prev_action.shape[0] - 1) * prev_action.shape[1], prev_action.shape[2], -1) # TODO
next_states = []
actions = []
av_actions = []
policies = []
obs_preds = []
for t in range(steps):
feat = state.get_features().detach()
obs_pred, _ = transition_model.observation_decoder(feat) # TODO
next_state = transition_model.transition(action, state) # TODO
next_feat = next_state.get_features().detach() # TODO
observations_next_other, _ = transition_model.observation_decoder(next_feat) # TODO
action, pi = policy(torch.cat((obs_pred.detach(), observations_next_other[:, :, -(config.num_agents-1)*4:-(config.num_agents-1)*2]), -1)) # TODO 3 vs 3
# print("feat:", feat)
# print("feat_shape:", feat.shape) # feat.shape=1920,2,1280
# action, pi = policy(feat)
if av_action is not None:
# print("av_action!")
avail_actions = av_action(feat).sample()
pi[avail_actions == 0] = -1e10
action_dist = OneHotCategorical(logits=pi)
action = action_dist.sample().squeeze(0)
av_actions.append(avail_actions.squeeze(0))
next_states.append(state)
obs_preds.append(obs_pred)
policies.append(pi)
actions.append(action)
state = transition_model.transition(action, state)
return {"imag_states": stack_states(next_states, dim=0),
"obs_preds": torch.stack(obs_preds, dim=0),
"actions": torch.stack(actions, dim=0),
"av_actions": torch.stack(av_actions, dim=0) if len(av_actions) > 0 else None,
"old_policy": torch.stack(policies, dim=0)}
================================================
FILE: examples/Social_Cognition/ToCM/networks/ToCM/utils.py
================================================
import torch.nn as nn
from braincog.base.node import LIFNode
from braincog.base.node.node import LIFNode, DoubleSidePLIFNode, PLIFNode
from braincog.base.strategy.surrogate import AtanGrad
import torch
class AtanLIFNode(LIFNode):
def __init__(self, tau=0.5, *args, **kwargs):
super().__init__(tau, *args, **kwargs)
self.act_fun = AtanGrad(alpha=1., requires_grad=True)
class BCNoSpikingLIFNode(LIFNode):
def __init__(self, tau, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tau = tau
def forward(self, dv: torch.Tensor):
# print("dv: ", dv)
# print("dv.shape: ", dv.shape)
self.integral(dv)
return self.mem
def build_model_snn(in_dim, out_dim, layers, hidden, th=0.5, re=0.0, tau=0.5, activation='LIFNode',
normalize=lambda x: x):
# print("build model snn!")
# 0.activation换成LIFNode...
if activation == 'LIFNode':
node = LIFNode(threshold=th, tau=tau)
elif activation == 'AtanLIFNode':
node = AtanLIFNode(tau=tau)
elif activation == 'BCNoSpikingLIFNode':
node = BCNoSpikingLIFNode(tau=tau)
elif activation == 'DoubleSidePLIFNode':
node = DoubleSidePLIFNode(tau=tau)
elif activation == 'PLIFNode':
node = PLIFNode(threshold=th)
elif activation == 'nn.ELU':
node = nn.ELU()
elif activation == 'nn.ReLU':
node = nn.ReLU()
elif activation == 'nn.Tanh':
node = nn.Tanh()
# 1.是否norm no norm
model = [normalize(nn.Linear(in_dim, hidden))]
model += [node]
for i in range(layers - 1):
model += [normalize(nn.Linear(hidden, hidden))]
model += [node]
model += [normalize(nn.Linear(hidden, out_dim))]
# 使用第二个归一化,node激活之后还要linear,最后的输出应该还得有个node,将out node定义到外面比较合适
return nn.Sequential(*model)
def build_model(in_dim, out_dim, layers, hidden, activation, normalize=lambda x: x):
model = [normalize(nn.Linear(in_dim, hidden))]
model += [activation()]
for i in range(layers - 1):
model += [normalize(nn.Linear(hidden, hidden))]
model += [activation()]
model += [normalize(nn.Linear(hidden, out_dim))]
return nn.Sequential(*model)
================================================
FILE: examples/Social_Cognition/ToCM/networks/ToCM/vae.py
================================================
import torch.nn as nn
import torch.nn.functional as F
from networks.ToCM.utils import build_model_snn
class Decoder(nn.Module):
def __init__(self, embed, hidden, out_dim, layers=2):
super().__init__()
self.fc1 = build_model_snn(embed, hidden, layers, hidden, activation='nn.ReLU') # activation=nn.ReLU
self.fc2 = nn.Linear(hidden, out_dim)
def forward(self, z):
x = F.relu(self.fc1(z))
return self.fc2(x), x
class Encoder(nn.Module):
def __init__(self, in_dim, hidden, embed, layers=2):
super().__init__()
self.fc1 = nn.Linear(in_dim, hidden)
self.encoder = build_model_snn(hidden, embed, layers, hidden, activation='nn.ReLU') # activation=nn.ReLU
def forward(self, x):
embed = F.relu(self.fc1(x))
return self.encoder(F.relu(embed))
================================================
FILE: examples/Social_Cognition/ToCM/networks/transformer/layers.py
================================================
import numpy as np
import torch
import torch.nn as nn
# 位置编码
class PositionalEncoding(nn.Module):
__author__ = "Yu-Hsiang Huang"
def __init__(self, d_hid, n_position=2):
super(PositionalEncoding, self).__init__()
# Not a parameter
'''
This is typically used to register a buffer that should not to be
considered a model parameter. For example, BatchNorm's ``running_mean``
is not a parameter, but is part of the module's state.
input: buffer's name, buffer's shape 应该是隐藏层之类的
'''
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
# return x + self.pos_table[:, :x.size(1)].clone().detach() 使用pos_table
@staticmethod # 系统提示我这个方法静态
def _get_sinusoid_encoding_table(n_position, d_hid):
""" Sinusoid position encoding table """
def get_position_angle_vec(position): # 获取每个位置的角度向量
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
# shape [pos_i, d_hid, position]
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0) # 增加一个维度
def forward(self, x):
return x + self.pos_table[:, :x.size(1)].clone().detach()
class AttentionEncoder(nn.Module):
def __init__(self, n_layers, in_dim, hidden, dropout=0.):
super().__init__()
self.pos_embed = PositionalEncoding(hidden, 30) # 返回位置编码方案,维度++
self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=in_dim, nhead=8,
dim_feedforward=hidden,
dropout=dropout), n_layers)
def forward(self, enc_input, **kwargs):
enc_input = self.pos_embed(enc_input)
x = self.encoder(enc_input.permute(1, 0, 2), **kwargs)
return x.permute(1, 0, 2) # 混洗 调换顺序
class AttentionActorEncoder(nn.Module):
def __init__(self, n_layers, in_dim, hidden, dropout=0.):
super().__init__()
self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=in_dim, nhead=8,
dim_feedforward=hidden,
dropout=dropout), n_layers)
def forward(self, enc_input, **kwargs):
x = self.encoder(enc_input, **kwargs)
return x # 混洗 调换顺序
================================================
FILE: examples/Social_Cognition/ToCM/requirements.txt
================================================
numpy~=1.18.5
torch~=1.7.0
ray~=1.13.0
git+https://github.com/oxwhirl/smac.git
wandb~=0.13.11
argparse~=1.4.0
================================================
FILE: examples/Social_Cognition/ToCM/run.sh
================================================
#!/bin/sh
seed_max=10
#for seed in `seq ${seed_max}`;
#do
# echo "seed is ${seed}:"
# python train.py
# kill Main_Thread
#done
#seed_max=10 # 设置最大的种子值,这里假设为10
#for./run, seed in $(seq 1 $seed_max); do
# echo "seed is $seed:"
# python train.py --seed $seed # 将当前种子值作为参数传递给 train.py
#done
python train.py --seed 50
pkill Main_Thread
python train.py --seed 50
pkill Main_Thread
#python train.py --seed 1
#pkill Main_Thread
================================================
FILE: examples/Social_Cognition/ToCM/smac/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/ToCM/smac/bin/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/ToCM/smac/bin/map_list.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from smac.env.starcraft2.maps import smac_maps
from pysc2 import maps as pysc2_maps
def main():
smac_map_registry = smac_maps.get_smac_map_registry()
all_maps = pysc2_maps.get_maps()
print("{:<15} {:7} {:7} {:7}".format("Name", "Agents", "Enemies", "Limit"))
for map_name, map_params in smac_map_registry.items():
map_class = all_maps[map_name]
if map_class.path:
print(
"{:<15} {:<7} {:<7} {:<7}".format(
map_name,
map_params["n_agents"],
map_params["n_enemies"],
map_params["limit"],
)
)
if __name__ == "__main__":
main()
================================================
FILE: examples/Social_Cognition/ToCM/smac/env/__init__.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from smac.env.multiagentenv import MultiAgentEnv
from smac.env.starcraft2.starcraft2 import StarCraft2Env
__all__ = ["MultiAgentEnv", "StarCraft2Env"]
================================================
FILE: examples/Social_Cognition/ToCM/smac/env/multiagentenv.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class MultiAgentEnv(object):
def step(self, actions):
"""Returns reward, terminated, info."""
raise NotImplementedError
def get_obs(self):
"""Returns all agent observations in a list."""
raise NotImplementedError
def get_obs_agent(self, agent_id):
"""Returns observation for agent_id."""
raise NotImplementedError
def get_obs_size(self):
"""Returns the size of the observation."""
raise NotImplementedError
def get_state(self):
"""Returns the global state."""
raise NotImplementedError
def get_state_size(self):
"""Returns the size of the global state."""
raise NotImplementedError
def get_avail_actions(self):
"""Returns the available actions of all agents in a list."""
raise NotImplementedError
def get_avail_agent_actions(self, agent_id):
"""Returns the available actions for agent_id."""
raise NotImplementedError
def get_total_actions(self):
"""Returns the total number of actions an agent could ever take."""
raise NotImplementedError
def reset(self):
"""Returns initial observations and states."""
raise NotImplementedError
def render(self):
raise NotImplementedError
def close(self):
raise NotImplementedError
def seed(self):
raise NotImplementedError
def save_replay(self):
"""Save a replay."""
raise NotImplementedError
def get_env_info(self):
env_info = {
"state_shape": self.get_state_size(),
"obs_shape": self.get_obs_size(),
"n_actions": self.get_total_actions(),
"n_agents": self.n_agents,
"episode_limit": self.episode_limit,
}
return env_info
================================================
FILE: examples/Social_Cognition/ToCM/smac/env/pettingzoo/StarCraft2PZEnv.py
================================================
from smac.env import StarCraft2Env
from gym.utils import EzPickle
from gym.utils import seeding
from gym import spaces
from pettingzoo.utils.env import ParallelEnv
from pettingzoo.utils.conversions import parallel_to_aec as from_parallel_wrapper
from pettingzoo.utils import wrappers
import numpy as np
max_cycles_default = 1000
def parallel_env(max_cycles=max_cycles_default, **smac_args):
return _parallel_env(max_cycles, **smac_args)
def raw_env(max_cycles=max_cycles_default, **smac_args):
return from_parallel_wrapper(parallel_env(max_cycles, **smac_args))
def make_env(raw_env):
def env_fn(**kwargs):
env = raw_env(**kwargs)
# env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1)
env = wrappers.AssertOutOfBoundsWrapper(env)
env = wrappers.OrderEnforcingWrapper(env)
return env
return env_fn
class smac_parallel_env(ParallelEnv):
def __init__(self, env, max_cycles):
self.max_cycles = max_cycles
self.env = env
self.env.reset()
self.reset_flag = 0
self.agents, self.action_spaces = self._init_agents()
self.possible_agents = self.agents[:]
observation_size = env.get_obs_size()
self.observation_spaces = {
name: spaces.Dict(
{
"observation": spaces.Box(
low=-1,
high=1,
shape=(observation_size,),
dtype="float32",
),
"action_mask": spaces.Box(
low=0,
high=1,
shape=(self.action_spaces[name].n,),
dtype=np.int8,
),
}
)
for name in self.agents
}
self._reward = 0
def _init_agents(self):
last_type = ""
agents = []
action_spaces = {}
self.agents_id = {}
i = 0
for agent_id, agent_info in self.env.agents.items():
unit_action_space = spaces.Discrete(
self.env.get_total_actions() - 1
) # no-op in dead units is not an action
if agent_info.unit_type == self.env.marine_id:
agent_type = "marine"
elif agent_info.unit_type == self.env.marauder_id:
agent_type = "marauder"
elif agent_info.unit_type == self.env.medivac_id:
agent_type = "medivac"
elif agent_info.unit_type == self.env.hydralisk_id:
agent_type = "hydralisk"
elif agent_info.unit_type == self.env.zergling_id:
agent_type = "zergling"
elif agent_info.unit_type == self.env.baneling_id:
agent_type = "baneling"
elif agent_info.unit_type == self.env.stalker_id:
agent_type = "stalker"
elif agent_info.unit_type == self.env.colossus_id:
agent_type = "colossus"
elif agent_info.unit_type == self.env.zealot_id:
agent_type = "zealot"
else:
raise AssertionError(f"agent type {agent_type} not supported")
if agent_type == last_type:
i += 1
else:
i = 0
agents.append(f"{agent_type}_{i}")
self.agents_id[agents[-1]] = agent_id
action_spaces[agents[-1]] = unit_action_space
last_type = agent_type
return agents, action_spaces
def seed(self, seed=None):
if seed is None:
self.env._seed = seeding.create_seed(seed, max_bytes=4)
else:
self.env._seed = seed
self.env.full_restart()
def render(self, mode="human"):
self.env.render(mode)
def close(self):
self.env.close()
def reset(self):
self.env._episode_count = 1
self.env.reset()
self.agents = self.possible_agents[:]
self.frames = 0
self.all_dones = {agent: False for agent in self.possible_agents}
return self._observe_all()
def get_agent_smac_id(self, agent):
return self.agents_id[agent]
def _all_rewards(self, reward):
all_rewards = [reward] * len(self.agents)
return {
agent: reward for agent, reward in zip(self.agents, all_rewards)
}
def _observe_all(self):
all_obs = []
for agent in self.agents:
agent_id = self.get_agent_smac_id(agent)
obs = self.env.get_obs_agent(agent_id)
action_mask = self.env.get_avail_agent_actions(agent_id)
action_mask = action_mask[1:]
action_mask = np.array(action_mask).astype(np.int8)
obs = np.asarray(obs, dtype=np.float32)
all_obs.append(
{"observation": obs, "action_mask": action_mask}
)
return {agent: obs for agent, obs in zip(self.agents, all_obs)}
def _all_dones(self, step_done=False):
dones = [True] * len(self.agents)
if not step_done:
for i, agent in enumerate(self.agents):
agent_done = False
agent_id = self.get_agent_smac_id(agent)
agent_info = self.env.get_unit_by_id(agent_id)
if agent_info.health == 0:
agent_done = True
dones[i] = agent_done
return {agent: bool(done) for agent, done in zip(self.agents, dones)}
def step(self, all_actions):
action_list = [0] * self.env.n_agents
for agent in self.agents:
agent_id = self.get_agent_smac_id(agent)
if agent in all_actions:
if all_actions[agent] is None:
action_list[agent_id] = 0
else:
action_list[agent_id] = all_actions[agent] + 1
self._reward, terminated, smac_info = self.env.step(action_list)
self.frames += 1
done = terminated or self.frames >= self.max_cycles
all_infos = {agent: {} for agent in self.agents}
# all_infos.update(smac_info)
all_dones = self._all_dones(done)
all_rewards = self._all_rewards(self._reward)
all_observes = self._observe_all()
self.agents = [
agent for agent in self.agents if not all_dones[agent]
]
return all_observes, all_rewards, all_dones, all_infos
def __del__(self):
self.env.close()
env = make_env(raw_env)
class _parallel_env(smac_parallel_env, EzPickle):
metadata = {"render.modes": ["human"], "name": "sc2"}
def __init__(self, max_cycles, **smac_args):
EzPickle.__init__(self, max_cycles, **smac_args)
env = StarCraft2Env(**smac_args)
super().__init__(env, max_cycles)
================================================
FILE: examples/Social_Cognition/ToCM/smac/env/pettingzoo/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/ToCM/smac/env/pettingzoo/test/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/ToCM/smac/env/pettingzoo/test/all_test.py
================================================
from smac.env.starcraft2.maps import smac_maps
from pysc2 import maps as pysc2_maps
from smac.env.pettingzoo import StarCraft2PZEnv as sc2
import pytest
from pettingzoo import test
import pickle
smac_map_registry = smac_maps.get_smac_map_registry()
all_maps = pysc2_maps.get_maps()
map_names = []
for map_name in smac_map_registry.keys():
map_class = all_maps[map_name]
if map_class.path:
map_names.append(map_name)
@pytest.mark.parametrize(("map_name"), map_names)
def test_env(map_name):
env = sc2.env(map_name=map_name)
test.api_test(env)
# test.parallel_api_test(sc2_v0.parallel_env()) # does not pass it due to
# illegal actions test.seed_test(sc2.env, 50) # not required, sc2 env only
# allows reseeding at initialization
test.render_test(env)
recreated_env = pickle.loads(pickle.dumps(env))
test.api_test(recreated_env)
================================================
FILE: examples/Social_Cognition/ToCM/smac/env/pettingzoo/test/smac_pettingzoo_test.py
================================================
import os
import sys
import inspect
from pettingzoo import test
from smac.env.pettingzoo import StarCraft2PZEnv as sc2
import pickle
current_dir = os.path.dirname(
os.path.abspath(inspect.getfile(inspect.currentframe()))
)
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
if __name__ == "__main__":
env = sc2.env(map_name="corridor")
test.api_test(env)
# test.parallel_api_test(sc2_v0.parallel_env()) # does not pass it due to
# illegal actions test.seed_test(sc2_v0.env, 50) # not required, sc2 env
# only allows reseeding at initialization
recreated_env = pickle.loads(pickle.dumps(env))
test.api_test(recreated_env)
================================================
FILE: examples/Social_Cognition/ToCM/smac/env/starcraft2/__init__.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
FLAGS = flags.FLAGS
FLAGS(["main.py"])
================================================
FILE: examples/Social_Cognition/ToCM/smac/env/starcraft2/maps/__init__.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from smac.env.starcraft2.maps import smac_maps
def get_map_params(map_name):
map_param_registry = smac_maps.get_smac_map_registry()
return map_param_registry[map_name]
================================================
FILE: examples/Social_Cognition/ToCM/smac/env/starcraft2/maps/smac_maps.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from pysc2.maps import lib
class SMACMap(lib.Map):
directory = "SMAC_Maps"
download = "https://github.com/oxwhirl/smac#smac-maps"
players = 2
step_mul = 8
game_steps_per_episode = 0
map_param_registry = {
"3m": {
"n_agents": 3,
"n_enemies": 3,
"limit": 60,
"a_race": "T",
"b_race": "T",
"unit_type_bits": 0,
"map_type": "marines",
},
"8m": {
"n_agents": 8,
"n_enemies": 8,
"limit": 120,
"a_race": "T",
"b_race": "T",
"unit_type_bits": 0,
"map_type": "marines",
},
"25m": {
"n_agents": 25,
"n_enemies": 25,
"limit": 150,
"a_race": "T",
"b_race": "T",
"unit_type_bits": 0,
"map_type": "marines",
},
"5m_vs_6m": {
"n_agents": 5,
"n_enemies": 6,
"limit": 70,
"a_race": "T",
"b_race": "T",
"unit_type_bits": 0,
"map_type": "marines",
},
"8m_vs_9m": {
"n_agents": 8,
"n_enemies": 9,
"limit": 120,
"a_race": "T",
"b_race": "T",
"unit_type_bits": 0,
"map_type": "marines",
},
"10m_vs_11m": {
"n_agents": 10,
"n_enemies": 11,
"limit": 150,
"a_race": "T",
"b_race": "T",
"unit_type_bits": 0,
"map_type": "marines",
},
"27m_vs_30m": {
"n_agents": 27,
"n_enemies": 30,
"limit": 180,
"a_race": "T",
"b_race": "T",
"unit_type_bits": 0,
"map_type": "marines",
},
"MMM": {
"n_agents": 10,
"n_enemies": 10,
"limit": 150,
"a_race": "T",
"b_race": "T",
"unit_type_bits": 3,
"map_type": "MMM",
},
"MMM2": {
"n_agents": 10,
"n_enemies": 12,
"limit": 180,
"a_race": "T",
"b_race": "T",
"unit_type_bits": 3,
"map_type": "MMM",
},
"2s3z": {
"n_agents": 5,
"n_enemies": 5,
"limit": 120,
"a_race": "P",
"b_race": "P",
"unit_type_bits": 2,
"map_type": "stalkers_and_zealots",
},
"3s5z": {
"n_agents": 8,
"n_enemies": 8,
"limit": 150,
"a_race": "P",
"b_race": "P",
"unit_type_bits": 2,
"map_type": "stalkers_and_zealots",
},
"3s5z_vs_3s6z": {
"n_agents": 8,
"n_enemies": 9,
"limit": 170,
"a_race": "P",
"b_race": "P",
"unit_type_bits": 2,
"map_type": "stalkers_and_zealots",
},
"3s_vs_3z": {
"n_agents": 3,
"n_enemies": 3,
"limit": 150,
"a_race": "P",
"b_race": "P",
"unit_type_bits": 0,
"map_type": "stalkers",
},
"3s_vs_4z": {
"n_agents": 3,
"n_enemies": 4,
"limit": 200,
"a_race": "P",
"b_race": "P",
"unit_type_bits": 0,
"map_type": "stalkers",
},
"3s_vs_5z": {
"n_agents": 3,
"n_enemies": 5,
"limit": 250,
"a_race": "P",
"b_race": "P",
"unit_type_bits": 0,
"map_type": "stalkers",
},
"1c3s5z": {
"n_agents": 9,
"n_enemies": 9,
"limit": 180,
"a_race": "P",
"b_race": "P",
"unit_type_bits": 3,
"map_type": "colossi_stalkers_zealots",
},
"2m_vs_1z": {
"n_agents": 2,
"n_enemies": 1,
"limit": 150,
"a_race": "T",
"b_race": "P",
"unit_type_bits": 0,
"map_type": "marines",
},
"corridor": {
"n_agents": 6,
"n_enemies": 24,
"limit": 400,
"a_race": "P",
"b_race": "Z",
"unit_type_bits": 0,
"map_type": "zealots",
},
"6h_vs_8z": {
"n_agents": 6,
"n_enemies": 8,
"limit": 150,
"a_race": "Z",
"b_race": "P",
"unit_type_bits": 0,
"map_type": "hydralisks",
},
"2s_vs_1sc": {
"n_agents": 2,
"n_enemies": 1,
"limit": 300,
"a_race": "P",
"b_race": "Z",
"unit_type_bits": 0,
"map_type": "stalkers",
},
"so_many_baneling": {
"n_agents": 7,
"n_enemies": 32,
"limit": 100,
"a_race": "P",
"b_race": "Z",
"unit_type_bits": 0,
"map_type": "zealots",
},
"bane_vs_bane": {
"n_agents": 24,
"n_enemies": 24,
"limit": 200,
"a_race": "Z",
"b_race": "Z",
"unit_type_bits": 2,
"map_type": "bane",
},
"2c_vs_64zg": {
"n_agents": 2,
"n_enemies": 64,
"limit": 400,
"a_race": "P",
"b_race": "Z",
"unit_type_bits": 0,
"map_type": "colossus",
},
}
def get_smac_map_registry():
return map_param_registry
for name in map_param_registry.keys():
globals()[name] = type(name, (SMACMap,), dict(filename=name))
================================================
FILE: examples/Social_Cognition/ToCM/smac/env/starcraft2/render.py
================================================
import numpy as np
import re
import subprocess
import platform
from absl import logging
import math
import time
import collections
import os
import pygame
import queue
from pysc2.lib import colors
from pysc2.lib import point
from pysc2.lib.renderer_human import _Surface
from pysc2.lib import transform
from pysc2.lib import features
def clamp(n, smallest, largest):
return max(smallest, min(n, largest))
def _get_desktop_size():
"""Get the desktop size."""
if platform.system() == "Linux":
try:
xrandr_query = subprocess.check_output(["xrandr", "--query"])
sizes = re.findall(
r"\bconnected primary (\d+)x(\d+)", str(xrandr_query)
)
if sizes[0]:
return point.Point(int(sizes[0][0]), int(sizes[0][1]))
except ValueError:
logging.error("Failed to get the resolution from xrandr.")
# Most general, but doesn't understand multiple monitors.
display_info = pygame.display.Info()
return point.Point(display_info.current_w, display_info.current_h)
class StarCraft2Renderer:
def __init__(self, env, mode):
os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "hide"
self.env = env
self.mode = mode
self.obs = None
self._window_scale = 0.75
self.game_info = game_info = self.env._controller.game_info()
self.static_data = self.env._controller.data()
self._obs_queue = queue.Queue()
self._game_times = collections.deque(
maxlen=100
) # Avg FPS over 100 frames. # pytype: disable=wrong-keyword-args
self._render_times = collections.deque(
maxlen=100
) # pytype: disable=wrong-keyword-args
self._last_time = time.time()
self._last_game_loop = 0
self._name_lengths = {}
self._map_size = point.Point.build(game_info.start_raw.map_size)
self._playable = point.Rect(
point.Point.build(game_info.start_raw.playable_area.p0),
point.Point.build(game_info.start_raw.playable_area.p1),
)
window_size_px = point.Point(
self.env.window_size[0], self.env.window_size[1]
)
window_size_px = self._map_size.scale_max_size(
window_size_px * self._window_scale
).ceil()
self._scale = window_size_px.y // 32
self.display = pygame.Surface(window_size_px)
if mode == "human":
self.display = pygame.display.set_mode(window_size_px, 0, 32)
pygame.display.init()
pygame.display.set_caption("Starcraft Viewer")
pygame.font.init()
self._world_to_world_tl = transform.Linear(
point.Point(1, -1), point.Point(0, self._map_size.y)
)
self._world_tl_to_screen = transform.Linear(scale=window_size_px / 32)
self.screen_transform = transform.Chain(
self._world_to_world_tl, self._world_tl_to_screen
)
surf_loc = point.Rect(point.origin, window_size_px)
sub_surf = self.display.subsurface(
pygame.Rect(surf_loc.tl, surf_loc.size)
)
self._surf = _Surface(
sub_surf,
None,
surf_loc,
self.screen_transform,
None,
self.draw_screen,
)
self._font_small = pygame.font.Font(None, int(self._scale * 0.5))
self._font_large = pygame.font.Font(None, self._scale)
def close(self):
pygame.display.quit()
pygame.quit()
def _get_units(self):
for u in sorted(
self.obs.observation.raw_data.units,
key=lambda u: (u.pos.z, u.owner != 16, -u.radius, u.tag),
):
yield u, point.Point.build(u.pos)
def get_unit_name(self, surf, name, radius):
"""Get a length limited unit name for drawing units."""
key = (name, radius)
if key not in self._name_lengths:
max_len = surf.world_to_surf.fwd_dist(radius * 1.6)
for i in range(len(name)):
if self._font_small.size(name[: i + 1])[0] > max_len:
self._name_lengths[key] = name[:i]
break
else:
self._name_lengths[key] = name
return self._name_lengths[key]
def render(self, mode):
self.obs = self.env._obs
self.score = self.env.reward
self.step = self.env._episode_steps
now = time.time()
self._game_times.append(
(
now - self._last_time,
max(
1,
self.obs.observation.game_loop
- self.obs.observation.game_loop,
),
)
)
if mode == "human":
pygame.event.pump()
self._surf.draw(self._surf)
observation = np.array(pygame.surfarray.pixels3d(self.display))
if mode == "human":
pygame.display.flip()
self._last_time = now
self._last_game_loop = self.obs.observation.game_loop
# self._obs_queue.put(self.obs)
return (
np.transpose(observation, axes=(1, 0, 2))
if mode == "rgb_array"
else None
)
def draw_base_map(self, surf):
"""Draw the base map."""
hmap_feature = features.SCREEN_FEATURES.height_map
hmap = self.env.terrain_height * 255
hmap = hmap.astype(np.uint8)
if (
self.env.map_name == "corridor"
or self.env.map_name == "so_many_baneling"
or self.env.map_name == "2s_vs_1sc"
):
hmap = np.flip(hmap)
else:
hmap = np.rot90(hmap, axes=(1, 0))
if not hmap.any():
hmap = hmap + 100 # pylint: disable=g-no-augmented-assignment
hmap_color = hmap_feature.color(hmap)
out = hmap_color * 0.6
surf.blit_np_array(out)
def draw_units(self, surf):
"""Draw the units."""
unit_dict = None # Cache the units {tag: unit_proto} for orders.
tau = 2 * math.pi
for u, p in self._get_units():
fraction_damage = clamp(
(u.health_max - u.health) / (u.health_max or 1), 0, 1
)
surf.draw_circle(
colors.PLAYER_ABSOLUTE_PALETTE[u.owner], p, u.radius
)
if fraction_damage > 0:
surf.draw_circle(
colors.PLAYER_ABSOLUTE_PALETTE[u.owner] // 2,
p,
u.radius * fraction_damage,
)
surf.draw_circle(colors.black, p, u.radius, thickness=1)
if self.static_data.unit_stats[u.unit_type].movement_speed > 0:
surf.draw_arc(
colors.white,
p,
u.radius,
u.facing - 0.1,
u.facing + 0.1,
thickness=1,
)
def draw_arc_ratio(
color, world_loc, radius, start, end, thickness=1
):
surf.draw_arc(
color, world_loc, radius, start * tau, end * tau, thickness
)
if u.shield and u.shield_max:
draw_arc_ratio(
colors.blue, p, u.radius - 0.05, 0, u.shield / u.shield_max
)
if u.energy and u.energy_max:
draw_arc_ratio(
colors.purple * 0.9,
p,
u.radius - 0.1,
0,
u.energy / u.energy_max,
)
elif u.orders and 0 < u.orders[0].progress < 1:
draw_arc_ratio(
colors.cyan, p, u.radius - 0.15, 0, u.orders[0].progress
)
if u.buff_duration_remain and u.buff_duration_max:
draw_arc_ratio(
colors.white,
p,
u.radius - 0.2,
0,
u.buff_duration_remain / u.buff_duration_max,
)
if u.attack_upgrade_level:
draw_arc_ratio(
self.upgrade_colors[u.attack_upgrade_level],
p,
u.radius - 0.25,
0.18,
0.22,
thickness=3,
)
if u.armor_upgrade_level:
draw_arc_ratio(
self.upgrade_colors[u.armor_upgrade_level],
p,
u.radius - 0.25,
0.23,
0.27,
thickness=3,
)
if u.shield_upgrade_level:
draw_arc_ratio(
self.upgrade_colors[u.shield_upgrade_level],
p,
u.radius - 0.25,
0.28,
0.32,
thickness=3,
)
def write_small(loc, s):
surf.write_world(self._font_small, colors.white, loc, str(s))
name = self.get_unit_name(
surf,
self.static_data.units.get(u.unit_type, ""),
u.radius,
)
if name:
write_small(p, name)
start_point = p
for o in u.orders:
target_point = None
if o.HasField("target_unit_tag"):
if unit_dict is None:
unit_dict = {
t.tag: t
for t in self.obs.observation.raw_data.units
}
target_unit = unit_dict.get(o.target_unit_tag)
if target_unit:
target_point = point.Point.build(target_unit.pos)
if target_point:
surf.draw_line(colors.cyan, start_point, target_point)
start_point = target_point
else:
break
def draw_overlay(self, surf):
"""Draw the overlay describing resources."""
obs = self.obs.observation
times, steps = zip(*self._game_times)
sec = obs.game_loop // 22.4
surf.write_screen(
self._font_large,
colors.green,
(-0.2, 0.2),
"Score: %s, Step: %s, %.1f/s, Time: %d:%02d"
% (
self.score,
self.step,
sum(steps) / (sum(times) or 1),
sec // 60,
sec % 60,
),
align="right",
)
surf.write_screen(
self._font_large,
colors.green * 0.8,
(-0.2, 1.2),
"APM: %d, EPM: %d, FPS: O:%.1f, R:%.1f"
% (
obs.score.score_details.current_apm,
obs.score.score_details.current_effective_apm,
len(times) / (sum(times) or 1),
len(self._render_times) / (sum(self._render_times) or 1),
),
align="right",
)
def draw_screen(self, surf):
"""Draw the screen area."""
self.draw_base_map(surf)
self.draw_units(surf)
self.draw_overlay(surf)
================================================
FILE: examples/Social_Cognition/ToCM/smac/env/starcraft2/starcraft2.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from smac.env.multiagentenv import MultiAgentEnv
from smac.env.starcraft2.maps import get_map_params
import atexit
from warnings import warn
from operator import attrgetter
from copy import deepcopy
import numpy as np
import enum
import math
from absl import logging
from pysc2 import maps
from pysc2 import run_configs
from pysc2.lib import protocol
from s2clientprotocol import common_pb2 as sc_common
from s2clientprotocol import sc2api_pb2 as sc_pb
from s2clientprotocol import raw_pb2 as r_pb
from s2clientprotocol import debug_pb2 as d_pb
races = {
"R": sc_common.Random,
"P": sc_common.Protoss,
"T": sc_common.Terran,
"Z": sc_common.Zerg,
}
difficulties = {
"1": sc_pb.VeryEasy,
"2": sc_pb.Easy,
"3": sc_pb.Medium,
"4": sc_pb.MediumHard,
"5": sc_pb.Hard,
"6": sc_pb.Harder,
"7": sc_pb.VeryHard,
"8": sc_pb.CheatVision,
"9": sc_pb.CheatMoney,
"A": sc_pb.CheatInsane,
}
actions = {
"move": 16, # target: PointOrUnit
"attack": 23, # target: PointOrUnit
"stop": 4, # target: None
"heal": 386, # Unit
}
class Direction(enum.IntEnum):
NORTH = 0
SOUTH = 1
EAST = 2
WEST = 3
class StarCraft2Env(MultiAgentEnv):
"""The StarCraft II environment for decentralised multi-agent
micromanagement scenarios.
"""
def __init__(
self,
map_name="8m",
step_mul=8,
move_amount=2,
difficulty="7",
game_version=None,
seed=None,
continuing_episode=False,
obs_all_health=True,
obs_own_health=True,
obs_last_action=False,
obs_pathing_grid=False,
obs_terrain_height=False,
obs_instead_of_state=False,
obs_timestep_number=False,
state_last_action=True,
state_timestep_number=False,
reward_sparse=False,
reward_only_positive=True,
reward_death_value=10,
reward_win=200,
reward_defeat=0,
reward_negative_scale=0.5,
reward_scale=True,
reward_scale_rate=20,
replay_dir="",
replay_prefix="",
window_size_x=1920,
window_size_y=1200,
heuristic_ai=False,
heuristic_rest=False,
debug=False,
):
"""
Create a StarCraftC2Env environment.
Parameters
----------
map_name : str, optional
The name of the SC2 map to play (default is "8m"). The full list
can be found by running bin/map_list.
step_mul : int, optional
How many game steps per agent step (default is 8). None
indicates to use the default map step_mul.
move_amount : float, optional
How far away units are ordered to move per step (default is 2).
difficulty : str, optional
The difficulty of built-in computer AI bot (default is "7").
game_version : str, optional
StarCraft II game version (default is None). None indicates the
latest version.
seed : int, optional
Random seed used during game initialisation. This allows to
continuing_episode : bool, optional
Whether to consider episodes continuing or finished after time
limit is reached (default is False).
obs_all_health : bool, optional
Agents receive the health of all units (in the sight range) as part
of observations (default is True).
obs_own_health : bool, optional
Agents receive their own health as a part of observations (default
is False). This flag is ignored when obs_all_health == True.
obs_last_action : bool, optional
Agents receive the last actions of all units (in the sight range)
as part of observations (default is False).
obs_pathing_grid : bool, optional
Whether observations include pathing values surrounding the agent
(default is False).
obs_terrain_height : bool, optional
Whether observations include terrain height values surrounding the
agent (default is False).
obs_instead_of_state : bool, optional
Use combination of all agents' observations as the global state
(default is False).
obs_timestep_number : bool, optional
Whether observations include the current timestep of the episode
(default is False).
state_last_action : bool, optional
Include the last actions of all agents as part of the global state
(default is True).
state_timestep_number : bool, optional
Whether the state include the current timestep of the episode
(default is False).
reward_sparse : bool, optional
Receive 1/-1 reward for winning/loosing an episode (default is
False). Whe rest of reward parameters are ignored if True.
reward_only_positive : bool, optional
Reward is always positive (default is True).
reward_death_value : float, optional
The amount of reward received for killing an enemy unit (default
is 10). This is also the negative penalty for having an allied unit
killed if reward_only_positive == False.
reward_win : float, optional
The reward for winning in an episode (default is 200).
reward_defeat : float, optional
The reward for loosing in an episode (default is 0). This value
should be nonpositive.
reward_negative_scale : float, optional
Scaling factor for negative rewards (default is 0.5). This
parameter is ignored when reward_only_positive == True.
reward_scale : bool, optional
Whether or not to scale the reward (default is True).
reward_scale_rate : float, optional
Reward scale rate (default is 20). When reward_scale == True, the
reward received by the agents is divided by (max_reward /
reward_scale_rate), where max_reward is the maximum possible
reward per episode without considering the shield regeneration
of Protoss units.
replay_dir : str, optional
The directory to save replays (default is None). If None, the
replay will be saved in Replays directory where StarCraft II is
installed.
replay_prefix : str, optional
The prefix of the replay to be saved (default is None). If None,
the name of the map will be used.
window_size_x : int, optional
The length of StarCraft II window size (default is 1920).
window_size_y: int, optional
The height of StarCraft II window size (default is 1200).
heuristic_ai: bool, optional
Whether or not to use a non-learning heuristic AI (default False).
heuristic_rest: bool, optional
At any moment, restrict the actions of the heuristic AI to be
chosen from actions available to RL agents (default is False).
Ignored if heuristic_ai == False.
debug: bool, optional
Log messages about observations, state, actions and rewards for
debugging purposes (default is False).
"""
# Map arguments
self.map_name = map_name
map_params = get_map_params(self.map_name)
self.n_agents = map_params["n_agents"]
self.n_enemies = map_params["n_enemies"]
self.episode_limit = map_params["limit"]
self._move_amount = move_amount
self._step_mul = step_mul
self.difficulty = difficulty
# Observations and state
self.obs_own_health = obs_own_health
self.obs_all_health = obs_all_health
self.obs_instead_of_state = obs_instead_of_state
self.obs_last_action = obs_last_action
self.obs_pathing_grid = obs_pathing_grid
self.obs_terrain_height = obs_terrain_height
self.obs_timestep_number = obs_timestep_number
self.state_last_action = state_last_action
self.state_timestep_number = state_timestep_number
if self.obs_all_health:
self.obs_own_health = True
self.n_obs_pathing = 8
self.n_obs_height = 9
# Rewards args
self.reward_sparse = reward_sparse
self.reward_only_positive = reward_only_positive
self.reward_negative_scale = reward_negative_scale
self.reward_death_value = reward_death_value
self.reward_win = reward_win
self.reward_defeat = reward_defeat
self.reward_scale = reward_scale
self.reward_scale_rate = reward_scale_rate
# Other
self.game_version = game_version
self.continuing_episode = continuing_episode
self._seed = seed
self.heuristic_ai = heuristic_ai
self.heuristic_rest = heuristic_rest
self.debug = debug
self.window_size = (window_size_x, window_size_y)
self.replay_dir = replay_dir
self.replay_prefix = replay_prefix
# Actions
self.n_actions_no_attack = 6
self.n_actions_move = 4
self.n_actions = self.n_actions_no_attack + self.n_enemies
# Map info
self._agent_race = map_params["a_race"]
self._bot_race = map_params["b_race"]
self.shield_bits_ally = 1 if self._agent_race == "P" else 0
self.shield_bits_enemy = 1 if self._bot_race == "P" else 0
self.unit_type_bits = map_params["unit_type_bits"]
self.map_type = map_params["map_type"]
self._unit_types = None
self.max_reward = (
self.n_enemies * self.reward_death_value + self.reward_win
)
# create lists containing the names of attributes returned in states
self.ally_state_attr_names = [
"health",
"energy/cooldown",
"rel_x",
"rel_y",
]
self.enemy_state_attr_names = ["health", "rel_x", "rel_y"]
if self.shield_bits_ally > 0:
self.ally_state_attr_names += ["shield"]
if self.shield_bits_enemy > 0:
self.enemy_state_attr_names += ["shield"]
if self.unit_type_bits > 0:
bit_attr_names = [
"type_{}".format(bit) for bit in range(self.unit_type_bits)
]
self.ally_state_attr_names += bit_attr_names
self.enemy_state_attr_names += bit_attr_names
self.agents = {}
self.enemies = {}
self._episode_count = 0
self._episode_steps = 0
self._total_steps = 0
self._obs = None
self.battles_won = 0
self.battles_game = 0
self.timeouts = 0
self.force_restarts = 0
self.last_stats = None
self.death_tracker_ally = np.zeros(self.n_agents)
self.death_tracker_enemy = np.zeros(self.n_enemies)
self.previous_ally_units = None
self.previous_enemy_units = None
self.last_action = np.zeros((self.n_agents, self.n_actions))
self._min_unit_type = 0
self.marine_id = self.marauder_id = self.medivac_id = 0
self.hydralisk_id = self.zergling_id = self.baneling_id = 0
self.stalker_id = self.colossus_id = self.zealot_id = 0
self.max_distance_x = 0
self.max_distance_y = 0
self.map_x = 0
self.map_y = 0
self.reward = 0
self.renderer = None
self.terrain_height = None
self.pathing_grid = None
self._run_config = None
self._sc2_proc = None
self._controller = None
# Try to avoid leaking SC2 processes on shutdown
atexit.register(lambda: self.close())
def _launch(self):
"""Launch the StarCraft II game."""
self._run_config = run_configs.get(version=self.game_version)
_map = maps.get(self.map_name)
# Setting up the interface
interface_options = sc_pb.InterfaceOptions(raw=True, score=False)
self._sc2_proc = self._run_config.start(
window_size=self.window_size, want_rgb=False
)
self._controller = self._sc2_proc.controller
# Request to create the game
create = sc_pb.RequestCreateGame(
local_map=sc_pb.LocalMap(
map_path=_map.path,
map_data=self._run_config.map_data(_map.path),
),
realtime=False,
random_seed=self._seed,
)
create.player_setup.add(type=sc_pb.Participant)
create.player_setup.add(
type=sc_pb.Computer,
race=races[self._bot_race],
difficulty=difficulties[self.difficulty],
)
self._controller.create_game(create)
join = sc_pb.RequestJoinGame(
race=races[self._agent_race], options=interface_options
)
self._controller.join_game(join)
game_info = self._controller.game_info()
map_info = game_info.start_raw
map_play_area_min = map_info.playable_area.p0
map_play_area_max = map_info.playable_area.p1
self.max_distance_x = map_play_area_max.x - map_play_area_min.x
self.max_distance_y = map_play_area_max.y - map_play_area_min.y
self.map_x = map_info.map_size.x
self.map_y = map_info.map_size.y
if map_info.pathing_grid.bits_per_pixel == 1:
vals = np.array(list(map_info.pathing_grid.data)).reshape(
self.map_x, int(self.map_y / 8)
)
self.pathing_grid = np.transpose(
np.array(
[
[(b >> i) & 1 for b in row for i in range(7, -1, -1)]
for row in vals
],
dtype=np.bool,
)
)
else:
self.pathing_grid = np.invert(
np.flip(
np.transpose(
np.array(
list(map_info.pathing_grid.data), dtype=np.bool
).reshape(self.map_x, self.map_y)
),
axis=1,
)
)
self.terrain_height = (
np.flip(
np.transpose(
np.array(list(map_info.terrain_height.data)).reshape(
self.map_x, self.map_y
)
),
1,
)
/ 255
)
def reset(self):
"""Reset the environment. Required after each full episode.
Returns initial observations and states.
"""
self._episode_steps = 0
if self._episode_count == 0:
# Launch StarCraft II
self._launch()
else:
self._restart()
# Information kept for counting the reward
self.death_tracker_ally = np.zeros(self.n_agents)
self.death_tracker_enemy = np.zeros(self.n_enemies)
self.previous_ally_units = None
self.previous_enemy_units = None
self.win_counted = False
self.defeat_counted = False
self.last_action = np.zeros((self.n_agents, self.n_actions))
if self.heuristic_ai:
self.heuristic_targets = [None] * self.n_agents
try:
self._obs = self._controller.observe()
self.init_units()
except (protocol.ProtocolError, protocol.ConnectionError):
self.full_restart()
if self.debug:
logging.debug(
"Started Episode {}".format(self._episode_count).center(
60, "*"
)
)
return self.get_obs(), self.get_state()
def _restart(self):
"""Restart the environment by killing all units on the map.
There is a trigger in the SC2Map file, which restarts the
episode when there are no units left.
"""
try:
self._kill_all_units()
self._controller.step(2)
except (protocol.ProtocolError, protocol.ConnectionError):
self.full_restart()
def full_restart(self):
"""Full restart. Closes the SC2 process and launches a new one."""
self._sc2_proc.close()
self._launch()
self.force_restarts += 1
def step(self, actions):
"""A single environment step. Returns reward, terminated, info."""
actions_int = [int(a) for a in actions]
self.last_action = np.eye(self.n_actions)[np.array(actions_int)]
# Collect individual actions
sc_actions = []
if self.debug:
logging.debug("Actions".center(60, "-"))
for a_id, action in enumerate(actions_int):
if not self.heuristic_ai:
sc_action = self.get_agent_action(a_id, action)
else:
sc_action, action_num = self.get_agent_action_heuristic(
a_id, action
)
actions[a_id] = action_num
if sc_action:
sc_actions.append(sc_action)
# Send action request
req_actions = sc_pb.RequestAction(actions=sc_actions)
try:
self._controller.actions(req_actions)
# Make step in SC2, i.e. apply actions
self._controller.step(self._step_mul)
# Observe here so that we know if the episode is over.
self._obs = self._controller.observe()
except (protocol.ProtocolError, protocol.ConnectionError):
self.full_restart()
return 0, True, {}
self._total_steps += 1
self._episode_steps += 1
# Update units
game_end_code = self.update_units()
terminated = False
reward = self.reward_battle()
info = {"battle_won": False}
# count units that are still alive
dead_allies, dead_enemies = 0, 0
for _al_id, al_unit in self.agents.items():
if al_unit.health == 0:
dead_allies += 1
for _e_id, e_unit in self.enemies.items():
if e_unit.health == 0:
dead_enemies += 1
info["dead_allies"] = dead_allies
info["dead_enemies"] = dead_enemies
if game_end_code is not None:
# Battle is over
terminated = True
self.battles_game += 1
if game_end_code == 1 and not self.win_counted:
self.battles_won += 1
self.win_counted = True
info["battle_won"] = True
if not self.reward_sparse:
reward += self.reward_win
else:
reward = 1
elif game_end_code == -1 and not self.defeat_counted:
self.defeat_counted = True
if not self.reward_sparse:
reward += self.reward_defeat
else:
reward = -1
elif self._episode_steps >= self.episode_limit:
# Episode limit reached
terminated = True
if self.continuing_episode:
info["episode_limit"] = True
self.battles_game += 1
self.timeouts += 1
if self.debug:
logging.debug("Reward = {}".format(reward).center(60, "-"))
if terminated:
self._episode_count += 1
if self.reward_scale:
reward /= self.max_reward / self.reward_scale_rate
self.reward = reward
return reward, terminated, info
def get_agent_action(self, a_id, action):
"""Construct the action for agent a_id."""
avail_actions = self.get_avail_agent_actions(a_id)
assert (
avail_actions[action] == 1
), "Agent {} cannot perform action {}".format(a_id, action)
unit = self.get_unit_by_id(a_id)
tag = unit.tag
x = unit.pos.x
y = unit.pos.y
if action == 0:
# no-op (valid only when dead)
assert unit.health == 0, "No-op only available for dead agents."
if self.debug:
logging.debug("Agent {}: Dead".format(a_id))
return None
elif action == 1:
# stop
cmd = r_pb.ActionRawUnitCommand(
ability_id=actions["stop"],
unit_tags=[tag],
queue_command=False,
)
if self.debug:
logging.debug("Agent {}: Stop".format(a_id))
elif action == 2:
# move north
cmd = r_pb.ActionRawUnitCommand(
ability_id=actions["move"],
target_world_space_pos=sc_common.Point2D(
x=x, y=y + self._move_amount
),
unit_tags=[tag],
queue_command=False,
)
if self.debug:
logging.debug("Agent {}: Move North".format(a_id))
elif action == 3:
# move south
cmd = r_pb.ActionRawUnitCommand(
ability_id=actions["move"],
target_world_space_pos=sc_common.Point2D(
x=x, y=y - self._move_amount
),
unit_tags=[tag],
queue_command=False,
)
if self.debug:
logging.debug("Agent {}: Move South".format(a_id))
elif action == 4:
# move east
cmd = r_pb.ActionRawUnitCommand(
ability_id=actions["move"],
target_world_space_pos=sc_common.Point2D(
x=x + self._move_amount, y=y
),
unit_tags=[tag],
queue_command=False,
)
if self.debug:
logging.debug("Agent {}: Move East".format(a_id))
elif action == 5:
# move west
cmd = r_pb.ActionRawUnitCommand(
ability_id=actions["move"],
target_world_space_pos=sc_common.Point2D(
x=x - self._move_amount, y=y
),
unit_tags=[tag],
queue_command=False,
)
if self.debug:
logging.debug("Agent {}: Move West".format(a_id))
else:
# attack/heal units that are in range
target_id = action - self.n_actions_no_attack
if self.map_type == "MMM" and unit.unit_type == self.medivac_id:
target_unit = self.agents[target_id]
action_name = "heal"
else:
target_unit = self.enemies[target_id]
action_name = "attack"
action_id = actions[action_name]
target_tag = target_unit.tag
cmd = r_pb.ActionRawUnitCommand(
ability_id=action_id,
target_unit_tag=target_tag,
unit_tags=[tag],
queue_command=False,
)
if self.debug:
logging.debug(
"Agent {} {}s unit # {}".format(
a_id, action_name, target_id
)
)
sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd))
return sc_action
def get_agent_action_heuristic(self, a_id, action):
unit = self.get_unit_by_id(a_id)
tag = unit.tag
target = self.heuristic_targets[a_id]
if unit.unit_type == self.medivac_id:
if (
target is None
or self.agents[target].health == 0
or self.agents[target].health == self.agents[target].health_max
):
min_dist = math.hypot(self.max_distance_x, self.max_distance_y)
min_id = -1
for al_id, al_unit in self.agents.items():
if al_unit.unit_type == self.medivac_id:
continue
if (
al_unit.health != 0
and al_unit.health != al_unit.health_max
):
dist = self.distance(
unit.pos.x,
unit.pos.y,
al_unit.pos.x,
al_unit.pos.y,
)
if dist < min_dist:
min_dist = dist
min_id = al_id
self.heuristic_targets[a_id] = min_id
if min_id == -1:
self.heuristic_targets[a_id] = None
return None, 0
action_id = actions["heal"]
target_tag = self.agents[self.heuristic_targets[a_id]].tag
else:
if target is None or self.enemies[target].health == 0:
min_dist = math.hypot(self.max_distance_x, self.max_distance_y)
min_id = -1
for e_id, e_unit in self.enemies.items():
if (
unit.unit_type == self.marauder_id
and e_unit.unit_type == self.medivac_id
):
continue
if e_unit.health > 0:
dist = self.distance(
unit.pos.x, unit.pos.y, e_unit.pos.x, e_unit.pos.y
)
if dist < min_dist:
min_dist = dist
min_id = e_id
self.heuristic_targets[a_id] = min_id
if min_id == -1:
self.heuristic_targets[a_id] = None
return None, 0
action_id = actions["attack"]
target_tag = self.enemies[self.heuristic_targets[a_id]].tag
action_num = self.heuristic_targets[a_id] + self.n_actions_no_attack
# Check if the action is available
if (
self.heuristic_rest
and self.get_avail_agent_actions(a_id)[action_num] == 0
):
# Move towards the target rather than attacking/healing
if unit.unit_type == self.medivac_id:
target_unit = self.agents[self.heuristic_targets[a_id]]
else:
target_unit = self.enemies[self.heuristic_targets[a_id]]
delta_x = target_unit.pos.x - unit.pos.x
delta_y = target_unit.pos.y - unit.pos.y
if abs(delta_x) > abs(delta_y): # east or west
if delta_x > 0: # east
target_pos = sc_common.Point2D(
x=unit.pos.x + self._move_amount, y=unit.pos.y
)
action_num = 4
else: # west
target_pos = sc_common.Point2D(
x=unit.pos.x - self._move_amount, y=unit.pos.y
)
action_num = 5
else: # north or south
if delta_y > 0: # north
target_pos = sc_common.Point2D(
x=unit.pos.x, y=unit.pos.y + self._move_amount
)
action_num = 2
else: # south
target_pos = sc_common.Point2D(
x=unit.pos.x, y=unit.pos.y - self._move_amount
)
action_num = 3
cmd = r_pb.ActionRawUnitCommand(
ability_id=actions["move"],
target_world_space_pos=target_pos,
unit_tags=[tag],
queue_command=False,
)
else:
# Attack/heal the target
cmd = r_pb.ActionRawUnitCommand(
ability_id=action_id,
target_unit_tag=target_tag,
unit_tags=[tag],
queue_command=False,
)
sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd))
return sc_action, action_num
def reward_battle(self):
"""Reward function when self.reward_spare==False.
Returns accumulative hit/shield point damage dealt to the enemy
+ reward_death_value per enemy unit killed, and, in case
self.reward_only_positive == False, - (damage dealt to ally units
+ reward_death_value per ally unit killed) * self.reward_negative_scale
"""
if self.reward_sparse:
return 0
reward = 0
delta_deaths = 0
delta_ally = 0
delta_enemy = 0
neg_scale = self.reward_negative_scale
# update deaths
for al_id, al_unit in self.agents.items():
if not self.death_tracker_ally[al_id]:
# did not die so far
prev_health = (
self.previous_ally_units[al_id].health
+ self.previous_ally_units[al_id].shield
)
if al_unit.health == 0:
# just died
self.death_tracker_ally[al_id] = 1
if not self.reward_only_positive:
delta_deaths -= self.reward_death_value * neg_scale
delta_ally += prev_health * neg_scale
else:
# still alive
delta_ally += neg_scale * (
prev_health - al_unit.health - al_unit.shield
)
for e_id, e_unit in self.enemies.items():
if not self.death_tracker_enemy[e_id]:
prev_health = (
self.previous_enemy_units[e_id].health
+ self.previous_enemy_units[e_id].shield
)
if e_unit.health == 0:
self.death_tracker_enemy[e_id] = 1
delta_deaths += self.reward_death_value
delta_enemy += prev_health
else:
delta_enemy += prev_health - e_unit.health - e_unit.shield
if self.reward_only_positive:
reward = abs(delta_enemy + delta_deaths) # shield regeneration
else:
reward = delta_enemy + delta_deaths - delta_ally
return reward
def get_total_actions(self):
"""Returns the total number of actions an agent could ever take."""
return self.n_actions
@staticmethod
def distance(x1, y1, x2, y2):
"""Distance between two points."""
return math.hypot(x2 - x1, y2 - y1)
def unit_shoot_range(self, agent_id):
"""Returns the shooting range for an agent."""
return 6
def unit_sight_range(self, agent_id):
"""Returns the sight range for an agent."""
return 9
def unit_max_cooldown(self, unit):
"""Returns the maximal cooldown for a unit."""
switcher = {
self.marine_id: 15,
self.marauder_id: 25,
self.medivac_id: 200, # max energy
self.stalker_id: 35,
self.zealot_id: 22,
self.colossus_id: 24,
self.hydralisk_id: 10,
self.zergling_id: 11,
self.baneling_id: 1,
}
return switcher.get(unit.unit_type, 15)
def save_replay(self):
"""Save a replay."""
prefix = self.replay_prefix or self.map_name
replay_dir = self.replay_dir or ""
replay_path = self._run_config.save_replay(
self._controller.save_replay(),
replay_dir=replay_dir,
prefix=prefix,
)
logging.info("Replay saved at: %s" % replay_path)
def unit_max_shield(self, unit):
"""Returns maximal shield for a given unit."""
if unit.unit_type == 74 or unit.unit_type == self.stalker_id:
return 80 # Protoss's Stalker
if unit.unit_type == 73 or unit.unit_type == self.zealot_id:
return 50 # Protoss's Zaelot
if unit.unit_type == 4 or unit.unit_type == self.colossus_id:
return 150 # Protoss's Colossus
def can_move(self, unit, direction):
"""Whether a unit can move in a given direction."""
m = self._move_amount / 2
if direction == Direction.NORTH:
x, y = int(unit.pos.x), int(unit.pos.y + m)
elif direction == Direction.SOUTH:
x, y = int(unit.pos.x), int(unit.pos.y - m)
elif direction == Direction.EAST:
x, y = int(unit.pos.x + m), int(unit.pos.y)
else:
x, y = int(unit.pos.x - m), int(unit.pos.y)
if self.check_bounds(x, y) and self.pathing_grid[x, y]:
return True
return False
def get_surrounding_points(self, unit, include_self=False):
"""Returns the surrounding points of the unit in 8 directions."""
x = int(unit.pos.x)
y = int(unit.pos.y)
ma = self._move_amount
points = [
(x, y + 2 * ma),
(x, y - 2 * ma),
(x + 2 * ma, y),
(x - 2 * ma, y),
(x + ma, y + ma),
(x - ma, y - ma),
(x + ma, y - ma),
(x - ma, y + ma),
]
if include_self:
points.append((x, y))
return points
def check_bounds(self, x, y):
"""Whether a point is within the map bounds."""
return 0 <= x < self.map_x and 0 <= y < self.map_y
def get_surrounding_pathing(self, unit):
"""Returns pathing values of the grid surrounding the given unit."""
points = self.get_surrounding_points(unit, include_self=False)
vals = [
self.pathing_grid[x, y] if self.check_bounds(x, y) else 1
for x, y in points
]
return vals
def get_surrounding_height(self, unit):
"""Returns height values of the grid surrounding the given unit."""
points = self.get_surrounding_points(unit, include_self=True)
vals = [
self.terrain_height[x, y] if self.check_bounds(x, y) else 1
for x, y in points
]
return vals
def get_obs_agent(self, agent_id):
"""Returns observation for agent_id. The observation is composed of:
- agent movement features (where it can move to, height information
and pathing grid)
- enemy features (available_to_attack, health, relative_x, relative_y,
shield, unit_type)
- ally features (visible, distance, relative_x, relative_y, shield,
unit_type)
- agent unit features (health, shield, unit_type)
All of this information is flattened and concatenated into a list,
in the aforementioned order. To know the sizes of each of the
features inside the final list of features, take a look at the
functions ``get_obs_move_feats_size()``,
``get_obs_enemy_feats_size()``, ``get_obs_ally_feats_size()`` and
``get_obs_own_feats_size()``.
The size of the observation vector may vary, depending on the
environment configuration and type of units present in the map.
For instance, non-Protoss units will not have shields, movement
features may or may not include terrain height and pathing grid,
unit_type is not included if there is only one type of unit in the
map etc.).
NOTE: Agents should have access only to their local observations
during decentralised execution.
"""
unit = self.get_unit_by_id(agent_id)
move_feats_dim = self.get_obs_move_feats_size()
enemy_feats_dim = self.get_obs_enemy_feats_size()
ally_feats_dim = self.get_obs_ally_feats_size()
own_feats_dim = self.get_obs_own_feats_size()
move_feats = np.zeros(move_feats_dim, dtype=np.float32)
enemy_feats = np.zeros(enemy_feats_dim, dtype=np.float32)
ally_feats = np.zeros(ally_feats_dim, dtype=np.float32)
own_feats = np.zeros(own_feats_dim, dtype=np.float32)
if unit.health > 0: # otherwise dead, return all zeros
x = unit.pos.x
y = unit.pos.y
sight_range = self.unit_sight_range(agent_id)
# Movement features
avail_actions = self.get_avail_agent_actions(agent_id)
for m in range(self.n_actions_move):
move_feats[m] = avail_actions[m + 2]
ind = self.n_actions_move
if self.obs_pathing_grid:
move_feats[
ind : ind + self.n_obs_pathing # noqa
] = self.get_surrounding_pathing(unit)
ind += self.n_obs_pathing
if self.obs_terrain_height:
move_feats[ind:] = self.get_surrounding_height(unit)
# Enemy features
for e_id, e_unit in self.enemies.items():
e_x = e_unit.pos.x
e_y = e_unit.pos.y
dist = self.distance(x, y, e_x, e_y)
if (
dist < sight_range and e_unit.health > 0
): # visible and alive
# Sight range > shoot range
enemy_feats[e_id, 0] = avail_actions[
self.n_actions_no_attack + e_id
] # available
enemy_feats[e_id, 1] = dist / sight_range # distance
enemy_feats[e_id, 2] = (
e_x - x
) / sight_range # relative X
enemy_feats[e_id, 3] = (
e_y - y
) / sight_range # relative Y
ind = 4
if self.obs_all_health:
enemy_feats[e_id, ind] = (
e_unit.health / e_unit.health_max
) # health
ind += 1
if self.shield_bits_enemy > 0:
max_shield = self.unit_max_shield(e_unit)
enemy_feats[e_id, ind] = (
e_unit.shield / max_shield
) # shield
ind += 1
if self.unit_type_bits > 0:
type_id = self.get_unit_type_id(e_unit, False)
enemy_feats[e_id, ind + type_id] = 1 # unit type
# Ally features
al_ids = [
al_id for al_id in range(self.n_agents) if al_id != agent_id
]
for i, al_id in enumerate(al_ids):
al_unit = self.get_unit_by_id(al_id)
al_x = al_unit.pos.x
al_y = al_unit.pos.y
dist = self.distance(x, y, al_x, al_y)
if (
dist < sight_range and al_unit.health > 0
): # visible and alive
ally_feats[i, 0] = 1 # visible
ally_feats[i, 1] = dist / sight_range # distance
ally_feats[i, 2] = (al_x - x) / sight_range # relative X
ally_feats[i, 3] = (al_y - y) / sight_range # relative Y
ind = 4
if self.obs_all_health:
ally_feats[i, ind] = (
al_unit.health / al_unit.health_max
) # health
ind += 1
if self.shield_bits_ally > 0:
max_shield = self.unit_max_shield(al_unit)
ally_feats[i, ind] = (
al_unit.shield / max_shield
) # shield
ind += 1
if self.unit_type_bits > 0:
type_id = self.get_unit_type_id(al_unit, True)
ally_feats[i, ind + type_id] = 1
ind += self.unit_type_bits
if self.obs_last_action:
ally_feats[i, ind:] = self.last_action[al_id]
# Own features
ind = 0
if self.obs_own_health:
own_feats[ind] = unit.health / unit.health_max
ind += 1
if self.shield_bits_ally > 0:
max_shield = self.unit_max_shield(unit)
own_feats[ind] = unit.shield / max_shield
ind += 1
if self.unit_type_bits > 0:
type_id = self.get_unit_type_id(unit, True)
own_feats[ind + type_id] = 1
agent_obs = np.concatenate(
(
move_feats.flatten(),
enemy_feats.flatten(),
ally_feats.flatten(),
own_feats.flatten(),
)
)
if self.obs_timestep_number:
agent_obs = np.append(
agent_obs, self._episode_steps / self.episode_limit
)
if self.debug:
logging.debug("Obs Agent: {}".format(agent_id).center(60, "-"))
logging.debug(
"Avail. actions {}".format(
self.get_avail_agent_actions(agent_id)
)
)
logging.debug("Move feats {}".format(move_feats))
logging.debug("Enemy feats {}".format(enemy_feats))
logging.debug("Ally feats {}".format(ally_feats))
logging.debug("Own feats {}".format(own_feats))
return agent_obs
def get_obs(self):
"""Returns all agent observations in a list.
NOTE: Agents should have access only to their local observations
during decentralised execution.
"""
agents_obs = [self.get_obs_agent(i) for i in range(self.n_agents)]
return agents_obs # return a list of agent obs
def get_state(self):
"""Returns the global state.
NOTE: This functon should not be used during decentralised execution.
"""
if self.obs_instead_of_state:
obs_concat = np.concatenate(self.get_obs(), axis=0).astype(
np.float32
)
return obs_concat
state_dict = self.get_state_dict()
state = np.append(
state_dict["allies"].flatten(), state_dict["enemies"].flatten()
)
if "last_action" in state_dict:
state = np.append(state, state_dict["last_action"].flatten())
if "timestep" in state_dict:
state = np.append(state, state_dict["timestep"])
state = state.astype(dtype=np.float32)
if self.debug:
logging.debug("STATE".center(60, "-"))
logging.debug("Ally state {}".format(state_dict["allies"]))
logging.debug("Enemy state {}".format(state_dict["enemies"]))
if self.state_last_action:
logging.debug("Last actions {}".format(self.last_action))
return state
def get_ally_num_attributes(self):
return len(self.ally_state_attr_names)
def get_enemy_num_attributes(self):
return len(self.enemy_state_attr_names)
def get_state_dict(self):
"""Returns the global state as a dictionary.
- allies: numpy array containing agents and their attributes
- enemies: numpy array containing enemies and their attributes
- last_action: numpy array of previous actions for each agent
- timestep: current no. of steps divided by total no. of steps
NOTE: This function should not be used during decentralised execution.
"""
# number of features equals the number of attribute names
nf_al = self.get_ally_num_attributes()
nf_en = self.get_enemy_num_attributes()
ally_state = np.zeros((self.n_agents, nf_al))
enemy_state = np.zeros((self.n_enemies, nf_en))
center_x = self.map_x / 2
center_y = self.map_y / 2
for al_id, al_unit in self.agents.items():
if al_unit.health > 0:
x = al_unit.pos.x
y = al_unit.pos.y
max_cd = self.unit_max_cooldown(al_unit)
ally_state[al_id, 0] = (
al_unit.health / al_unit.health_max
) # health
if (
self.map_type == "MMM"
and al_unit.unit_type == self.medivac_id
):
ally_state[al_id, 1] = al_unit.energy / max_cd # energy
else:
ally_state[al_id, 1] = (
al_unit.weapon_cooldown / max_cd
) # cooldown
ally_state[al_id, 2] = (
x - center_x
) / self.max_distance_x # relative X
ally_state[al_id, 3] = (
y - center_y
) / self.max_distance_y # relative Y
if self.shield_bits_ally > 0:
max_shield = self.unit_max_shield(al_unit)
ally_state[al_id, 4] = (
al_unit.shield / max_shield
) # shield
if self.unit_type_bits > 0:
type_id = self.get_unit_type_id(al_unit, True)
ally_state[al_id, type_id - self.unit_type_bits] = 1
for e_id, e_unit in self.enemies.items():
if e_unit.health > 0:
x = e_unit.pos.x
y = e_unit.pos.y
enemy_state[e_id, 0] = (
e_unit.health / e_unit.health_max
) # health
enemy_state[e_id, 1] = (
x - center_x
) / self.max_distance_x # relative X
enemy_state[e_id, 2] = (
y - center_y
) / self.max_distance_y # relative Y
if self.shield_bits_enemy > 0:
max_shield = self.unit_max_shield(e_unit)
enemy_state[e_id, 3] = e_unit.shield / max_shield # shield
if self.unit_type_bits > 0:
type_id = self.get_unit_type_id(e_unit, False)
enemy_state[e_id, type_id - self.unit_type_bits] = 1
state = {"allies": ally_state, "enemies": enemy_state}
if self.state_last_action:
state["last_action"] = self.last_action
if self.state_timestep_number:
state["timestep"] = self._episode_steps / self.episode_limit
return state
def get_obs_enemy_feats_size(self):
"""Returns the dimensions of the matrix containing enemy features.
Size is n_enemies x n_features.
"""
nf_en = 4 + self.unit_type_bits
if self.obs_all_health:
nf_en += 1 + self.shield_bits_enemy
return self.n_enemies, nf_en
def get_obs_ally_feats_size(self):
"""Returns the dimensions of the matrix containing ally features.
Size is n_allies x n_features.
"""
nf_al = 4 + self.unit_type_bits
if self.obs_all_health:
nf_al += 1 + self.shield_bits_ally
if self.obs_last_action:
nf_al += self.n_actions
return self.n_agents - 1, nf_al
def get_obs_own_feats_size(self):
"""
Returns the size of the vector containing the agents' own features.
"""
own_feats = self.unit_type_bits
if self.obs_own_health:
own_feats += 1 + self.shield_bits_ally
if self.obs_timestep_number:
own_feats += 1
return own_feats
def get_obs_move_feats_size(self):
"""Returns the size of the vector containing the agents's movement-
related features.
"""
move_feats = self.n_actions_move
if self.obs_pathing_grid:
move_feats += self.n_obs_pathing
if self.obs_terrain_height:
move_feats += self.n_obs_height
return move_feats
def get_obs_size(self):
"""Returns the size of the observation."""
own_feats = self.get_obs_own_feats_size()
move_feats = self.get_obs_move_feats_size()
n_enemies, n_enemy_feats = self.get_obs_enemy_feats_size()
n_allies, n_ally_feats = self.get_obs_ally_feats_size()
enemy_feats = n_enemies * n_enemy_feats
ally_feats = n_allies * n_ally_feats
return move_feats + enemy_feats + ally_feats + own_feats
def get_state_size(self):
"""Returns the size of the global state."""
if self.obs_instead_of_state:
return self.get_obs_size() * self.n_agents
nf_al = 4 + self.shield_bits_ally + self.unit_type_bits
nf_en = 3 + self.shield_bits_enemy + self.unit_type_bits
enemy_state = self.n_enemies * nf_en
ally_state = self.n_agents * nf_al
size = enemy_state + ally_state
if self.state_last_action:
size += self.n_agents * self.n_actions
if self.state_timestep_number:
size += 1
return size
def get_visibility_matrix(self):
"""Returns a boolean numpy array of dimensions
(n_agents, n_agents + n_enemies) indicating which units
are visible to each agent.
"""
arr = np.zeros(
(self.n_agents, self.n_agents + self.n_enemies),
dtype=np.bool,
)
for agent_id in range(self.n_agents):
current_agent = self.get_unit_by_id(agent_id)
if current_agent.health > 0: # it agent not dead
x = current_agent.pos.x
y = current_agent.pos.y
sight_range = self.unit_sight_range(agent_id)
# Enemies
for e_id, e_unit in self.enemies.items():
e_x = e_unit.pos.x
e_y = e_unit.pos.y
dist = self.distance(x, y, e_x, e_y)
if dist < sight_range and e_unit.health > 0:
# visible and alive
arr[agent_id, self.n_agents + e_id] = 1
# The matrix for allies is filled symmetrically
al_ids = [
al_id for al_id in range(self.n_agents) if al_id > agent_id
]
for _, al_id in enumerate(al_ids):
al_unit = self.get_unit_by_id(al_id)
al_x = al_unit.pos.x
al_y = al_unit.pos.y
dist = self.distance(x, y, al_x, al_y)
if dist < sight_range and al_unit.health > 0:
# visible and alive
arr[agent_id, al_id] = arr[al_id, agent_id] = 1
return arr
def get_unit_type_id(self, unit, ally):
"""Returns the ID of unit type in the given scenario."""
if ally: # use new SC2 unit types
type_id = unit.unit_type - self._min_unit_type
else: # use default SC2 unit types
if self.map_type == "stalkers_and_zealots":
# id(Stalker) = 74, id(Zealot) = 73
type_id = unit.unit_type - 73
elif self.map_type == "colossi_stalkers_zealots":
# id(Stalker) = 74, id(Zealot) = 73, id(Colossus) = 4
if unit.unit_type == 4:
type_id = 0
elif unit.unit_type == 74:
type_id = 1
else:
type_id = 2
elif self.map_type == "bane":
if unit.unit_type == 9:
type_id = 0
else:
type_id = 1
elif self.map_type == "MMM":
if unit.unit_type == 51:
type_id = 0
elif unit.unit_type == 48:
type_id = 1
else:
type_id = 2
return type_id
def get_avail_agent_actions(self, agent_id):
"""Returns the available actions for agent_id."""
unit = self.get_unit_by_id(agent_id)
if unit.health > 0:
# cannot choose no-op when alive
avail_actions = [0] * self.n_actions
# stop should be allowed
avail_actions[1] = 1
# see if we can move
if self.can_move(unit, Direction.NORTH):
avail_actions[2] = 1
if self.can_move(unit, Direction.SOUTH):
avail_actions[3] = 1
if self.can_move(unit, Direction.EAST):
avail_actions[4] = 1
if self.can_move(unit, Direction.WEST):
avail_actions[5] = 1
# Can attack only alive units that are alive in the shooting range
shoot_range = self.unit_shoot_range(agent_id)
target_items = self.enemies.items()
if self.map_type == "MMM" and unit.unit_type == self.medivac_id:
# Medivacs cannot heal themselves or other flying units
target_items = [
(t_id, t_unit)
for (t_id, t_unit) in self.agents.items()
if t_unit.unit_type != self.medivac_id
]
for t_id, t_unit in target_items:
if t_unit.health > 0:
dist = self.distance(
unit.pos.x, unit.pos.y, t_unit.pos.x, t_unit.pos.y
)
if dist <= shoot_range:
avail_actions[t_id + self.n_actions_no_attack] = 1
return avail_actions
else:
# only no-op allowed
return [1] + [0] * (self.n_actions - 1)
def get_avail_actions(self):
"""Returns the available actions of all agents in a list."""
avail_actions = []
for agent_id in range(self.n_agents):
avail_agent = self.get_avail_agent_actions(agent_id)
avail_actions.append(avail_agent)
return avail_actions
def close(self):
"""Close StarCraft II."""
if self.renderer is not None:
self.renderer.close()
self.renderer = None
if self._sc2_proc:
self._sc2_proc.close()
def seed(self):
"""Returns the random seed used by the environment."""
return self._seed
def render(self, mode="human"):
if self.renderer is None:
from smac.env.starcraft2.render import StarCraft2Renderer
self.renderer = StarCraft2Renderer(self, mode)
assert (
mode == self.renderer.mode
), "mode must be consistent across render calls"
return self.renderer.render(mode)
def _kill_all_units(self):
"""Kill all units on the map."""
units_alive = [
unit.tag for unit in self.agents.values() if unit.health > 0
] + [unit.tag for unit in self.enemies.values() if unit.health > 0]
debug_command = [
d_pb.DebugCommand(kill_unit=d_pb.DebugKillUnit(tag=units_alive))
]
self._controller.debug(debug_command)
def init_units(self):
"""Initialise the units."""
while True:
# Sometimes not all units have yet been created by SC2
self.agents = {}
self.enemies = {}
ally_units = [
unit
for unit in self._obs.observation.raw_data.units
if unit.owner == 1
]
ally_units_sorted = sorted(
ally_units,
key=attrgetter("unit_type", "pos.x", "pos.y"),
reverse=False,
)
for i in range(len(ally_units_sorted)):
self.agents[i] = ally_units_sorted[i]
if self.debug:
logging.debug(
"Unit {} is {}, x = {}, y = {}".format(
len(self.agents),
self.agents[i].unit_type,
self.agents[i].pos.x,
self.agents[i].pos.y,
)
)
for unit in self._obs.observation.raw_data.units:
if unit.owner == 2:
self.enemies[len(self.enemies)] = unit
if self._episode_count == 0:
self.max_reward += unit.health_max + unit.shield_max
if self._episode_count == 0:
min_unit_type = min(
unit.unit_type for unit in self.agents.values()
)
self._init_ally_unit_types(min_unit_type)
all_agents_created = len(self.agents) == self.n_agents
all_enemies_created = len(self.enemies) == self.n_enemies
self._unit_types = [
unit.unit_type for unit in ally_units_sorted
] + [
unit.unit_type
for unit in self._obs.observation.raw_data.units
if unit.owner == 2
]
if all_agents_created and all_enemies_created: # all good
return
try:
self._controller.step(1)
self._obs = self._controller.observe()
except (protocol.ProtocolError, protocol.ConnectionError):
self.full_restart()
self.reset()
def get_unit_types(self):
if self._unit_types is None:
warn(
"unit types have not been initialized yet, please call"
"env.reset() to populate this and call t1286he method again."
)
return self._unit_types
def update_units(self):
"""Update units after an environment step.
This function assumes that self._obs is up-to-date.
"""
n_ally_alive = 0
n_enemy_alive = 0
# Store previous state
self.previous_ally_units = deepcopy(self.agents)
self.previous_enemy_units = deepcopy(self.enemies)
for al_id, al_unit in self.agents.items():
updated = False
for unit in self._obs.observation.raw_data.units:
if al_unit.tag == unit.tag:
self.agents[al_id] = unit
updated = True
n_ally_alive += 1
break
if not updated: # dead
al_unit.health = 0
for e_id, e_unit in self.enemies.items():
updated = False
for unit in self._obs.observation.raw_data.units:
if e_unit.tag == unit.tag:
self.enemies[e_id] = unit
updated = True
n_enemy_alive += 1
break
if not updated: # dead
e_unit.health = 0
if (
n_ally_alive == 0
and n_enemy_alive > 0
or self.only_medivac_left(ally=True)
):
return -1 # lost
if (
n_ally_alive > 0
and n_enemy_alive == 0
or self.only_medivac_left(ally=False)
):
return 1 # won
if n_ally_alive == 0 and n_enemy_alive == 0:
return 0
return None
def _init_ally_unit_types(self, min_unit_type):
"""Initialise ally unit types. Should be called once from the
init_units function.
"""
self._min_unit_type = min_unit_type
if self.map_type == "marines":
self.marine_id = min_unit_type
elif self.map_type == "stalkers_and_zealots":
self.stalker_id = min_unit_type
self.zealot_id = min_unit_type + 1
elif self.map_type == "colossi_stalkers_zealots":
self.colossus_id = min_unit_type
self.stalker_id = min_unit_type + 1
self.zealot_id = min_unit_type + 2
elif self.map_type == "MMM":
self.marauder_id = min_unit_type
self.marine_id = min_unit_type + 1
self.medivac_id = min_unit_type + 2
elif self.map_type == "zealots":
self.zealot_id = min_unit_type
elif self.map_type == "hydralisks":
self.hydralisk_id = min_unit_type
elif self.map_type == "stalkers":
self.stalker_id = min_unit_type
elif self.map_type == "colossus":
self.colossus_id = min_unit_type
elif self.map_type == "bane":
self.baneling_id = min_unit_type
self.zergling_id = min_unit_type + 1
def only_medivac_left(self, ally):
"""Check if only Medivac units are left."""
if self.map_type != "MMM":
return False
if ally:
units_alive = [
a
for a in self.agents.values()
if (a.health > 0 and a.unit_type != self.medivac_id)
]
if len(units_alive) == 0:
return True
return False
else:
units_alive = [
a
for a in self.enemies.values()
if (a.health > 0 and a.unit_type != self.medivac_id)
]
if len(units_alive) == 1 and units_alive[0].unit_type == 54:
return True
return False
def get_unit_by_id(self, a_id):
"""Get unit by ID."""
return self.agents[a_id]
def get_stats(self):
stats = {
"battles_won": self.battles_won,
"battles_game": self.battles_game,
"battles_draw": self.timeouts,
"win_rate": self.battles_won / self.battles_game,
"timeouts": self.timeouts,
"restarts": self.force_restarts,
}
return stats
def get_env_info(self):
env_info = super().get_env_info()
env_info["agent_features"] = self.ally_state_attr_names
env_info["enemy_features"] = self.enemy_state_attr_names
return env_info
================================================
FILE: examples/Social_Cognition/ToCM/smac/examples/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/ToCM/smac/examples/pettingzoo/README.rst
================================================
SMAC on PettingZoo
==================
This example shows how to run SMAC environments with PettingZoo multi-agent API.
Instructions
------------
To get started, first install PettingZoo with ``pip install pettingzoo``.
The SMAC environment for PettingZoo, ``StarCraft2PZEnv``, can be initialized with two different API templates.
* **AEC**: PettingZoo is based in the *Agent Environment Cycle* game model, more information about "AEC" can be read in the following `paper `_. To create a SMAC environment as an "AEC" PettingZoo game model use: ::
from smac.env.pettingzoo import StarCraft2PZEnv
env = StarCraft2PZEnv.env()
* **Parallel**: PettingZoo also supports parallel environments where all agents have simultaneous actions and observations. This type of environment can be created as follows: ::
from smac.env.pettingzoo import StarCraft2PZEnv
env = StarCraft2PZEnv.parallel_env()
`pettingzoo_demo.py` has an example of a SMAC environment being used as a PettingZoo "AEC" environment. With `env.render()` it is possible to output an instance of the environment as a frame in pygame. This is useful for debugging purposes.
| See https://www.pettingzoo.ml/api for more documentation.
================================================
FILE: examples/Social_Cognition/ToCM/smac/examples/pettingzoo/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/ToCM/smac/examples/pettingzoo/pettingzoo_demo.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import numpy as np
from smac.env.pettingzoo import StarCraft2PZEnv
def main():
"""
Runs an env object with random actions.
"""
env = StarCraft2PZEnv.env()
episodes = 10
total_reward = 0
done = False
completed_episodes = 0
while completed_episodes < episodes:
env.reset()
for agent in env.agent_iter():
env.render()
obs, reward, done, _ = env.last()
total_reward += reward
if done:
action = None
elif isinstance(obs, dict) and "action_mask" in obs:
action = random.choice(np.flatnonzero(obs["action_mask"]))
else:
action = env.action_spaces[agent].sample()
env.step(action)
completed_episodes += 1
env.close()
print("Average total reward", total_reward / episodes)
if __name__ == "__main__":
main()
================================================
FILE: examples/Social_Cognition/ToCM/smac/examples/random_agents.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from smac.env import StarCraft2Env
import numpy as np
def main():
env = StarCraft2Env(map_name="8m")
env_info = env.get_env_info()
n_actions = env_info["n_actions"]
n_agents = env_info["n_agents"]
n_episodes = 10
for e in range(n_episodes):
env.reset()
terminated = False
episode_reward = 0
while not terminated:
obs = env.get_obs()
state = env.get_state()
# env.render() # Uncomment for rendering
actions = []
for agent_id in range(n_agents):
avail_actions = env.get_avail_agent_actions(agent_id)
avail_actions_ind = np.nonzero(avail_actions)[0]
action = np.random.choice(avail_actions_ind)
actions.append(action)
reward, terminated, _ = env.step(actions)
episode_reward += reward
print("Total reward in episode {} = {}".format(e, episode_reward))
env.close()
if __name__ == "__main__":
main()
================================================
FILE: examples/Social_Cognition/ToCM/smac/examples/rllib/README.rst
================================================
SMAC on RLlib
=============
This example shows how to run SMAC environments with RLlib multi-agent.
Instructions
------------
To get started, first install RLlib with ``pip install -U ray[rllib]``. You will also need TensorFlow installed.
In ``run_ppo.py``, each agent will be controlled by an independent PPO policy (the policies share weights). This setup serves as a single-agent baseline for this task.
In ``run_qmix.py``, the agents are controlled by the multi-agent QMIX policy. This setup is an example of centralized training and decentralized execution.
See https://ray.readthedocs.io/en/latest/rllib.html for more documentation.
================================================
FILE: examples/Social_Cognition/ToCM/smac/examples/rllib/__init__.py
================================================
from smac.examples.rllib.env import RLlibStarCraft2Env
from smac.examples.rllib.model import MaskedActionsModel
__all__ = ["RLlibStarCraft2Env", "MaskedActionsModel"]
================================================
FILE: examples/Social_Cognition/ToCM/smac/examples/rllib/env.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import numpy as np
from gym.spaces import Discrete, Box, Dict
from ray import rllib
from smac.env import StarCraft2Env
class RLlibStarCraft2Env(rllib.MultiAgentEnv):
"""Wraps a smac StarCraft env to be compatible with RLlib multi-agent."""
def __init__(self, **smac_args):
"""Create a new multi-agent StarCraft env compatible with RLlib.
Arguments:
smac_args (dict): Arguments to pass to the underlying
smac.env.starcraft.StarCraft2Env instance.
Examples:
>>> from smac.examples.rllib import RLlibStarCraft2Env
>>> env = RLlibStarCraft2Env(map_name="8m")
>>> print(env.reset())
"""
self._env = StarCraft2Env(**smac_args)
self._ready_agents = []
self.observation_space = Dict(
{
"obs": Box(-1, 1, shape=(self._env.get_obs_size(),)),
"action_mask": Box(
0, 1, shape=(self._env.get_total_actions(),)
),
}
)
self.action_space = Discrete(self._env.get_total_actions())
def reset(self):
"""Resets the env and returns observations from ready agents.
Returns:
obs (dict): New observations for each ready agent.
"""
obs_list, state_list = self._env.reset()
return_obs = {}
for i, obs in enumerate(obs_list):
return_obs[i] = {
"action_mask": np.array(self._env.get_avail_agent_actions(i)),
"obs": obs,
}
self._ready_agents = list(range(len(obs_list)))
return return_obs
def step(self, action_dict):
"""Returns observations from ready agents.
The returns are dicts mapping from agent_id strings to values. The
number of agents in the env can vary over time.
Returns
-------
obs (dict): New observations for each ready agent.
rewards (dict): Reward values for each ready agent. If the
episode is just started, the value will be None.
dones (dict): Done values for each ready agent. The special key
"__all__" (required) is used to indicate env termination.
infos (dict): Optional info values for each agent id.
"""
actions = []
for i in self._ready_agents:
if i not in action_dict:
raise ValueError(
"You must supply an action for agent: {}".format(i)
)
actions.append(action_dict[i])
if len(actions) != len(self._ready_agents):
raise ValueError(
"Unexpected number of actions: {}".format(
action_dict,
)
)
rew, done, info = self._env.step(actions)
obs_list = self._env.get_obs()
return_obs = {}
for i, obs in enumerate(obs_list):
return_obs[i] = {
"action_mask": self._env.get_avail_agent_actions(i),
"obs": obs,
}
rews = {i: rew / len(obs_list) for i in range(len(obs_list))}
dones = {i: done for i in range(len(obs_list))}
dones["__all__"] = done
infos = {i: info for i in range(len(obs_list))}
self._ready_agents = list(range(len(obs_list)))
return return_obs, rews, dones, infos
def close(self):
"""Close the environment"""
self._env.close()
def seed(self, seed):
random.seed(seed)
np.random.seed(seed)
================================================
FILE: examples/Social_Cognition/ToCM/smac/examples/rllib/model.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from ray.rllib.models import Model
from ray.rllib.models.tf.misc import normc_initializer
class MaskedActionsModel(Model):
"""Custom RLlib model that emits -inf logits for invalid actions.
This is used to handle the variable-length StarCraft action space.
"""
def _build_layers_v2(self, input_dict, num_outputs, options):
action_mask = input_dict["obs"]["action_mask"]
if num_outputs != action_mask.shape[1].value:
raise ValueError(
"This model assumes num outputs is equal to max avail actions",
num_outputs,
action_mask,
)
# Standard fully connected network
last_layer = input_dict["obs"]["obs"]
hiddens = options.get("fcnet_hiddens")
for i, size in enumerate(hiddens):
label = "fc{}".format(i)
last_layer = tf.layers.dense(
last_layer,
size,
kernel_initializer=normc_initializer(1.0),
activation=tf.nn.tanh,
name=label,
)
action_logits = tf.layers.dense(
last_layer,
num_outputs,
kernel_initializer=normc_initializer(0.01),
activation=None,
name="fc_out",
)
# Mask out invalid actions (use tf.float32.min for stability)
inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min)
masked_logits = inf_mask + action_logits
return masked_logits, last_layer
================================================
FILE: examples/Social_Cognition/ToCM/smac/examples/rllib/run_ppo.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Example of running StarCraft2 with RLlib PPO.
In this setup, each agent will be controlled by an independent PPO policy.
However the policies share weights.
Increase the level of parallelism by changing --num-workers.
"""
import argparse
import ray
from ray.tune import run_experiments, register_env
from ray.rllib.models import ModelCatalog
from smac.examples.rllib.env import RLlibStarCraft2Env
from smac.examples.rllib.model import MaskedActionsModel
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-iters", type=int, default=100)
parser.add_argument("--num-workers", type=int, default=2)
parser.add_argument("--map-name", type=str, default="8m")
args = parser.parse_args()
ray.init()
register_env("smac", lambda smac_args: RLlibStarCraft2Env(**smac_args))
ModelCatalog.register_custom_model("mask_model", MaskedActionsModel)
run_experiments(
{
"ppo_sc2": {
"run": "PPO",
"env": "smac",
"stop": {
"training_iteration": args.num_iters,
},
"config": {
"num_workers": args.num_workers,
"observation_filter": "NoFilter", # breaks the action mask
"vf_share_layers": True, # no separate value model
"env_config": {
"map_name": args.map_name,
},
"model": {
"custom_model": "mask_model",
},
},
},
}
)
================================================
FILE: examples/Social_Cognition/ToCM/smac/examples/rllib/run_qmix.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Example of running StarCraft2 with RLlib QMIX.
This assumes all agents are homogeneous. The agents are grouped and assigned
to the multi-agent QMIX policy. Note that the default hyperparameters for
RLlib QMIX are different from pymarl's QMIX.
"""
import argparse
from gym.spaces import Tuple
import ray
from ray.tune import run_experiments, register_env
from smac.examples.rllib.env import RLlibStarCraft2Env
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-iters", type=int, default=100)
parser.add_argument("--num-workers", type=int, default=2)
parser.add_argument("--map-name", type=str, default="8m")
args = parser.parse_args()
def env_creator(smac_args):
env = RLlibStarCraft2Env(**smac_args)
agent_list = list(range(env._env.n_agents))
grouping = {
"group_1": agent_list,
}
obs_space = Tuple([env.observation_space for i in agent_list])
act_space = Tuple([env.action_space for i in agent_list])
return env.with_agent_groups(
grouping, obs_space=obs_space, act_space=act_space
)
ray.init()
register_env("sc2_grouped", env_creator)
run_experiments(
{
"qmix_sc2": {
"run": "QMIX",
"env": "sc2_grouped",
"stop": {
"training_iteration": args.num_iters,
},
"config": {
"num_workers": args.num_workers,
"env_config": {
"map_name": args.map_name,
},
},
},
}
)
================================================
FILE: examples/Social_Cognition/ToCM/train.py
================================================
import argparse
import os
import sys
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from agent.runners.ToCMRunner import ToCMRunner
from configs import Experiment, SimpleObservationConfig, NearRewardConfig, DeadlockPunishmentConfig, \
RewardsComposerConfig
from configs.EnvConfigs import StarCraftConfig, EnvCurriculumConfig, MPEConfig # TODO
from configs.ToCM.ToCMControllerConfig import ToCMControllerConfig
from configs.ToCM.ToCMLearnerConfig import ToCMLearnerConfig
from environments import Env
from utils.util import get_dim_from_space, get_cent_act_dim
import torch
import random
import numpy as np
import setproctitle
setproctitle.setproctitle("MPE_obs_2_hetero")
def occumpy_mem(cuda_device):
total, used = os.popen(
'"/usr/bin/nvidia-smi" --query-gpu=memory.total,memory.used --format=csv,nounits,noheader').read().strip().split(
"\n")[int(cuda_device)].split(',')
total = int(total)
used = int(used)
cc = 0.85
block_mem = int((total - used) * cc)
x = torch.cuda.FloatTensor(256, 1024, block_mem)
del x
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--env', type=str, default="mpe", help='starcraft or mpe') # TODO
parser.add_argument('--env_name', type=str, default="hetero_spread", help='Specific setting') # TODO
# star : 2s_vs_1sc MMM 2s3z 3s_vs_3z 3s5z_vs_3s6z simple_spread
parser.add_argument('--n_workers', type=int, default=4, help='Number of workers')
parser.add_argument('--device', type=str, default='cuda', help='device')
parser.add_argument('--seed', type=int, default=50, help='random')
# TODO num_landmarks num_adversaries episode_length num_good_agents
parser.add_argument('--num_agents', type=int, default=2, help='mpe_num_agents') # simple_adversary
parser.add_argument('--num_adversaries', type=int, default=None, help='mpe_num_adversaries')
parser.add_argument('--num_good_agents', type=int, default=None, help='mpe_num_good_agents')
parser.add_argument('--num_landmarks', type=int, default=2, help='mpe_num_landmarks')
parser.add_argument('--episode_length', type=int, default=25, help='mpe_episode_length')
parser.add_argument('--num_rollout_threads', type=int, default=128, help='mpe_episode_length')
parser.add_argument('--benchmark', type=bool, default=False, help='mpe_use_benchmark')
return parser.parse_args() # 为啥直接跳到prepare_starcraft_configs函数里了
def setup_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def train_ToCM(exp, n_workers): # no env.episode_length
runner = ToCMRunner(exp.env_config, exp.learner_config, exp.controller_config, n_workers)
runner.run(exp.steps, exp.episodes) # 10 ** 10 50000
def get_env_info(configs, env):
for config in configs:
config.IN_DIM = env.n_obs # 17 2s_vs_1sc
config.ACTION_SIZE = env.n_actions # 7 2s_vs_1sc
env.close()
def get_env_info_mpe(configs, env): # add to ToCM controller and worker
# TODO cent_obs_dim and cent_action_dim use share_policy
for config in configs:
config.CENT_OBS_DIM = get_dim_from_space(env.env.share_observation_space[0]) # 54=num_agents*IN_DIM
config.CENT_ACT_DIM = get_cent_act_dim(env.env.action_space) # 15=num_agents*ACTION_SIZE
config.IN_DIM = get_dim_from_space(env.env.observation_space[0]) # dim 18
config.ACTION_SIZE = get_dim_from_space(env.env.action_space[0]) # dim 5
env.close()
def prepare_starcraft_configs(env_name, device):
# env_name '3s5z_vs_3s6z' device 'cuda:6' RANDOM_SEED 1 args.n_workers 2
# args.env 'starcraft' args.env_name '3s5z_vs_3s6z' args.device 'cuda:6' args.seed 1
agent_configs = [ToCMControllerConfig(env_name, RANDOM_SEED, device),
ToCMLearnerConfig(env_name, RANDOM_SEED, device)]
env_config = StarCraftConfig(env_name, RANDOM_SEED)
get_env_info(agent_configs, env_config.create_env())
return {"env_config": (env_config, 100),
"controller_config": agent_configs[0],
"learner_config": agent_configs[1],
"reward_config": None,
"obs_builder_config": None}
def prepare_mpe_configs(arg):
agent_configs = [ToCMControllerConfig(arg.env_name, RANDOM_SEED, arg.device),
ToCMLearnerConfig(arg.env_name, RANDOM_SEED, arg.device)]
env_config = MPEConfig(arg)
get_env_info_mpe(agent_configs, env_config.create_env())
return {"env_config": (env_config, 100),
"controller_config": agent_configs[0],
"learner_config": agent_configs[1],
"reward_config": None,
"obs_builder_config": None} # TODO whether has reward config and obs builder config
if __name__ == "__main__":
# occumpy_mem(2)
# RANDOM_SEED = 23 # RANDOM_SEED 1
args = parse_args()
# print("args=", args)
RANDOM_SEED = args.seed
setup_seed(RANDOM_SEED) # TODO
# args.env_name '3s5z_vs_3s6z' args.device 'cuda:6' args.seed 1 args.n_workers 2
if args.env == Env.STARCRAFT:
configs = prepare_starcraft_configs(args.env_name, args.device)
elif args.env == Env.MPE:
configs = prepare_mpe_configs(args)
# as env is mpe env_name is simple_adversary
else:
raise Exception("Unknown environment")
configs["env_config"][0].ENV_TYPE = Env(args.env) # 转化为字符串
configs["learner_config"].ENV_TYPE = Env(args.env)
configs["controller_config"].ENV_TYPE = Env(args.env)
exp = Experiment(steps=10 ** 10,
episodes=50000,
random_seed=RANDOM_SEED,
env_config=EnvCurriculumConfig(*zip(configs["env_config"]), Env(args.env), args.device, # TODO
obs_builder_config=configs["obs_builder_config"],
reward_config=configs["reward_config"]),
controller_config=configs["controller_config"],
learner_config=configs["learner_config"])
# print("exp=", exp)
train_ToCM(exp, n_workers=args.n_workers)
================================================
FILE: examples/Social_Cognition/ToCM/utils/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/ToCM/utils/mlp_buffer.py
================================================
import numpy as np
from utils.util import get_dim_from_space
from utils.segment_tree import SumSegmentTree, MinSegmentTree
def _cast(x):
return x.transpose(1, 0, 2)
class MlpReplayBuffer(object):
def __init__(self, policy_info, policy_agents, buffer_size, use_same_share_obs, use_avail_acts,
use_reward_normalization=False):
"""
Replay buffer class for training MLP policies.
:param policy_info: (dict) maps policy id to a dict containing information about corresponding policy.
:param policy_agents: (dict) maps policy id to list of agents controled by corresponding policy.
:param buffer_size: (int) max number of transitions to store in the buffer.
:param use_same_share_obs: (bool) whether all agents share the same centralized observation.
:param use_avail_acts: (bool) whether to store what actions are available.
:param use_reward_normalization: (bool) whether to use reward normalization.
"""
self.policy_info = policy_info
self.policy_buffers = {p_id: MlpPolicyBuffer(buffer_size,
len(policy_agents[p_id]),
self.policy_info[p_id]['obs_space'],
self.policy_info[p_id]['share_obs_space'],
self.policy_info[p_id]['act_space'],
use_same_share_obs,
use_avail_acts,
use_reward_normalization)
for p_id in self.policy_info.keys()}
def __len__(self):
return self.policy_buffers['policy_0'].filled_i
def insert(self, num_insert_steps, obs, share_obs, acts, rewards,
next_obs, next_share_obs, dones, dones_env, valid_transition,
avail_acts, next_avail_acts):
"""
Insert a set of transitions into buffer. If the buffer size overflows, old transitions are dropped.
:param num_insert_steps: (int) number of transitions to be added to buffer
:param obs: (dict) maps policy id to numpy array of observations of agents corresponding to that policy
:param share_obs: (dict) maps policy id to numpy array of centralized observation corresponding to that policy
:param acts: (dict) maps policy id to numpy array of actions of agents corresponding to that policy
:param rewards: (dict) maps policy id to numpy array of rewards of agents corresponding to that policy
:param next_obs: (dict) maps policy id to numpy array of next step observations of agents corresponding to that policy
:param next_share_obs: (dict) maps policy id to numpy array of next step centralized observations corresponding to that policy
:param dones: (dict) maps policy id to numpy array of terminal status of agents corresponding to that policy
:param dones_env: (dict) maps policy id to numpy array of terminal status of env
:param valid_transition: (dict) maps policy id to numpy array of whether the corresponding transition is valid of agents corresponding to that policy
:param avail_acts: (dict) maps policy id to numpy array of available actions of agents corresponding to that policy
:param next_avail_acts: (dict) maps policy id to numpy array of next step available actions of agents corresponding to that policy
:return: (np.ndarray) indexes in which the new transitions were placed.
"""
idx_range = None
for p_id in self.policy_info.keys():
idx_range = self.policy_buffers[p_id].insert(num_insert_steps,
np.array(obs[p_id]), np.array(share_obs[p_id]),
np.array(acts[p_id]), np.array(rewards[p_id]),
np.array(next_obs[p_id]), np.array(next_share_obs[p_id]),
np.array(dones[p_id]), np.array(dones_env[p_id]),
np.array(valid_transition[p_id]),
np.array(avail_acts[p_id]), np.array(next_avail_acts[p_id]))
return idx_range
def sample(self, batch_size):
"""
Sample a set of transitions from buffer, uniformly at random.
:param batch_size: (int) number of transitions to sample from buffer.
:return: obs: (dict) maps policy id to sampled observations corresponding to that policy
:return: share_obs: (dict) maps policy id to sampled observations corresponding to that policy
:return: acts: (dict) maps policy id to sampled actions corresponding to that policy
:return: rewards: (dict) maps policy id to sampled rewards corresponding to that policy
:return: next_obs: (dict) maps policy id to sampled next step observations corresponding to that policy
:return: next_share_obs: (dict) maps policy id to sampled next step centralized observations corresponding to that policy
:return: dones: (dict) maps policy id to sampled terminal status of agents corresponding to that policy
:return: dones_env: (dict) maps policy id to sampled environment terminal status corresponding to that policy
:return: valid_transition: (dict) maps policy_id to whether each sampled transition is valid or not (invalid if corresponding agent is dead)
:return: avail_acts: (dict) maps policy_id to available actions corresponding to that policy
:return: next_avail_acts: (dict) maps policy_id to next step available actions corresponding to that policy
"""
inds = np.random.choice(len(self), batch_size)
obs, share_obs, acts, rewards, next_obs, next_share_obs, dones, dones_env, valid_transition, avail_acts, next_avail_acts = {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}
for p_id in self.policy_info.keys():
obs[p_id], share_obs[p_id], acts[p_id], rewards[p_id], next_obs[p_id], next_share_obs[p_id], \
dones[p_id], dones_env[p_id], valid_transition[p_id], avail_acts[p_id], next_avail_acts[p_id] = \
self.policy_buffers[p_id].sample_inds(inds)
return obs, share_obs, acts, rewards, next_obs, next_share_obs, dones, dones_env, valid_transition, avail_acts, next_avail_acts, None, None
class MlpPolicyBuffer(object):
def __init__(self, buffer_size, num_agents, obs_space, share_obs_space, act_space, use_same_share_obs,
use_avail_acts, use_reward_normalization=False):
"""
Buffer class containing buffer data corresponding to a single policy.
:param buffer_size: (int) max number of transitions to store in buffer.
:param num_agents: (int) number of agents controlled by the policy.
:param obs_space: (gym.Space) observation space of the environment.
:param share_obs_space: (gym.Space) centralized observation space of the environment.
:param act_space: (gym.Space) action space of the environment.
:use_same_share_obs: (bool) whether all agents share the same centralized observation.
:use_avail_acts: (bool) whether to store what actions are available.
:param use_reward_normalization: (bool) whether to use reward normalization.
"""
self.buffer_size = buffer_size
self.num_agents = num_agents
self.use_same_share_obs = use_same_share_obs
self.use_avail_acts = use_avail_acts
self.use_reward_normalization = use_reward_normalization
self.filled_i = 0
self.current_i = 0
# obs
if obs_space.__class__.__name__ == 'Box':
obs_shape = obs_space.shape
share_obs_shape = share_obs_space.shape
elif obs_space.__class__.__name__ == 'list':
obs_shape = obs_space
share_obs_shape = share_obs_space
else:
raise NotImplementedError
self.obs = np.zeros(
(self.buffer_size, self.num_agents, obs_shape[0]), dtype=np.float32)
if self.use_same_share_obs:
self.share_obs = np.zeros((self.buffer_size, share_obs_shape[0]), dtype=np.float32)
else:
self.share_obs = np.zeros((self.buffer_size, self.num_agents, share_obs_shape[0]), dtype=np.float32)
self.next_obs = np.zeros_like(self.obs, dtype=np.float32)
self.next_share_obs = np.zeros_like(self.share_obs, dtype=np.float32)
# action
act_dim = np.sum(get_dim_from_space(act_space))
self.acts = np.zeros((self.buffer_size, self.num_agents, act_dim), dtype=np.float32)
if self.use_avail_acts:
self.avail_acts = np.ones_like(self.acts, dtype=np.float32)
self.next_avail_acts = np.ones_like(self.avail_acts, dtype=np.float32)
# rewards
self.rewards = np.zeros((self.buffer_size, self.num_agents, 1), dtype=np.float32)
# default to done being True
self.dones = np.ones_like(self.rewards, dtype=np.float32)
self.dones_env = np.ones((self.buffer_size, 1), dtype=np.float32)
self.valid_transition = np.zeros_like(self.dones, dtype=np.float32)
def __len__(self):
return self.filled_i
def insert(self, num_insert_steps, obs, share_obs, acts, rewards,
next_obs, next_share_obs, dones, dones_env, valid_transition,
avail_acts=None, next_avail_acts=None):
"""
Insert a set of transitions corresponding to this policy into buffer. If the buffer size overflows, old transitions are dropped.
:param num_insert_steps: (int) number of transitions to be added to buffer
:param obs: (np.ndarray) observations of agents corresponding to this policy.
:param share_obs: (np.ndarray) centralized observations of agents corresponding to this policy.
:param acts: (np.ndarray) actions of agents corresponding to this policy.
:param rewards: (np.ndarray) rewards of agents corresponding to this policy.
:param next_obs: (np.ndarray) next step observations of agents corresponding to this policy.
:param next_share_obs: (np.ndarray) next step centralized observations of agents corresponding to this policy.
:param dones: (np.ndarray) terminal status of agents corresponding to this policy.
:param dones_env: (np.ndarray) environment terminal status.
:param valid_transition: (np.ndarray) whether each transition is valid or not (invalid if agent was dead during transition)
:param avail_acts: (np.ndarray) available actions of agents corresponding to this policy.
:param next_avail_acts: (np.ndarray) next step available actions of agents corresponding to this policy.
:return: (np.ndarray) indexes of the buffer the new transitions were placed in.
"""
# obs: [step, episode, agent, dim]
assert obs.shape[0] == num_insert_steps, ("different size!")
if self.current_i + num_insert_steps <= self.buffer_size:
idx_range = np.arange(self.current_i, self.current_i + num_insert_steps)
else:
num_left_steps = self.current_i + num_insert_steps - self.buffer_size
idx_range = np.concatenate((np.arange(self.current_i, self.buffer_size), np.arange(num_left_steps)))
self.obs[idx_range] = obs.copy()
self.share_obs[idx_range] = share_obs.copy()
self.acts[idx_range] = acts.copy()
self.rewards[idx_range] = rewards.copy()
self.next_obs[idx_range] = next_obs.copy()
self.next_share_obs[idx_range] = next_share_obs.copy()
self.dones[idx_range] = dones.copy()
self.dones_env[idx_range] = dones_env.copy()
self.valid_transition[idx_range] = valid_transition.copy()
if self.use_avail_acts:
self.avail_acts[idx_range] = avail_acts.copy()
self.next_avail_acts[idx_range] = next_avail_acts.copy()
self.current_i = idx_range[-1] + 1
self.filled_i = min(self.filled_i + len(idx_range), self.buffer_size)
return idx_range
def sample_inds(self, sample_inds):
"""
Sample a set of transitions from buffer from the specified indices.
:param sample_inds: (np.ndarray) indices of samples to return from buffer.
:return: obs: (np.ndarray) sampled observations corresponding to that policy
:return: share_obs: (np.ndarray) sampled observations corresponding to that policy
:return: acts: (np.ndarray) sampled actions corresponding to that policy
:return: rewards: (np.ndarray) sampled rewards corresponding to that policy
:return: next_obs: (np.ndarray) sampled next step observations corresponding to that policy
:return: next_share_obs: (np.ndarray) sampled next step centralized observations corresponding to that policy
:return: dones: (np.ndarray) sampled terminal status of agents corresponding to that policy
:return: dones_env: (np.ndarray) sampled environment terminal status corresponding to that policy
:return: valid_transition: (np.ndarray) whether each sampled transition is valid or not (invalid if corresponding agent is dead)
:return: avail_acts: (np.ndarray) sampled available actions corresponding to that policy
:return: next_avail_acts: (np.ndarray) sampled next step available actions corresponding to that policy
"""
obs = _cast(self.obs[sample_inds])
acts = _cast(self.acts[sample_inds])
if self.use_reward_normalization:
mean_reward = self.rewards[:self.filled_i].mean()
std_reward = self.rewards[:self.filled_i].std()
rewards = _cast(
(self.rewards[sample_inds] - mean_reward) / std_reward)
else:
rewards = _cast(self.rewards[sample_inds])
next_obs = _cast(self.next_obs[sample_inds])
if self.use_same_share_obs:
share_obs = self.share_obs[sample_inds]
next_share_obs = self.next_share_obs[sample_inds]
else:
share_obs = _cast(self.share_obs[sample_inds])
next_share_obs = _cast(self.next_share_obs[sample_inds])
dones = _cast(self.dones[sample_inds])
dones_env = self.dones_env[sample_inds]
valid_transition = _cast(self.valid_transition[sample_inds])
if self.use_avail_acts:
avail_acts = _cast(self.avail_acts[sample_inds])
next_avail_acts = _cast(self.next_avail_acts[sample_inds])
else:
avail_acts = None
next_avail_acts = None
return obs, share_obs, acts, rewards, next_obs, next_share_obs, dones, dones_env, valid_transition, avail_acts, next_avail_acts
class PrioritizedMlpReplayBuffer(MlpReplayBuffer):
def __init__(self, alpha, policy_info, policy_agents, buffer_size, use_same_share_obs, use_avail_acts,
use_reward_normalization=False):
"""Prioritized replay buffer class for training MLP policies. See parent class."""
super(PrioritizedMlpReplayBuffer, self).__init__(policy_info, policy_agents,
buffer_size, use_same_share_obs, use_avail_acts,
use_reward_normalization)
self.alpha = alpha
self.policy_info = policy_info
it_capacity = 1
while it_capacity < buffer_size:
it_capacity *= 2
self._it_sums = {p_id: SumSegmentTree(it_capacity) for p_id in self.policy_info.keys()}
self._it_mins = {p_id: MinSegmentTree(it_capacity) for p_id in self.policy_info.keys()}
self.max_priorities = {p_id: 1.0 for p_id in self.policy_info.keys()}
def insert(self, num_insert_steps, obs, share_obs, acts, rewards, next_obs, next_share_obs, dones, dones_env,
valid_transition, avail_acts=None, next_avail_acts=None):
"""See parent class."""
idx_range = super().insert(num_insert_steps, obs, share_obs, acts, rewards, next_obs, next_share_obs, dones,
dones_env, valid_transition, avail_acts, next_avail_acts)
for idx in range(idx_range[0], idx_range[1]):
for p_id in self.policy_info.keys():
self._it_sums[p_id][idx] = self.max_priorities[p_id] ** self.alpha
self._it_mins[p_id][idx] = self.max_priorities[p_id] ** self.alpha
return idx_range
def _sample_proportional(self, batch_size, p_id=None):
total = self._it_sums[p_id].sum(0, len(self) - 1)
mass = np.random.random(size=batch_size) * total
idx = self._it_sums[p_id].find_prefixsum_idx(mass)
return idx
def sample(self, batch_size, beta=0, p_id=None):
"""
Sample a set of transitions from buffer; probability of choosing a given sample is proportional to its priority.
:param batch_size: (int) number of transitions to sample.
:param beta: (float) controls the amount of prioritization to apply.
:param p_id: (str) policy which will be updated using the samples.
:return: See parent class.
"""
assert len(self) > batch_size, "Not enough samples in the buffer!"
assert beta > 0
batch_inds = self._sample_proportional(batch_size, p_id)
p_min = self._it_mins[p_id].min() / self._it_sums[p_id].sum()
max_weight = (p_min * len(self)) ** (-beta)
p_sample = self._it_sums[p_id][batch_inds] / self._it_sums[p_id].sum()
weights = (p_sample * len(self)) ** (-beta) / max_weight
obs, share_obs, acts, rewards, next_obs, next_share_obs, dones, dones_env, valid_transition, avail_acts, next_avail_acts = {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}
for p_id in self.policy_info.keys():
p_buffer = self.policy_buffers[p_id]
obs[p_id], share_obs[p_id], acts[p_id], rewards[p_id], next_obs[p_id], next_share_obs[p_id], dones[p_id], \
dones_env[p_id], valid_transition[p_id], avail_acts[p_id], next_avail_acts[p_id] = p_buffer.sample_inds(batch_inds)
return obs, share_obs, acts, rewards, next_obs, next_share_obs, dones, dones_env, valid_transition, avail_acts, next_avail_acts, weights, batch_inds
def update_priorities(self, idxes, priorities, p_id=None):
"""
Update priorities of sampled transitions.
sets priority of transition at index idxes[i] in buffer
to priorities[i].
:param idxes: ([int]) List of idxes of sampled transitions
:param priorities: ([float]) List of updated priorities corresponding to transitions at the sampled idxes
denoted by variable `idxes`.
"""
assert len(idxes) == len(priorities)
assert np.min(priorities) > 0
assert np.min(idxes) >= 0
assert np.max(idxes) < len(self)
self._it_sums[p_id][idxes] = priorities ** self.alpha
self._it_mins[p_id][idxes] = priorities ** self.alpha
self.max_priorities[p_id] = max(
self.max_priorities[p_id], np.max(priorities))
================================================
FILE: examples/Social_Cognition/ToCM/utils/mlp_nstep_buffer.py
================================================
import numpy as np
import torch
import random
class NStepReplayBuffer:
def __init__(self, max_size, episode_len, n, policy_ids, agent_ids, policy_agents, policy_obs_dim, policy_act_dim, gamma):
self.max_size = max_size
self.episode_len = episode_len
# n for n-step returns
self.n = n
self.policy_ids = policy_ids
self.agent_ids = agent_ids
self.policy_agents = policy_agents
self.policy_buffers = {
p_id: NStepPolicyBuffer(p_id, self.max_size, episode_len, n, self.policy_agents[p_id], policy_obs_dim[p_id],
policy_act_dim[p_id], gamma) for p_id in self.policy_ids}
self.num_episodes = 0
self.num_transitions = 0
def push(self, t_env, observation_batch, action_batch, reward_batch, next_observation_batch, dones_batch, finish_episodes):
batch_size = observation_batch.shape[0]
observations = {a_id: np.vstack(
[obs[a_id] for obs in observation_batch]) for a_id in self.agent_ids}
actions = {a_id: np.vstack([act[a_id] for act in action_batch])
for a_id in self.agent_ids}
rewards = {a_id: np.vstack([rew[a_id] for rew in reward_batch])
for a_id in self.agent_ids}
n_observations = {a_id: np.vstack(
[nobs[a_id] for nobs in next_observation_batch]) for a_id in self.agent_ids}
if finish_episodes:
dones = {a_id: np.ones_like(rewards[a_id]).astype(
bool) for a_id in self.agent_ids}
else:
dones = {a_id: np.vstack(
[done[a_id] for done in dones_batch]) for a_id in self.agent_ids}
for p_id in self.policy_ids:
self.policy_buffers[p_id].push(
batch_size, t_env, observations, actions, rewards, n_observations, dones, finish_episodes)
assert len(
set([p_buffer.num_episodes for p_buffer in self.policy_buffers.values()])) == 1
assert len(
set([p_buffer.num_transitions for p_buffer in self.policy_buffers.values()])) == 1
self.num_episodes = self.policy_buffers[self.policy_ids[0]].num_episodes
self.num_transitions = self.policy_buffers[self.policy_ids[0]
].num_transitions
def sample(self, batch_size):
assert self.num_transitions > batch_size, "Cannot sample with no completed episodes in the buffer!"
chunk_starts = np.random.choice(self.episode_len, batch_size)
batch_inds = np.random.choice(self.num_episodes, batch_size)
obs = {}
act = {}
rew = {}
nobs = {}
dones = {}
for p_id in self.policy_ids:
p_buffer = self.policy_buffers[p_id]
o, a, r, no, d = p_buffer.get(batch_inds, chunk_starts)
obs[p_id] = o
act[p_id] = a
rew[p_id] = r
nobs[p_id] = no
dones[p_id] = d
return obs, act, rew, nobs, dones
class NStepPolicyBuffer:
def __init__(self, policy_id, max_size, episode_len, n, policy_agents, obs_dim, act_dim, gamma):
self.max_size = max_size
self.n = n
self.num_agents = len(policy_agents)
self.policy_id = policy_id
self.episode_len = episode_len
self.agent_ids = policy_agents
self.gamma = gamma
random.shuffle(self.agent_ids)
self.observations = np.zeros(
(self.num_agents, max_size, episode_len, obs_dim))
self.actions = np.zeros(
(self.num_agents, max_size, episode_len, act_dim))
self.rewards = np.zeros(
(self.num_agents, max_size, episode_len + n - 1, 1))
self.next_observations = np.zeros(
(self.num_agents, max_size, episode_len + n - 1, obs_dim))
self.dones = np.ones(
(self.num_agents, max_size, episode_len + n - 1, 1))
self.num_episodes = 0
self.num_transitions = 0
def push(self, num_envs, t_env, observation_batch, action_batch, reward_batch, next_observation_batch,
dones_batch, finish_episodes):
assert t_env < self.episode_len
if t_env == 0:
# shuffle the agent ids at the start of a new episode batch
random.shuffle(self.agent_ids)
if t_env == 0 and self.num_episodes + num_envs > self.max_size:
diff = self.num_episodes + num_envs - self.max_size
self.observations = np.roll(self.observations, -diff, axis=1)
self.actions = np.roll(self.actions, -diff, axis=1)
self.rewards = np.roll(self.rewards, -diff, axis=1)
self.next_observations = np.roll(
self.next_observations, -diff, axis=1)
self.dones = np.roll(self.dones, -diff, axis=1)
self.num_episodes -= diff
for i in range(self.num_agents):
if finish_episodes:
dones = np.ones_like(dones_batch[self.agent_ids[i]])
else:
dones = dones_batch[self.agent_ids[i]]
self.observations[i, self.num_episodes: self.num_episodes +
num_envs, t_env, :] = observation_batch[self.agent_ids[i]]
self.actions[i, self.num_episodes: self.num_episodes +
num_envs, t_env, :] = action_batch[self.agent_ids[i]]
self.rewards[i, self.num_episodes: self.num_episodes +
num_envs, t_env, :] = reward_batch[self.agent_ids[i]]
self.next_observations[i, self.num_episodes: self.num_episodes + num_envs, t_env, :] = next_observation_batch[
self.agent_ids[i]]
self.dones[i, self.num_episodes: self.num_episodes +
num_envs, t_env, :] = dones
self.num_transitions += num_envs
if finish_episodes:
self.num_episodes += num_envs
def get(self, batch_inds, start_inds):
batch_inds_col = batch_inds[:, None]
start_inds_col = start_inds[:, None]
nstep_inds = start_inds_col + np.arange(self.n)
obs = self.observations[:, batch_inds, start_inds, :]
acts = self.actions[:, batch_inds, start_inds, :]
# get the n-step rewards and weight each by exponentiated discounts
rews = self.rewards[:, batch_inds_col, nstep_inds, 0]
rews = rews * \
np.power((np.ones(self.n) * self.gamma), np.arange(self.n))
# sum the n-step rewards: rewards for terminal states are pre-set to 0, so don't need to mask
rews = np.sum(rews, axis=2).reshape(
self.num_agents, len(batch_inds), 1)
# get the nobs of the nth
nobs = self.next_observations[:,
batch_inds, start_inds + self.n - 1, :]
dones = self.dones[:, batch_inds, start_inds + self.n - 1, :]
return torch.from_numpy(obs), torch.from_numpy(acts), torch.from_numpy(rews), torch.from_numpy(
nobs), torch.from_numpy(dones)
================================================
FILE: examples/Social_Cognition/ToCM/utils/popart.py
================================================
import numpy as np
import torch
import torch.nn as nn
class PopArt(nn.Module):
""" Normalize a vector of observations - across the first norm_axes dimensions"""
def __init__(self, input_shape, norm_axes=1, beta=0.99999, per_element_update=False, epsilon=1e-5, device=torch.device("cpu")):
super(PopArt, self).__init__()
self.input_shape = input_shape
self.norm_axes = norm_axes
self.epsilon = epsilon
self.beta = beta
self.per_element_update = per_element_update
self.device = device
self.tpdv = dict(dtype=torch.float32, device=device)
self.running_mean = nn.Parameter(torch.zeros(input_shape, dtype=torch.float), requires_grad=False).to(self.device)
self.running_mean_sq = nn.Parameter(torch.zeros(input_shape, dtype=torch.float), requires_grad=False).to(self.device)
self.debiasing_term = nn.Parameter(torch.tensor(0.0, dtype=torch.float), requires_grad=False).to(self.device)
def reset_parameters(self):
self.running_mean.zero_()
self.running_mean_sq.zero_()
self.debiasing_term.zero_()
def running_mean_var(self):
debiased_mean = self.running_mean / self.debiasing_term.clamp(min=self.epsilon)
debiased_mean_sq = self.running_mean_sq / self.debiasing_term.clamp(min=self.epsilon)
debiased_var = (debiased_mean_sq - debiased_mean ** 2).clamp(max=self.alpha, min=1e-2)
return debiased_mean, debiased_var
def forward(self, input_vector, train=True):
# Make sure input is float32
input_vector = input_vector.to(**self.tpdv)
if train:
# Detach input before adding it to running means to avoid backpropping through it on
# subsequent batches.
detached_input = input_vector.detach()
batch_mean = detached_input.mean(dim=tuple(range(self.norm_axes)))
batch_sq_mean = (detached_input ** 2).mean(dim=tuple(range(self.norm_axes)))
if self.per_element_update:
batch_size = np.prod(detached_input.size()[:self.norm_axes])
weight = self.beta ** batch_size
else:
weight = self.beta
self.running_mean.mul_(weight).add_(batch_mean * (1.0 - weight))
self.running_mean_sq.mul_(weight).add_(batch_sq_mean * (1.0 - weight))
self.debiasing_term.mul_(weight).add_(1.0 * (1.0 - weight))
mean, var = self.running_mean_var()
out = (input_vector - mean[(None,) * self.norm_axes]) / torch.sqrt(var)[(None,) * self.norm_axes]
return out
def denormalize(self, input_vector):
""" Transform normalized data back into original distribution """
input_vector = input_vector.to(**self.tpdv)
mean, var = self.running_mean_var()
out = input_vector * torch.sqrt(var)[(None,) * self.norm_axes] + mean[(None,) * self.norm_axes]
return out
================================================
FILE: examples/Social_Cognition/ToCM/utils/rec_buffer.py
================================================
import numpy as np
from utils.util import get_dim_from_space
from utils.segment_tree import SumSegmentTree, MinSegmentTree
def _cast(x):
return x.transpose(2, 0, 1, 3)
class RecReplayBuffer(object):
def __init__(self, policy_info, policy_agents, buffer_size, episode_length, use_same_share_obs, use_avail_acts,
use_reward_normalization=False):
"""
Replay buffer class for training RNN policies. Stores entire episodes rather than single transitions.
:param policy_info: (dict) maps policy id to a dict containing information about corresponding policy.
:param policy_agents: (dict) maps policy id to list of agents controled by corresponding policy.
:param buffer_size: (int) max number of transitions to store in the buffer.
:param use_same_share_obs: (bool) whether all agents share the same centralized observation.
:param use_avail_acts: (bool) whether to store what actions are available.
:param use_reward_normalization: (bool) whether to use reward normalization.
"""
self.policy_info = policy_info
self.policy_buffers = {p_id: RecPolicyBuffer(buffer_size,
episode_length,
len(policy_agents[p_id]),
self.policy_info[p_id]['obs_space'],
self.policy_info[p_id]['share_obs_space'],
self.policy_info[p_id]['act_space'],
use_same_share_obs,
use_avail_acts,
use_reward_normalization)
for p_id in self.policy_info.keys()}
def __len__(self):
return self.policy_buffers['policy_0'].filled_i
def insert(self, num_insert_episodes, obs, share_obs, acts, rewards, dones, dones_env, avail_acts):
"""
Insert a set of episodes into buffer. If the buffer size overflows, old episodes are dropped.
:param num_insert_episodes: (int) number of episodes to be added to buffer
:param obs: (dict) maps policy id to numpy array of observations of agents corresponding to that policy
:param share_obs: (dict) maps policy id to numpy array of centralized observation corresponding to that policy
:param acts: (dict) maps policy id to numpy array of actions of agents corresponding to that policy
:param rewards: (dict) maps policy id to numpy array of rewards of agents corresponding to that policy
:param dones: (dict) maps policy id to numpy array of terminal status of agents corresponding to that policy
:param dones_env: (dict) maps policy id to numpy array of terminal status of env
:param valid_transition: (dict) maps policy id to numpy array of whether the corresponding transition is valid of agents corresponding to that policy
:param avail_acts: (dict) maps policy id to numpy array of available actions of agents corresponding to that policy
:return: (np.ndarray) indexes in which the new transitions were placed.
"""
for p_id in self.policy_info.keys():
idx_range = self.policy_buffers[p_id].insert(num_insert_episodes, np.array(obs[p_id]),
np.array(share_obs[p_id]), np.array(acts[p_id]),
np.array(rewards[p_id]), np.array(dones[p_id]),
np.array(dones_env[p_id]), np.array(avail_acts[p_id]))
return idx_range
def sample(self, batch_size):
"""
Sample a set of episodes from buffer, uniformly at random.
:param batch_size: (int) number of episodes to sample from buffer.
:return: obs: (dict) maps policy id to sampled observations corresponding to that policy
:return: share_obs: (dict) maps policy id to sampled observations corresponding to that policy
:return: acts: (dict) maps policy id to sampled actions corresponding to that policy
:return: rewards: (dict) maps policy id to sampled rewards corresponding to that policy
:return: dones: (dict) maps policy id to sampled terminal status of agents corresponding to that policy
:return: dones_env: (dict) maps policy id to sampled environment terminal status corresponding to that policy
:return: valid_transition: (dict) maps policy_id to whether each sampled transition is valid or not (invalid if corresponding agent is dead)
:return: avail_acts: (dict) maps policy_id to available actions corresponding to that policy
"""
inds = np.random.choice(self.__len__(), batch_size)
obs, share_obs, acts, rewards, dones, dones_env, avail_acts = {}, {}, {}, {}, {}, {}, {}
for p_id in self.policy_info.keys():
obs[p_id], share_obs[p_id], acts[p_id], rewards[p_id], dones[p_id], dones_env[p_id], avail_acts[p_id] = \
self.policy_buffers[p_id].sample_inds(inds)
return obs, share_obs, acts, rewards, dones, dones_env, avail_acts, None, None
class RecPolicyBuffer(object):
def __init__(self, buffer_size, episode_length, num_agents, obs_space, share_obs_space, act_space,
use_same_share_obs, use_avail_acts, use_reward_normalization=False):
"""
Buffer class containing buffer data corresponding to a single policy.
:param buffer_size: (int) max number of episodes to store in buffer.
:param episode_length: (int) max length of an episode.
:param num_agents: (int) number of agents controlled by the policy.
:param obs_space: (gym.Space) observation space of the environment.
:param share_obs_space: (gym.Space) centralized observation space of the environment.
:param act_space: (gym.Space) action space of the environment.
:use_same_share_obs: (bool) whether all agents share the same centralized observation.
:use_avail_acts: (bool) whether to store what actions are available.
:param use_reward_normalization: (bool) whether to use reward normalization.
"""
self.buffer_size = buffer_size
self.episode_length = episode_length
self.num_agents = num_agents
self.use_same_share_obs = use_same_share_obs
self.use_avail_acts = use_avail_acts
self.use_reward_normalization = use_reward_normalization
self.filled_i = 0
self.current_i = 0
# obs
if obs_space.__class__.__name__ == 'Box':
obs_shape = obs_space.shape
share_obs_shape = share_obs_space.shape
elif obs_space.__class__.__name__ == 'list':
obs_shape = obs_space
share_obs_shape = share_obs_space
else:
raise NotImplementedError
self.obs = np.zeros((self.episode_length + 1, self.buffer_size,
self.num_agents, obs_shape[0]), dtype=np.float32)
if self.use_same_share_obs:
self.share_obs = np.zeros((self.episode_length + 1, self.buffer_size, share_obs_shape[0]), dtype=np.float32)
else:
self.share_obs = np.zeros((self.episode_length + 1, self.buffer_size, self.num_agents, share_obs_shape[0]),
dtype=np.float32)
# action
act_dim = np.sum(get_dim_from_space(act_space))
self.acts = np.zeros((self.episode_length, self.buffer_size, self.num_agents, act_dim), dtype=np.float32)
if self.use_avail_acts:
self.avail_acts = np.ones((self.episode_length + 1, self.buffer_size, self.num_agents, act_dim),
dtype=np.float32)
# rewards
self.rewards = np.zeros((self.episode_length, self.buffer_size, self.num_agents, 1), dtype=np.float32)
# default to done being True
self.dones = np.ones_like(self.rewards, dtype=np.float32)
self.dones_env = np.ones((self.episode_length, self.buffer_size, 1), dtype=np.float32)
def __len__(self):
return self.filled_i
def insert(self, num_insert_episodes, obs, share_obs, acts, rewards, dones, dones_env, avail_acts=None):
"""
Insert a set of episodes corresponding to this policy into buffer. If the buffer size overflows, old transitions are dropped.
:param num_insert_steps: (int) number of transitions to be added to buffer
:param obs: (np.ndarray) observations of agents corresponding to this policy.
:param share_obs: (np.ndarray) centralized observations of agents corresponding to this policy.
:param acts: (np.ndarray) actions of agents corresponding to this policy.
:param rewards: (np.ndarray) rewards of agents corresponding to this policy.
:param dones: (np.ndarray) terminal status of agents corresponding to this policy.
:param dones_env: (np.ndarray) environment terminal status.
:param valid_transition: (np.ndarray) whether each transition is valid or not (invalid if agent was dead during transition)
:param avail_acts: (np.ndarray) available actions of agents corresponding to this policy.
:return: (np.ndarray) indexes of the buffer the new transitions were placed in.
"""
# obs: [step, episode, agent, dim]
episode_length = acts.shape[0]
assert episode_length == self.episode_length, ("different dimension!")
if self.current_i + num_insert_episodes <= self.buffer_size:
idx_range = np.arange(self.current_i, self.current_i + num_insert_episodes)
else:
num_left_episodes = self.current_i + num_insert_episodes - self.buffer_size
idx_range = np.concatenate((np.arange(self.current_i, self.buffer_size), np.arange(num_left_episodes)))
if self.use_same_share_obs:
# remove agent dimension since all agents share centralized observation
share_obs = share_obs[:, :, 0]
self.obs[:, idx_range] = obs.copy()
self.share_obs[:, idx_range] = share_obs.copy()
self.acts[:, idx_range] = acts.copy()
self.rewards[:, idx_range] = rewards.copy()
self.dones[:, idx_range] = dones.copy()
self.dones_env[:, idx_range] = dones_env.copy()
if self.use_avail_acts:
self.avail_acts[:, idx_range] = avail_acts.copy()
self.current_i = idx_range[-1] + 1
self.filled_i = min(self.filled_i + len(idx_range), self.buffer_size)
return idx_range
def sample_inds(self, sample_inds):
"""
Sample a set of transitions from buffer from the specified indices.
:param sample_inds: (np.ndarray) indices of samples to return from buffer.
:return: obs: (np.ndarray) sampled observations corresponding to that policy
:return: share_obs: (np.ndarray) sampled observations corresponding to that policy
:return: acts: (np.ndarray) sampled actions corresponding to that policy
:return: rewards: (np.ndarray) sampled rewards corresponding to that policy
:return: dones: (np.ndarray) sampled terminal status of agents corresponding to that policy
:return: dones_env: (np.ndarray) sampled environment terminal status corresponding to that policy
:return: valid_transition: (np.ndarray) whether each sampled transition in episodes are valid or not (invalid if corresponding agent is dead)
:return: avail_acts: (np.ndarray) sampled available actions corresponding to that policy
"""
obs = _cast(self.obs[:, sample_inds])
acts = _cast(self.acts[:, sample_inds])
if self.use_reward_normalization:
# mean std
# [length, envs, agents, 1]
# [length, envs, 1]
all_dones_env = np.tile(np.expand_dims(
self.dones_env[:, :self.filled_i], -1), (1, 1, self.num_agents, 1))
first_step_dones_env = np.zeros((1, self.filled_i, self.num_agents, 1))
curr_dones_env = np.concatenate((first_step_dones_env, all_dones_env[:self.episode_length - 1]))
temp_rewards = self.rewards[:, :self.filled_i].copy()
temp_rewards[curr_dones_env == 1.0] = np.nan
mean_reward = np.nanmean(temp_rewards)
std_reward = np.nanstd(temp_rewards)
rewards = _cast(
(self.rewards[:, sample_inds] - mean_reward) / std_reward)
else:
rewards = _cast(self.rewards[:, sample_inds])
if self.use_same_share_obs:
share_obs = self.share_obs[:, sample_inds]
else:
share_obs = _cast(self.share_obs[:, sample_inds])
dones = _cast(self.dones[:, sample_inds])
dones_env = self.dones_env[:, sample_inds]
if self.use_avail_acts:
avail_acts = _cast(self.avail_acts[:, sample_inds])
else:
avail_acts = None
return obs, share_obs, acts, rewards, dones, dones_env, avail_acts
class PrioritizedRecReplayBuffer(RecReplayBuffer):
def __init__(self, alpha, policy_info, policy_agents, buffer_size, episode_length, use_same_share_obs,
use_avail_acts, use_reward_normalization=False):
"""Prioritized replay buffer class for training RNN policies. See parent class."""
super(PrioritizedRecReplayBuffer, self).__init__(policy_info, policy_agents, buffer_size,
episode_length, use_same_share_obs, use_avail_acts,
use_reward_normalization)
self.alpha = alpha
self.policy_info = policy_info
it_capacity = 1
while it_capacity < buffer_size:
it_capacity *= 2
self._it_sums = {p_id: SumSegmentTree(
it_capacity) for p_id in self.policy_info.keys()}
self._it_mins = {p_id: MinSegmentTree(
it_capacity) for p_id in self.policy_info.keys()}
self.max_priorities = {p_id: 1.0 for p_id in self.policy_info.keys()}
def insert(self, num_insert_episodes, obs, share_obs, acts, rewards, dones, dones_env, avail_acts=None):
"""See parent class."""
idx_range = super().insert(num_insert_episodes, obs, share_obs, acts, rewards, dones, dones_env, avail_acts)
for idx in range(idx_range[0], idx_range[1]):
for p_id in self.policy_info.keys():
self._it_sums[p_id][idx] = self.max_priorities[p_id] ** self.alpha
self._it_mins[p_id][idx] = self.max_priorities[p_id] ** self.alpha
return idx_range
def _sample_proportional(self, batch_size, p_id=None):
total = self._it_sums[p_id].sum(0, len(self) - 1)
mass = np.random.random(size=batch_size) * total
idx = self._it_sums[p_id].find_prefixsum_idx(mass)
return idx
def sample(self, batch_size, beta=0, p_id=None):
"""
Sample a set of episodes from buffer; probability of choosing a given episode is proportional to its priority.
:param batch_size: (int) number of episodes to sample.
:param beta: (float) controls the amount of prioritization to apply.
:param p_id: (str) policy which will be updated using the samples.
:return: See parent class.
"""
assert len(
self) > batch_size, "Cannot sample with no completed episodes in the buffer!"
assert beta > 0
batch_inds = self._sample_proportional(batch_size, p_id)
p_min = self._it_mins[p_id].min() / self._it_sums[p_id].sum()
max_weight = (p_min * len(self)) ** (-beta)
p_sample = self._it_sums[p_id][batch_inds] / self._it_sums[p_id].sum()
weights = (p_sample * len(self)) ** (-beta) / max_weight
obs, share_obs, acts, rewards, dones, dones_env, avail_acts = {}, {}, {}, {}, {}, {}, {}
for p_id in self.policy_info.keys():
p_buffer = self.policy_buffers[p_id]
obs[p_id], share_obs[p_id], acts[p_id], rewards[p_id], dones[p_id], dones_env[p_id], avail_acts[
p_id] = p_buffer.sample_inds(batch_inds)
return obs, share_obs, acts, rewards, dones, dones_env, avail_acts, weights, batch_inds
def update_priorities(self, idxes, priorities, p_id=None):
"""
Update priorities of sampled transitions.
sets priority of transition at index idxes[i] in buffer
to priorities[i].
:param idxes: ([int]) List of idxes of sampled transitions
:param priorities: ([float]) List of updated priorities corresponding to transitions at the sampled idxes
denoted by variable `idxes`.
"""
assert len(idxes) == len(priorities)
assert np.min(priorities) > 0
assert np.min(idxes) >= 0
assert np.max(idxes) < len(self)
self._it_sums[p_id][idxes] = priorities ** self.alpha
self._it_mins[p_id][idxes] = priorities ** self.alpha
self.max_priorities[p_id] = max(
self.max_priorities[p_id], np.max(priorities))
================================================
FILE: examples/Social_Cognition/ToCM/utils/segment_tree.py
================================================
import numpy as np
def unique(sorted_array):
"""
More efficient implementation of np.unique for sorted arrays
:param sorted_array: (np.ndarray)
:return:(np.ndarray) sorted_array without duplicate elements
"""
if len(sorted_array) == 1:
return sorted_array
left = sorted_array[:-1]
right = sorted_array[1:]
uniques = np.append(right != left, True)
return sorted_array[uniques]
class SegmentTree(object):
def __init__(self, capacity, operation, neutral_element):
"""
Build a Segment Tree data structure.
https://en.wikipedia.org/wiki/Segment_tree
Can be used as regular array that supports Index arrays, but with two
important differences:
a) setting item's value is slightly slower.
It is O(lg capacity) instead of O(1).
b) user has access to an efficient ( O(log segment size) )
`reduce` operation which reduces `operation` over
a contiguous subsequence of items in the array.
:param capacity: (int) Total size of the array - must be a power of two.
:param operation: (lambda (Any, Any): Any) operation for combining elements (eg. sum, max) must form a
mathematical group together with the set of possible values for array elements (i.e. be associative)
:param neutral_element: (Any) neutral element for the operation above. eg. float('-inf') for max and 0 for sum.
"""
assert capacity > 0 and capacity & (
capacity - 1) == 0, "capacity must be positive and a power of 2."
self._capacity = capacity
self._value = [neutral_element for _ in range(2 * capacity)]
self._operation = operation
self.neutral_element = neutral_element
def _reduce_helper(self, start, end, node, node_start, node_end):
if start == node_start and end == node_end:
return self._value[node]
mid = (node_start + node_end) // 2
if end <= mid:
return self._reduce_helper(start, end, 2 * node, node_start, mid)
else:
if mid + 1 <= start:
return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end)
else:
return self._operation(
self._reduce_helper(start, mid, 2 * node, node_start, mid),
self._reduce_helper(
mid + 1, end, 2 * node + 1, mid + 1, node_end)
)
def reduce(self, start=0, end=None):
"""
Returns result of applying `self.operation`
to a contiguous subsequence of the array.
self.operation(arr[start], operation(arr[start+1], operation(... arr[end])))
:param start: (int) beginning of the subsequence
:param end: (int) end of the subsequences
:return: (Any) result of reducing self.operation over the specified range of array elements.
"""
if end is None:
end = self._capacity
if end < 0:
end += self._capacity
end -= 1
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
def __setitem__(self, idx, val):
# indexes of the leaf
idxs = idx + self._capacity
self._value[idxs] = val
if isinstance(idxs, int):
idxs = np.array([idxs])
# go up one level in the tree and remove duplicate indexes
idxs = unique(idxs // 2)
while len(idxs) > 1 or idxs[0] > 0:
# as long as there are non-zero indexes, update the corresponding values
self._value[idxs] = self._operation(
self._value[2 * idxs],
self._value[2 * idxs + 1]
)
# go up one level in the tree and remove duplicate indexes
idxs = unique(idxs // 2)
def __getitem__(self, idx):
assert np.max(idx) < self._capacity
assert 0 <= np.min(idx)
return self._value[self._capacity + idx]
class SumSegmentTree(SegmentTree):
def __init__(self, capacity):
super(SumSegmentTree, self).__init__(
capacity=capacity,
operation=np.add,
neutral_element=0.0
)
self._value = np.array(self._value)
def sum(self, start=0, end=None):
"""
Returns arr[start] + ... + arr[end]
:param start: (int) start position of the reduction (must be >= 0)
:param end: (int) end position of the reduction (must be < len(arr), can be None for len(arr) - 1)
:return: (Any) reduction of SumSegmentTree
"""
return super(SumSegmentTree, self).reduce(start, end)
def find_prefixsum_idx(self, prefixsum):
"""
Find the highest index `i` in the array such that
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum for each entry in prefixsum
if array values are probabilities, this function
allows to sample indexes according to the discrete
probability efficiently.
:param prefixsum: (np.ndarray) float upper bounds on the sum of array prefix
:return: (np.ndarray) highest indexes satisfying the prefixsum constraint
"""
if isinstance(prefixsum, float):
prefixsum = np.array([prefixsum])
assert 0 <= np.min(prefixsum)
assert np.max(prefixsum) <= self.sum() + 1e-5
assert isinstance(prefixsum[0], float)
idx = np.ones(len(prefixsum), dtype=int)
cont = np.ones(len(prefixsum), dtype=bool)
while np.any(cont): # while not all nodes are leafs
idx[cont] = 2 * idx[cont]
prefixsum_new = np.where(
self._value[idx] <= prefixsum, prefixsum - self._value[idx], prefixsum)
# prepare update of prefixsum for all right children
idx = np.where(np.logical_or(
self._value[idx] > prefixsum, np.logical_not(cont)), idx, idx + 1)
# Select child node for non-leaf nodes
prefixsum = prefixsum_new
# update prefixsum
cont = idx < self._capacity
# collect leafs
return idx - self._capacity
class MinSegmentTree(SegmentTree):
def __init__(self, capacity):
super(MinSegmentTree, self).__init__(
capacity=capacity,
operation=np.minimum,
neutral_element=float('inf')
)
self._value = np.array(self._value)
def min(self, start=0, end=None):
"""
Returns min(arr[start], ..., arr[end])
:param start: (int) start position of the reduction (must be >= 0)
:param end: (int) end position of the reduction (must be < len(arr), can be None for len(arr) - 1)
:return: (Any) reduction of MinSegmentTree
"""
return super(MinSegmentTree, self).reduce(start, end)
================================================
FILE: examples/Social_Cognition/ToCM/utils/util.py
================================================
import copy
import gym
import numpy as np
from gym.spaces import Box, Discrete, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.autograd import Variable
def to_torch(input):
return torch.from_numpy(input) if type(input) == np.ndarray else input
def to_numpy(x):
return x.detach().cpu().numpy()
class FixedCategorical(torch.distributions.Categorical):
def sample(self):
return super().sample()
def log_probs(self, actions):
return (
super()
.log_prob(actions.squeeze(-1))
.view(actions.size(0), -1)
.sum(-1)
.unsqueeze(-1)
)
def mode(self):
return self.probs.argmax(dim=-1, keepdim=True)
class MultiDiscrete(gym.Space):
"""
- The multi-discrete action space consists of a series of discrete action spaces with different parameters
- It can be adapted to both a Discrete action space or a continuous (Box) action space
- It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space
- It is parametrized by passing an array of arrays containing [min, max] for each discrete action space
where the discrete action space can take any integers from `min` to `max` (both inclusive)
Note: A value of 0 always need to represent the NOOP action.
e.g. Nintendo Game Controller
- Can be conceptualized as 3 discrete action spaces:
1) Arrow Keys: Discrete 5 - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4] - params: min: 0, max: 4
2) Button A: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1
3) Button B: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1
- Can be initialized as
MultiDiscrete([ [0,4], [0,1], [0,1] ])
"""
def __init__(self, array_of_param_array):
self.low = np.array([x[0] for x in array_of_param_array])
self.high = np.array([x[1] for x in array_of_param_array])
self.num_discrete_space = self.low.shape[0]
self.n = np.sum(self.high) + 2
def sample(self):
""" Returns a array with one sample from each discrete action space """
# For each row: round(random .* (max - min) + min, 0)
random_array = np.random.rand(self.num_discrete_space)
return [int(x) for x in np.floor(np.multiply((self.high - self.low + 1.), random_array) + self.low)]
def contains(self, x):
return len(x) == self.num_discrete_space and (np.array(x) >= self.low).all() and (np.array(x) <= self.high).all()
@property
def shape(self):
return self.num_discrete_space
def __repr__(self):
return "MultiDiscrete" + str(self.num_discrete_space)
def __eq__(self, other):
return np.array_equal(self.low, other.low) and np.array_equal(self.high, other.high)
class DecayThenFlatSchedule():
def __init__(self,
start,
finish,
time_length,
decay="exp"):
self.start = start
self.finish = finish
self.time_length = time_length
self.delta = (self.start - self.finish) / self.time_length
self.decay = decay
if self.decay in ["exp"]:
self.exp_scaling = (-1) * self.time_length / \
np.log(self.finish) if self.finish > 0 else 1
def eval(self, T):
if self.decay in ["linear"]:
return max(self.finish, self.start - self.delta * T)
elif self.decay in ["exp"]:
return min(self.start, max(self.finish, np.exp(- T / self.exp_scaling)))
pass
def huber_loss(e, d):
a = (abs(e) <= d).float()
b = (abs(e) > d).float()
return a*e**2/2 + b*d*(abs(e)-d/2)
def mse_loss(e):
return e**2
def init(module, weight_init, bias_init, gain=1):
weight_init(module.weight.data, gain=gain)
bias_init(module.bias.data)
return module
def get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
# https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py#L11
def soft_update(target, source, tau):
"""
Perform DDPG soft update (move target params toward source based on weight
factor tau)
Inputs:
target (torch.nn.Module): Net to copy parameters to
source (torch.nn.Module): Net whose parameters to copy
tau (float, 0 < x < 1): Weight factor for update
"""
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(
target_param.data * (1.0 - tau) + param.data * tau)
# https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py#L15
def hard_update(target, source):
"""
Copy network parameters from source to target
Inputs:
target (torch.nn.Module): Net to copy parameters to
source (torch.nn.Module): Net whose parameters to copy
"""
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)
# https://github.com/seba-1511/dist_tuto.pth/blob/gh-pages/train_dist.py
def average_gradients(model):
""" Gradient averaging. """
size = float(dist.get_world_size())
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, group=0)
param.grad.data /= size
def onehot_from_logits(logits, avail_logits=None, eps=0.0):
"""
Given batch of logits, return one-hot sample using epsilon greedy strategy
(based on given epsilon)
"""
# get best (according to current policy) actions in one-hot form
logits = to_torch(logits)
dim = len(logits.shape) - 1
if avail_logits is not None:
avail_logits = to_torch(avail_logits)
logits[avail_logits == 0] = -1e10
argmax_acs = (logits == logits.max(dim, keepdim=True)[0]).float()
if eps == 0.0:
return argmax_acs
# get random actions in one-hot form
rand_acs = Variable(torch.eye(logits.shape[1])[[np.random.choice(range(logits.shape[1]), size=logits.shape[0])]], requires_grad=False)
# chooses between best and random actions using epsilon greedy
return torch.stack([argmax_acs[i] if r > eps else rand_acs[i] for i, r in
enumerate(torch.rand(logits.shape[0]))])
# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb
def sample_gumbel(shape, eps=1e-20, tens_type=torch.FloatTensor):
"""Sample from Gumbel(0, 1)"""
U = Variable(tens_type(*shape).uniform_(), requires_grad=False)
return -torch.log(-torch.log(U + eps) + eps)
# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb
def gumbel_softmax_sample(logits, avail_logits, temperature, device=torch.device('cpu')):
""" Draw a sample from the Gumbel-Softmax distribution"""
if str(device) == 'cpu':
y = logits + sample_gumbel(logits.shape, tens_type=type(logits.data))
else:
y = (logits.cpu() + sample_gumbel(logits.shape,
tens_type=type(logits.data))).cuda()
dim = len(logits.shape) - 1
if avail_logits is not None:
avail_logits = to_torch(avail_logits).to(device)
y[avail_logits==0] = -1e10
return F.softmax(y / temperature, dim=dim)
# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb
def gumbel_softmax(logits, avail_logits=None, temperature=1.0, hard=False, device=torch.device('cpu')):
"""Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits: [batch_size, n_class] unnormalized log-probs
temperature: non-negative scalar
hard: if True, take argmax, but differentiate w.r.t. soft sample y
Returns:
[batch_size, n_class] sample from the Gumbel-Softmax distribution.
If hard=True, then the returned sample will be one-hot, otherwise it will
be a probabilitiy distribution that sums to 1 across classes
"""
y = gumbel_softmax_sample(logits, avail_logits, temperature, device)
if hard:
y_hard = onehot_from_logits(y)
y = (y_hard - y).detach() + y
return y
def gaussian_noise(shape, std):
return torch.empty(shape).normal_(mean=0, std=std)
def get_obs_shape(obs_space):
if obs_space.__class__.__name__ == "Box":
obs_shape = obs_space.shape
elif obs_space.__class__.__name__ == "list":
obs_shape = obs_space
else:
raise NotImplementedError
return obs_shape
def get_dim_from_space(space):
if isinstance(space, Box):
dim = space.shape[0]
elif isinstance(space, Discrete):
dim = space.n
elif isinstance(space, Tuple):
dim = sum([get_dim_from_space(sp) for sp in space])
elif "MultiDiscrete" in space.__class__.__name__:
return (space.high - space.low) + 1
elif isinstance(space, list):
dim = space[0]
else:
raise Exception("Unrecognized space: ", type(space))
return dim
def get_state_dim(observation_dict, action_dict):
combined_obs_dim = sum([get_dim_from_space(space)
for space in observation_dict.values()])
combined_act_dim = 0
for space in action_dict.values():
dim = get_dim_from_space(space)
if isinstance(dim, np.ndarray):
combined_act_dim += int(sum(dim))
else:
combined_act_dim += dim
return combined_obs_dim, combined_act_dim, combined_obs_dim+combined_act_dim
def get_cent_act_dim(action_space):
cent_act_dim = 0
for space in action_space:
dim = get_dim_from_space(space)
if isinstance(dim, np.ndarray):
cent_act_dim += int(sum(dim))
else:
cent_act_dim += dim
return cent_act_dim
def is_discrete(space):
if isinstance(space, Discrete) or "MultiDiscrete" in space.__class__.__name__:
return True
else:
return False
def is_multidiscrete(space):
if "MultiDiscrete" in space.__class__.__name__:
return True
else:
return False
def make_onehot(int_action, action_dim, seq_len=None):
if type(int_action) == torch.Tensor:
int_action = int_action.cpu().numpy()
if not seq_len:
return np.eye(action_dim)[int_action]
if seq_len:
onehot_actions = []
for i in range(seq_len):
onehot_action = np.eye(action_dim)[int_action[i]]
onehot_actions.append(onehot_action)
return np.stack(onehot_actions)
def avail_choose(x, avail_x=None):
x = to_torch(x)
if avail_x is not None:
avail_x = to_torch(avail_x)
x[avail_x == 0] = -1e10
return x#FixedCategorical(logits=x)
def tile_images(img_nhwc):
"""
Tile N images into one big PxQ image
(P,Q) are chosen to be as close as possible, and if N
is square, then P=Q.
input: img_nhwc, list or array of images, ndim=4 once turned into array
n = batch index, h = height, w = width, c = channel
returns:
bigim_HWc, ndarray with ndim=3
"""
img_nhwc = np.asarray(img_nhwc)
N, h, w, c = img_nhwc.shape
H = int(np.ceil(np.sqrt(N)))
W = int(np.ceil(float(N)/H))
img_nhwc = np.array(
list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)])
img_HWhwc = img_nhwc.reshape(H, W, h, w, c)
img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)
img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c)
return img_Hh_Ww_c
================================================
FILE: examples/Social_Cognition/ToM/BrainArea/PFC_ToM.py
================================================
from braincog.base.learningrule.STDP import *
from braincog.base.brainarea.PFC import dlPFC
from utils.Encoder import *
#exploit or explore
num_enpop = 6
num_depop = 10
greedy = 0.8#0.5
#state
A_state = 4
N_state = 6
cell_num = 6
#action
C=10
class PFC_ToM(dlPFC):
"""
SNNLinear
"""
def __init__(self,
step,
encode_type,
in_features:int,
out_features:int,
bias,
node,
num_state,
greedy=0.8,
*args,
**kwargs):
super().__init__(step, encode_type, in_features, out_features, bias, *args, **kwargs)
self.encoder = PopEncoder(self.step, encode_type)
self.encoder.device = torch.device('cpu')
self.bias = bias
self.in_features = in_features
self.out_features = out_features
self.node1 = node(threshold=0.5, tau=2.)
self.node_name1 = node
self.node2 = node(threshold=0.5, tau=2.)
self.node_name2 = node
self.num_state = num_state
self.greedy = greedy
self.fc = self._create_fc()
self.c = self._rest_c()
def _rest_c(self):
c = torch.rand((self.out_features, self.in_features)) # eligibility trace
return c
def _create_fc(self):
"""
the connection of the SNN linear
@return: nn.Linear
"""
fc = nn.Linear(in_features=self.in_features,
out_features=self.out_features, bias=self.bias)
return fc
def update_c(self, c, dw, tau_c=0.2):
"""
update the trace of eligibility
@param c: a tensor to record eligibility
@param dw: the results of STDP
@param tau_c: the parameter of trace decay
@return: a update tensor to record eligibility
Equation:
delta_c = (-(c / tau_c) + dw) * dela_t
c = c + delta_c
reference:
"""
# delta_c = -(c / tau_c) + dw #dela_t = 1 ignore
# c = c + delta_c
c = c + tau_c * dw
return c
def _call_reward(self, R, c, s, T_map): # eligibility
"""
R-STDP
@param R: reward
@param c: a tensor to record eligibility
@param s: weight of network
@param T_map: the mapping of the state-action pair
@return: update weight of network
Equation:
delta_s = c * reward
s = s + delta_s
reference:
"""
c[c > 0] = c[c > 0] * R * 1
c[c <= 0] = - c[c <= 0] * R * 1
c = c.clamp(min=-1, max=1)
# print('before',s[:, torch.where(T_map.gt(0))[1][0]])
s = s + c * T_map
# # print('after',s[:, torch.where(T_map.gt(0))[1][0]])
s = (s - s.min(dim=0).values.unsqueeze(dim=1).T.detach().repeat(s.shape[0], 1)) / (
s.max(dim=0).values.unsqueeze(dim=1).T.detach().repeat(s.shape[0], 1) -
s.min(dim=0).values.unsqueeze(dim=1).T.detach().repeat(s.shape[0], 1)
)
# s = s * 0.5
return s
def update_s(self, R, mapping):
T_map = torch.zeros((self.out_features, self.in_features))
T_map[mapping['action']*C:mapping['action']*C+C,\
torch.where(torch.tensor(self.encoder(mapping['state'],\
self.in_features, self.num_state)[:, 0]).gt(0))]=1
self.fc.weight.data = self._call_reward(R, self.c, self.fc.weight.data, T_map)
# print(mapping, 'mapping')
def forward(self, inputs, num_action, episode):
"""
decision
@param inputs: state
@param num_action: num_action # consider to delete
@return: action
"""
inputs = self.encoder(inputs, self.in_features, self.num_state)
count_group = torch.zeros(num_action)
stdp = STDP(self.node2, self.fc, decay=0.80)
# self.c = self._rest_c()
# stdp.connection.weight.data = torch.rand((self.out_features, self.in_features))
for t in range(self.step):
l1_in = torch.tensor(inputs[:, t])
l1_out = self.node1(l1_in).unsqueeze(0) #pre : l1_out
l2_out, dw = stdp(l1_out) #dw -- STDP
self.c = self.update_c(self.c, dw[0])
# l2_out = l2_out.T
for i in range(num_action):
count_group[i] = l2_out.T[i * num_depop:(i + 1) * num_depop].sum()
# exploration or exploitation
epsilon = random.random()
if epsilon < self.greedy + episode * 0.004:#:
action = count_group.argmax()
else:
action = torch.tensor(random.randint(0, 3))
return action.item()
================================================
FILE: examples/Social_Cognition/ToM/BrainArea/TPJ.py
================================================
import torch
from braincog.base.brainarea.Insula import *
from rulebasedpolicy.world_model import *
from BrainArea.dACC import *
from BrainArea.PFC_ToM import *
NPC_1 = 2
NPC_2 = 3
Agent = 4
#exploit or explore
num_enpop = 6
num_depop = 10
greedy = 0.8#0.5
#state
A_state = 4
N_state = 6
cell_num = 6
#action
C=10
class ToM:
def __init__(self, env):
"""
@param axis:输入为agent自己的观察到位置信息
@param obs:遮挡关系
"""
self.axis = None
self.obs = None
self.NPC_num = None
self.env = env
self.env.trigger = 0
def TPJ(self, NPC_num, axis, obs):
"""
perspective_taking
agent take NPC2's perspective
@param NPC_num: which NPC?
@return:
axis_new:站在other的角度看到其他智能体的遮挡关系,return axis,
axis_switch:站在self的角度看到其他智能体的遮挡关系,return axis
obs_switch:站在other的角度看到其他智能体的遮挡关系,return obs
"""
self.env.trigger = 0
axis_switch = [[6,6], [6,6], [6,6]]
axis_new = [[6, 6], [6, 6], [6, 6]]
self.axis = axis
self.obs = obs
axis_switch[0], axis_switch[NPC_num] = axis[NPC_num], axis[0]
axis_switch[1] = axis[1]
obs_switch = big_env(self.obs)
obs_switch[self.axis[0][0], self.axis[0][1]], obs_switch[self.axis[NPC_num][0],self.axis[NPC_num][1]] = \
obs_switch[self.axis[NPC_num][0],self.axis[NPC_num][1]],obs_switch[self.axis[0][0], self.axis[0][1]]
x = np.argwhere((obs_switch==2)|(obs_switch==8))
if self.axis[NPC_num][0] != 6 or self.axis[NPC_num][1] != 6:
shelter_obs = shelter_env(obs_switch[1:6,1:6])
obs_switch[1:6,1:6], m = self.gain_obs(a=obs_switch[1:6,1:6], aa=shelter_obs, b=axis_switch[1], c=axis_switch[2], bb=2,cc=4)
if m == True:
axis_switch[1] = [6,6]
else:
obs_switch = []
axis_new[0] = axis_switch[NPC_num]
axis_new[1] = axis_switch[1]
axis_new[NPC_num] = axis_switch[0]
return axis_new, axis_switch, obs_switch
def gain_obs(self, a,aa,b,c,bb,cc):
m = False
if b!=[6,6]:
if aa[b[0]-1, b[1]-1] == 0:
a[b[0]-1, b[1]-1] = 1#2
m = True
else:
a[b[0] - 1, b[1] - 1] = bb
if aa[c[0]-1, c[1]-1] ==0:
# print('-------')
a[c[0]-1, c[1]-1] = 1#4
else:
a[c[0] - 1, c[1] - 1] = cc
return a, m
def belief_reasoning(self, test_x, net_NPC, num_action, episode):
output = net_NPC(inputs=test_x, \
num_action=num_action, \
episode=episode)
return output
def state_evaluation(self, prediction_next_state):
"""
state_evaluation
@param prediction_next_state:
@return:
"""
input = np.array(prediction_next_state)
test_x = torch.tensor([[(int(bool(input[0][0] - input[2][0])))*10, (int(bool(input[0][1] - input[2][1])))*10]])
T = 5
num_popneurons = 2
safety = 2
dACC_net = dACC(step=T, encode_type='rate', bias=True,
in_features=num_popneurons, out_features=safety,
node=node.LIFNode)
dACC_net.load_state_dict(torch.load(os.path.join(sys.path[0], 'BrainArea/checkpoint', 'dACC_net.pth'))['dacc'])
output = dACC_net(inputs=test_x, epoch=50)
output = bool(int(output[0].cpu().detach().numpy().tolist()))
print(output,test_x)
return output
def prediction_state(self, axis_new, axis, action_NPC1, net, num_action, episode):
"""
根据当前状态和经验预测下一个状态
@return:下一个step的state
"""
self.env.trigger = 0
action_move = {
0: (0, -1),
1: (0, 1),
2: (-1, 0),
3: (1, 0),
4: (0, 0)
}
next_axis = [[6,6],[6,6],[6,6]]
# inputspike_test = np.array([axis_new[0],axis_new[1],axis_new[2]])
inputspike_test = sum(axis_new, [])
action_NPC2 = self.belief_reasoning(test_x=inputspike_test, net_NPC=net, num_action=num_action, episode=episode)
action_agent = 3
#NPC_1
next_axis[1][0] = axis[1][0] + action_move[action_NPC1][1]
next_axis[1][1] = axis[1][1] + action_move[action_NPC1][0]
#NPC_2
if self.obs[axis[2][0] + action_move[action_NPC2][1]-1, axis[2][1] + action_move[action_NPC2][0]-1] != 5:
next_axis[2][0] = axis[2][0] + action_move[action_NPC2][1]
next_axis[2][1] = axis[2][1] + action_move[action_NPC2][0]
#NPC_agent
next_axis[0][0] = axis[0][0] + action_move[action_agent][1]
next_axis[0][1] = axis[0][1] + action_move[action_agent][0]
return next_axis
def altruism(self, axis_switch, axis_NPC, n_actions):
"""
假设有一个开关,agent按下去可以让NPC不动
Q_bad:NPC的错误观测的有偏差Q
Q_good:正确的Q
Q_delta:中最小的值就是容易导致NPC出现危险的值
找到最小危险中的最大值对应的action
@param axis_switch:
@param axis_NPC:
@param n_actions:
@return:下一个step的action
"""
actions = list(range(n_actions))
action_NPC_list = list(range(n_actions))
#others' view
data_NPC = pd.read_csv('./data/NPC_assessment.csv', index_col=[0],
dtype={1: np.float64, 2: np.float64, 3: np.float64, 4: np.float64,
5: np.float64})
#self's view
data_agent = pd.read_csv('./data/agent_assessment.csv', index_col=[0],
dtype={1: np.float64, 2: np.float64, 3: np.float64, 4: np.float64,
5: np.float64})
# print(axis_NPC, axis_switch)
Q_bad = data_NPC.loc[str(axis_NPC), :]
if str(axis_switch) not in data_agent.index:
# append new state to q table
# print('1')
data_agent = data_agent.append(
pd.Series(
[0] * len(list(range(self.env.n_actions))),
index=data_agent.columns,
name=str(axis_switch),
)
)
Q_good = data_agent.loc[str(axis_switch), :]
Q_delta = Q_good - Q_bad
# print(Q_delta)
# max_Q_delta = [None] * n_actions
min_Q_delta_set = []
#stop
for action_a in actions:
if action_a == 4:
action_NPC_list = [4]
min_Q_delta = []
for i in action_NPC_list:
#
# print(i)
min_Q_delta.append(Q_delta[i])
min_Q_delta_set.append(min(min_Q_delta))
# print(min_Q_delta_set,'---------')
action_altruism = min_Q_delta_set.index(max(min_Q_delta_set))
if action_altruism == 4:
self.env.trigger = 1
# print('---------------------------------------------')
# env.SHOW()
# time.sleep(1.0)
# max_Q_delta[action_a] = max(min_Q_delta_set)
return action_altruism
def INS(self, axis1, axis2):
num_IPLM = axis1.shape[1]
num_IPLV = axis1.shape[1]
Insula_connection = []
# IPLV-Insula
con_matrix0 = torch.eye(num_IPLM, dtype=torch.float) * 2
Insula_connection.append(CustomLinear(con_matrix0))
# STS-Insula
con_matrix1 = torch.eye(num_IPLV, dtype=torch.float) * 2
Insula_connection.append(CustomLinear(con_matrix1))
Insula = InsulaNet(Insula_connection)
confidence = 0
Insula.reset()
for t in range(2):
Insula((axis1-axis2) * 10, torch.zeros_like(axis1) * 10)
if sum(sum(Insula.out_Insula)) > 0:
confidence = confidence + 1
return confidence
================================================
FILE: examples/Social_Cognition/ToM/BrainArea/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/ToM/BrainArea/dACC.py
================================================
import torch
import matplotlib.pyplot as plt
import numpy as np
np.set_printoptions(threshold=np.inf)
from utils.one_hot import *
import os
import time
import sys
from tqdm import tqdm
from braincog.model_zoo.base_module import BaseLinearModule, BaseModule
from braincog.base.learningrule.STDP import *
import sys
sys.path.append("..")
class dACC(BaseModule):
"""
SNNLinear
"""
def __init__(self,
step,
encode_type,
in_features:int,
out_features:int,
bias,
node,
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.bias = bias
self.in_features = in_features
self.out_features = out_features
self.node1 = node(threshold=0.5, tau=2.)
self.node_name1 = node
self.node2 = node(threshold=0.1, tau=2.)
self.node_name2 = node
self.fc = self._create_fc()
self.c = self._rest_c()
def _rest_c(self):
c = torch.rand((self.out_features, self.in_features)) # eligibility trace
return c
def _create_fc(self):
"""
the connection of the SNN linear
@return: nn.Linear
"""
fc = nn.Linear(in_features=self.in_features,
out_features=self.out_features, bias=self.bias)
return fc
def update_c(self, c, STDP, tau_c=0.2):
"""
update the trace of eligibility
@param c: a tensor to record eligibility
@param STDP: the results of STDP
@param tau_c: the parameter of trace decay
@return: a update tensor to record eligibility
Equation:
delta_c = (-(c / tau_c) + STDP) * dela_t
c = c + delta_c
reference:
"""
c = c + tau_c * STDP
return c
def forward(self, inputs, epoch):
"""
decision
@param inputs: state
@return: action
"""
output = []
stdp = STDP(self.node2, self.fc, decay=0.80)
self.c = self._rest_c()
# stdp.connection.weight.data = torch.rand((self.out_features, self.in_features))
for i in range(inputs.shape[0]):
for t in range(self.step):
l1_in = torch.tensor(inputs[i, :])
l1_out = self.node1(l1_in).unsqueeze(0) #pre : l1_out
l2_out, dw = stdp(l1_out) #dw -- STDP
self.c = self.update_c(self.c, dw[0])
output.append(torch.min(l2_out))
# output.append((l2_out.any() == 0).cpu().detach().numpy().tolist())
return output
# if __name__ == '__main__':
# np.random.seed(6)
# T = 5
# num_popneurons = 2
# safety = 2
# epoch = 50
# file_name = "/home/zhaozhuoya/braincog/examples/ToM/data/injury_value.txt"
# state = []
# with open(file_name) as f:
# data = []
# data_split = f.readlines() #
# for i in data_split:
# state.append(one_hot(int(i[0])))
#
# output = np.array(state)
# train_y = output
# test_y = output[79:82]#output[12].reshape(1,2)
#
# file_name = "/home/zhaozhuoya/braincog/examples/ToM/data/injury_memory.txt"
# state = []
# with open(file_name) as f:
# data_split = f.readlines()
# for i in data_split:
# data = []
# data.append(int(bool(abs(int(i[2]) - int(i[18]))))*10)
# data.append(int(bool(abs(int(i[5]) - int(i[21]))))*10)
# state.append(data)
# input = np.array(state)
# train_x = input
# test_x = input[79:82]
# dACC_net = dACC(step=T, encode_type='rate', bias=True,
# in_features=num_popneurons, out_features=safety,
# node=node.LIFNode)
# dACC_net.fc.weight.data = torch.rand((safety, num_popneurons))
# dACC_net.load_state_dict(torch.load('./checkpoint/dACC_net.pth')['dacc'])
# output = dACC_net(inputs=train_x, epoch=50)
# for i in range(len(output)):
# print(output[i], train_x[i])
# torch.save({'dacc': dACC_net.state_dict()}, os.path.join('./checkpoint', 'dACC_net.pth'))
# dACC_net.load_state_dict(torch.load('./checkpoint/dACC_net.pth')['dacc'])
# output = dACC_net(inputs=test_x, epoch=50)
# for i in range(len(test_x)):
#
# print(output[i],test_x[i])
================================================
FILE: examples/Social_Cognition/ToM/BrainArea/one_hot.py
================================================
from numpy import argmax
import numpy as np
def one_hot(value):
num = '01'
letter = [0 for _ in range(len(num))]
letter[value] = 1
letter = np.array([letter])
return letter
# print(one_hot(4))
================================================
FILE: examples/Social_Cognition/ToM/BrainArea/test.py
================================================
import torch
from braincog.base.connection.CustomLinear import *
from braincog.base.node.node import *
from braincog.base.learningrule.STDP import *
from braincog.base.brainarea.IPL import *
from braincog.base.brainarea.Insula import *
if __name__ == "__main__":
num_neuron = 4
num_vPMC = num_neuron
num_STS = num_neuron
num_IPLM = num_neuron
num_IPLV = num_neuron
num_Insula = num_neuron
# InsulaNet
# connection
Insula_connection = []
# IPLV-Insula
con_matrix0 = torch.eye(num_IPLM, dtype=torch.float) * 2
Insula_connection.append(CustomLinear(con_matrix0))
# STS-Insula
con_matrix1 = torch.eye(num_IPLV, dtype=torch.float) * 2
Insula_connection.append(CustomLinear(con_matrix1))
Insula = InsulaNet(Insula_connection)
a = torch.tensor([[1.,2.,1.,2.]])
b = torch.tensor([[1., 2., 1., 2.]])
c = torch.tensor([[2., 2., 4., 2.]])
confidence = [0, 0]
for t in range(2):
Insula(a*10, b*10)
if sum(sum(Insula.out_Insula)) > 0:
confidence[0] = confidence[0] + 1
Insula.reset()
for t in range(2):
Insula(a*10, c*10)
if sum(sum(Insula.out_Insula)) > 0:
confidence[0] = confidence[0] + 1
Insula.reset()
print(confidence)
================================================
FILE: examples/Social_Cognition/ToM/README.md
================================================
# Requirments
* numpy
* scipy
* pytorch >= 1.7.0
* torchvision
* pygame
# Run
## Train
* the file to be run: main_both.py
* args:
* the path to save net_NPC: --save_net_N
* the path to save net_a: --save_net_a
* time steps: --T
```bash
python main_both.py --save_net_N=net_NPC.pth --save_net_a=net_agent.pth --episodes=45 --trajectories=30 --T=50 --mode=train --task=both
```
## Test
You can use the weigts saved by taining in the test environment.
```bash
python main_ToM.py --save_net_N=net_NPC.pth --save_net_a=net_agent.pth --episodes=45 --trajectories=30 --T=50 --mode=train --task=both
```
# Citation
```
@article{zhao2022brain,
title = {A Brain-Inspired Theory of Mind Spiking Neural Network for Reducing Safety Risks of Other Agents},
author = {Zhao, Zhuoya and Lu, Enmeng and Zhao, Feifei and Zeng, Yi and Zhao, Yuxuan},
journal = {Frontiers in neuroscience},
pages = {446},
year = {2022},
publisher = {Frontiers}
}
```
================================================
FILE: examples/Social_Cognition/ToM/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/ToM/data/NPC_assessment.csv
================================================
,0,1,2,3,4
"[[3, 5], [6, 6], [4, 2]]",-0.868083432541418,-0.733763528003341,11.6748771364348,-1.13941568407204,-0.640638369281893
"[[3, 5], [6, 6], [4, 3]]",-0.418851253559184,-0.353923546821901,2.45363281923355,-0.450103087254002,-0.37827704665847
"[[3, 5], [6, 6], [5, 3]]",-0.021898491492153,-0.01,-0.01,-0.01,-0.011802114342782
"[[4, 5], [6, 6], [5, 2]]",-0.024908407481436,-0.0199,0,-0.01009,-0.01017991
"[[3, 5], [6, 6], [5, 2]]",-0.023268882562677,-0.0199,-0.029791,-0.0298801,-0.037630851041005
"[[3, 4], [6, 6], [5, 3]]",-0.021197399198277,-0.01,-0.0199,-0.013038149463134,-0.01999
"[[3, 3], [6, 6], [5, 3]]",0,-0.5,0,0,-0.01
"[[3, 3], [6, 6], [5, 2]]",0,0,0,-0.01,0
"[[3, 4], [6, 6], [5, 1]]",-0.01,0,-0.01,-0.01009,0
"[[3, 4], [6, 6], [4, 1]]",-0.0101791,-0.0199,-0.01,-0.01,-0.01
"[[3, 3], [6, 6], [4, 1]]",-0.01,0,-0.01,0,0
"[[3, 2], [6, 6], [4, 2]]",0,-0.01,0,0,0
"[[4, 2], [6, 6], [4, 3]]",0,0,0,-0.01,0
"[[4, 3], [6, 6], [4, 4]]",0.089467366020037,-0.010984993825322,-0.01,-0.013452906086556,-0.011867720274541
"[[4, 2], [6, 6], [4, 5]]",4.09524024945437,-0.090412523006309,-0.09623134888059,-0.104387744707876,-0.070973835764104
"[[3, 2], [6, 6], [4, 5]]",1.40391860495941,0.022997490426814,43.0024182397514,1.53202462973447,3.08888716957364
"[[2, 2], [6, 6], [4, 5]]",-0.019749094852309,0.036781365436129,15.9675319796684,-0.020051902468894,0.04621779219384
"[[4, 1], [6, 6], [4, 5]]",5.50845885876033,-0.030316371018288,-0.01999,0.003924246487139,-0.01999
"[[5, 1], [6, 6], [4, 5]]",-0.04218527817012,-0.089619886118225,-0.089555438562711,-0.09202879886063,-0.089640838741259
"[[5, 2], [6, 6], [4, 5]]",-0.181115480137325,-0.188135727209872,-0.190391091543011,-0.189212077096485,-0.18825217812927
"[[4, 3], [6, 6], [4, 5]]",3.20388938261604,-0.277669046855729,-0.263170616607684,-0.263425702870801,-0.273172354685186
"[[4, 4], [6, 6], [4, 5]]",0.088264809806085,-0.536047549432865,-0.496801753642635,-1.7227487525,-0.525308942251634
"[[3, 3], [6, 6], [4, 5]]",-0.21666970575642,-0.137772024778711,35.7440209050429,0.074933770913733,1.72993072468866
"[[2, 3], [6, 6], [6, 6]]",-0.059222187527032,-0.030584932730533,1.19896558354272,-0.02997001,-0.039803368321337
"[[3, 4], [6, 6], [4, 5]]",-0.40734083054261,-0.590251991054757,12.0029208279787,-0.553842323501647,-0.473260640862758
"[[4, 5], [6, 6], [4, 5]]",0,0,0,0,0
"[[3, 5], [6, 6], [4, 4]]",-0.481765407319486,-0.511211461721295,-0.564380288792446,-0.492462845628926,-0.469959754386144
"[[3, 5], [6, 6], [4, 5]]",-0.931224461309762,-4.42392016540183,0.839256941677985,-0.924126498596367,-0.93129552223439
"[[3, 1], [6, 6], [4, 5]]",49.811526809228,0.227846091910513,5.03145376471985,3.02887653427164,4.38800221134503
"[[2, 1], [6, 6], [4, 5]]",0,0,0,0,0
"[[3, 4], [6, 6], [4, 3]]",1.19380671730375,-5.00699176831198,20.0406112237339,-0.271487593364538,18.7143725924663
"[[3, 4], [6, 6], [4, 4]]",0.218803966115085,-0.17930117271935,27.2213501282548,-0.25238045900879,0.399605811763158
"[[5, 3], [6, 6], [4, 5]]",-0.375176845280851,-0.382612376673214,-0.387234215280032,-0.391512374343505,-0.38256981360011
"[[5, 4], [6, 6], [4, 5]]",-0.648971759046833,-0.647885303763355,-0.648165623973329,-0.644724798572273,-0.648282307415749
"[[4, 5], [6, 6], [3, 2]]",-0.01999,-0.029791,-0.0199,-0.0199891,-0.029882503677127
"[[3, 5], [6, 6], [3, 2]]",-0.034237585544166,-0.029789209,-0.0297901,-0.031801666387697,-0.023636229278446
"[[4, 5], [6, 6], [4, 3]]",-0.254773250904236,-0.247726142705935,-0.4975,-0.251602554065065,-0.253564419614127
"[[4, 5], [6, 6], [4, 4]]",-0.414539088621474,-0.413180632405154,-0.45683393116936,-0.4975,-0.98509975
"[[5, 5], [6, 6], [4, 5]]",-2.88794812477989,-0.897593650362875,-0.896337578776872,-0.897383653648252,-0.896911728329473
"[[3, 4], [6, 6], [4, 2]]",-0.041846129714671,-0.0491890501,-0.049128264132635,-0.04752963742507,-0.041122932699059
"[[4, 4], [6, 6], [4, 3]]",-0.010582728786191,-0.0199,-0.019606621899442,-0.023027393843706,-0.25
"[[4, 4], [6, 6], [4, 4]]",0,0,0,0,0
"[[1, 3], [6, 6], [6, 6]]",-0.02997001,-0.037241738974595,-0.038877118273757,-0.02997001,-0.02997001
"[[1, 2], [6, 6], [6, 6]]",-0.01,0.092960550898947,-0.01,-0.01,-0.01
"[[1, 1], [6, 6], [4, 5]]",0,0.25,0,0,-0.01
"[[5, 5], [6, 6], [4, 2]]",-0.011245609205401,-0.010090875348172,-0.01,-0.01,-0.01
"[[5, 5], [6, 6], [4, 3]]",-0.023589076057546,-0.030475766514878,-0.0199,-0.010097260907938,-0.020481507078377
"[[5, 4], [6, 6], [4, 4]]",-0.014561932553648,-0.014123403608977,-0.012977094986838,-0.013958043392411,-0.0102689209
"[[5, 5], [6, 6], [3, 3]]",-0.01,0,0,0,-0.01009
"[[5, 5], [6, 6], [4, 4]]",-0.742525,-0.112519042227942,-0.104809666804011,-0.120899601119155,-0.114009048694518
"[[3, 3], [6, 6], [4, 4]]",-0.010153907821337,-0.013354155304859,3.90887441187658,-0.020453180303816,-0.010444104437097
"[[3, 5], [6, 6], [4, 1]]",-0.010526678655391,-0.01999,-0.0199,-0.013489754345763,-0.020341425501639
"[[4, 5], [6, 6], [4, 2]]",-0.040636256244247,-0.0297901,-0.029794680345807,-0.032266639674154,-0.030144504483964
"[[4, 5], [6, 6], [4, 1]]",-0.022834041173462,-0.029701,-0.029701,-0.030019473640868,-0.029791
"[[4, 4], [6, 6], [4, 2]]",-0.010267309,-0.01999,-0.25,-0.021462046263096,-0.01
"[[5, 5], [6, 6], [5, 4]]",0,-0.01,0,0,-0.010213344196849
"[[3, 5], [6, 6], [3, 4]]",-0.013599152071303,0,-0.010345480742543,-0.01,-0.25
"[[3, 5], [6, 6], [3, 5]]",0,0,0,0,0
"[[3, 5], [6, 6], [3, 3]]",-0.011856953925967,-0.01999,-0.0199,-0.032817580205153,-0.010785322603466
"[[4, 4], [6, 6], [5, 2]]",0,-0.01,-0.01,-0.01009,-0.01
"[[5, 4], [6, 6], [4, 2]]",0,0,-0.01,-0.01,0
"[[3, 3], [6, 6], [4, 3]]",-0.01,-0.01,-0.0199,-0.005562516126618,-0.009541766156335
"[[3, 2], [6, 6], [4, 4]]",0,0,-0.01,-0.01,0.236388148843983
"[[4, 3], [6, 6], [4, 2]]",0,-0.01,0,-0.01,0
"[[5, 3], [6, 6], [4, 3]]",-0.01,0,0,0,0
"[[4, 5], [6, 6], [5, 3]]",-0.01,-0.01009,-0.0199,-0.010784041942122,-0.01
"[[3, 5], [6, 6], [5, 4]]",-0.012894421101825,-0.012100903162978,-0.007186044736698,0,-0.011442680105686
"[[3, 4], [6, 6], [3, 2]]",-0.01009,-0.01999,-0.0199,-0.01009,-0.010091626994647
"[[4, 4], [6, 6], [5, 4]]",0,0,0,-0.010869859082056,0
"[[4, 5], [6, 6], [3, 3]]",-0.01,-0.01,-0.01999,-0.01,-0.011561635811299
"[[5, 4], [6, 6], [4, 3]]",0,0,-0.01,0,0
"[[5, 3], [6, 6], [4, 4]]",0,-0.012647205121462,0,-0.015190262813025,0
"[[3, 2], [6, 6], [4, 3]]",0,-0.01,0,0,-0.01
"[[3, 1], [6, 6], [4, 3]]",0,-0.01,0,0,0
"[[4, 1], [6, 6], [4, 4]]",0,0,0,0,-0.002916121937907
"[[4, 3], [6, 6], [4, 3]]",0,0,0,0,0
"[[5, 5], [6, 6], [3, 4]]",-0.012969708589991,-0.01,-0.01,-0.010643056796965,0
"[[3, 4], [6, 6], [3, 4]]",0,0,0,0,0
"[[2, 3], [6, 6], [4, 4]]",0,0,0,-0.01,0
"[[2, 3], [6, 6], [5, 4]]",0,0,-0.01,0,0
"[[2, 2], [6, 6], [4, 4]]",-0.01,0,0,-0.008244859693453,0
"[[1, 2], [6, 6], [4, 3]]",0,-0.01,0,0,0
"[[3, 4], [6, 6], [5, 2]]",-0.01999,-0.01999,-0.0199,-0.01999,-0.01035517099859
"[[3, 3], [6, 6], [4, 2]]",0,0,-0.01,-0.010197581295095,0
"[[3, 3], [6, 6], [5, 4]]",0,0,0,-0.01009,0
"[[5, 5], [6, 6], [3, 1]]",-0.01999081,-0.01999,-0.01999,-0.01999,-0.01999
"[[5, 5], [6, 6], [2, 1]]",-0.01,0,-0.01,-0.01,-0.01
"[[5, 4], [6, 6], [2, 2]]",0,0,-0.01,-0.01,-0.01
"[[5, 4], [6, 6], [3, 2]]",-0.01999,-0.0199,-0.01,-0.01,-0.01
"[[5, 3], [6, 6], [3, 1]]",-0.01,-0.01,0,0,0
"[[4, 3], [6, 6], [4, 1]]",-0.01,-0.01,-0.01,-0.01,-0.01
"[[4, 2], [6, 6], [5, 1]]",0,-0.01,0,-0.01,0
"[[5, 2], [6, 6], [5, 1]]",-0.01,0,0,0,0
"[[4, 2], [6, 6], [5, 2]]",0,0,0,0,-0.01
"[[5, 4], [6, 6], [3, 3]]",-0.01,-0.01,-0.01,-0.01,-0.01
"[[4, 4], [6, 6], [3, 2]]",-0.01,-0.01,-0.0199,-0.0200791,-0.0199
"[[4, 5], [6, 6], [6, 6]]",-0.01009,-0.0199,-0.0199,-0.0199,-0.01999
"[[4, 4], [6, 6], [5, 3]]",0,0,0,-0.01,-0.01
"[[5, 5], [6, 6], [2, 2]]",-0.01,-0.01,-0.01,-0.01,0
"[[5, 4], [6, 6], [2, 3]]",-0.01,-0.01,0,-0.01,-0.01
"[[4, 5], [6, 6], [2, 2]]",-0.010198196967161,-0.01999,-0.01,-0.01009,-0.01
"[[4, 5], [6, 6], [1, 1]]",0,-0.01,0,0,-0.01
"[[4, 5], [6, 6], [2, 1]]",-0.01,-0.01,-0.01,-0.0199,-0.01
"[[4, 5], [6, 6], [3, 1]]",-0.01,-0.01009,-0.0199,-0.01,-0.01009
"[[4, 3], [6, 6], [3, 3]]",0,0,-0.01,0,0
"[[4, 2], [6, 6], [3, 4]]",-0.008998626598505,0,0,0,0
"[[4, 5], [6, 6], [5, 1]]",-0.01009,-0.0199,-0.0199,-0.020016673101577,-0.01999
"[[4, 4], [6, 6], [5, 1]]",-0.01,-0.01,-0.01,-0.01009,-0.01
"[[3, 5], [6, 6], [5, 1]]",-0.01009,-0.01,-0.0199,-0.01999,-0.01999
"[[5, 5], [6, 6], [5, 1]]",-0.01009,0,-0.01,-0.0199,-0.01
"[[5, 5], [6, 6], [4, 1]]",-0.01,-0.01,-0.0199,-0.01,-0.01999
"[[4, 4], [6, 6], [3, 1]]",0,0,0,-0.01,-0.01
"[[5, 5], [6, 6], [1, 1]]",-0.01,0,-0.01,-0.01,0
"[[5, 5], [6, 6], [1, 2]]",-0.01,-0.01,0,-0.01,-0.01
"[[4, 4], [6, 6], [2, 1]]",0,-0.01,-0.01,-0.01,-0.01
"[[4, 3], [6, 6], [3, 1]]",-0.01,0,0,-0.01,-0.01
"[[4, 4], [6, 6], [4, 1]]",-0.010289510270133,-0.01,-0.0199,-0.02007991,-0.01999
"[[5, 4], [6, 6], [4, 1]]",0,0,-0.01,0,-0.01
"[[5, 4], [6, 6], [3, 1]]",-0.01,0,0,-0.01,-0.01
"[[3, 4], [6, 6], [2, 2]]",0,-0.01,0,-0.01,-0.01
"[[3, 4], [6, 6], [6, 6]]",-0.01,-0.01,-0.01,-0.01999,-0.01999
"[[3, 3], [6, 6], [1, 2]]",0,-0.01,0,0,0
"[[4, 3], [6, 6], [1, 3]]",0,0,-0.01,-0.01,0
"[[4, 4], [6, 6], [6, 6]]",-0.01,-0.01,-0.0199,-0.01,-0.01
"[[5, 5], [6, 6], [2, 3]]",-0.01,-0.01999,-0.01999,-0.01999,-0.01
"[[5, 4], [6, 6], [1, 3]]",-0.0199,-0.01,-0.01,-0.01,-0.01
"[[5, 4], [6, 6], [1, 2]]",0,-0.01,-0.01,-0.01,0
"[[5, 4], [6, 6], [1, 1]]",-0.0199,-0.01,-0.01,-0.01,-0.01
"[[4, 4], [6, 6], [1, 1]]",-0.01,-0.01,-0.01,0,0
"[[5, 3], [6, 6], [1, 1]]",0,0,-0.01,-0.01009,0
"[[5, 2], [6, 6], [1, 1]]",-0.01,0,0,-0.01,-0.01
"[[5, 3], [6, 6], [1, 2]]",0,-0.01,0,-0.01,0
"[[5, 3], [6, 6], [1, 3]]",-0.01,-0.01,-0.01,0,-0.0199
"[[5, 3], [6, 6], [2, 3]]",0,0,-0.01,0,0
"[[5, 2], [6, 6], [2, 3]]",-0.01,0,0,-0.01,0
"[[5, 3], [6, 6], [3, 3]]",0,0,0,-0.01,0
"[[4, 3], [6, 6], [1, 2]]",-0.01,-0.01,-0.01,-0.01,-0.01
"[[3, 3], [6, 6], [2, 2]]",-0.01,-0.01,-0.01,-0.01,-0.01
"[[2, 3], [6, 6], [3, 2]]",0,0,-0.01,-0.01,0
"[[2, 3], [6, 6], [3, 1]]",-0.01,-0.01,-0.01,-0.01,-0.01
"[[2, 3], [6, 6], [2, 1]]",0,-0.01,0,0,-0.01
"[[3, 2], [6, 6], [2, 3]]",0,-0.01,0,0,0
"[[4, 2], [6, 6], [2, 2]]",-0.01,0,0,-0.01,0
"[[3, 2], [6, 6], [2, 1]]",-0.01,0,0,-0.01,0
"[[3, 3], [6, 6], [1, 1]]",0,0,0,-0.01,-0.01
"[[3, 4], [6, 6], [2, 1]]",-0.01,-0.01,0,0,-0.01
"[[4, 4], [6, 6], [2, 2]]",0,-0.01,-0.01,0,0
"[[4, 3], [6, 6], [2, 2]]",0,0,-0.01,-0.01009,0
"[[4, 2], [6, 6], [1, 2]]",-0.01,0,-0.01,-0.01,-0.01
"[[3, 2], [6, 6], [1, 2]]",0,0,-0.01,0,0
"[[3, 1], [6, 6], [1, 1]]",0,0,-0.01,-0.01,0
"[[3, 2], [6, 6], [1, 1]]",-0.01,-0.01,0,0,0
"[[4, 1], [6, 6], [1, 2]]",-0.01,0,0,0,0
"[[3, 1], [6, 6], [1, 3]]",0.5,0,0,0,-0.01
"[[2, 1], [6, 6], [1, 3]]",0.995,0,0,0,0
"[[2, 1], [6, 6], [2, 3]]",1.489505,0,0,0,0
"[[2, 1], [6, 6], [2, 2]]",0.5,0,0,0,0
"[[2, 1], [6, 6], [3, 2]]",0.5,0,0,0,0
"[[2, 1], [6, 6], [4, 2]]",0.5,0,1.0039955,0,0
"[[2, 1], [6, 6], [5, 2]]",0.995,0,0,0,0
"[[2, 1], [6, 6], [5, 3]]",0.995,0,0,0,0
"[[2, 1], [6, 6], [4, 3]]",0.5045,0,0.5045,0,0
"[[2, 1], [6, 6], [5, 4]]",0.995,0,0,0,0
"[[2, 1], [6, 6], [4, 4]]",0.5,0,0.5,0,0
"[[4, 4], [6, 6], [1, 2]]",-0.01,0,0,-0.01009,0
"[[4, 3], [6, 6], [1, 1]]",-0.01,0,0,0,0
"[[3, 3], [6, 6], [2, 1]]",-0.0199,0,0,0,0
"[[1, 3], [6, 6], [3, 1]]",0,0,-0.01,0,0
"[[1, 2], [6, 6], [3, 1]]",-0.01,0,0,0,-0.01
"[[1, 2], [6, 6], [4, 1]]",-0.01,0,-0.0199,0,0
"[[1, 1], [6, 6], [5, 1]]",0,0,0,-0.01,-0.01
"[[1, 2], [6, 6], [5, 1]]",0,-0.01,0,-0.01,0
"[[1, 3], [6, 6], [5, 1]]",-0.01,0,0,0,0
"[[1, 3], [6, 6], [5, 2]]",0,0,-0.01,0,-0.01
"[[1, 2], [6, 6], [5, 2]]",-0.01,-0.01,0,-0.01,0
"[[2, 2], [6, 6], [5, 1]]",-0.01,0,0,0,-0.01
"[[2, 2], [6, 6], [4, 1]]",-0.01,0,0,0,0
"[[1, 1], [6, 6], [4, 2]]",-0.01,0,0,0,-0.01
"[[1, 1], [6, 6], [4, 1]]",0,0,0,0,-0.01
"[[1, 1], [6, 6], [5, 2]]",0,0,0,-0.01,-0.01
"[[1, 2], [6, 6], [5, 3]]",0,0,-0.01,0,-0.01
"[[1, 1], [6, 6], [5, 3]]",0,0,0,-0.01,0
"[[1, 3], [6, 6], [4, 2]]",-0.01,0,0,0,-0.01
"[[1, 3], [6, 6], [5, 3]]",0,0,0,0,-0.01
"[[1, 3], [6, 6], [5, 4]]",0,0,0,0,-0.010266711022056
"[[5, 5], [6, 6], [3, 2]]",-0.010094761832099,0,0,0,0
"[[3, 5], [6, 6], [3, 1]]",-0.01,-0.0101791,-0.01,0,-0.01
"[[3, 5], [6, 6], [2, 1]]",-0.01,-0.01009,-0.01,0,0
"[[3, 5], [6, 6], [6, 6]]",-0.01999,-0.0199891,-0.01999,-0.01999,-0.01999
"[[3, 4], [6, 6], [1, 1]]",0,-0.01,-0.01,0,0
"[[4, 5], [6, 6], [3, 4]]",0,0,-0.01,0,0
"[[4, 4], [6, 6], [3, 5]]",0,-0.01,0,0,0
"[[5, 4], [6, 6], [3, 4]]",0,0,0,-0.01,-0.01
"[[3, 2], [6, 6], [5, 2]]",0,0,-0.01,0,0
"[[3, 1], [6, 6], [5, 3]]",0,0,-0.01,-0.01,0
"[[4, 2], [6, 6], [4, 4]]",0,0,0.004567853841478,0,0
"[[5, 4], [6, 6], [2, 1]]",0,0,0,0,-0.01
"[[3, 3], [6, 6], [3, 1]]",-0.01,0,0,-0.01009,0
"[[5, 5], [6, 6], [5, 2]]",0,-0.01,-0.01,0,0
"[[5, 4], [6, 6], [5, 1]]",-0.01009,0,-0.0199,-0.01,-0.01
"[[5, 3], [6, 6], [5, 1]]",-0.01,0,-0.5,0,0
"[[5, 2], [6, 6], [5, 2]]",0,0,0,0,0
"[[3, 4], [6, 6], [3, 3]]",0,0,0,-0.01,0
"[[5, 3], [6, 6], [4, 2]]",0,0,-0.01,-0.01,0
"[[5, 2], [6, 6], [4, 3]]",0,0,-0.01,-0.01,0
"[[5, 1], [6, 6], [4, 4]]",0,0,-0.010759540535295,0,0
"[[4, 4], [6, 6], [2, 3]]",0,-0.01,0,0,0
"[[5, 5], [6, 6], [1, 3]]",-0.01009,0,-0.01,0,0
"[[4, 2], [6, 6], [3, 3]]",0,-0.01,0,0,0
"[[3, 3], [6, 6], [5, 1]]",0,-0.01,0,0,0
"[[5, 3], [6, 6], [2, 1]]",0,-0.01,-0.01,0,-0.01
"[[2, 2], [6, 6], [3, 1]]",0,0,0,-0.01,-0.01
"[[2, 3], [6, 6], [4, 1]]",-0.01,0,0,-0.01,-0.01
"[[1, 3], [6, 6], [4, 1]]",0,0,-0.01,-0.01,0
"[[1, 2], [6, 6], [3, 2]]",0,0,-0.01,0,0
"[[1, 1], [6, 6], [3, 2]]",0,0,-0.01,0,0
"[[1, 1], [6, 6], [3, 1]]",0,0,0,-0.01,0
"[[1, 2], [6, 6], [2, 1]]",-0.01,0,0,-0.01,-0.01
"[[1, 3], [6, 6], [2, 2]]",0,0,0,0,-0.01
"[[1, 3], [6, 6], [2, 1]]",0,0,-0.01,0,0
"[[1, 2], [6, 6], [1, 1]]",-0.5,0,0,0,0
"[[1, 2], [6, 6], [1, 2]]",0,0,0,0,0
"[[2, 3], [6, 6], [4, 2]]",0,0,-0.01,0,0
"[[2, 2], [6, 6], [4, 2]]",0,0,0.5,-0.01,0
"[[2, 3], [6, 6], [4, 3]]",0,-0.002363469450096,0,0,0
"[[4, 4], [6, 6], [3, 3]]",-0.011265121636209,0,-0.01,0,0
"[[4, 3], [6, 6], [2, 3]]",-0.01,0,0,0,0
"[[5, 3], [6, 6], [2, 2]]",0,0,0,0,-0.01
"[[4, 2], [6, 6], [1, 1]]",-0.01,0,-0.01,0,0
"[[4, 1], [6, 6], [1, 1]]",0,0,0,-0.01,0
"[[2, 2], [6, 6], [2, 1]]",0,-0.01,0,-0.01,0
"[[2, 1], [6, 6], [4, 1]]",0,0,1.988015495,0,0
"[[2, 1], [6, 6], [5, 1]]",0,0,1.988015495,0,0
"[[3, 4], [6, 6], [3, 1]]",0,-0.01009,0,0,0
"[[4, 3], [6, 6], [5, 1]]",0,-0.01,0,-0.01,-0.01
"[[5, 3], [6, 6], [4, 1]]",0,-0.01,0,-0.01,0
"[[5, 3], [6, 6], [3, 2]]",0,0,0,-0.01,0
================================================
FILE: examples/Social_Cognition/ToM/data/agent_assessment.csv
================================================
,0,1,2,3,4
"[[3, 5], [6, 6], [4, 2]]",-0.052875711096120555,-0.09588695110964494,-0.06793465209301,-0.05438523750227702,-0.06406147756340548
"[[3, 4], [6, 6], [4, 3]]",-0.01009,-0.5,-0.5,-0.01998648239166538,-0.010090809999999999
"[[3, 5], [3, 3], [4, 4]]",-0.5216067506862471,-1.9701995,-0.2957229880478374,-0.5521303794833469,-0.291912651239883
"[[3, 5], [4, 3], [4, 5]]",-0.17020038264944942,-0.995,-0.17003066056014654,-0.17084067541667597,-0.18061252619011572
"[[3, 5], [5, 3], [4, 5]]",-0.310721791006476,-0.995,-0.3044196312169591,-0.3066034117158839,-0.30258164309452773
"[[3, 4], [5, 4], [4, 5]]",-0.5068849351000232,-0.5215919496170031,0.33272479226972446,-0.529392040004043,-0.5476025113537396
"[[3, 5], [5, 4], [4, 5]]",-0.9653009777091316,-8.274311927495619,-0.9562443360051255,-0.9663825574739041,-0.9577571287880812
"[[3, 3], [5, 4], [4, 5]]",-0.12144345586147146,-0.129151719633336,6.61777080738469,-0.12913417684277412,-0.12187211348591424
"[[3, 2], [5, 4], [4, 5]]",22.462534983722144,-0.037566698284270936,0.0002049650900000019,-0.020492300663398463,0.36064342058027793
"[[2, 3], [5, 4], [6, 6]]",-0.0199,-0.01,1.911710899685282,-0.02507112120629763,-0.01
"[[4, 4], [5, 4], [4, 5]]",-0.38856856345777463,-0.995,-0.3848636851605009,-0.5,-0.39216615065025906
"[[4, 3], [5, 4], [4, 5]]",-0.042040801551562916,-0.17418582965398824,-0.17906219922741637,-0.20272904839771475,-0.17845067061226685
"[[2, 2], [5, 4], [4, 5]]",-0.03593611593419173,0.19699025702983197,41.38750349042489,-0.025347536417836183,0.5235719953054313
"[[1, 3], [5, 4], [6, 6]]",-0.01999,-0.010049500000000001,-0.0199,-0.01999,-0.01999
"[[1, 2], [5, 4], [6, 6]]",-0.01999,1.1413350726366454,-0.01,-0.01009,-0.02969406465902331
"[[4, 5], [5, 4], [4, 5]]",0.0,0.0,0.0,0.0,0.0
"[[4, 5], [6, 6], [4, 3]]",-0.022056368799828086,-0.0199,-0.5,-0.0199,-0.01999
"[[4, 5], [3, 3], [4, 4]]",-0.05334353898157222,-0.058519850599,-0.058698950598999995,-0.5,-0.5
"[[4, 5], [4, 3], [4, 5]]",-0.5,-1.9701995,0.0,-0.5,-0.5
"[[4, 5], [5, 3], [4, 5]]",-0.995,-3.3967326046505,0.0,-0.995,-0.5
"[[3, 5], [6, 6], [4, 3]]",-0.5052708076706848,-0.6503942485687713,-0.31464561043008055,-0.5567063168667764,-0.3196072406789152
"[[3, 4], [4, 3], [4, 5]]",-0.06538757013815065,-0.06815609254531127,0.07157941051049484,-0.06943525585877257,-0.07243109013283643
"[[3, 4], [5, 3], [4, 5]]",-0.03627763723533166,-0.04878628543714749,0.4425679590136676,-0.04432926021182233,-0.03615058165751049
"[[5, 3], [5, 4], [4, 5]]",-0.1326255476738545,-0.12913304314430477,-0.12670229132838703,-0.995,-0.12922285286285284
"[[5, 4], [5, 4], [4, 5]]",0.0,0.0,0.0,0.0,0.0
"[[3, 4], [3, 3], [4, 4]]",-0.08912453287009733,-0.514891,-0.08664897275426689,-0.0850562755111616,-0.08029821601054066
"[[3, 3], [4, 3], [4, 5]]",-0.01,-0.01,-0.01,-0.009878977320149305,0.020248904987081453
"[[4, 3], [5, 3], [4, 5]]",0.0,-0.011055227128542065,0.0,-0.01,-0.01101218190086918
"[[5, 5], [4, 3], [4, 5]]",-0.5,-0.01,-0.0199,0.0,-0.01
"[[5, 4], [5, 3], [4, 5]]",0.0,-0.5,-0.01,-0.010717485033705036,-0.5
"[[5, 2], [5, 4], [4, 5]]",-0.07919569212102055,-0.07971975334246782,-0.0787530179209123,-0.08104017906592893,-0.07963109795235838
"[[5, 1], [5, 4], [4, 5]]",-0.04890803822116347,-0.049900099950010005,-0.049900099950010005,-0.05034279750865073,-0.049900099950010005
"[[4, 1], [5, 4], [4, 5]]",0.10944099197973417,-0.01999,-0.01999,-0.02007991,-0.01999
"[[4, 2], [5, 4], [4, 5]]",0.15058860156940262,-0.059846938189381006,-0.06722016746200328,-0.07174124407437676,-0.059850199850059994
"[[3, 1], [5, 4], [4, 5]]",5.680756414193536,-0.01,-0.01,-0.01,-0.01
"[[2, 1], [5, 4], [4, 5]]",0.0,0.0,0.0,0.0,0.0
"[[4, 4], [3, 3], [4, 4]]",0.0,-0.5,-0.5,0.0,0.0
"[[4, 4], [4, 3], [4, 4]]",0.0,-0.5,-0.5,0.0,0.0
"[[4, 4], [5, 3], [4, 4]]",0.0,-0.5,-0.5,0.0,0.0
"[[4, 4], [5, 4], [4, 4]]",0.0,0.0,0.0,0.0,0.0
"[[3, 4], [6, 6], [5, 2]]",0.0,0.0,0.0,0.0,-0.01
"[[3, 4], [3, 3], [5, 2]]",0.0,-0.01,0.0,0.0,0.0
"[[4, 4], [4, 3], [5, 2]]",0.0,0.0,-0.5,0.0,0.0
"[[4, 3], [5, 3], [5, 3]]",0.0,0.0,0.0,0.0,0.0
"[[3, 5], [3, 3], [4, 3]]",-2.2448043853624212e-05,-0.01,-0.01,-6.208892285322127e-05,-0.01
"[[3, 5], [4, 3], [4, 4]]",-0.01028609422864929,-0.5,-0.016408702021722208,-0.020784796048626684,-0.011281391343046818
"[[3, 5], [3, 3], [4, 2]]",-0.5,-0.5,-0.5,-0.9900500000000001,-0.5
"[[4, 5], [4, 3], [4, 3]]",0.0,0.0,0.0,0.0,0.0
"[[3, 5], [4, 3], [3, 4]]",-0.01,0.0,0.0,-0.5,0.0
"[[3, 5], [5, 3], [3, 5]]",0.0,0.0,0.0,-0.5,0.0
"[[3, 5], [5, 4], [3, 5]]",0.0,0.0,0.0,0.0,0.0
"[[4, 5], [6, 6], [5, 2]]",0.0,0.0,-0.01,0.0,0.0
"[[4, 4], [3, 3], [5, 1]]",0.0,0.0,0.0,0.0,-0.01
"[[4, 4], [4, 3], [4, 1]]",0.0,0.0,0.0,0.0,-0.01
"[[4, 4], [5, 3], [4, 1]]",0.0,0.0,-0.01,0.0,0.0
"[[4, 3], [5, 4], [3, 1]]",0.0,0.0,0.0,-0.01,0.0
"[[4, 4], [5, 4], [3, 1]]",0.0,-0.5,0.0,0.0,0.0
"[[5, 4], [5, 4], [3, 1]]",0.0,0.0,0.0,0.0,0.0
"[[1, 1], [5, 4], [4, 5]]",0.0,0.0,0.0,-0.01,0.0
"[[4, 4], [4, 3], [4, 5]]",-0.016306998715494642,-0.0199,-0.0199,-0.5,-0.01009945645877131
"[[4, 4], [5, 3], [4, 5]]",-0.02368161102075795,-0.5,-0.022639868292312258,-0.5,-0.026451207697811334
"[[5, 5], [3, 3], [4, 4]]",-0.5,0.0,-0.01,0.0,0.0
"[[5, 4], [4, 3], [4, 5]]",0.0,0.0,0.0,-0.01,0.0
"[[5, 5], [5, 3], [4, 5]]",-0.5,-0.01044910089955009,0.0,-0.01,0.0
"[[3, 5], [3, 3], [3, 3]]",0.0,0.0,0.0,0.0,0.0
"[[3, 3], [5, 3], [4, 5]]",0.0,-0.01062811314685189,3.5215627346534975,0.0,0.0
"[[3, 5], [4, 3], [4, 3]]",0.0,0.0,0.0,0.0,0.0
"[[3, 5], [3, 3], [5, 3]]",0.0,0.0,0.0,0.0,0.0
"[[3, 5], [4, 3], [5, 2]]",0.0,-0.01,0.0,0.0,0.0
"[[4, 5], [5, 3], [4, 2]]",-0.01,0.0,0.0,0.0,0.0
"[[3, 5], [5, 4], [4, 3]]",0.0,0.0,0.0,-0.01,0.0
"[[3, 5], [5, 4], [4, 4]]",-0.015943472966397428,-0.01,0.0,-0.01654250067191671,-0.01
"[[4, 5], [5, 4], [4, 3]]",0.0,0.0,-0.5,-0.01,-0.01
"[[4, 5], [5, 4], [4, 4]]",-0.01639600436420607,-0.01,0.0,0.0,0.0
"[[5, 5], [5, 4], [4, 5]]",-1.9701995,-0.11934219505791074,-0.5,-0.11934219505791074,-0.10945164670461537
"[[3, 4], [4, 3], [5, 4]]",0.0,0.0,0.0,-0.01,0.0
"[[3, 5], [5, 3], [4, 4]]",0.0,0.0,-0.00819848349845159,-0.014389398124565982,-0.016963089341491665
"[[4, 5], [3, 3], [4, 2]]",0.0,-0.5,-0.5,0.0,0.0
"[[4, 4], [4, 3], [4, 3]]",0.0,0.0,0.0,0.0,0.0
"[[3, 5], [6, 6], [5, 2]]",0.0,-0.01,-0.01,0.0,0.0
"[[5, 5], [4, 3], [4, 3]]",0.0,0.0,0.0,0.0,0.0
"[[3, 5], [4, 3], [5, 4]]",0.0,0.0,-0.0199,0.0,0.0
"[[3, 4], [5, 3], [4, 4]]",0.0,0.0,0.0,0.0,-0.013570917448717714
"[[3, 4], [4, 3], [4, 4]]",0.0,0.0,0.0,-0.010943974970380799,0.0
"[[4, 5], [4, 3], [4, 4]]",-0.01,0.0,0.0,0.0,0.0
"[[3, 5], [5, 3], [3, 4]]",-0.01,0.0,0.0,0.0,0.0
"[[3, 5], [6, 6], [3, 2]]",0.0,0.0,-0.01,0.0,0.0
"[[3, 5], [3, 3], [3, 2]]",0.0,0.0,0.0,0.0,0.0
"[[3, 5], [4, 3], [4, 2]]",0.0,0.0,0.0,0.0,0.0
"[[3, 5], [5, 3], [4, 3]]",0.0,-0.01,0.0,0.0,0.0
"[[3, 5], [6, 6], [4, 1]]",0.0,0.0,0.0,0.0,-0.01
"[[3, 5], [3, 3], [4, 1]]",0.0,-0.01,-0.01,0.0,0.0
"[[3, 4], [4, 3], [5, 1]]",0.0,-0.01,0.0,0.0,0.0
"[[4, 4], [5, 3], [5, 1]]",0.0,0.0,0.0,-0.01,0.0
"[[4, 5], [5, 4], [5, 1]]",0.0,-0.01,0.0,0.0,0.0
"[[5, 5], [5, 4], [5, 1]]",0.0,0.0,0.0,0.0,-0.01
"[[5, 5], [5, 4], [5, 2]]",0.0,0.0,0.0,-0.01,0.0
"[[5, 5], [5, 4], [4, 2]]",0.0,-0.01,-0.5,0.0,0.0
"[[5, 4], [5, 4], [4, 3]]",0.0,0.0,0.0,0.0,0.0
"[[3, 4], [5, 3], [5, 4]]",0.0,0.0,0.0,-0.01,0.0
"[[5, 5], [4, 3], [3, 4]]",0.0,0.0,-0.01,0.0,0.0
"[[5, 4], [5, 3], [4, 4]]",-0.5,0.0,0.0,0.0,0.0
"[[4, 4], [5, 4], [5, 4]]",0.0,0.0,0.0,0.0,0.0
"[[4, 5], [4, 3], [3, 1]]",0.0,-0.01,0.0,0.0,0.0
"[[5, 5], [5, 3], [4, 1]]",0.0,0.0,0.0,-0.01,0.0
"[[5, 5], [5, 4], [4, 3]]",0.0,0.0,0.0,-0.01,0.0
"[[5, 5], [5, 4], [4, 4]]",0.0,-0.01,0.0,0.0,0.0
"[[3, 4], [4, 3], [4, 3]]",0.0,0.0,0.0,0.0,0.0
"[[3, 4], [3, 3], [4, 2]]",-0.5,0.0,-0.5,0.0,0.0
"[[3, 3], [4, 3], [4, 3]]",0.0,0.0,0.0,0.0,0.0
"[[3, 3], [3, 3], [4, 4]]",0.0,0.0,0.0,0.0,0.0
"[[3, 5], [5, 3], [4, 2]]",0.0,-0.01,0.0,0.0,0.0
"[[3, 5], [4, 3], [5, 3]]",0.0,-0.5,0.0,0.0,0.0
"[[4, 5], [5, 3], [5, 3]]",0.0,0.0,0.0,0.0,0.0
"[[3, 2], [5, 3], [4, 5]]",0.0,0.0,0.037097785583577604,0.0,0.0
"[[2, 3], [5, 3], [6, 6]]",0.0,0.0,0.0,0.004391030075795664,0.0
"[[4, 5], [3, 3], [3, 3]]",0.0,0.0,0.0,0.0,0.0
"[[3, 4], [3, 3], [4, 3]]",0.0,0.0,0.0,-0.01009092457551116,0.0
================================================
FILE: examples/Social_Cognition/ToM/data/injury_memory.txt
================================================
[[2, 1], [6, 6], [2, 1]]
[[3, 1], [6, 6], [3, 1]]
[[3, 1], [6, 6], [3, 1]]
[[2, 2], [6, 6], [2, 2]]
[[2, 1], [6, 6], [2, 1]]
[[3, 2], [6, 6], [3, 2]]
[[2, 1], [6, 6], [2, 1]]
[[3, 3], [6, 6], [3, 3]]
[[5, 3], [6, 6], [5, 3]]
[[4, 4], [6, 6], [4, 4]]
[[4, 1], [6, 6], [4, 1]]
[[3, 2], [6, 6], [3, 2]]
[[3, 2], [6, 6], [3, 2]]
[[3, 3], [6, 6], [3, 3]]
[[4, 4], [6, 6], [4, 4]]
[[4, 4], [6, 6], [4, 4]]
[[4, 2], [6, 6], [3, 5]]
[[4, 3], [6, 6], [3, 4]]
[[4, 4], [6, 6], [3, 3]]
[[4, 5], [6, 6], [3, 3]]
[[4, 5], [6, 6], [3, 2]]
[[4, 5], [6, 6], [2, 2]]
[[4, 5], [6, 6], [3, 2]]
[[4, 5], [6, 6], [3, 2]]
[[4, 5], [6, 6], [2, 2]]
[[4, 5], [6, 6], [3, 2]]
[[4, 5], [6, 6], [2, 2]]
[[4, 2], [6, 6], [3, 5]]
[[4, 3], [6, 6], [3, 4]]
[[4, 4], [6, 6], [3, 3]]
[[4, 5], [6, 6], [3, 4]]
[[4, 5], [6, 6], [3, 5]]
[[4, 5], [6, 6], [3, 4]]
[[4, 5], [6, 6], [3, 3]]
[[4, 5], [6, 6], [3, 2]]
[[4, 5], [6, 6], [2, 2]]
[[4, 2], [6, 6], [3, 5]]
[[4, 3], [6, 6], [3, 5]]
[[4, 1], [3, 3], [3, 3]]]
[[3, 2], [5, 4], [3, 2]]]
[[5, 1], [3, 3], [3, 3]]]
[[3, 3], [3, 3], [3, 5]]]
[[2, 2], [3, 3], [3, 3]]]
[[3, 3], [3, 3], [3, 4]]]
[[2, 1], [5, 4], [2, 1]]]
[[4, 2], [5, 4], [5, 4]]]
[[4, 3], [3, 3], [3, 3]]]
[[5, 2], [3, 3], [3, 3]]]
[[3, 1], [3, 3], [3, 3]]]
[[5, 4], [5, 4], [2, 1]]]
[[5, 4], [5, 4], [2, 2]]]
[[4, 1], [5, 4], [4, 1]]]
[[4, 3], [4, 3], [3, 4]]]
[[4, 2], [3, 3], [3, 3]]]
[[5, 4], [5, 4], [4, 1]]]
[[4, 4], [5, 4], [4, 4]]]
[[4, 3], [4, 3], [3, 5]]]
[[5, 3], [3, 3], [3, 3]]]
[[3, 2], [3, 3], [3, 3]]]
[[5, 4], [5, 4], [3, 3]]]
[[4, 3], [4, 3], [3, 3]]]
[[5, 4], [5, 4], [3, 1]]]
[[3, 1], [5, 4], [3, 1]]]
[[4, 3], [4, 3], [1, 5]]]
[[4, 3], [4, 3], [1, 4]]]
[[4, 3], [4, 3], [2, 3]]]
[[3, 1], [5, 4], [3, 1]]]
[[3, 2], [4, 4], [4, 4]]]
[[3, 1], [5, 4], [3, 1]]]
[[4, 2], [3, 3], [3, 3]]]
[[3, 1], [5, 4], [3, 1]]]
[[3, 3], [3, 3], [2, 4]]]
[[4, 4], [3, 3], [3, 3]]]
[[4, 3], [4, 3], [2, 4]]]
[[4, 4], [3, 3], [3, 3]]]
[[4, 2], [3, 3], [3, 3]]]
[[3, 3], [3, 3], [2, 5]]]
[[4, 2], [3, 3], [3, 3]]]
[[4, 2], [3, 3], [3, 3]]]
[[4, 5], [5, 4], [4, 5]]]
[[4, 4], [3, 3], [3, 3]]]
[[4, 4], [3, 3], [3, 3]]]
[[4, 4], [3, 3], [3, 3]]]
[[5, 4], [5, 4], [3, 2]]]
[[4, 4], [4, 4], [2, 4]]]
[[4, 4], [4, 4], [3, 2]]]
[[4, 3], [4, 3], [1, 4]]]
[[4, 3], [4, 3], [2, 3]]]
[[4, 4], [4, 4], [4, 4]]]
[[4, 4], [3, 3], [3, 3]]]
[[4, 4], [4, 4], [3, 3]]]
[[4, 4], [4, 4], [3, 2]]]
[[4, 3], [4, 3], [3, 3]]]
[[4, 3], [4, 3], [3, 3]]]
[[4, 5], [5, 4], [4, 5]]]
[[4, 4], [3, 3], [3, 3]]]
[[4, 4], [3, 3], [3, 3]]]
[[4, 4], [4, 4], [4, 4]]]
[[4, 4], [3, 3], [3, 3]]]
[[3, 4], [5, 4], [3, 4]]]
[[4, 5], [5, 4], [4, 5]]]
[[4, 4], [4, 4], [4, 4]]]
[[4, 5], [6, 6], [3, 1]]
[[4, 2], [6, 6], [3, 5]]
[[4, 3], [6, 6], [3, 5]]
[[4, 4], [6, 6], [3, 4]]
[[4, 5], [6, 6], [3, 4]]
[[4, 5], [6, 6], [3, 3]]
[[4, 5], [6, 6], [3, 3]]
[[4, 5], [6, 6], [3, 2]]
[[4, 5], [6, 6], [3, 1]]
[[4, 2], [6, 6], [3, 5]]
[[4, 3], [6, 6], [3, 5]]
[[4, 4], [6, 6], [3, 4]]
[[4, 5], [6, 6], [3, 5]]
[[4, 5], [6, 6], [3, 5]]
[[4, 5], [6, 6], [3, 4]]
[[4, 5], [6, 6], [3, 3]]
[[4, 5], [6, 6], [3, 3]]
[[4, 5], [6, 6], [3, 2]]
[[4, 5], [6, 6], [2, 2]]
[[4, 5], [6, 6], [3, 2]]
[[4, 5], [6, 6], [3, 2]]
[[4, 5], [6, 6], [3, 1]]
[[4, 5], [6, 6], [4, 1]]
[[4, 5], [6, 6], [4, 2]]
[[4, 5], [6, 6], [3, 2]]
[[4, 5], [6, 6], [3, 3]]
[[4, 5], [6, 6], [3, 4]]
[[4, 5], [6, 6], [3, 4]]
[[4, 5], [6, 6], [3, 3]]
[[4, 5], [6, 6], [3, 2]]
[[4, 5], [6, 6], [3, 1]]
[[4, 2], [6, 6], [3, 5]]
[[4, 3], [6, 6], [3, 4]]
[[4, 4], [6, 6], [3, 4]]
[[4, 5], [6, 6], [3, 4]]
================================================
FILE: examples/Social_Cognition/ToM/data/injury_value.txt
================================================
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
================================================
FILE: examples/Social_Cognition/ToM/data/one_hot.py
================================================
from numpy import argmax
import numpy as np
def one_hot(value):
num = '01'
letter = [0 for _ in range(len(num))]
letter[value] = 1
letter = np.array([letter])
return letter
# print(one_hot(4))
================================================
FILE: examples/Social_Cognition/ToM/env/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/ToM/env/env.py
================================================
from numpy import argmax
import random, time, pygame, sys
import pygame
pygame.init()
from pygame.locals import *
import os
# os.environ['SDL_AUDIODRIVER'] = 'dsp'
# os.environ['SDL_VIDEODRIVER']='windib'
os.environ["SDL_VIDEODRIVER"] = "dummy"
# os.environ['DISPLAY'] = "localhost:13.0"
from rulebasedpolicy.world_model import *
from rulebasedpolicy.statedata_pre import *
from rulebasedpolicy.Find_a_way import *
import numpy as np
from utils.one_hot import one_hot
# =============================================================================
# set the value of interface
# =============================================================================
FPS = 25
WinWidth = 340 #window width
WinHeight = 260 #window width
BoxSize = 20 #the size of one grid
GridWidth = 7 #the number of lattices are there in the x-axis
GridHeight = 7 #the number of lattices are there in the y-axis
#representation of different objective
BlankBox = 1
Wall = 5
Obstacle = 5
observer = 8
obeservation_1 = 11
obeservation_2 = 22
obeservation_3 = 33
#Text = None
XMargin = int((WinWidth - GridWidth * BoxSize)/2)
TopMargin = int((WinHeight - GridHeight * BoxSize))/2-5
# =============================================================================
# set color
# =============================================================================
White = (255,255,255)
Gray = (185,185,185)
Black = (0,0,0)
Red = (200,0,0)
Green = (0,139,0)
Green_B = (78, 238, 148)
Light_A = (233, 232, 170)
Blue = (30, 144, 255)
pink = (238, 99, 99)
BoardColor = White
BGColor = White
TextColor = White
Test = []
# =============================================================================
# agents - env interactive
# =============================================================================
class FalseBelief_env(object):
def __init__(self, reward=10):
super(FalseBelief_env, self).__init__()
self.action_space = ['up', 'down', 'left', 'right', 'stay']
self.action_move = {
0 : (0, -1),
1 : (0, 1),
2 : (-1, 0),
3 : (1, 0),
4 :(0, 0)
}#[(0, -1), (0, 1), (-1, 0), (1, 0), (0, 0)]
self.n_actions = len(self.action_space)
self._build_AB()
self.board, self.obs = self.getBlankBoard()
self._agent_init()
self.score = 0
self.steps = 0
self.n = 0
self.R = int(5/2) * (BoxSize - 5)
self.trigger = 0
self.x = 0
self.n_features = 30
self.reward = reward
def _build_AB(self):
global FPSCLOCK, DISPLAYSURF, BASICFONT, BIGFONT
FPSCLOCK = pygame.time.Clock()
DISPLAYSURF = pygame.display.set_mode((WinWidth, WinHeight))
BASICFONT = pygame.font.Font('freesansbold.ttf', 18)
BIGFONT = pygame.font.Font('freesansbold.ttf', 100)
pygame.display.set_caption('AB')
pygame.display.update()
FPSCLOCK.tick()
def _agent_init(self):
"""
Aim:Initialize the basic information of the agent
"""
self.NPC_1 = {
'shape' : [['#']],
'x' : 3, #row
'y' : 1, #column
'color' : Blue,
'style' : "circle",
'obs' : None,
'axis' : None,#,[[1,3],[3,5],[4,2]]
'reward' : 0,
'Done' : False
}
self.NPC_2 = {
'shape' : [['@']],
'x' : 5, #row
'y' : 3, #column
'color' : pink,
'style' : "circle",
'obs' : None,
'axis' :None ,#,[[3,5],[1,3],[4,2]]
'reward': 0,
'Done': False
}
self.agent = {
'shape' : [['$']],
'x' : 2, #row
'y' : 4, #column
'color' : Green_B,
'style' : "circle",
'obs' : None,
'axis' : None,#[[4,2],[1,3],[3,5]]
'reward': 0,
'Done': False
}
def actu_obs(self):
"""
将状态转化成可以训练的数据形式
"""
_, state = self.getBlankBoard()
a = state
b = state
c = state
state1 = np.r_[a, np.ones((4, 5))].astype(np.int_)
state2 = np.r_[b, np.ones((4, 5))].astype(np.int_)
statea = np.r_[c, np.ones((4, 5))].astype(np.int_)
NPC_1_state = state1
NPC_2_state = state2
Agent_state = statea
NPC_1_state[self.NPC_1['y']-1, self.NPC_1['x']-1] = observer
q = shelter_env(NPC_1_state[:5, :])
NPC_1_state[:5, :] = shelter_env(NPC_1_state[:5, :])
NPC_2_state[self.NPC_2['y']-1, self.NPC_2['x']-1] = observer
r = shelter_env(NPC_2_state[:5, :])
NPC_2_state[:5, :] = shelter_env(NPC_2_state[:5, :])
Agent_state[self.agent['y']-1, self.agent['x']-1] = observer
p = shelter_env(Agent_state[:5, :])
Agent_state[:5, :] = shelter_env(Agent_state[:5, :])
"""
########### num ############
#2-NPC1 in other agents' obs
#3-NPC2 in other agents' obs
#4-Agent in other agents' obs
"""
self.NPC_1['obs'] = q
self.NPC_1['obs'] = self.gain_obs(self.NPC_1['obs'],NPC_1_state,self.NPC_2,self.agent,3,4)
self.NPC_1['axis'] = self.gain_axis(self.NPC_1,NPC_1_state,self.NPC_2,self.agent,3,4)
self.NPC_2['obs'] = r
self.NPC_2['obs'] = self.gain_obs(self.NPC_2['obs'], NPC_2_state, self.NPC_1, self.agent, 2, 4)
self.NPC_2['axis'] = self.gain_axis(self.NPC_2, NPC_2_state, self.NPC_1, self.agent, 2, 4)
self.agent['obs'] = p
self.agent['obs'] = self.gain_obs(self.agent['obs'], Agent_state, self.NPC_1, self.NPC_2, 2, 3)
self.agent['axis'] = self.gain_axis(self.agent, Agent_state, self.NPC_1, self.NPC_2, 2, 3)
return NPC_1_state, NPC_2_state, Agent_state
def gain_obs(self, a, aa, b, c, bb, cc):
"""
获得智能体真正的环境遮挡关系
@param a: self - observation
@param aa: self - self-axis, other-b-axis, other-c-axis
@param b: other-b 遮挡后的可见区域 5*5
@param c: other-c 遮挡后的可见区域 5*5
@param bb: other-b' num
@param cc: other-c' num
@return: self - observation
"""
if aa[b['y']-1, b['x']-1] == 1:
a[b['y']-1, b['x']-1] = bb
if aa[c['y']-1, c['x']-1] ==1:
a[c['y'] - 1, c['x'] - 1] = cc
return a
def gain_axis(self,a,aa,b,c,bb,cc):
"""
获得坐标,但是看不见的坐标就用6来表示
@param a:
@param aa:
@param b:
@param c:
@param bb:
@param cc:
@return:
"""
axis = []
axis.append([a['y'], a['x']])
if aa[b['y']-1, b['x']-1] != 0:
axis.append([b['y'], b['x']])
else:
axis.append([6,6])
if aa[c['y']-1, c['x']-1] != 0:
axis.append([c['y'] , c['x']])
else:
axis.append([6, 6])
return axis
def interact(self, action_NPC1, action_NPC2, action_agent):
"""
三个智能体进行交互
@param action_NPC1: action
@param action_NPC2: actionF
@param action_agent: action
@return:5*5 NPC1遮挡后看见了什么 5*5 NPC2遮挡后看见了什么 5*5 agent遮挡后看见了什么
"""
self.agent['reward'] = 0
self.NPC_1['reward'] = 0
self.NPC_2['reward'] = 0
#三个智能体分别会看到什么?
NPC_1_state, NPC_2_state, Agent_state = self.actu_obs()
#看到这些状态,智能体们会分别采取什么行为? ---depend on RL
#这些行为对状态的影响 ---首先,影响本身的位置坐标,然后,影响观测
base = np.where(np.array(self.board) == obeservation_1)
base_x = int(base[0])
base_y= int(base[1])
if self.NPC_1['Done'] == False:
dis1 = np.sqrt(np.square(base_x - self.NPC_1['y']) + np.square(base_y - self.NPC_1['x']))
if self.isNotWall(self.board, self.NPC_1, self.action_move[action_NPC1][0], \
self.action_move[action_NPC1][1]):
self.NPC_1['x'] = self.NPC_1['x'] + self.action_move[action_NPC1][0]
self.NPC_1['y'] = self.NPC_1['y'] + self.action_move[action_NPC1][1]
dis2 = np.sqrt(np.square(base_x - self.NPC_1['y']) + np.square(base_y - self.NPC_1['x']))
self.NPC_1['reward'] = (((dis1 - dis2) * 2 - 1) / dis1)
else:
self.NPC_1['reward'] = -1 * (1 / dis1)
if self.board[self.NPC_1['y'], self.NPC_1['x']] == obeservation_1:
self.NPC_1['reward'] = 50
self.NPC_1['Done'] = True
base = np.where(np.array(self.board) == obeservation_2)
base_x = int(base[0])
base_y= int(base[1])
if self.NPC_2['Done'] == False:
dis1 = np.sqrt(np.square(base_x - self.NPC_2['y']) + np.square(base_y - self.NPC_2['x']))
if self.isNotWall(self.board, self.NPC_2, self.action_move[action_NPC2][0], \
self.action_move[action_NPC2][1]):
self.NPC_2['x'] = self.NPC_2['x'] + self.action_move[action_NPC2][0]
self.NPC_2['y'] = self.NPC_2['y'] + self.action_move[action_NPC2][1]
dis2 = np.sqrt(np.square(base_x - self.NPC_2['y']) + np.square(base_y - self.NPC_2['x']))
self.NPC_2['reward'] = (((dis1 - dis2)*10 - 1/2) / dis1)
while self.NPC_2['reward'] < 0.5 and self.NPC_2['reward'] > -0.5 :
self.NPC_2['reward'] = self.NPC_2['reward'] * 2
if self.NPC_2['reward'] > 1:
self.NPC_2['reward'] = 1
elif self.NPC_2['reward'] < -1:
self.NPC_2['reward'] = -1
else:
self.NPC_2['reward'] = -0.9 #* (1 / dis1)
if self.board[self.NPC_2['y'], self.NPC_2['x']] == obeservation_2:
self.NPC_2['reward'] = self.reward
self.NPC_2['Done'] = True
base = np.where(np.array(self.board) == obeservation_3)
base_x = int(base[0])
base_y= int(base[1])
if self.agent['Done'] == False:
dis1 = np.sqrt(np.square(base_x - self.agent['y']) + np.square(base_y - self.agent['x']))
if self.isNotWall(board=self.board, piece=self.agent, xT=self.action_move[action_agent][0], \
yT=self.action_move[action_agent][1]):
self.agent['x'] = self.agent['x'] + self.action_move[action_agent][0]
self.agent['y'] = self.agent['y'] + self.action_move[action_agent][1]
dis2 = np.sqrt(np.square(base_x - self.agent['y']) + np.square(base_y - self.agent['x']))
# self.agent['reward'] = (((dis1 - dis2) * 2 - 1) / dis1)
# else:
# # print('action', action_agent)
# self.agent['reward'] = -1 * (1 / dis1)
self.agent['reward'] = (((dis1 - dis2)*10 - 1/2) / dis1)
while self.agent['reward'] < 0.5 and self.agent['reward'] > -0.5 :
self.agent['reward'] = self.agent['reward'] * 2 - 0.1
if self.agent['reward'] > 1:
self.agent['reward'] = 1
elif self.agent['reward'] < -1:
self.agent['reward'] = -1
else:
self.agent['reward'] = -0.9 #* (1 / dis1)
if self.board[self.agent['y'], self.agent['x']] == obeservation_3:
self.agent['reward'] = self.reward
self.agent['Done'] = True
NPC_1_state, NPC_2_state, Agent_state = self.actu_obs()
#判断是否会相撞?
location = [(self.NPC_1['x'], self.NPC_1['y']), (self.NPC_2['x'], self.NPC_2['y']),\
(self.agent['x'], self.agent['y'])]
#达到目标或者相撞都会结束该智能体的回合
terminal = self.gameover(location)
if terminal[0] == True and self.NPC_1['Done'] == False:
self.NPC_1['Done'] = True
self.NPC_1['reward'] = -50
self.NPC_1['color'] = Red
if terminal[1] == True and self.NPC_2['Done'] == False:
self.NPC_2['Done'] = True
self.NPC_2['reward'] = -self.reward
self.NPC_2['color'] = Red
if terminal[2] == True and self.agent['Done'] == False:
self.agent['Done'] = True
self.agent['reward'] = -self.reward
self.agent['color'] = Red
return NPC_1_state, NPC_2_state, Agent_state
def SHOW(self):
"""
显示函数
"""
DISPLAYSURF.fill(BGColor)
self.DrawBoard(self.board)
self.DrawPiece(self.NPC_1)
self.DrawPiece(self.NPC_2)
self.DrawPiece(self.agent)
pygame.display.update()
FPSCLOCK.tick(FPS)
# return flag
def reset(self):
self._agent_init()
def getBlankBoard(self):
"""
11 - NPC1-goal
22 - NPC2-goal
33 - Agent-goal
@return:
"""
board = np.array([[1,1,1,5,5],[22,1,1,5,5],[1,1,1,1,1],[1,1,1,1,33],[1,1,1,11,1]])
state_init = np.array([[1,1,1,5,5],[1,1,1,5,5],[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1]])
board = big_env(board)
# print(board)
# for x in range(GridWidth):
# for y in range(GridHeight):
# print(board[x][y],x,y)
return board, state_init
def ValidPos(self, piece1, piece2, xT=0, yT=0):
"""
to judge the next place vaild or not
@param piece1:
@param piece2:
@param xT:
@param yT:
@return:
"""
if piece1['x'] == (piece2['x'] + xT) and piece1['y'] == (piece2['y'] + yT):
return True
return False
def isNotWall(self, board, piece, xT=0 , yT=0 ):
"""
判断是否到达墙
@param board: board
@param piece: agent
@param xT:
@param yT:
@return:
"""
if board[piece['y'] + yT][piece['x'] + xT] == Wall:#############
return False
else:
return True
def gameover(self, location):
"""
回合是否结束,以及奖励值
@param location:目标位置
"""
result = False
terminal = [False, False, False]
if location[0] == location[1]:
terminal[0] = True
terminal[1] = True
if location[2] == location[1]:
terminal[2] = True
terminal[1] = True
if location[2] ==location[0]:
terminal[2] = True
terminal[0] = True
return terminal
def pixel(self, xbox, ybox):
return (XMargin + (xbox * BoxSize)), (TopMargin + (ybox * BoxSize))
def DrawBox(self, xbox, ybox, color, xpixel=None, ypixel=None):
if color == BlankBox:
return
elif color == obeservation_1:
if xpixel == None and ypixel == None:
xpixel, ypixel = self.pixel(xbox, ybox)
pygame.draw.rect(DISPLAYSURF, (60,107,255), (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))
elif color == obeservation_2:
if xpixel == None and ypixel == None:
xpixel, ypixel = self.pixel(xbox, ybox)
pygame.draw.rect(DISPLAYSURF, (205, 155, 155), (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))
elif color == obeservation_3:
if xpixel == None and ypixel == None:
xpixel, ypixel = self.pixel(xbox, ybox)
pygame.draw.rect(DISPLAYSURF, (154, 205, 50), (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))
elif color == Wall:
if xpixel == None and ypixel == None:
xpixel, ypixel = self.pixel(xbox, ybox)
pygame.draw.rect(DISPLAYSURF, Gray, (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))
else:
if xpixel == None and ypixel == None:
xpixel, ypixel = self.pixel(xbox, ybox)
pygame.draw.rect(DISPLAYSURF, color, (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))
def fun_trigger(self):
xpixel, ypixel = self.pixel(3, 5)
pygame.draw.line(DISPLAYSURF, Red, (xpixel + BoxSize, ypixel - BoxSize),
(xpixel + BoxSize, ypixel - 2*BoxSize), 5)
def DrawCircle(self, xbox, ybox, color, xpixel=None, ypixel=None):
"""
画圆
@param xbox:
@param ybox:
@param color:
@param xpixel:
@param ypixel:
"""
pygame.draw.circle(DISPLAYSURF,
color,
(int(xpixel+BoxSize/2), int(ypixel+BoxSize/2)),
int(0.3 * self.R))
def DrawPiece(self, piece, xpixel=None, ypixel=None):
if xpixel == None and ypixel == None:
xpixel, ypixel = self.pixel(piece['x'], piece['y'])
if piece['style'] == "circle":
self.DrawCircle(None, None, piece['color'], xpixel, ypixel)
else:
self.DrawBox(None, None, piece['color'], xpixel, ypixel)
def DrawBoard(self, board):
pygame.draw.rect(DISPLAYSURF, BoardColor,
(XMargin - 3, TopMargin - 7, (GridWidth * BoxSize) + 8, (GridHeight * BoxSize) + 8), 5)
pygame.draw.rect(DISPLAYSURF, BGColor, (XMargin, TopMargin, GridWidth * BoxSize, GridHeight * BoxSize))
for x in range(GridWidth):
for y in range(GridHeight):
self.DrawBox(y, x, board[x][y])
if self.trigger == 1:
self.fun_trigger()
def ShowScore(self, score):
scoreSurf = BASICFONT.render('Score : %s' % score, True, TextColor)
scoreRect = scoreSurf.get_rect()
scoreRect.topleft = (WinWidth - 250, 20)
DISPLAYSURF.blit(scoreSurf, scoreRect)
def Terminal(self, piece1, piece2, piece1_old, piece2_old):
# print(piece1['x'],piece1['y'],piece2_old[0],piece2_old[1],'/',piece2['x'],piece2['y'],piece1_old[0],piece1_old[1])
# if self.steps == 1:# wrong!!!!
if piece1['x'] == piece2_old[0] and piece1['y'] == piece2_old[1] and piece2['x'] == piece1_old[0] and piece2['y'] == piece1_old[1]:
return 1
else:
return 2
def Paint(self, board, piece, color):
board[piece['x']][piece['y']] = color
piece['color'] = color
return board
# if __name__ == "__main__":
# env0 = FalseBelief_env0()
# action_agent = 0
# action_NPC2 = 1
# action_NPC1 = 4
# for i in range(10):
# if i > 8:
# break
# else:
# env0.interact(action_NPC1, action_NPC2, action_agent)
#
# env0.SHOW()
# time.sleep(2)
# pygame.quit()
================================================
FILE: examples/Social_Cognition/ToM/env/env3_train_env00.py
================================================
from numpy import argmax
import random, time, pygame, sys
import pygame
pygame.init()
from pygame.locals import *
import os
os.environ["SDL_VIDEODRIVER"] = "dummy"
from rulebasedpolicy.world_model import *
from rulebasedpolicy.statedata_pre import *
from rulebasedpolicy.Find_a_way import *
import numpy as np
from utils.one_hot import one_hot
# =============================================================================
# set the value of interface
# =============================================================================
FPS = 25
WinWidth = 340 #window width
WinHeight = 260 #window width
BoxSize = 20 #the size of one grid
GridWidth = 7 #the number of lattices are there in the x-axis
GridHeight = 7 #the number of lattices are there in the y-axis
#representation of different objective
BlankBox = 1
Wall = 5
Obstacle = 5
observer = 8
obeservation_1 = 11
obeservation_2 = 22
obeservation_3 = 33
#Text = None
XMargin = int((WinWidth - GridWidth * BoxSize)/2)
TopMargin = int((WinHeight - GridHeight * BoxSize))/2-5
# =============================================================================
# set color
# =============================================================================
White = (255,255,255)
Gray = (185,185,185)
Black = (0,0,0)
Red = (200,0,0)
Green = (0,139,0)
Green_B = (78, 238, 148)
Light_A = (233, 232, 170)
Blue = (30, 144, 255)
pink = (238, 99, 99)
BoardColor = White
BGColor = White
TextColor = White
Test = []
# =============================================================================
# agents - env interactive
# =============================================================================
class FalseBelief_env0(object):
def __init__(self, reward=10):
super(FalseBelief_env0, self).__init__()
self.action_space = ['up', 'down', 'left', 'right', 'stay']
self.action_move = {
0 : (0, -1),
1 : (0, 1),
2 : (-1, 0),
3 : (1, 0),
4 :(0, 0)
}#[(0, -1), (0, 1), (-1, 0), (1, 0), (0, 0)]
self.n_actions = len(self.action_space)
self._build_AB()
self.board, self.obs = self.getBlankBoard()
self._agent_init()
self.score = 0
self.steps = 0
self.n = 0
self.R = int(5/2) * (BoxSize - 5)
self.x = 0
self.n_features = 30
self.reward = reward
def _build_AB(self):
global FPSCLOCK, DISPLAYSURF, BASICFONT, BIGFONT
FPSCLOCK = pygame.time.Clock()
DISPLAYSURF = pygame.display.set_mode((WinWidth, WinHeight))
BASICFONT = pygame.font.Font('freesansbold.ttf', 18)
BIGFONT = pygame.font.Font('freesansbold.ttf', 100)
pygame.display.set_caption('AB')
pygame.display.update()
FPSCLOCK.tick()
def _agent_init(self):
"""
Aim:Initialize the basic information of the agent
"""
self.NPC_1 = {
'shape' : [['#']],
'x' : 3, #row
'y' : 1, #column
'color' : Blue,
'style' : "circle",
'obs' : None,
'axis' : None,#,[[1,3],[3,5],[4,2]]
'reward' : 0,
'Done' : False
}
self.NPC_2 = {
'shape' : [['@']],
'x' : 5, #row
'y' : 3, #column
'color' : pink,
'style' : "circle",
'obs' : None,
'axis' :None ,#,[[3,5],[1,3],[4,2]]
'reward': 0,
'Done': False
}
self.agent = {
'shape' : [['$']],
'x' : 2, #row
'y' : 4, #column
'color' : Green_B,
'style' : "circle",
'obs' : None,
'axis' : None,#[[4,2],[1,3],[3,5]]
'reward': 0,
'Done': False
}
def actu_obs(self):
"""
将状态转化成可以训练的数据形式
"""
_, state = self.getBlankBoard()
a = state
b = state
c = state
state1 = np.r_[a, np.ones((4, 5))].astype(np.int_)
state2 = np.r_[b, np.ones((4, 5))].astype(np.int_)
statea = np.r_[c, np.ones((4, 5))].astype(np.int_)
NPC_1_state = state1
NPC_2_state = state2
Agent_state = statea
NPC_1_state[self.NPC_1['y']-1, self.NPC_1['x']-1] = observer
q = shelter_env(NPC_1_state[:5, :])
NPC_1_state[:5, :] = shelter_env(NPC_1_state[:5, :])
NPC_2_state[self.NPC_2['y']-1, self.NPC_2['x']-1] = observer
r = shelter_env(NPC_2_state[:5, :])
NPC_2_state[:5, :] = shelter_env(NPC_2_state[:5, :])
Agent_state[self.agent['y']-1, self.agent['x']-1] = observer
p = shelter_env(Agent_state[:5, :])
Agent_state[:5, :] = shelter_env(Agent_state[:5, :])
"""
########### num ############
#2-NPC1 in other agents' obs
#3-NPC2 in other agents' obs
#4-Agent in other agents' obs
"""
self.NPC_1['obs'] = q
self.NPC_1['obs'] = self.gain_obs(self.NPC_1['obs'],NPC_1_state,self.NPC_2,self.agent,3,4)
self.NPC_1['axis'] = self.gain_axis(self.NPC_1,NPC_1_state,self.NPC_2,self.agent,3,4)
self.NPC_2['obs'] = r
self.NPC_2['obs'] = self.gain_obs(self.NPC_2['obs'], NPC_2_state, self.NPC_1, self.agent, 2, 4)
self.NPC_2['axis'] = self.gain_axis(self.NPC_2, NPC_2_state, self.NPC_1, self.agent, 2, 4)
self.agent['obs'] = p
self.agent['obs'] = self.gain_obs(self.agent['obs'], Agent_state, self.NPC_1, self.NPC_2, 2, 3)
self.agent['axis'] = self.gain_axis(self.agent, Agent_state, self.NPC_1, self.NPC_2, 2, 3)
return NPC_1_state, NPC_2_state, Agent_state
def gain_obs(self, a, aa, b, c, bb, cc):
"""
获得智能体真正的环境遮挡关系
@param a: self - observation
@param aa: self - self-axis, other-b-axis, other-c-axis
@param b: other-b 遮挡后的可见区域 5*5
@param c: other-c 遮挡后的可见区域 5*5
@param bb: other-b' num
@param cc: other-c' num
@return: self - observation
"""
if aa[b['y']-1, b['x']-1] == 1:
a[b['y']-1, b['x']-1] = bb
if aa[c['y']-1, c['x']-1] ==1:
a[c['y'] - 1, c['x'] - 1] = cc
return a
def gain_axis(self,a,aa,b,c,bb,cc):
"""
获得坐标,但是看不见的坐标就用6来表示
@param a:
@param aa:
@param b:
@param c:
@param bb:
@param cc:
@return:
"""
axis = []
axis.append([a['y'], a['x']])
if aa[b['y']-1, b['x']-1] != 0:
axis.append([b['y'], b['x']])
else:
axis.append([6,6])
if aa[c['y']-1, c['x']-1] != 0:
axis.append([c['y'] , c['x']])
else:
axis.append([6, 6])
return axis
def interact(self, action_NPC1, action_NPC2, action_agent):
"""
三个智能体进行交互
@param action_NPC1: action
@param action_NPC2: actionF
@param action_agent: action
@return:5*5 NPC1遮挡后看见了什么 5*5 NPC2遮挡后看见了什么 5*5 agent遮挡后看见了什么
"""
self.agent['reward'] = 0
self.NPC_1['reward'] = 0
self.NPC_2['reward'] = 0
#三个智能体分别会看到什么?
NPC_1_state, NPC_2_state, Agent_state = self.actu_obs()
#看到这些状态,智能体们会分别采取什么行为? ---depend on RL
#这些行为对状态的影响 ---首先,影响本身的位置坐标,然后,影响观测
base = np.where(np.array(self.board) == obeservation_1)
base_x = int(base[0])
base_y= int(base[1])
if self.NPC_1['Done'] == False:
dis1 = np.sqrt(np.square(base_x - self.NPC_1['y']) + np.square(base_y - self.NPC_1['x']))
if self.isNotWall(self.board, self.NPC_1, self.action_move[action_NPC1][0], \
self.action_move[action_NPC1][1]):
self.NPC_1['x'] = self.NPC_1['x'] + self.action_move[action_NPC1][0]
self.NPC_1['y'] = self.NPC_1['y'] + self.action_move[action_NPC1][1]
dis2 = np.sqrt(np.square(base_x - self.NPC_1['y']) + np.square(base_y - self.NPC_1['x']))
self.NPC_1['reward'] = (((dis1 - dis2) * 2 - 1) / dis1)
else:
self.NPC_1['reward'] = -1 * (1 / dis1)
if self.board[self.NPC_1['y'], self.NPC_1['x']] == obeservation_1:
self.NPC_1['reward'] = 50
self.NPC_1['Done'] = True
base = np.where(np.array(self.board) == obeservation_2)
base_x = int(base[0])
base_y= int(base[1])
if self.NPC_2['Done'] == False:
dis1 = np.sqrt(np.square(base_x - self.NPC_2['y']) + np.square(base_y - self.NPC_2['x']))
if self.isNotWall(self.board, self.NPC_2, self.action_move[action_NPC2][0], \
self.action_move[action_NPC2][1]):
self.NPC_2['x'] = self.NPC_2['x'] + self.action_move[action_NPC2][0]
self.NPC_2['y'] = self.NPC_2['y'] + self.action_move[action_NPC2][1]
dis2 = np.sqrt(np.square(base_x - self.NPC_2['y']) + np.square(base_y - self.NPC_2['x']))
self.NPC_2['reward'] = (((dis1 - dis2)*10 - 1/2) / dis1)
while self.NPC_2['reward'] < 0.5 and self.NPC_2['reward'] > -0.5 :
self.NPC_2['reward'] = self.NPC_2['reward'] * 2
if self.NPC_2['reward'] > 1:
self.NPC_2['reward'] = 1
elif self.NPC_2['reward'] < -1:
self.NPC_2['reward'] = -1
else:
self.NPC_2['reward'] = -0.9 #* (1 / dis1)
if self.board[self.NPC_2['y'], self.NPC_2['x']] == obeservation_2:
self.NPC_2['reward'] = self.reward
self.NPC_2['Done'] = True
base = np.where(np.array(self.board) == obeservation_3)
base_x = int(base[0])
base_y= int(base[1])
if self.agent['Done'] == False:
dis1 = np.sqrt(np.square(base_x - self.agent['y']) + np.square(base_y - self.agent['x']))
if self.isNotWall(board=self.board, piece=self.agent, xT=self.action_move[action_agent][0], \
yT=self.action_move[action_agent][1]):
self.agent['x'] = self.agent['x'] + self.action_move[action_agent][0]
self.agent['y'] = self.agent['y'] + self.action_move[action_agent][1]
dis2 = np.sqrt(np.square(base_x - self.agent['y']) + np.square(base_y - self.agent['x']))
# self.agent['reward'] = (((dis1 - dis2) * 2 - 1) / dis1)
# else:
# # print('action', action_agent)
# self.agent['reward'] = -1 * (1 / dis1)
self.agent['reward'] = (((dis1 - dis2)*10 - 1/2) / dis1)
while self.agent['reward'] < 0.5 and self.agent['reward'] > -0.5 :
self.agent['reward'] = self.agent['reward'] * 2 - 0.1
if self.agent['reward'] > 1:
self.agent['reward'] = 1
elif self.agent['reward'] < -1:
self.agent['reward'] = -1
else:
self.agent['reward'] = -0.9 #* (1 / dis1)
if self.board[self.agent['y'], self.agent['x']] == obeservation_3:
self.agent['reward'] = self.reward
self.agent['Done'] = True
NPC_1_state, NPC_2_state, Agent_state = self.actu_obs()
#判断是否会相撞?
location = [(self.NPC_1['x'], self.NPC_1['y']), (self.NPC_2['x'], self.NPC_2['y']),\
(self.agent['x'], self.agent['y'])]
#达到目标或者相撞都会结束该智能体的回合
terminal = self.gameover(location)
if terminal[0] == True and self.NPC_1['Done'] == False:
self.NPC_1['Done'] = True
self.NPC_1['reward'] = -50
self.NPC_1['color'] = Red
if terminal[1] == True and self.NPC_2['Done'] == False:
self.NPC_2['Done'] = True
self.NPC_2['reward'] = -self.reward
self.NPC_2['color'] = Red
if terminal[2] == True and self.agent['Done'] == False:
self.agent['Done'] = True
self.agent['reward'] = -self.reward
self.agent['color'] = Red
return NPC_1_state, NPC_2_state, Agent_state
def SHOW(self):
"""
显示函数
"""
DISPLAYSURF.fill(BGColor)
self.DrawBoard(self.board)
self.DrawPiece(self.NPC_1)
self.DrawPiece(self.NPC_2)
self.DrawPiece(self.agent)
pygame.display.update()
FPSCLOCK.tick(FPS)
# return flag
def reset(self):
self._agent_init()
def getBlankBoard(self):
"""
11 - NPC1-goal
22 - NPC2-goal
33 - Agent-goal
@return:
"""
# board = data_transfer('env_1.txt','env_11.txt')
board = np.array([[1,1,1,1,1],[22,1,1,1,1],[1,1,1,1,1],[1,1,1,1,33],[1,1,1,11,1]])
state_init = np.array([[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1]])
board = big_env(board)
# print(board)
# for x in range(GridWidth):
# for y in range(GridHeight):
# print(board[x][y],x,y)
return board, state_init
def ValidPos(self, piece1, piece2, xT=0, yT=0):
"""
to judge the next place vaild or not
@param piece1:
@param piece2:
@param xT:
@param yT:
@return:
"""
if piece1['x'] == (piece2['x'] + xT) and piece1['y'] == (piece2['y'] + yT):
return True
return False
def isNotWall(self, board, piece, xT=0 , yT=0 ):
"""
判断是否到达墙
@param board: board
@param piece: agent
@param xT:
@param yT:
@return:
"""
if board[piece['y'] + yT][piece['x'] + xT] == Wall:#############
return False
else:
return True
def gameover(self, location):
"""
回合是否结束,以及奖励值
@param location:目标位置
"""
result = False
terminal = [False, False, False]
if location[0] == location[1]:
terminal[0] = True
terminal[1] = True
if location[2] == location[1]:
terminal[2] = True
terminal[1] = True
if location[2] ==location[0]:
terminal[2] = True
terminal[0] = True
return terminal
def pixel(self, xbox, ybox):
return (XMargin + (xbox * BoxSize)), (TopMargin + (ybox * BoxSize))
def DrawBox(self, xbox, ybox, color, xpixel=None, ypixel=None):
if color == BlankBox:
return
elif color == obeservation_1:
if xpixel == None and ypixel == None:
xpixel, ypixel = self.pixel(xbox, ybox)
pygame.draw.rect(DISPLAYSURF, (60,107,255), (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))
elif color == obeservation_2:
if xpixel == None and ypixel == None:
xpixel, ypixel = self.pixel(xbox, ybox)
pygame.draw.rect(DISPLAYSURF, (205, 155, 155), (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))
elif color == obeservation_3:
if xpixel == None and ypixel == None:
xpixel, ypixel = self.pixel(xbox, ybox)
pygame.draw.rect(DISPLAYSURF, (154, 205, 50), (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))
elif color == Wall:
if xpixel == None and ypixel == None:
xpixel, ypixel = self.pixel(xbox, ybox)
pygame.draw.rect(DISPLAYSURF, Gray, (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))
else:
if xpixel == None and ypixel == None:
xpixel, ypixel = self.pixel(xbox, ybox)
pygame.draw.rect(DISPLAYSURF, color, (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))
def DrawCircle(self, xbox, ybox, color, xpixel=None, ypixel=None):
"""
画圆
@param xbox:
@param ybox:
@param color:
@param xpixel:
@param ypixel:
"""
pygame.draw.circle(DISPLAYSURF,
color,
(int(xpixel+BoxSize/2), int(ypixel+BoxSize/2)),
int(0.3 * self.R))
def DrawPiece(self, piece, xpixel=None, ypixel=None):
if xpixel == None and ypixel == None:
xpixel, ypixel = self.pixel(piece['x'], piece['y'])
if piece['style'] == "circle":
self.DrawCircle(None, None, piece['color'], xpixel, ypixel)
else:
self.DrawBox(None, None, piece['color'], xpixel, ypixel)
def DrawBoard(self, board):
pygame.draw.rect(DISPLAYSURF, BoardColor,
(XMargin - 3, TopMargin - 7, (GridWidth * BoxSize) + 8, (GridHeight * BoxSize) + 8), 5)
pygame.draw.rect(DISPLAYSURF, BGColor, (XMargin, TopMargin, GridWidth * BoxSize, GridHeight * BoxSize))
for x in range(GridWidth):
for y in range(GridHeight):
self.DrawBox(y, x, board[x][y])
def ShowScore(self, score):
scoreSurf = BASICFONT.render('Score : %s' % score, True, TextColor)
scoreRect = scoreSurf.get_rect()
scoreRect.topleft = (WinWidth - 250, 20)
DISPLAYSURF.blit(scoreSurf, scoreRect)
def Terminal(self, piece1, piece2, piece1_old, piece2_old):
# print(piece1['x'],piece1['y'],piece2_old[0],piece2_old[1],'/',piece2['x'],piece2['y'],piece1_old[0],piece1_old[1])
# if self.steps == 1:# wrong!!!!
if piece1['x'] == piece2_old[0] and piece1['y'] == piece2_old[1] and piece2['x'] == piece1_old[0] and piece2['y'] == piece1_old[1]:
return 1
else:
return 2
def Paint(self, board, piece, color):
board[piece['x']][piece['y']] = color
piece['color'] = color
return board
# if __name__ == "__main__":
# env0 = FalseBelief_env0()
# action_agent = 0
# action_NPC2 = 1
# action_NPC1 = 4
# for i in range(10):
# if i > 8:
# break
# else:
# env0.interact(action_NPC1, action_NPC2, action_agent)
#
# env0.SHOW()
# time.sleep(2)
# pygame.quit()
================================================
FILE: examples/Social_Cognition/ToM/env/env3_train_env01.py
================================================
"""
Zoe Zhao 2022.5
Env Demo
"""
from numpy import argmax
import random, time, pygame, sys
import pygame
pygame.init()
from pygame.locals import *
# import os
# os.environ["SDL_VIDEODRIVER"] = "dummy"
from rulebasedpolicy.world_model import *
from rulebasedpolicy.statedata_pre import *
from rulebasedpolicy.Find_a_way import *
import numpy as np
from utils.one_hot import one_hot
# =============================================================================
# set the value of interface
# =============================================================================
FPS = 25
WinWidth = 340 #window width
WinHeight = 260 #window width
BoxSize = 20 #the size of one grid
GridWidth = 7 #the number of lattices are there in the x-axis
GridHeight = 7 #the number of lattices are there in the y-axis
#representation of different objective
BlankBox = 1
Wall = 5
Obstacle = 5
observer = 8
obeservation_1 = 11
obeservation_2 = 22
obeservation_3 = 33
#Text = None
XMargin = int((WinWidth - GridWidth * BoxSize)/2)
TopMargin = int((WinHeight - GridHeight * BoxSize))/2-5
# =============================================================================
# set color
# =============================================================================
White = (255,255,255)
Gray = (185,185,185)
Black = (0,0,0)
Red = (200,0,0)
Green = (0,139,0)
Green_B = (78, 238, 148)
Light_A = (233, 232, 170)
Blue = (30, 144, 255)
pink = (238, 99, 99)
BoardColor = White
BGColor = White
TextColor = White
Test = []
# =============================================================================
# agents - env interactive
# =============================================================================
class FalseBelief_env1(object):
def __init__(self, reward=10):
super(FalseBelief_env1, self).__init__()
self.action_space = ['up', 'down', 'left', 'right', 'stay']
self.action_move = {
0 : (0, -1),
1 : (0, 1),
2 : (-1, 0),
3 : (1, 0),
4 :(0, 0)
}#[(0, -1), (0, 1), (-1, 0), (1, 0), (0, 0)]
self.n_actions = len(self.action_space)
self._build_AB()
self.board, self.obs = self.getBlankBoard()
self._agent_init()
self.score = 0
self.steps = 0
self.n = 0
self.R = int(5/2) * (BoxSize - 5)
self.x = 0
self.n_features = 30
self.reward = reward
def _build_AB(self):
global FPSCLOCK, DISPLAYSURF, BASICFONT, BIGFONT
pygame.init()
FPSCLOCK = pygame.time.Clock()
DISPLAYSURF = pygame.display.set_mode((WinWidth, WinHeight))
BASICFONT = pygame.font.Font('freesansbold.ttf', 18)
BIGFONT = pygame.font.Font('freesansbold.ttf', 100)
pygame.display.set_caption('AB')
pygame.display.update()
FPSCLOCK.tick()
def _agent_init(self):
"""
Aim:Initialize the basic information of the agent
"""
self.NPC_2 = {
'shape' : [['@']],
'x' : 5, #row
'y' : 3, #column
'color' : pink,
'style' : "circle",
'obs' : None,
'axis' :None ,#,[[3,5],[1,3],[4,2]]
'reward': 0,
'Done': False
}
self.agent = {
'shape' : [['$']],
'x' : 2, #row
'y' : 4, #column
'color' : Green_B,
'style' : "circle",
'obs' : None,
'axis' : None,#[[4,2],[1,3],[3,5]]
'reward': 0,
'Done': False
}
def actu_obs(self):
"""
将状态转化成可以训练的数据形式
"""
_, state = self.getBlankBoard()
a = state
b = state
c = state
state2 = np.r_[b, np.ones((4, 5))].astype(np.int)
statea = np.r_[c, np.ones((4, 5))].astype(np.int)
NPC_2_state = state2
Agent_state = statea
NPC_2_state[self.NPC_2['y']-1, self.NPC_2['x']-1] = observer
r = shelter_env(NPC_2_state[:5, :])
NPC_2_state[:5, :] = shelter_env(NPC_2_state[:5, :])
Agent_state[self.agent['y']-1, self.agent['x']-1] = observer
p = shelter_env(Agent_state[:5, :])
Agent_state[:5, :] = shelter_env(Agent_state[:5, :])
self.NPC_2['obs'] = r
self.NPC_2['obs'] = self.gain_obs(self.NPC_2['obs'], NPC_2_state, self.agent, 4)
self.NPC_2['axis'] = self.gain_axis(self.NPC_2, NPC_2_state, 6, self.agent, 2, 4)
self.agent['obs'] = p
self.agent['obs'] = self.gain_obs(self.agent['obs'], Agent_state, self.NPC_2, 3)
self.agent['axis'] = self.gain_axis(self.agent, Agent_state, 6, self.NPC_2, 2, 3)
return NPC_2_state, Agent_state#NPC_1_state,
def gain_obs(self, a,aa,c,cc):
if aa[c['y']-1, c['x']-1] ==1:
a[c['y'] - 1, c['x'] - 1] = cc
return a
def gain_axis(self,a,aa,b,c,bb,cc):
axis = []
axis.append([a['y'], a['x']])
if b == 6:
axis.append([6, 6])
else:
axis.append([6, 6])
if aa[c['y']-1, c['x']-1] != 0:
axis.append([c['y'] , c['x']])
else:
axis.append([6, 6])
return axis
def interact(self, action_NPC2, action_agent):
self.agent['reward'] = 0
self.NPC_2['reward'] = 0
#三个智能体分别会看到什么?
NPC_2_state, Agent_state = self.actu_obs()# NPC_1_state,
#看到这些状态,智能体们会分别采取什么行为? ---depend on RL
#这些行为对状态的影响 ---首先,影响本身的位置坐标,然后,影响观测
base = np.where(np.array(self.board) == obeservation_2)
base_x = int(base[0])
base_y= int(base[1])
if self.NPC_2['Done'] == False:
dis1 = np.sqrt(np.square(base_x - self.NPC_2['y']) + np.square(base_y - self.NPC_2['x']))
if self.isNotWall(self.board, self.NPC_2, self.action_move[action_NPC2][0], \
self.action_move[action_NPC2][1]):
self.NPC_2['x'] = self.NPC_2['x'] + self.action_move[action_NPC2][0]
self.NPC_2['y'] = self.NPC_2['y'] + self.action_move[action_NPC2][1]
dis2 = np.sqrt(np.square(base_x - self.NPC_2['y']) + np.square(base_y - self.NPC_2['x']))
self.NPC_2['reward'] = (((dis1 - dis2) * 10 - 1 / 2) / dis1)
while self.NPC_2['reward'] < 0.5 and self.NPC_2['reward'] > -0.5:
self.NPC_2['reward'] = self.NPC_2['reward'] * 2
if self.NPC_2['reward'] > 1:
self.NPC_2['reward'] = 1
elif self.NPC_2['reward'] < -1:
self.NPC_2['reward'] = -1
else:
self.NPC_2['reward'] = -0.9 # * (1 / dis1)
if self.board[self.NPC_2['y'], self.NPC_2['x']] == obeservation_2:
self.NPC_2['reward'] = self.reward
self.NPC_2['Done'] = True
base = np.where(np.array(self.board) == obeservation_3)
base_x = int(base[0])
base_y= int(base[1])
if self.agent['Done'] == False:
dis1 = np.sqrt(np.square(base_x - self.agent['y']) + np.square(base_y - self.agent['x']))
if self.isNotWall(board=self.board, piece=self.agent, xT=self.action_move[action_agent][0], \
yT=self.action_move[action_agent][1]):
self.agent['x'] = self.agent['x'] + self.action_move[action_agent][0]
self.agent['y'] = self.agent['y'] + self.action_move[action_agent][1]
dis2 = np.sqrt(np.square(base_x - self.agent['y']) + np.square(base_y - self.agent['x']))
# self.agent['reward'] = (((dis1 - dis2) * 2 - 1) / dis1)
#
# else:
# # print('action', action_agent)
# self.agent['reward'] = -1 * (1 / dis1)
self.agent['reward'] = (((dis1 - dis2)*10 - 1/2) / dis1)
while self.agent['reward'] < 0.5 and self.agent['reward'] > -0.5 :
self.agent['reward'] = self.agent['reward'] * 2 - 0.1
if self.agent['reward'] > 1:
self.agent['reward'] = 1
elif self.agent['reward'] < -1:
self.agent['reward'] = -1
else:
self.agent['reward'] = -0.9 #* (1 / dis1)
if self.board[self.agent['y'], self.agent['x']] == obeservation_3:
self.agent['reward'] = self.reward
self.agent['Done'] = True
NPC_2_state, Agent_state = self.actu_obs()
#判断是否会相撞?
location = [(self.NPC_2['x'], self.NPC_2['y']),\
(self.agent['x'], self.agent['y'])]
#达到目标或者相撞都会结束该智能体的回合
terminal = self.gameover(location)
if self.agent['Done'] == False and terminal[1] == True:
self.agent['Done'] = True
self.agent['reward'] = -self.reward
self.agent['color'] = Red
if self.NPC_2['Done'] == False and terminal[0] == True:
self.NPC_2['Done'] = True
self.NPC_2['reward'] = -self.reward
self.NPC_2['color'] = Red
return NPC_2_state, Agent_state
def SHOW(self):
DISPLAYSURF.fill(BGColor)
self.DrawBoard(self.board)
# self.DrawPiece(self.NPC_1)
self.DrawPiece(self.NPC_2)
self.DrawPiece(self.agent)
pygame.display.update()
FPSCLOCK.tick(FPS)
# return flag
def reset(self):
self._agent_init()
def getBlankBoard(self):
# board = data_transfer('env_1.txt','env_11.txt')
board = np.array([[1,1,1,5,5],[22,1,1,5,5],[1,1,1,1,1],[1,1,1,1,33],[1,1,1,11,1]])
state_init = np.array([[1,1,1,5,5],[1,1,1,5,5],[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1]])
board = big_env(board)
# print(board)
# for x in range(GridWidth):
# for y in range(GridHeight):
# print(board[x][y],x,y)
return board, state_init
def ValidPos(self, piece1, piece2, xT=0, yT=0):
"""
to judge the next place vaild or not
@param piece1:
@param piece2:
@param xT:
@param yT:
@return:
"""
if piece1['x'] == (piece2['x'] + xT) and piece1['y'] == (piece2['y'] + yT):
return True
return False
def isNotWall(self, board, piece, xT=0 , yT=0 ):
if board[piece['y'] + yT][piece['x'] + xT] == Wall:#############
return False
else:
return True
def gameover(self, location):
"""
回合是否结束,以及奖励值
@param location:目标位置
"""
result = False
terminal = [False, False] #NPC_2, agent
for r in range(len(location) - 1):
for c in range(r + 1, len(location)):
if location[r] == location[c]:
result = True ####相撞会带来一个巨大的副奖励,并且结束该回合
boom = location[r] ###############相撞################
if result == True:
terminal[r] = True
terminal[c] = True
return terminal
def pixel(self, xbox, ybox):
return (XMargin + (xbox * BoxSize)), (TopMargin + (ybox * BoxSize))
def DrawBox(self, xbox, ybox, color, xpixel=None, ypixel=None):
if color == BlankBox:
return
elif color == obeservation_1:
if xpixel == None and ypixel == None:
xpixel, ypixel = self.pixel(xbox, ybox)
pygame.draw.rect(DISPLAYSURF, (60,107,255), (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))
elif color == obeservation_2:
if xpixel == None and ypixel == None:
xpixel, ypixel = self.pixel(xbox, ybox)
pygame.draw.rect(DISPLAYSURF, (205, 155, 155 ), (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))
elif color == obeservation_3:
if xpixel == None and ypixel == None:
xpixel, ypixel = self.pixel(xbox, ybox)
pygame.draw.rect(DISPLAYSURF, (154, 205, 50), (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))
elif color == Wall:
if xpixel == None and ypixel == None:
xpixel, ypixel = self.pixel(xbox, ybox)
pygame.draw.rect(DISPLAYSURF, Gray, (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))
else:
if xpixel == None and ypixel == None:
xpixel, ypixel = self.pixel(xbox, ybox)
pygame.draw.rect(DISPLAYSURF, color, (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))
def DrawCircle(self, xbox, ybox, color, xpixel=None, ypixel=None):
pygame.draw.circle(DISPLAYSURF,
color,
(int(xpixel+BoxSize/2), int(ypixel+BoxSize/2)),
int(0.3 * self.R))
def DrawPiece(self, piece, xpixel=None, ypixel=None):
if xpixel == None and ypixel == None:
xpixel, ypixel = self.pixel(piece['x'], piece['y'])
if piece['style'] == "circle":
self.DrawCircle(None, None, piece['color'], xpixel, ypixel)
else:
self.DrawBox(None, None, piece['color'], xpixel, ypixel)
def DrawBoard(self, board):
pygame.draw.rect(DISPLAYSURF, BoardColor,
(XMargin - 3, TopMargin - 7, (GridWidth * BoxSize) + 8, (GridHeight * BoxSize) + 8), 5)
pygame.draw.rect(DISPLAYSURF, BGColor, (XMargin, TopMargin, GridWidth * BoxSize, GridHeight * BoxSize))
for x in range(GridWidth):
for y in range(GridHeight):
self.DrawBox(y, x, board[x][y])
def ShowScore(self, score):
scoreSurf = BASICFONT.render('Score : %s' % score, True, TextColor)
scoreRect = scoreSurf.get_rect()
scoreRect.topleft = (WinWidth - 250, 20)
DISPLAYSURF.blit(scoreSurf, scoreRect)
def Terminal(self, piece1, piece2, piece1_old, piece2_old):
# print(piece1['x'],piece1['y'],piece2_old[0],piece2_old[1],'/',piece2['x'],piece2['y'],piece1_old[0],piece1_old[1])
# if self.steps == 1:# wrong!!!!
if piece1['x'] == piece2_old[0] and piece1['y'] == piece2_old[1] and piece2['x'] == piece1_old[0] and piece2['y'] == piece1_old[1]:
return 1
else:
return 2
def Paint(self, board, piece, color):
board[piece['x']][piece['y']] = color
piece['color'] = color
return board
================================================
FILE: examples/Social_Cognition/ToM/main_ToM.py
================================================
"""
Zoe Zhao 2022.5
ToM Demo
"""
import argparse
import copy
import numpy as np
import torch
np.set_printoptions(threshold=np.inf)
torch.set_printoptions(threshold=np.inf)
import matplotlib
import pygame
pygame.init()
matplotlib.rcParams.update({'font.size': 12})
import os
os.environ["SDL_VIDEODRIVER"] = "dummy"
from BrainArea.PFC_ToM import PFC_ToM
from BrainArea.TPJ import ToM
from BrainArea.dACC import *
from rulebasedpolicy.Find_a_way import *
from env.env import FalseBelief_env
from braincog.base.encoder.encoder import *
from braincog.base.node import node
#NPC2
#state
N_state = 6
cell_num = 6
# action
N_action = 5
NC=10 #50 cells represent one character
#synapstic
bfs = pow(cell_num, N_state) #before synapstic
afs = N_action * NC
#agent
C=10
A_state = 4
abfs = pow(cell_num, A_state) #agent before synapstic
aafs = N_action * C
parser = argparse.ArgumentParser(description='sequence character (policy inference)')
parser.add_argument('--mode', type=str, default='test')
parser.add_argument('--task', type=str, default='both')
parser.add_argument('--logdir', type=str, default='checkpoint')
parser.add_argument('--save_net_a', type=str, default='net_NPC_11.pth', help='save the parameters of net_agent')
parser.add_argument('--save_net_N', type=str, default='net_NPC_11.pth', help='save the parameters of net_NPC')
parser.add_argument('--device', default='cpu', help='device') # cuda:0
parser.add_argument('--T', default=40, type=int, help='simulating time-steps') # 模拟时长
parser.add_argument('--dt', default=1, type=int, help='simulating dt') # 模拟dt
parser.add_argument('--episodes', default=25, type=int, help='episodes')
parser.add_argument('--trajectories', default=10, type=int, help='trajectories')
parser.add_argument('--greedy', default=0.8, type=int, help='exploration or exploitation')
parser.add_argument('--num_enpop', default=6, type=int, help='the number of one population in the encoding layer') #
parser.add_argument('--num_depop', default=10, type=int, help='the number of one population in the decoding layer') #
parser.add_argument('--num_stateA', default=2, type=int, help='the number of states')
parser.add_argument('--num_stateN', default=6, type=int, help='the number of states')
parser.add_argument('--num_action', default=5, type=int, help='the number of actions')
parser.add_argument('--reward', default=10, type=float, help='environment parameter reward')
args = parser.parse_args()
def update(env, net_agent_belief, net_NPC, episodes, trajectories):
"""
agents learn to reach the goal without collision
update agents' positions
@param env:
@param env1:
@param net_agent_belief: the SNN network of agent
@param net_NPC: the SNN network of NPC
@param episodes: train times
@return: None
"""
for episode in tqdm(range(episodes)):
timer = 0
env.reset()
env.actu_obs()
scores = {
'agent_0': 0,
'NPC2_0' : 0,
'agent_1': 0,
'NPC2_1': 0,
}
Done_agent_0 = Done_agent_1 = False
Done_NPC2_0 = Done_NPC2_1 = False
action_agent = 3
action_NPC2 = 2
action_NPC1 = 1
action_agent1 = 4
# the start position are the same in two envs
# mapping_a = {'state': sum(env.agent['axis'], []),
# 'action': action_agent}
mapping_N = {'state': sum(env.NPC_2['axis'], []),
'action': action_NPC2}
while True and timer < trajectories:
timer = timer + 1
NPC_1_state, NPC_2_state, Agent_state \
= env.interact(action_NPC1, action_NPC2, action_agent)
env.SHOW()
# time.sleep(2)
# NPC_1 selects action by pp
if env.NPC_1['Done'] == False:
action_seq1 = Find_a_way(size=5, board=NPC_1_state, \
start_x=env.NPC_1['x'] - 1, \
start_y=4 - (env.NPC_1['y'] - 1), \
end_x=3, end_y=4 - 4)
action_NPC1 = list(env.action_move.keys())[ \
list(env.action_move.values()).index(
(action_seq1[1][0] - (action_seq1[0][0]), -action_seq1[1][1] + (action_seq1[0][1])))]
# agent selects action on purpose
# Agent_obs = sum(env.agent['axis'], [])
if env.agent['Done'] == False:
axis_new, axis_switch, obs_switch = ToM.TPJ(NPC_num=2, axis=env.agent['axis'], obs=env.agent['obs'], )
if axis_new == env.agent['axis']:
'''
没有遮挡关系 have teached
'''
action_agent = 3
else:
'''
有遮挡关系
'''
Agent_obs_NPC2 = sum(env.NPC_2['axis'], [])
action_agent = net_agent_belief(inputs=Agent_obs_NPC2,
num_action=args.num_action,
episode=episode)
prediction_next_state = ToM.prediction_state(axis_new, env.agent['axis'], action_NPC1, net_NPC,
num_action=args.num_action,
episode = episode)
if ToM.state_evaluation(prediction_next_state=prediction_next_state) == False:
print(False)
action_agent = ToM.altruism(axis_switch=axis_switch , axis_NPC=env.NPC_2['axis'], n_actions = env.n_actions)
env.trigger = 1
else:
action_agent = 3
# NPC_2 selects action by E-STDP
NPC2_obs = sum(env.NPC_2['axis'], [])
if Done_NPC2_0 == False:
if action_agent == 4 and env.agent['Done'] == False:
action_NPC2 = 4
else:
action_NPC2 = net_NPC(inputs=NPC2_obs, \
num_action=args.num_action, \
episode=episode)
state_NPC2 = copy.deepcopy(NPC2_obs)
Done_NPC2_0 = copy.deepcopy(env.NPC_2['Done'])
# mapping_N = {'state': state_NPC2, # at time t
# 'action': action_NPC2}
def train():
print('train mode loading ... ')
if not os.path.isdir(args.logdir):
os.mkdir(args.logdir)
bfs = pow(args.num_enpop, args.num_stateN) # before synapstic
afs = args.num_action * args.num_depop
#agent
# abfs = pow(args.num_enpop, args.num_stateA) # agent before synapstic
# aafs = args.num_action * args.num_depop
net_agent_belief = PFC_ToM(step=args.T, encode_type='rate', bias=True,
in_features=bfs, out_features=afs,
node=node.LIFNode, num_state=args.num_stateN,
greedy=args.greedy) #out_features the kinds of policies
net_agent_belief.to(args.device)
net_agent_belief.fc.weight.data = torch.rand((afs, bfs))
# net_agent_belief.load_state_dict(torch.load(os.path.join(args.logdir, args.save_net_N))['model'])
#NPC
net_NPC = PFC_ToM(step=args.T, encode_type='rate', bias=True,
in_features=bfs, out_features=afs,
node=node.LIFNode, num_state=args.num_stateN,
greedy=args.greedy) #out_features the kinds of policies
net_NPC.to(args.device)
net_NPC.load_state_dict(torch.load(os.path.join(args.logdir, args.save_net_N))['model'])
total_scores = update(env, net_agent_belief, net_NPC, args.episodes,\
args.trajectories)
# torch.save({'model': net_agent.state_dict()}, os.path.join(args.logdir, args.save_net_a))
torch.save({'model': net_NPC.state_dict()}, os.path.join(args.logdir, args.save_net_N))
time_end = time.time()
print('totally cost',time_end-time_start)
if __name__ == "__main__":
time_start = time.time()
env = FalseBelief_env(args.reward)
ToM = ToM(env=env)
# args.task = 'both'#'zero'
# args.mode = 'test'#'train'
# args.save_net_N = 'net_NPC_3.pth'
# args.save_net_a = 'net_agent_3.pth'
# args.greedy = 111
train()
================================================
FILE: examples/Social_Cognition/ToM/main_both.py
================================================
import argparse
import time
import copy
import numpy as np
import torch
np.set_printoptions(threshold=np.inf)
torch.set_printoptions(threshold=np.inf)
from tqdm import *
import matplotlib
import seaborn as sns
import pygame
pygame.init()
sns.set(style='ticks', palette='Set2')
matplotlib.rcParams.update({'font.size': 12})
import os
os.environ["SDL_VIDEODRIVER"] = "dummy"
from BrainArea.PFC_ToM import PFC_ToM
from rulebasedpolicy.Find_a_way import *
from env.env3_train_env00 import FalseBelief_env0 #3
from env.env3_train_env01 import FalseBelief_env1 #2
from braincog.base.encoder.encoder import *
from braincog.base.node import node
torch.manual_seed(1)
#NPC2
#state
N_state = 6
cell_num = 6
# action
N_action = 5
NC=10 #50 cells represent one character
#synapstic
bfs = pow(cell_num, N_state) #before synapstic
afs = N_action * NC
#agent
C=10
A_state = 4
abfs = pow(cell_num, A_state) #agent before synapstic
aafs = N_action * C
parser = argparse.ArgumentParser(description='sequence character (policy inference)')
parser.add_argument('--mode', type=str, default='train')
parser.add_argument('--task', type=str, default='both')
parser.add_argument('--logdir', type=str, default='checkpoint')
parser.add_argument('--save_net_a', type=str, default='net_agent_4.pth', help='save the parameters of net_agent')
parser.add_argument('--save_net_N', type=str, default='net_NPC_4.pth', help='save the parameters of net_NPC')
parser.add_argument('--device', default='cpu', help='device') # cuda:0
parser.add_argument('--T', default=40, type=int, help='simulating time-steps') # 模拟时长
parser.add_argument('--dt', default=1, type=int, help='simulating dt') # 模拟dt
parser.add_argument('--episodes', default=25, type=int, help='episodes')
parser.add_argument('--trajectories', default=10, type=int, help='trajectories')
parser.add_argument('--greedy', default=0.8, type=float, help='exploration or exploitation')
parser.add_argument('--num_enpop', default=6, type=int, help='the number of one population in the encoding layer') #
parser.add_argument('--num_depop', default=10, type=int, help='the number of one population in the decoding layer') #
parser.add_argument('--num_stateA', default=2, type=int, help='the number of states, (X, Y)')
parser.add_argument('--num_stateN', default=6, type=int, help='the number of states, [(X, Y), (X, Y), (X, Y)]')
parser.add_argument('--num_action', default=5, type=int, help='the number of actions')
parser.add_argument('--reward', default=10, type=float, help='environment parameter reward')
args = parser.parse_args()
def reward_plot(episodes, scores, Note):
fig = plt.figure(figsize=(7.5, 4.5))
ax1 = fig.add_subplot(111)
ax1.set_title('Reward Plot')
plt.xlim(1, episodes)
plt.grid(ls='--', c='gray')
plt.xlabel('Epoch')
plt.ylabel('Reward')
episodes_list = list(range(1,episodes+1))
plt.plot(episodes_list, scores['be observed agent without the ToM'], label='be observed agent without the ToM')
plt.legend()
plt.savefig('reward_plot_' + str(episodes) + '.png')
def update(env0, env1, net_agent, net_NPC, episodes, trajectories, task):
"""
agents learn to reach the goal without collision
update agents' positions
@param env0:
@param env1:
@param net_agent: the SNN network of agent
@param net_NPC: the SNN network of NPC
@param episodes: train times
@return: None
"""
scores_agent = []
scores_NPC2 = []
for episode in tqdm(range(episodes)):
timer0 = 0
timer1 = 0
env0.reset()
env1.reset()
env0.actu_obs()
env1.actu_obs()
scores = {
'agent_0': 0,
'NPC2_0' : 0,
'agent_1': 0,
'NPC2_1': 0,
}
Done_agent_0 = Done_agent_1 = False
Done_NPC2_0 = Done_NPC2_1 = False
action_agent = 3
action_NPC2 = 2
action_NPC1 = 4
action_agent1 = 4
# the start position are the same in two envs
mapping_a = {'state': sum(env0.agent['axis'], []),
'action': action_agent}
mapping_N = {'state': sum(env0.NPC_2['axis'], []),
'action': action_NPC2}
if task == 'both' or task == 'zero':
while True and timer0 < trajectories:
timer0 = timer0 + 1
NPC_1_state, NPC_2_state, Agent_state \
= env0.interact(action_NPC1, action_NPC2, action_agent)
env0.SHOW()
# time.sleep(2)
#NPC_1 selects action by pp
if env0.NPC_1['Done'] == False:
action_seq1 = Find_a_way(size=5, board=NPC_1_state,\
start_x=env0.NPC_1['x']-1,\
start_y=4-(env0.NPC_1['y']-1),\
end_x=3, end_y=4-4)
action_NPC1 = list(env0.action_move.keys())[\
list(env0.action_move.values()).index((action_seq1[1][0]-(action_seq1[0][0]), -action_seq1[1][1]+(action_seq1[0][1])))]
#agent selects action by E-STDP
Agent_obs = sum(env0.agent['axis'], [])
if Done_agent_0 == False:
action_agent = 3
# net_agent.update_s(R = env0.agent['reward'],\
# mapping=mapping_a)
# action_agent = net_agent(inputs = Agent_obs,\
# num_action = args.num_action,\
# episode = episode)
# state_agent = copy.deepcopy(Agent_obs)
# Done_agent_0 = copy.deepcopy(env0.agent['Done'])
# mapping_a = {'state': state_agent, # at time t
# 'action': action_agent}
#NPC_2 selects action by E-STDP
NPC2_obs = sum(env0.NPC_2['axis'], [])
if Done_NPC2_0 == False:
net_NPC.update_s(R = env0.NPC_2['reward'], \
mapping=mapping_N)
action_NPC2 = net_NPC(inputs = NPC2_obs,\
num_action = args.num_action,\
episode = episode)
state_NPC2 = copy.deepcopy(NPC2_obs)
Done_NPC2_0 = copy.deepcopy(env0.NPC_2['Done'])
mapping_N = {'state': state_NPC2, # at time t
'action': action_NPC2}
# continue
scores['agent_0'] += env0.agent['reward']
scores['NPC2_0'] += env0.NPC_2['reward']
if env0.NPC_1['Done'] == env0.NPC_2['Done'] == env0.agent['Done'] == True:
break
scores_agent.append(scores['agent_0'])
scores_NPC2.append(scores['NPC2_0'])
######################
if task == 'both' or task == 'one':
while True and timer1 < trajectories:
timer1 = timer1 + 1
NPC_2_state, Agent_state \
= env1.interact(action_NPC2, action_agent)
env1.SHOW()
# time.sleep(2)
# agent selects action by E-STDP
Agent_obs = sum(env1.agent['axis'], [])
if Done_agent_1 == False:
action_agent = 3
# net_agent.update_s(R=env1.agent['reward'], \
# mapping=mapping_a)
# scores['agent_1'] += env1.agent['reward']
# action_agent = net_agent(inputs=Agent_obs, \
# num_action=args.num_action,\
# episode = episode)
# state_agent = copy.deepcopy(Agent_obs)
# Done_agent_1 = copy.deepcopy(env1.agent['Done'])
# mapping_a = {'state': state_agent, # at time t
# 'action': action_agent}
# NPC_2 selects action by E-STDP
NPC2_obs = sum(env1.NPC_2['axis'], [])
if Done_NPC2_1 == False:
net_NPC.update_s(R=env1.NPC_2['reward'], \
mapping=mapping_N)
scores['NPC2_1'] += env1.NPC_2['reward']
action_NPC2 = net_NPC(inputs=NPC2_obs, \
num_action=args.num_action,\
episode = episode)
state_NPC2 = copy.deepcopy(NPC2_obs)
Done_NPC2_1 = copy.deepcopy(env1.NPC_2['Done'])
mapping_N = {'state': state_NPC2, # at time t
'action': action_NPC2}
scores['agent_1'] += env1.agent['reward']
scores['NPC2_1'] += env1.NPC_2['reward']
if env1.NPC_2['Done'] == env1.agent['Done'] == True:
break
scores_agent.append(scores['agent_1'])
scores_NPC2.append(scores['NPC2_1'])
total_scores = {
'the agent with the ToM': scores_agent,
'be observed agent without the ToM' : scores_NPC2
}
return total_scores
def train():
print('train mode loading ... ')
if not os.path.isdir(args.logdir):
os.mkdir(args.logdir)
#agent
abfs = pow(args.num_enpop, args.num_stateA) # agent before synapstic
aafs = args.num_action * args.num_depop
net_agent = PFC_ToM(step=args.T, encode_type='rate', bias=True,
in_features=abfs, out_features=aafs,
node=node.LIFNode, num_state=args.num_stateA,
greedy=args.greedy) #out_features the kinds of policies
net_agent.to(args.device)
net_agent.fc.weight.data = torch.rand((aafs, abfs))
# net_agent.load_state_dict(torch.load('./checkpoint/net_agent_12.pth')['model'])
#NPC
bfs = pow(args.num_enpop, args.num_stateN) # before synapstic
afs = args.num_action * args.num_depop
net_NPC = PFC_ToM(step=args.T, encode_type='rate', bias=True,
in_features=bfs, out_features=afs,
node=node.LIFNode, num_state=args.num_stateN,
greedy=args.greedy) #out_features the kinds of policies
net_NPC.to(args.device)
net_NPC.fc.weight.data = torch.rand((afs, bfs))
# net_NPC.load_state_dict(torch.load('./checkpoint/net_NPC_12.pth')['model'])
total_scores = update(env0, env1, net_agent, net_NPC, args.episodes,\
args.trajectories, args.task)
torch.save({'model': net_agent.state_dict()}, os.path.join(args.logdir, args.save_net_a))
torch.save({'model': net_NPC.state_dict()}, os.path.join(args.logdir, args.save_net_N))
time_end = time.time()
print('totally cost',time_end-time_start)
if args.task == 'zero' or args.task == 'one':
reward_plot(args.episodes, total_scores, 'Scores')
elif args.task == 'both':
reward_plot(args.episodes * 2, total_scores, 'Scores')
plt.show()
def test():
args.greedy = 1
print('test mode loading ... ')
print('greedy :', args.greedy)
#agent
abfs = pow(args.num_enpop, args.num_stateA) # agent before synapstic
aafs = args.num_action * args.num_depop
net_agent = PFC_ToM(step=args.T, encode_type='rate', bias=True,
in_features=abfs, out_features=aafs,
node=node.LIFNode, num_state=args.num_stateA,
greedy=args.greedy)
net_agent.to(args.device)
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
net_agent.load_state_dict(torch.load(os.path.join(args.logdir, args.save_net_a))['model']) #out_features the kinds of policies
#NPC
bfs = pow(args.num_enpop, args.num_stateN) # before synapstic
afs = args.num_action * args.num_depop
net_NPC = PFC_ToM(step=args.T, encode_type='rate', bias=True,
in_features=bfs, out_features=afs,
node=node.LIFNode, num_state=args.num_stateN,
greedy=args.greedy)
net_NPC.to(args.device)
net_NPC.load_state_dict(torch.load(os.path.join(args.logdir, args.save_net_N))['model']) #out_features the kinds of policies
total_scores = update(env0, env1, net_agent, net_NPC, args.episodes,
args.trajectories, args.task)
time_end = time.time()
print('totally cost',time_end-time_start)
if args.task == 'zero' or args.task == 'one':
reward_plot(args.episodes, total_scores, 'Scores')
elif args.task == 'both':
reward_plot(args.episodes * 2, total_scores, 'Scores')
plt.show()
if __name__=="__main__":
time_start = time.time()
env0 = FalseBelief_env0(args.reward)
env1 = FalseBelief_env1(args.reward)
# args.task = 'both'#'zero'
# args.mode = 'test'#'train'
# args.save_net_N = 'net_NPC_3.pth'
# args.save_net_a = 'net_agent_3.pth'
if args.mode == 'train':
train()
elif args.mode == 'test':
test()
================================================
FILE: examples/Social_Cognition/ToM/rulebasedpolicy/Find_a_way.py
================================================
# main.py
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from rulebasedpolicy.random_map import *
from rulebasedpolicy.a_star import *
from env.env3_train_env01 import FalseBelief_env1
def Find_a_way(size, board, start_x, start_y, end_x, end_y):
map = RandomMap(size=size, board=board)
for i in range(map.size):
for j in range(map.size):
if map.IsObstacle(i,j):
rec = Rectangle((i, j), width=1, height=1, color='gray')
else:
rec = Rectangle((i, j), width=1, height=1, edgecolor='gray', facecolor='w')
rec = Rectangle((start_x, start_y), width = 1, height = 1, facecolor='b')
rec = Rectangle((end_x, end_y), width = 1, height = 1, facecolor='r')
A_star = AStar(map)
action_seq = A_star.RunAndSaveImage( start_x, start_y, end_x, end_y)#ax, plt,
return action_seq
================================================
FILE: examples/Social_Cognition/ToM/rulebasedpolicy/__init__.py
================================================
================================================
FILE: examples/Social_Cognition/ToM/rulebasedpolicy/a_star.py
================================================
import sys
import time
import numpy as np
from matplotlib.patches import Rectangle
from rulebasedpolicy.point import *
from rulebasedpolicy.random_map import *
class AStar:
def __init__(self, map):
self.map=map
self.open_set = []
self.close_set = []
def BaseCost(self, p):
x_dis = p.x
y_dis = p.y
# Distance to start point
return x_dis + y_dis + (np.sqrt(2) - 2) * min(x_dis, y_dis)
def HeuristicCost(self, p):
x_dis = self.map.size - 1 - p.x
y_dis = self.map.size - 1 - p.y
# Distance to end point
return x_dis + y_dis + (np.sqrt(2) - 2) * min(x_dis, y_dis)
def TotalCost(self, p):
return self.BaseCost(p) + self.HeuristicCost(p)
def IsValidPoint(self, x, y):
if x < 0 or y < 0:
return False
if x >= self.map.size or y >= self.map.size:
return False
return not self.map.IsObstacle(x, y)
def IsInPointList(self, p, point_list):
for point in point_list:
if point.x == p.x and point.y == p.y:
return True
return False
def IsInOpenList(self, p):
return self.IsInPointList(p, self.open_set)
def IsInCloseList(self, p):
return self.IsInPointList(p, self.close_set)
def IsStartPoint(self, p, start_x, start_y):
return p.x == start_x and p.y ==start_y
def IsEndPoint(self, p, end_x, end_y):
return p.x == end_x and p.y == end_y###############
def SaveImage(self, plt):
millis = int(round(time.time() * 1000))
filename = './' + str(millis) + '.png'
plt.savefig(filename)
def ProcessPoint(self, x, y, parent):
if not self.IsValidPoint(x, y):
return # Do nothing for invalid point
p = Point(x, y)
if self.IsInCloseList(p):
return # Do nothing for visited point
# print('Process Point [', p.x, ',', p.y, ']', ', cost: ', p.cost)
if not self.IsInOpenList(p):
p.parent = parent
p.cost = self.TotalCost(p)
self.open_set.append(p)
def SelectPointInOpenList(self):
index = 0
selected_index = -1
min_cost = sys.maxsize
for p in self.open_set:
cost = self.TotalCost(p)
if cost < min_cost:
min_cost = cost
selected_index = index
index += 1
return selected_index
def BuildPath(self, p, start_time, start_x, start_y, end_x, end_y):#ax, plt,
path = []
record = []
while True:
path.insert(0, p) # Insert first
if self.IsStartPoint(p, start_x, start_y):
break
else:
p = p.parent
p_x=start_x
p_y=start_y
for p in path:
if abs(p.x-p_x) == abs(p.y-p_y) == 1:
# rec = Rectangle((p_x, p.y), 1, 1, color='g')
# rec = Rectangle((p.x, p.y), 1, 1, color='g')
# ax.add_patch(rec)
# plt.draw()
# self.SaveImage(plt)
if abs(end_x - start_x) >= abs(end_y - start_y):
record.append((p.x, p_y))
record.append((p.x, p.y))
else:
record.append((p_x, p.y))
record.append((p.x, p.y))
else:
rec = Rectangle((p.x, p.y), 1, 1, color='g')
# ax.add_patch(rec)
# plt.draw()
# self.SaveImage(plt)
record.append((p.x, p.y))
p_x = p.x
p_y = p.y
end_time = time.time()
# print('===== Algorithm finish in', int(end_time-start_time), ' seconds')
return record
def RunAndSaveImage(self, start_x, start_y, end_x, end_y):#ax, plt,
start_time = time.time()
start_point = Point(start_x, start_y)############################
start_point.cost = 0
self.open_set.append(start_point)
while True:
index = self.SelectPointInOpenList()
if index < 0:
print('No path found, algorithm failed!!!')
# self.SaveImage(plt)
return
p = self.open_set[index]
# rec = Rectangle((p.x, p.y), 1, 1, color='c')
# ax.add_patch(rec)
# self.SaveImage(plt)
if self.IsEndPoint(p, end_x, end_y):
return self.BuildPath(p, start_time, start_x, start_y, end_x, end_y)#ax, plt,
del self.open_set[index]
self.close_set.append(p)
# Process all neighbors
x = p.x
y = p.y
self.ProcessPoint(x-1, y+1, p)
self.ProcessPoint(x-1, y, p)
self.ProcessPoint(x-1, y-1, p)
self.ProcessPoint(x, y-1, p)
self.ProcessPoint(x+1, y-1, p)
self.ProcessPoint(x+1, y, p)
self.ProcessPoint(x+1, y+1, p)
self.ProcessPoint(x, y+1, p)
================================================
FILE: examples/Social_Cognition/ToM/rulebasedpolicy/load_statedata.py
================================================
import random
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
# from torch.autograd import Variable
class StateDataset:
# initial
def __init__(self, mode, num):
self.state = np.loadtxt(mode, dtype=np.int)
self.num = 1
self.state = self.state.reshape(num, 5*self.num , -1)
#data:A label:B
def __getitem__(self, item):
state = self.state[item]
state_A = state[:,0:5*self.num]
state_A = np.expand_dims(state_A, axis=0)
state_B = state[:, 5*self.num:10*self.num]
state_B = np.expand_dims(state_B, axis=0)
return {"A":state_A, "B":state_B}
#the number of data
def __len__(self):
return len(self.state)
# def get_dataloader(self):
def get_dataloader(mode, num, batch):
train_dataset = StateDataset(mode, num)
train_loader = DataLoader(train_dataset, batch, shuffle=True)
return train_loader
================================================
FILE: examples/Social_Cognition/ToM/rulebasedpolicy/point.py
================================================
import sys
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
self.cost = sys.maxsize
================================================
FILE: examples/Social_Cognition/ToM/rulebasedpolicy/random_map.py
================================================
import numpy as np
from rulebasedpolicy.point import *
class RandomMap:
def __init__(self, size, board):
self.size = size
self.board = board
self.obstacle = size//8
self.GenerateObstacle()
def GenerateObstacle(self):
self.obstacle_point = []
# Generate an obstacle in the middle
for i in range(self.size):
for j in range(self.size):
if self.board[i,j] == 5:
self.obstacle_point.append(Point(j, 4-i))
def IsObstacle(self, i ,j):
for p in self.obstacle_point:
if i==p.x and j==p.y:
return True
return False
================================================
FILE: examples/Social_Cognition/ToM/rulebasedpolicy/statedata_pre.py
================================================
import numpy as np
def data_transfer(B_txt, A_txt):
"""
Aim:读取训练数据,并将其转换为可以处理的形式
@param B_txt:Before processing -txt
@param A_txt:After processing -txt
@return:After processing -data
"""
with open(B_txt, 'r') as f: #'dataA_B.txt'
data_all =[]
data_1 = []
data_2 = []
data_3 = []
data = f.read() #Read all the data in txt ...str
data_split = data.split('\n\n') #Divide the data with '\n\n'
for i in range(len(data_split)-1): #There are (len(data_split)-1) sets of valid data
data_split[i] = data_split[i].split('\n') #Remove '\n' from each set of data
for j in range(len(data_split[i])):
# Split number
data_split[i][j] = " ".join(data_split[i][j])
data_split[i][j] = data_split[i][j].split(' ')
data_split[i][j] = list(map(int, data_split[i][j])) #str-int
data_split[i] = np.array(data_split[i]) #list-np.array
# Data expansion
data_all.append(data_split[i])
data_1.append(np.flipud(data_split[i])) #上下对称
# data_2_split = data_split[i][:, [5, 6, 7, 8, 9, 0, 1, 2, 3, 4]]
# data_2.append(np.fliplr(data_2_split)) #左右对称
# data_3_split = data_split[i][:, [5, 6, 7, 8, 9, 0, 1, 2, 3, 4]]
# data_3.append(np.fliplr(data_3_split))
data_all.extend(data_1)
# data_all.extend(data_2)
# data_all.extend(data_3)
data_all = np.array(data_all)
data_all = data_all.reshape(data_all.shape[0]*data_all.shape[1], data_all.shape[2])
data_all = data_all.astype(int)
# new_data = np.repeat(data_all, repeats=num, axis=0)
# new_data = np.repeat(new_data, repeats=num, axis=1)
np.savetxt(A_txt, data_all, fmt='%i') #'train.txt'
# Read TXT data into numpy
state = np.loadtxt(A_txt, dtype = np.int)
print(data_all.shape)
return state
================================================
FILE: examples/Social_Cognition/ToM/rulebasedpolicy/train.txt
================================================
8 5 1 1 1 8 5 0 0 0
1 1 1 1 1 1 0 0 0 0
1 1 1 1 1 1 1 0 0 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 1 0
1 8 5 1 1 1 8 5 0 0
1 1 1 1 1 1 1 0 0 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 1
1 1 8 5 1 1 1 8 5 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 8 5 1 1 1 8 5
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 0
1 1 1 8 5 1 1 1 8 5
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 0 0
1 1 8 5 1 1 1 8 5 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 0 0 0
1 8 5 1 1 1 8 5 0 0
1 1 1 1 1 1 1 0 0 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 0 0 0 0
8 5 1 1 1 8 5 0 0 0
1 1 1 1 1 1 0 0 0 0
1 1 1 1 1 1 1 0 0 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 0
1 1 1 8 5 1 1 1 8 5
1 1 1 1 1 1 1 1 1 0
8 1 5 1 1 8 1 5 0 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 8 1 5 1 1 8 1 5 0
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 0 0
8 1 5 1 1 8 1 5 0 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 0
1 8 1 5 1 1 8 1 5 0
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
8 1 1 5 1 8 1 1 5 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
8 1 1 5 1 8 1 1 5 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
8 1 1 1 1 8 1 1 1 1
1 5 1 1 1 1 5 0 0 1
1 1 1 1 1 1 0 0 0 0
1 1 1 1 1 1 0 0 0 0
1 1 1 1 1 1 1 0 0 0
1 8 1 1 1 1 8 1 1 1
1 1 5 1 1 1 1 5 0 0
1 1 1 1 1 1 1 0 0 0
1 1 1 1 1 1 1 0 0 0
1 1 1 1 1 1 1 1 0 0
1 1 8 1 1 1 1 8 1 1
1 1 1 5 1 1 1 1 5 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 1 0
1 1 1 8 1 1 1 1 8 1
1 1 1 1 5 1 1 1 1 5
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
8 1 1 1 1 8 1 1 1 1
1 5 1 1 1 1 5 0 0 1
1 1 1 1 1 1 0 0 0 0
1 1 1 1 1 1 0 0 0 0
1 1 1 1 1 1 1 1 1 1
1 8 1 1 1 1 8 1 1 1
1 1 5 1 1 1 1 5 0 0
1 1 1 1 1 1 1 0 0 0
1 1 1 1 1 1 1 0 0 0
1 1 1 1 1 1 1 1 1 1
1 1 8 1 1 1 1 8 1 1
1 1 1 5 1 1 1 1 5 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 1 1
1 1 1 8 1 1 1 1 8 1
1 1 1 1 5 1 1 1 1 5
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 0
8 1 1 1 1 8 1 1 1 1
1 1 5 1 1 1 1 5 0 0
1 1 1 1 1 1 1 0 0 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 1 0
1 8 1 1 1 1 8 1 1 1
1 1 1 5 1 1 1 1 5 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 1
1 1 8 1 1 1 1 8 1 1
1 1 1 1 5 1 1 1 1 5
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
8 1 1 1 1 8 1 1 1 1
1 1 5 1 1 1 1 5 0 0
1 1 1 1 1 1 1 0 0 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 1 1
1 8 1 1 1 1 8 1 1 1
1 1 1 5 1 1 1 1 5 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 1
1 1 8 1 1 1 1 8 1 1
1 1 1 1 5 1 1 1 1 5
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 1
8 1 1 1 1 8 1 1 1 1
1 1 1 5 1 1 1 1 5 0
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 0 0 0
1 1 1 1 1 1 0 0 0 0
8 5 1 1 1 8 5 0 0 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 0 0 0
1 8 5 1 1 1 8 5 0 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 0 0
1 1 8 5 1 1 1 8 5 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 0
1 1 1 8 5 1 1 1 8 5
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 0
1 1 1 8 5 1 1 1 8 5
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 0 0
1 1 8 5 1 1 1 8 5 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 0 0 0
1 8 5 1 1 1 8 5 0 0
1 1 1 1 1 1 1 0 0 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 0 0 0
1 1 1 1 1 1 0 0 0 0
8 5 1 1 1 8 5 0 0 0
1 1 1 1 1 1 0 0 0 0
1 1 1 1 1 1 1 1 1 0
1 1 1 8 5 1 1 1 8 5
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 0 0
8 1 5 1 1 8 1 5 0 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 0
1 8 1 5 1 1 8 1 5 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 0 0
8 1 5 1 1 8 1 5 0 0
1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 0
1 8 1 5 1 1 8 1 5 0
1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
8 1 1 5 1 8 1 1 5 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
8 1 1 5 1 8 1 1 5 0
1 1 1 1 1 1 1 1 1 1
================================================
FILE: examples/Social_Cognition/ToM/rulebasedpolicy/world_model.py
================================================
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from rulebasedpolicy.load_statedata import *
import math
np.set_printoptions(threshold = np.inf)
def data():
batch_size = 45
# read data
txt = os.path.join(sys.path[0],'rulebasedpolicy', 'train.txt')
train_loader=get_dataloader(mode=txt, num=batch_size ,batch=batch_size)
for data in train_loader:
A = data["A"].numpy()
B = data["B"].numpy()
B = B.reshape(batch_size, -1, 5, 5)
A = A.reshape(batch_size, 5, 5)
# distant between A and Wall(B)计算智能体与墙之间的距离
A_train = np.sum(np.square(np.argwhere(A==8)-np.argwhere(A==5)), axis = 1)
dist_AW = 1 #指定一个特定的距离1,2,4,5,9,10 distant between agent and wall############
o_idx = np.argwhere(A_train == dist_AW) #找到固定距离对应的所有矩阵Find all matrices corresponding to a fixed distance
return B, A_train
def flip180(arr):
"""
翻转180度
@param arr:
@return:
"""
new_arr = arr.reshape(arr.size)
new_arr = new_arr[::-1]
new_arr = new_arr.reshape(arr.shape)
return new_arr
def flip90_left(arr):
"""
向左翻转90度逆时针
@param arr:
@return:
"""
new_arr = np.transpose(arr)
new_arr = new_arr[::-1]
return new_arr
def flip90_right(arr):
"""
向右翻转90度顺时针
@param arr:
@return:
"""
new_arr = arr.reshape(arr.size)
new_arr = new_arr[::-1]
new_arr = new_arr.reshape(arr.shape)
new_arr = np.transpose(new_arr)[::-1]
return new_arr
def gain_env(obs, agent, wall):
"""
Aim:可以根据多个部分观察在一起拼成一个大的环境
@param obs:多组观测
@param agent:代表智能体的参数,这里默认用8来表示
@param wall:代表墙的参数,这里默认用5来表示
@return:拼成的环境
obs_random:随便选择一个矩阵作为based环境
obs[i]:其他用来补全的矩阵
x_a, y_a:based环境-智能体坐标
x_w, y_w:based环境-wall坐标
x_t, y_t:其他环境-智能体坐标
x_tt, y_tt:其他环境-wall坐标
"""
obs_random = obs[0]
for i in range(1, obs.shape[0]):
#based env
x_a, y_a = np.argwhere(obs_random == agent)[0]
x_w, y_w = np.argwhere(obs_random == wall)[0]
#external env
x_t, y_t = np.argwhere(obs[i] == agent)[0]
x_tt, y_tt = np.argwhere(obs[i] == wall)[0]
h, l = obs_random.shape
delta_up = max(x_t - x_a,0)
delta_down = max(5-x_tt - (h-x_w),0)
delta_left = max(y_t - y_a,0)
delta_right = max(5-y_tt - (l-y_w),0)
obs_random = np.r_[np.ones((delta_up, l)), obs_random] if delta_up != 0 else obs_random
h, l = obs_random.shape
obs_random = np.r_[obs_random, np.ones((delta_down, l))] if delta_down != 0 else obs_random
h, l = obs_random.shape
obs_random = np.c_[np.ones((h, delta_left)), obs_random] if delta_left != 0 else obs_random
h, l = obs_random.shape
obs_random = np.c_[obs_random, np.ones((h, delta_right))] if delta_right != 0 else obs_random
obs_random = obs_random.astype(np.int)
#based env
x_a, y_a = np.argwhere(obs_random == agent)[0]
up = x_a - x_t
left = y_a - y_t
obs_random[up:up+5, left:left+5] = obs_random[up:up+5, left:left+5] & obs[i]
return obs_random
def shelter_env(obs):
"""
Aim:用gain_env环境中的图,来描述更复杂的环境的遮挡关系
@param obs:复杂的环境
@return:环境的遮挡关系
"""
# print(obs,'----------------')
position_A = np.argwhere(obs==8)
position_W = np.argwhere(obs==5)
# print(position_W)
position = np.sum(np.square(np.argwhere(obs == 8) - np.argwhere(obs == 5)), axis=1) #numpy (walls,)
# print(position_W, position_A,position)
shelter_env_i = np.ones((5,5)).astype(np.int)
B, A_train = data()
for i in range(position.size):
o_idx = np.argwhere(A_train == position[i])
if o_idx.size == 0:
break
else:
model = gain_env(B[o_idx].reshape(-1, 5, 5), 8 ,5).astype(np.int)
# print(model,'=============')
if (position_A[0,0] > position_W[i,0] and position_A[0,1] < position_W[i,1] and \
position_A[0, 0] - position_W[i, 0] < -position_A[0, 1] + position_W[i, 1])\
or\
(position_A[0, 0] < position_W[i, 0] and position_A[0, 1] < position_W[i, 1] and \
-position_A[0, 0] + position_W[i, 0] > -position_A[0, 1] + position_W[i, 1])\
or\
(position_A[0, 0] > position_W[i, 0] and position_A[0, 1] > position_W[i, 1] and \
position_A[0, 0] - position_W[i, 0] > position_A[0, 1] - position_W[i, 1])\
or\
(position_A[0, 0] < position_W[i, 0] and position_A[0, 1] > position_W[i, 1] and \
-position_A[0, 0] + position_W[i, 0] < position_A[0, 1] - position_W[i, 1]):
model = np.flip(model, 0)
# print(model, '-=-=-=-=-=-=-====')
model = flip90_right(model)
# print(model,'-=-=-=-=-=-=-====')
if position_A[0, 0] >= position_W[i, 0] and position_A[0, 1] > position_W[i, 1]:
model = flip180(model)
elif position_A[0, 0] > position_W[i, 0] and position_A[0, 1] <= position_W[i, 1]:
model = flip90_left(model)
elif position_A[0, 0] < position_W[i, 0] and position_A[0, 1] >= position_W[i, 1]:
model = flip90_right(model)
else:
model = model
x_t, y_t = np.argwhere(model == 8)[0]
if y_t=3200:
error=3200
if error>0:
pain=1
if error==0:
pain=0
if pain==0:
X1=torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]])
X2=torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
env.render()
if pain==1:
set_pain = 1
X1=torch.tensor([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])
X2=torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
env.render()
for i in range(T):
if i>=2:
X2=X1
OUTPUT = a(X1,X2)
a.UpdateWeight(2,OUTPUT[0][1],0.01)
a.UpdateWeight(5,OUTPUT[1][1],-0.1)
if OUTPUT[2][0][0]==1:
env.canvas.itemconfig(env.rect, fill="red", outline='red')
if OUTPUT[2][0][40]==1:
env.canvas.itemconfig(env.rect, fill="green", outline='green')
env.render()
print('out_ifg:',OUTPUT[2])
print('out_sma:',OUTPUT[3])
print('out_m1:',OUTPUT[4])
# print('con2:',a.connection[2].weight.data)
# print('con5:',a.connection[5].weight.data)
s = s_
if set_pain==1 and pain==0:
env.render()
env.destroy()
def BAESNN_test():
a.reset()
s1,s=env2.reset()
pain=0
pain1 = 0
i=0
set_pain=0
for i in range(1000):
env2.render()
s_now = env2.canvas.coords(env2.agent1)
action1 = np.random.choice([0,1,2,3], p=[0.2, 0.3, 0.3, 0.2])
if env2.open_door==1 and s_now[0] <(9 / 2) * 40:
action1 = np.random.choice([0,1,2,3], p=[0.5, 0.0, 0.0, 0.5])
s1_, s1_pre,s1_color = env2.step1(action1,pain)
print('s1_color:',s1_color)
if env2.open_door == 1 :
env2.render()
true_s1_1 = np.array(s1_)
predict_s1_1=np.array(s1_pre)
error1 = true_s1_1 - predict_s1_1
error1 = sum([c * c for c in error1])
if error1>=3200:
error1=3200
if error1>0:
pain=1
set_pain=1
if error1==0:
pain=0
env2.generate_expression1(pain)
if s1_color=="red":
X3=torch.tensor([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])
if s1_color=="blue":
X3=torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]])
a.reset()
for i in range(20):
OUT=a.empathy(X3)
print(OUT)
if pain==1:
env2.agent_help()
s1 = s1_
env2.render()
if pain==0 and set_pain==1:
env2.render()
break
# env2.destroy()
if __name__ == "__main__":
env = Maze()
a = BAESNN()
BAESNN_train()
env.mainloop()
env2 = Maze2()
BAESNN_test()
env2.mainloop()
================================================
FILE: examples/Social_Cognition/affective_empathy/BAE-SNN/README.md
================================================
# Requirments
* numpy
* scipy
* pytorch >= 1.7.0
* torchvision
# Run
## Train
* the file to be run: BAESNN.py
# Citation
```
@ARTICLE{Hui2022,
AUTHOR={Feng, Hui and Zeng, Yi and Lu, Enmeng},
TITLE={Brain-Inspired Affective Empathy Computational Model and Its Application on Altruistic Rescue Task},
JOURNAL={Frontiers in Computational Neuroscience},
VOLUME={16},
YEAR={2022},
URL={https://www.frontiersin.org/articles/10.3389/fncom.2022.784967},
DOI={10.3389/fncom.2022.784967},
ISSN={1662-5188}
}
```
================================================
FILE: examples/Social_Cognition/affective_empathy/BAE-SNN/env_poly.py
================================================
import numpy as np
np.random.seed(1)
import tkinter as tk
import time
from PIL import ImageGrab
UNIT = 40 # pixels
MAZE_H = 9 # grid height
MAZE_W = 4 # grid width
class Maze(tk.Tk, object):
def __init__(self):
super(Maze, self).__init__()
self.action_space = ['u', 'd', 'l', 'r']
self.n_actions = len(self.action_space)
self.title('self-pain')
self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_W * UNIT))
self._build_maze()
self.danger=0
self.action_hurt=0
self.sensory_hurt = 0
self.open_door = 0
self.pain_state=0
# create environment
def _build_maze(self):
self.canvas = tk.Canvas(self, bg='white',
height=MAZE_W * UNIT,
width=MAZE_H * UNIT)
# create grids
for c in range(0, MAZE_H * UNIT, UNIT):
x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
self.canvas.create_line(x0, y0, x1, y1)
for r in range(0, MAZE_H * UNIT, UNIT):
x0, y0, x1, y1 = 0, r, MAZE_H * UNIT, r
self.canvas.create_line(x0, y0, x1, y1)
# create agent
self.orgin=[20,20]
# 上
self.points0 = [
# 右下
self.orgin[0]+15,#35
self.orgin[1]+15,#35
# 左下
self.orgin[0]-15,#5
self.orgin[1]+15,#35
# 左上+
self.orgin[0]-15,#5
self.orgin[1],#20
# 顶点
self.orgin[0],#20
self.orgin[1]-15,#5
# 右上+
self.orgin[0]+15,#35
self.orgin[1],#20
]
# self.rect0 = self.canvas.create_polygon(self.points0, fill="green")
# self.agent_action0 = self.canvas.coords(self.rect0)
# 下
self.points1 = [
# 左上
self.orgin[0]-15,#5
self.orgin[1]-15,#5
# 右上
self.orgin[0]+15,#35
self.orgin[1]-15,#5
# 右下+
self.orgin[0]+15,#35
self.orgin[1],#20
# 顶点
self.orgin[0],#20
self.orgin[1]+15,#35
# 左下+
self.orgin[0]-15,#5
self.orgin[1],#20
]
self.rect = self.canvas.create_polygon(self.points1, fill="green")
# self.agent_action1 = self.canvas.coords(self.rect1)
# 右
self.points2 = [
# 左下
self.orgin[0]-15,#5
self.orgin[1]+15,#35
# 左上
self.orgin[0]-15,#5
self.orgin[1]-15,#5
# 右上+
self.orgin[0],#20
self.orgin[1]-15,#5
# 顶点
self.orgin[0]+15,#35
self.orgin[1],#20
# 右下+
self.orgin[0],#20
self.orgin[1]+15,#35
]
# self.rect2 = self.canvas.create_polygon(self.points2, fill="green")
# self.agent_action2 = self.canvas.coords(self.rect2)
# 左
self.points3 = [
# 右上
self.orgin[0]+15,#20+15
self.orgin[1]-15,#20-15
# 右下
self.orgin[0]+15,#20+15
self.orgin[1]+15,#20+15
# 左下+
self.orgin[0],#20
self.orgin[1]+15,#20+15
# 顶点
self.orgin[0]-15,#20-15
self.orgin[1],#20
# 左上+
self.orgin[0],#20
self.orgin[1]-15,#20-15
]
# self.rect3 = self.canvas.create_polygon(self.points3, fill="green")
# self.agent_action3 = self.canvas.coords(self.rect3)
self.canvas.pack()
#reset agent location
def reset(self):
self.open_door = 0
self.update()
time.sleep(0.5)
self.canvas.delete(self.rect)
self.orgin = [20, 20]
# 下
self.points1 = [
# 左上
self.orgin[0] - 15, # 5
self.orgin[1] - 15, # 5
# 右上
self.orgin[0] + 15, # 35
self.orgin[1] - 15, # 5
# 右下+
self.orgin[0] + 15, # 35
self.orgin[1], # 20
# 顶点
self.orgin[0], # 20
self.orgin[1] + 15, # 35
# 左下+
self.orgin[0] - 15, # 5
self.orgin[1], # 20
]
self.rect = self.canvas.create_polygon(self.points1, fill="green")
# self.agent_action1 = self.canvas.coords(self.rect1)
return self.canvas.coords(self.rect)
def step(self, s, action, pain):
s = self.canvas.coords(self.rect)
self.centre = [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]
# danger or switch
if self.danger==1:
if all(self.centre == self.oval_center):
s_color = 'yellow'
self.canvas.delete(self.wall[3])
self.render()
# self.getter(self.canvas)
self.render()
# self.getter(self.canvas)#figure8 ,figure3.1 all red changed to green
self.open_door = 1
move = np.array([80, 0])
self.canvas.move(self.rect, move[0], move[1])
s = self.canvas.coords(self.rect)
self.render()
# self.getter(self.canvas)
elif all(self.centre == self.hell1_center):
s_color = 'black'
self.action_hurt = 1
self.render()
# self.getter(self.canvas)#figure4
self.render()
else:
s_color = 'white'
# modify current state
self.canvas.delete(self.rect)# 主要为开关那几步考虑,所以重复写了
self.centre = [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]
if action==0:
self.points0 = [
# 右下
self.centre[0] + 15, # 35
self.centre[1] + 15, # 35
# 左下
self.centre[0] - 15, # 5
self.centre[1] + 15, # 35
# 左上+
self.centre[0] - 15, # 5
self.centre[1], # 20
# 顶点
self.centre[0], # 20
self.centre[1] - 15, # 5
# 右上+
self.centre[0] + 15, # 35
self.centre[1], # 20
]
if pain==0:
color="green"
if pain == 1:
color = "red"
self.rect = self.canvas.create_polygon(self.points0, fill=color)
if action==1:
self.points1 = [
# 左上
self.centre[0] - 15, # 5
self.centre[1] - 15, # 5
# 右上
self.centre[0] + 15, # 35
self.centre[1] - 15, # 5
# 右下+
self.centre[0] + 15, # 35
self.centre[1], # 20
# 顶点
self.centre[0], # 20
self.centre[1] + 15, # 35
# 左下+
self.centre[0] - 15, # 5
self.centre[1], # 20
]
if pain==0:
color="green"
if pain == 1:
color = "red"
self.rect = self.canvas.create_polygon(self.points1, fill=color)
if action==2:
self.points2 = [
# 左下
self.centre[0] - 15, # 5
self.centre[1] + 15, # 35
# 左上
self.centre[0] - 15, # 5
self.centre[1] - 15, # 5
# 右上+
self.centre[0], # 20
self.centre[1] - 15, # 5
# 顶点
self.centre[0] + 15, # 35
self.centre[1], # 20
# 右下+
self.centre[0], # 20
self.centre[1] + 15, # 35
]
if pain==0:
color="green"
if pain == 1:
color = "red"
self.rect = self.canvas.create_polygon(self.points2, fill=color)
if action==3:
self.points3 = [
# 右上
self.centre[0] + 15, # 20+15
self.centre[1] - 15, # 20-15
# 右下
self.centre[0] + 15, # 20+15
self.centre[1] + 15, # 20+15
# 左下+
self.centre[0], # 20
self.centre[1] + 15, # 20+15
# 顶点
self.centre[0] - 15, # 20-15
self.centre[1], # 20
# 左上+
self.centre[0], # 20
self.centre[1] - 15, # 20-15
]
if pain==0:
color="green"
if pain == 1:
color = "red"
self.rect = self.canvas.create_polygon(self.points3, fill=color)
s = self.canvas.coords(self.rect)
self.render()#显示当前的动作指令是什么
# self.getter(self.canvas)#figure5 after figure4
if s[0] > (9 / 2) * 40:
self.action_hurt = 0
# ensure ture action
base_action = np.array([0, 0])
if self.action_hurt == 0:
true_action = action
else:
if action == 0:
true_action = 1
if action == 1:
true_action = 0
if action == 2:
true_action = 3
if action == 3:
true_action = 2
# predict next state
b = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
if self.centre[0] <= ((MAZE_H - 1) / 2 +1) * UNIT:#120
if action == 0: # up
if self.centre[1] > UNIT:
b = [0, -40, 0, -40,0, -40, 0, -40,0, -40]
elif action == 1: # down
if self.centre[1] < (MAZE_W - 1) * UNIT:
b = [0, 40, 0, 40,0, 40, 0, 40, 0, 40]
elif action == 2: # right
if self.centre[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:
b = [40, 0, 40, 0,40, 0, 40, 0,40, 0]
elif action == 3: # left
if self.centre[0] > UNIT:
b = [-40, 0, -40, 0,-40, 0, -40, 0,-40, 0]
else:
if action == 0: # up
if self.centre[1] > UNIT:
b = [0, -40, 0, -40,0, -40, 0, -40,0, -40]
elif action == 1: # down
if self.centre[1] < (MAZE_W - 1) * UNIT:
b = [0, 40, 0, 40,0, 40, 0, 40, 0, 40]
elif action == 2: # right
if self.centre[0] < (MAZE_H - 1) * UNIT:
b = [40, 0, 40, 0,40, 0, 40, 0,40, 0]
elif action == 3: # left
if self.centre[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:
b = [-40, 0, -40, 0,-40, 0, -40, 0,-40, 0]
s_predict = []
for i in range(len(b)):
s_predict1 = s[i] + b[i]
s_predict.append(s_predict1)
# true next state
if self.centre[0]<=((MAZE_H - 1) / 2 +1) * UNIT:
if true_action == 0: # up
if self.centre[1] > UNIT:
base_action[1] -= UNIT
elif true_action == 1: # down
if self.centre[1] < (MAZE_W - 1) * UNIT:
base_action[1] += UNIT
elif true_action == 2: # right
if self.centre[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:
base_action[0] += UNIT
elif true_action == 3: # left
if self.centre[0] > UNIT:
base_action[0] -= UNIT
else:
if true_action == 0: # up
if self.centre[1] > UNIT:
base_action[1] -= UNIT
elif true_action == 1: # down
if self.centre[1] < (MAZE_W - 1) * UNIT:
base_action[1] += UNIT
elif true_action == 2: # right
if self.centre[0] < (MAZE_H - 1) * UNIT:
base_action[0] += UNIT
elif true_action == 3: # left
if self.centre[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:
base_action[0] -= UNIT
self.canvas.move(self.rect, base_action[0], base_action[1])
s_ = self.canvas.coords(self.rect)
return s_, s_predict, s_color
def step_RL1(self, action):
s = self.canvas.coords(self.rect)
base_action = np.array([0, 0])
if s[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:
if action == 0: # up
if s[1] > UNIT:
base_action[1] -= UNIT
elif action == 1: # down
if s[1] < (MAZE_W - 1) * UNIT:
base_action[1] += UNIT
elif action == 2: # right
if s[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:
base_action[0] += UNIT
elif action == 3: # left
if s[0] > UNIT:
base_action[0] -= UNIT
else:
if action == 0: # up
if s[1] > UNIT:
base_action[1] -= UNIT
elif action == 1: # down
if s[1] < (MAZE_W - 1) * UNIT:
base_action[1] += UNIT
elif action == 2: # right
if s[0] < (MAZE_H - 1) * UNIT:
base_action[0] += UNIT
elif action == 3: # left
if s[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:
base_action[0] -= UNIT
self.canvas.move(self.rect, base_action[0], base_action[1]) # move agent
s_ = self.canvas.coords(self.rect) # next state
if s_==self.canvas.coords(self.hell1):
self.canvas.itemconfig(self.rect, fill="red", outline='red')
reward = -1
self.pain_state=1
else:
reward = 0
return s_, reward,self.pain_state
def step_RL2(self, action):
s = self.canvas.coords(self.rect)
if s == self.canvas.coords(self.oval):
self.canvas.delete(self.wall[3])
self.open_door = 1
move = np.array([40, 0])
self.canvas.move(self.rect, move[0], move[1])
move = np.array([40, 0])
self.canvas.move(self.rect, move[0], move[1])
self.render()
self.canvas.itemconfig(self.rect, fill="green", outline='green')
self.render()
base_action = np.array([0, 0])
if s[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:
if action == 0: # up
if s[1] > UNIT:
base_action[1] -= UNIT
elif action == 1: # down
if s[1] < (MAZE_W - 1) * UNIT:
base_action[1] += UNIT
elif action == 2: # right
if s[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:
base_action[0] += UNIT
elif action == 3: # left
if s[0] > UNIT:
base_action[0] -= UNIT
else:
if action == 0: # up
if s[1] > UNIT:
base_action[1] -= UNIT
elif action == 1: # down
if s[1] < (MAZE_W - 1) * UNIT:
base_action[1] += UNIT
elif action == 2: # right
if s[0] < (MAZE_H - 1) * UNIT:
base_action[0] += UNIT
elif action == 3: # left
if s[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:
base_action[0] -= UNIT
self.canvas.move(self.rect, base_action[0], base_action[1]) # move agent
s_ = self.canvas.coords(self.rect) # next state
if s_ == self.canvas.coords(self.oval):
self.open_door = 1
if self.pain_state == 0:
reward = 0
if self.pain_state == 1:
reward = 1
self.pain_state = 0
elif s_ == self.canvas.coords(self.hell1):
self.canvas.itemconfig(self.rect, fill="red", outline='red')
reward = -1
self.pain_state = 1
self.render()
else:
reward = 0
return s_, reward, self.pain_state
def _set_danger(self):
self.hell1_center = np.array([60, 60])
self.hell1 = self.canvas.create_oval(
self.hell1_center[0] - 15, self.hell1_center[1] - 15,
self.hell1_center[0] + 15, self.hell1_center[1] + 15,
fill='black')
# self.canvas.create_bitmap((40 , 40), bitmap='error')
self.hell = self.canvas.coords(self.hell1)
self.canvas.pack()
self.danger=1
def _set_switch(self):
self.oval_center = np.array([(MAZE_H * UNIT) / 2 - UNIT, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])
self.oval = self.canvas.create_oval(
self.oval_center[0] - 15, self.oval_center[1] - 15,
self.oval_center[0] + 15, self.oval_center[1] + 15,
fill='yellow')
self.switch = self.canvas.coords(self.oval)
self.canvas.pack()
def _set_wall(self):
wall_center=[]
self.wall=[]
for a in range(MAZE_W):
wall_center.append([0,0])
self.wall.append([])
for b in range(MAZE_W):
wall_center[b]=np.array([(MAZE_H*UNIT)/2,((b)*UNIT)+UNIT/2])
self.wall[b] = self.canvas.create_rectangle(
wall_center[b][0] - 20, wall_center[b][1] - 20,
wall_center[b][0] + 20, wall_center[b][1] + 20,
fill='grey')
self.wall0 = self.canvas.coords(self.wall[0])
self.wall1 = self.canvas.coords(self.wall[1])
self.wall2 = self.canvas.coords(self.wall[2])
self.wall3 = self.canvas.coords(self.wall[3])
# self.canvas.pack()
def generate_expression(self,pain):
if pain==1:
self.canvas.itemconfig(self.rect, fill="red", outline='red')
# self.canvas.pack()
if pain == 0:
self.canvas.itemconfig(self.rect, fill="green", outline='green')
# self.canvas.pack()
def render(self):
time.sleep(0.01)
self.update()
# def getter(self, widget):
# widget.update()
# x = tk.Tk.winfo_rootx(self) + widget.winfo_x()
# y = tk.Tk.winfo_rooty(self) + widget.winfo_y()
# x1 = x + widget.winfo_width()
# y1 = y + widget.winfo_height()
# ImageGrab.grab().crop((x, y, x1, y1)).save("first.jpg")
# return ImageGrab.grab().crop((x, y, x1, y1))
================================================
FILE: examples/Social_Cognition/affective_empathy/BAE-SNN/env_two_poly.py
================================================
import numpy as np
np.random.seed(1)
import tkinter as tk
import time
from PIL import ImageGrab
UNIT = 40 # pixels
MAZE_H = 9 # grid height
MAZE_W = 4 # grid width
class Maze2(tk.Tk, object):
def __init__(self):
super(Maze2, self).__init__()
self.action_space = ['u', 'd', 'l', 'r']
self.action_space1 = ['u', 'd', 'l', 'r']
self.n_actions = len(self.action_space)
self.n_actions1 = len(self.action_space1)
self.title('two_agent_empathy')
self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_W * UNIT))
self._build_maze()
self.danger=0
self.action_hurt=0
self.sensory_hurt = 0
self.action_hurt1 = 0
self.sensory_hurt1 = 0
self.open_door=0
# create environment
def _build_maze(self):
self.canvas = tk.Canvas(self, bg='white',
height=MAZE_W * UNIT,
width=MAZE_H * UNIT)
# create grids
for c in range(0, MAZE_H * UNIT, UNIT):
x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
self.canvas.create_line(x0, y0, x1, y1)
for r in range(0, MAZE_H * UNIT, UNIT):
x0, y0, x1, y1 = 0, r, MAZE_H * UNIT, r
self.canvas.create_line(x0, y0, x1, y1)
# create switch
self.oval_center = np.array([(MAZE_H * UNIT)/2-UNIT+80, ((MAZE_W+4) * UNIT)/2-UNIT/2-80])
self.oval = self.canvas.create_oval(
self.oval_center[0] - 15, self.oval_center[1] - 15,
self.oval_center[0] + 15, self.oval_center[1] + 15,
fill='yellow')
self.switch = self.canvas.coords(self.oval)
self.orgin1 = np.array([20, 20])
# 下
self.points1 = [
# 左上
self.orgin1[0] - 15, # 5
self.orgin1[1] - 15, # 5
# 右上
self.orgin1[0] + 15, # 35
self.orgin1[1] - 15, # 5
# 右下+
self.orgin1[0] + 15, # 35
self.orgin1[1], # 20
# 顶点
self.orgin1[0], # 20
self.orgin1[1] + 15, # 35
# 左下+
self.orgin1[0] - 15, # 5
self.orgin1[1], # 20
]
self.agent1 = self.canvas.create_polygon(self.points1, outline='black',fill="blue")
self.orgin = np.array([MAZE_H * UNIT - UNIT / 2, 20])
# 下
self.points = [
# 左上
self.orgin[0] - 15, # 5
self.orgin[1] - 15, # 5
# 右上
self.orgin[0] + 15, # 35
self.orgin[1] - 15, # 5
# 右下+
self.orgin[0] + 15, # 35
self.orgin[1], # 20
# 顶点
self.orgin[0], # 20
self.orgin[1] + 15, # 35
# 左下+
self.orgin[0] - 15, # 5
self.orgin[1], # 20
]
self.agent = self.canvas.create_polygon(self.points, fill="green")
wall_center = []
self.wall = []
for i in range(MAZE_W):
wall_center.append([])
self.wall.append([])
for i in range(MAZE_W):
wall_center[i] = np.array([(MAZE_H * UNIT) / 2, ((i) * UNIT) + UNIT / 2])
self.wall[i] = self.canvas.create_rectangle(
wall_center[i][0] - 20, wall_center[i][1] - 20,
wall_center[i][0] + 20, wall_center[i][1] + 20,
fill='grey')
self.hell1_center = np.array([100, 20])
self.hell1 = self.canvas.create_oval(
self.hell1_center[0] - 15, self.hell1_center[1] - 15,
self.hell1_center[0] + 15, self.hell1_center[1] + 15,
fill='black')
self.hell2_center = np.array([60, 100])
self.hell2 = self.canvas.create_oval(
self.hell2_center[0] - 15, self.hell2_center[1] - 15,
self.hell2_center[0] + 15, self.hell2_center[1] + 15,
fill='black')
# self.canvas.create_bitmap((40 , 40), bitmap='error')
self.danger = 1
self.canvas.pack()
#reset agent location
def reset(self):
self.update()
time.sleep(0.5)
self.canvas.delete(self.agent1)
self.canvas.delete(self.agent)
self.orgin1 = np.array([20, 20])
# 下
self.points1 = [
# 左上
self.orgin1[0] - 15, # 5
self.orgin1[1] - 15, # 5
# 右上
self.orgin1[0] + 15, # 35
self.orgin1[1] - 15, # 5
# 右下+
self.orgin1[0] + 15, # 35
self.orgin1[1], # 20
# 顶点
self.orgin1[0], # 20
self.orgin1[1] + 15, # 35
# 左下+
self.orgin1[0] - 15, # 5
self.orgin1[1], # 20
]
self.agent1 = self.canvas.create_polygon(self.points1, outline='black',fill="blue")
self.orgin = np.array([MAZE_H * UNIT - UNIT / 2, 20])
# 下
self.points = [
# 左上
self.orgin[0] - 15, # 5
self.orgin[1] - 15, # 5
# 右上
self.orgin[0] + 15, # 35
self.orgin[1] - 15, # 5
# 右下+
self.orgin[0] + 15, # 35
self.orgin[1], # 20
# 顶点
self.orgin[0], # 20
self.orgin[1] + 15, # 35
# 左下+
self.orgin[0] - 15, # 5
self.orgin[1], # 20
]
self.agent = self.canvas.create_polygon(self.points, fill="green")
return self.canvas.coords(self.agent1),self.canvas.coords(self.agent)
# move agent1
def step1(self, action1,pain):
s1 = self.canvas.coords(self.agent1)
self.centre1 = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2]
if all(self.centre1 == self.hell1_center):
self.action_hurt1 = 1
if all(self.centre1 == self.hell2_center):
self.action_hurt1 = 1
self.oval_center111 = np.array([(MAZE_H * UNIT) / 2, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])
if all(self.centre1 ==self.oval_center111):
move = np.array([80, 0])
self.canvas.move(self.agent1, move[0], move[1])
s1 = self.canvas.coords(self.agent1)
self.render()
self.oval_center111 = np.array([(MAZE_H * UNIT) / 2 - UNIT, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])
if all(self.centre1 ==self.oval_center111):
move = np.array([80, 0])
self.canvas.move(self.agent1, move[0], move[1])
s1 = self.canvas.coords(self.agent1)
self.render()
self.oval_center111 = np.array([(MAZE_H * UNIT) / 2 - UNIT*2, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])
if all(self.centre1 == self.oval_center111):
move = np.array([80, 0])
self.canvas.move(self.agent1, move[0], move[1])
s1 = self.canvas.coords(self.agent1)
self.render()
self.oval_center111 = np.array([(MAZE_H * UNIT) / 2 - UNIT*3, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])
if all(self.centre1 == self.oval_center111):
move = np.array([80, 0])
self.canvas.move(self.agent1, move[0], move[1])
s1 = self.canvas.coords(self.agent1)
self.render()
self.oval_center111 = np.array([(MAZE_H * UNIT) / 2 - UNIT*4, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])
if all(self.centre1 == self.oval_center111):
move = np.array([80, 0])
self.canvas.move(self.agent1, move[0], move[1])
s1 = self.canvas.coords(self.agent1)
self.render()
self.canvas.delete(self.agent1) # 主要为开关那几步考虑,所以重复写了
self.centre1 = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2]
if action1==0:
self.points0 = [
# 右下
self.centre1[0] + 15, # 35
self.centre1[1] + 15, # 35
# 左下
self.centre1[0] - 15, # 5
self.centre1[1] + 15, # 35
# 左上+
self.centre1[0] - 15, # 5
self.centre1[1], # 20
# 顶点
self.centre1[0], # 20
self.centre1[1] - 15, # 5
# 右上+
self.centre1[0] + 15, # 35
self.centre1[1], # 20
]
if pain==0:
color="blue"
if pain == 1:
color = "red"
self.agent1 = self.canvas.create_polygon(self.points0, fill=color,outline='black')
if action1==1:
self.points1 = [
# 左上
self.centre1[0] - 15, # 5
self.centre1[1] - 15, # 5
# 右上
self.centre1[0] + 15, # 35
self.centre1[1] - 15, # 5
# 右下+
self.centre1[0] + 15, # 35
self.centre1[1], # 20
# 顶点
self.centre1[0], # 20
self.centre1[1] + 15, # 35
# 左下+
self.centre1[0] - 15, # 5
self.centre1[1], # 20
]
if pain==0:
color="blue"
if pain == 1:
color = "red"
self.agent1 = self.canvas.create_polygon(self.points1, fill=color,outline='black')
if action1==2:
self.points2 = [
# 左下
self.centre1[0] - 15, # 5
self.centre1[1] + 15, # 35
# 左上
self.centre1[0] - 15, # 5
self.centre1[1] - 15, # 5
# 右上+
self.centre1[0], # 20
self.centre1[1] - 15, # 5
# 顶点
self.centre1[0] + 15, # 35
self.centre1[1], # 20
# 右下+
self.centre1[0], # 20
self.centre1[1] + 15, # 35
]
if pain==0:
color="blue"
if pain == 1:
color = "red"
self.agent1 = self.canvas.create_polygon(self.points2, fill=color,outline='black')
if action1==3:
self.points3 = [
# 右上
self.centre1[0] + 15, # 20+15
self.centre1[1] - 15, # 20-15
# 右下
self.centre1[0] + 15, # 20+15
self.centre1[1] + 15, # 20+15
# 左下+
self.centre1[0], # 20
self.centre1[1] + 15, # 20+15
# 顶点
self.centre1[0] - 15, # 20-15
self.centre1[1], # 20
# 左上+
self.centre1[0], # 20
self.centre1[1] - 15, # 20-15
]
if pain==0:
color="blue"
if pain == 1:
color = "red"
self.agent1 = self.canvas.create_polygon(self.points3, fill=color,outline='black')
s1 = self.canvas.coords(self.agent1)
self.render()#显示当前的动作指令是什么
self.centre1 = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2]
if self.centre1[0] > (9 / 2) * 40:
self.action_hurt1 = 0
# whether hurt
if self.action_hurt1 == 0:
true_action1 = action1
else:
if action1 == 0:
true_action1 = 1
if action1 == 1:
true_action1 = 0
if action1 == 2:
true_action1 = 3
if action1 == 3:
true_action1 = 2
# predict next state
b = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT: # 120
if action1 == 0: # up
if self.centre1[1] > UNIT:
b = [0, -40, 0, -40, 0, -40, 0, -40, 0, -40]
elif action1 == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
b = [0, 40, 0, 40, 0, 40, 0, 40, 0, 40]
elif action1 == 2: # right
if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:
b = [40, 0, 40, 0, 40, 0, 40, 0, 40, 0]
elif action1 == 3: # left
if self.centre1[0] > UNIT:
b = [-40, 0, -40, 0, -40, 0, -40, 0, -40, 0]
else:
if action1 == 0: # up
if self.centre1[1] > UNIT:
b = [0, -40, 0, -40, 0, -40, 0, -40, 0, -40]
elif action1 == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
b = [0, 40, 0, 40, 0, 40, 0, 40, 0, 40]
elif action1 == 2: # right
if self.centre1[0] < (MAZE_H - 1) * UNIT:
b = [40, 0, 40, 0, 40, 0, 40, 0, 40, 0]
elif action1 == 3: # left
if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:
b = [-40, 0, -40, 0, -40, 0, -40, 0, -40, 0]
s_predict = []
for i in range(len(b)):
s_predict1 = s1[i] + b[i]
s_predict.append(s_predict1)
base_action1 = np.array([0, 0])
# true next state
if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:
if true_action1 == 0: # up
if self.centre1[1] > UNIT:
base_action1[1] -= UNIT
elif true_action1 == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
base_action1[1] += UNIT
elif true_action1 == 2: # right
if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:
base_action1[0] += UNIT
elif true_action1 == 3: # left
if self.centre1[0] > UNIT:
base_action1[0] -= UNIT
else:
if true_action1 == 0: # up
if self.centre1[1] > UNIT:
base_action1[1] -= UNIT
elif true_action1 == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
base_action1[1] += UNIT
elif true_action1 == 2: # right
if self.centre1[0] < (MAZE_H - 1) * UNIT:
base_action1[0] += UNIT
elif true_action1 == 3: # left
if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:
base_action1[0] -= UNIT
self.canvas.move(self.agent1, base_action1[0], base_action1[1])
s1_ = self.canvas.coords(self.agent1)
return s1_, s_predict,color
def agent_help(self):
s = self.canvas.coords(self.agent)
self.centre2= [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]
if all(self.centre2 == self.oval_center):
self.canvas.delete(self.wall[3])
self.render()
self.open_door=1
else:
self.canvas.move(self.agent, -40, 0) # move agent
self.render()
self.canvas.move(self.agent, -40, 0)
self.render()
self.canvas.move(self.agent, -40, 0)
self.render()
self.canvas.move(self.agent, 0, 40)
self.render()
s_ = self.canvas.coords(self.agent) # next state
return s_
def _set_danger(self):
hell1_center = np.array([140, 60])
self.hell1 = self.canvas.create_rectangle(
hell1_center[0] - 15, hell1_center[1] - 15,
hell1_center[0] + 15, hell1_center[1] + 15,
fill='black')
hell2_center = np.array([100, 140])
self.hell2 = self.canvas.create_rectangle(
hell2_center[0] - 15, hell2_center[1] - 15,
hell2_center[0] + 15, hell2_center[1] + 15,
fill='black')
# self.canvas.create_bitmap((40 , 40), bitmap='error')
self.canvas.pack()
self.danger=1
def _set_wall(self):
wall_center=[]
self.wall=[]
for i in range(MAZE_W):
wall_center.append([])
self.wall.append([])
for i in range(MAZE_W):
wall_center[i]=np.array([(MAZE_H*UNIT)/2,((i)*UNIT)+UNIT/2])
self.wall[i] = self.canvas.create_rectangle(
wall_center[i][0] - 20, wall_center[i][1] - 20,
wall_center[i][0] + 20, wall_center[i][1] + 20,
fill='grey')
self.canvas.pack()
def generate_expression1(self,pain1):
if pain1==1:
self.canvas.itemconfig(self.agent1, fill="red", outline='black')
self.canvas.pack()
if pain1 ==0:
self.canvas.itemconfig(self.agent1, fill="blue", outline='black')
self.canvas.pack()
def render(self):
time.sleep(0.2)
self.update()
================================================
FILE: examples/Social_Cognition/affective_empathy/BEEAD-SNN/BEEAD-SNN.py
================================================
import os
import sys
import imageio
from env_poly_SNN import Maze
from env import Maze2
from RL_brain import QLearningTable
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
np.random.seed(1)
from torch.utils.tensorboard import SummaryWriter
from sklearn.preprocessing import MinMaxScaler
import torch, os, sys
from torch import nn
from torch.nn import Parameter
import abc
import math
from abc import ABC
import torch.nn.functional as F
from braincog.base.node.node import *
from braincog.base.learningrule.STDP import *
from braincog.base.connection.CustomLinear import *
from braincog.base.utils.visualization import spike_rate_vis, spike_rate_vis_1d#, spike_vis_2, spike_vis_5
class BrainArea(nn.Module, abc.ABC):
@abc.abstractmethod
def __init__(self):
super().__init__()
@abc.abstractmethod
def forward(self, x):
"""
Calculate the forward propagation process
:return:x is spike
"""
return x
def reset(self):
"""
Calculate the forward propagation process
:return:x is spike
"""
pass
class BAESNN(BrainArea):
"""
Affactive Empathy Network
"""
def __init__(self,):
super().__init__()
self.node = [IFNode() for i in range(5)]
self.connection = []
con_matrix0 = torch.eye(24, 24)*6
self.connection.append(CustomLinear(con_matrix0))#input-emotion
con_matrix1 = torch.eye(24, 24)
self.connection.append(CustomLinear(con_matrix1))#emotion-ifg
con_matrix2 = torch.zeros((24, 24), dtype=torch.float)
self.connection.append(CustomLinear(con_matrix2))#perception-ifg
con_matrix3 = torch.eye(24, 24)*6
self.connection.append(CustomLinear(con_matrix3))#input-perception
con_matrix4=torch.zeros((24,10), dtype=torch.float)
for j in range(10):
if j in np.arange(0,5,1):
for i in np.arange(0, 12, 1):
con_matrix4[i,j] =2
if j in np.arange(5,10,1):
for i in np.arange(12, 24, 1):
con_matrix4[i,j] =2
self.connection.append(CustomLinear(con_matrix4))#emotion-sma
con_matrix5=torch.zeros((24,10), dtype=torch.float)
self.connection.append(CustomLinear(con_matrix5))#perception-m1
con_matrix6 = torch.eye(10, 10)*6
self.connection.append(CustomLinear(con_matrix6))#sma-m1
self.stdp = []
self.stdp.append(STDP(self.node[0], self.connection[0]))#0 node0 emotion
self.stdp.append(STDP(self.node[2], self.connection[3]))#1 node2 perception
self.stdp.append(MutliInputSTDP(self.node[1], [self.connection[1], self.connection[2]]))#2 node1 ifg
self.stdp.append(MutliInputSTDP(self.node[3], [self.connection[4], self.connection[5]]))#3 node3 sma
self.stdp.append(STDP(self.node[4], self.connection[6]))#4 node4 m1
self.stdp.append(STDP(self.node[1],self.connection[2]))#5 node1 ifg
self.stdp.append(STDP(self.node[3],self.connection[5]))#6 node3 sma
def forward(self, x1,x2):
out__m, dw0 = self.stdp[0](x1)#node0 emotion
out__p, dw3 = self.stdp[1](x2)#node2 perception
out__ifg,dw_p_i=self.stdp[2](out__m,out__p)#node1 ifg
out__sma,dw_p_s=self.stdp[3](out__m,out__p)#node3 sma
out__m1,dw1=self.stdp[4](out__sma)#node4 m1
return dw_p_i,dw_p_s,out__ifg,out__sma,out__m1,out__m,out__p
def empathy(self,x3):
out_p,dw2=self.stdp[1](x3)#node2 perception
out_ifg,dw4=self.stdp[5](out_p)#node1 ifg
out_sma,dw5=self.stdp[6](out_p)#node3 sma
out_m1,dw6=self.stdp[4](out_sma)#node4 m1
return out_ifg,out_sma,out_m1,out_p
def UpdateWeight(self, i, dw, delta):
self.connection[i].update(dw*delta)
self.connection[i].weight.data= torch.clamp(self.connection[i].weight.data,-1,4)
def reset(self):
for i in range(5):
self.node[i].n_reset()
for i in range(len(self.stdp)):
self.stdp[i].reset()
class DopamineArea(BrainArea):
"""
Dopamine brain area with a group of spiking neurons, computes reward prediction error.
"""
def __init__(self, n_neurons, beta=0.2):
super().__init__()
self.n_neurons = n_neurons
self.beta = beta
self.node = [IFNode() for _ in range(n_neurons)]
self.P = np.zeros(n_neurons) # prediction for each neuron
def forward(self, spikes):
out_spikes = []
for i in range(self.n_neurons):
spike = self.node[i](torch.tensor([spikes[i]], dtype=torch.float32))
out_spikes.append(spike)
S = torch.stack(out_spikes).mean().item()
delta = S - self.P
self.P = self.P + self.beta * delta
return delta, out_spikes
def reset(self):
self.P = np.zeros(self.n_neurons)
for n in self.node:
n.n_reset()
def BAESNN_train():
s = env.reset()
env._set_danger()
env._set_wall()
pain = 0
i = 0
set_pain = 0
env._set_switch()
for i in range(100):
snn2.reset()
T = 100
pain = 0
print('**************step:', i)
env.render()
action = np.random.choice(list(range(env.n_actions)))
print('action:', action)
d, d_pre, s_, sss = env.step(s, action, pain)
print('d:', d, 'd_pre:', d_pre, 'sss:', sss)
env.render()
while (d == np.array([0, 0])).all():
action = np.random.choice(list(range(env.n_actions)))
print('action:', action)
d, d_pre, s_, sss = env.step(s, action, pain)
print('d:', d, 'd_pre:', d_pre, 'sss:', sss)
env.render()
# Use env.is_agent1_in_danger to set OUT_PAIN, pain, emotion
if env.is_agent_in_danger():
OUT_PAIN = torch.ones(24)
pain = 1
set_pain = 1
emotion = -1
else:
OUT_PAIN = torch.zeros(24)
pain = 0
emotion = 0
print("OUT_PAIN:", OUT_PAIN)
print("pain:", pain)
print("emotion:", emotion)
T2 = 20
X1 = OUT_PAIN
X2 = torch.zeros(24)
X3 = torch.cat([torch.ones(12) * 0.1, torch.zeros(12)])
print('X1,X2:', X1, X2)
spike_emotion = []
spike_ifg = []
spike_sma = []
spike_m1 = []
spike_per = []
for i in range(T2):
if i >= 2:
X2 = X3
OUTPUT = snn2(X1, X2)
snn2.UpdateWeight(2, OUTPUT[0][1], 0.01)
snn2.UpdateWeight(5, OUTPUT[1][1], -0.1)
if OUTPUT[2][0] == 1:
env.canvas.itemconfig(env.rect, fill="red", outline='red')
if OUTPUT[2][0] == 0:
env.canvas.itemconfig(env.rect, fill="green", outline='green')
spike_emotion.append(OUTPUT[5])
spike_per.append(OUTPUT[6])
spike_ifg.append(OUTPUT[2])
spike_sma.append(OUTPUT[3])
spike_m1.append(OUTPUT[4])
print('out_ifg:', OUTPUT[2])
print('out_sma:', OUTPUT[3])
print('out_m1:', OUTPUT[4])
print('con2:', snn2.connection[2].weight.data)
print('con5:', snn2.connection[5].weight.data)
spike_emotion = torch.stack(spike_emotion)
spike_per = torch.stack(spike_per)
spike_ifg = torch.stack(spike_ifg)
spike_sma = torch.stack(spike_sma)
spike_m1 = torch.stack(spike_m1)
print(spike_emotion.shape)
env.render()
s = s_
if set_pain == 1 and pain == 0:
env.render()
break
env.destroy()
def BAESNN_train_alstruism(lamda, E):
global writer
for episode in range(E):
print('*******************episode:', episode, ',factor:', lamda, '*********************************')
s1,s2=env2.reset()
env2._set_wall()
pain1 = 0
pain2 = 0
i = 0
set_pain = 0
env2.emotion = 0
env2.empathy_emotion = 0
env2.empathy_emotion_t_1 = 0
rr = 0
hh = 0
g = 0
a = []
if episode < 200:
e_greedy = 0.5
elif episode < 500:
e_greedy = 0.7
elif episode < 700:
e_greedy = 0.9
else:
e_greedy = 1
r = np.random.uniform()
for i in range(100):
env2.render()
done = False
env2.empathy_emotion_t_1=env2.empathy_emotion
action2 = RL.choose_action(str([(s2[4] + s2[8]) / 2, (s2[5] + s2[9]) / 2,env2.empathy_emotion]),e_greedy=e_greedy)
s2_, done, done_oval = env2.step2(action2)
env2.render()
T=100
print('**************step:',i)
env2.render()
action1 = np.random.choice(list(range(env.n_actions)))
print('action1:',action1)
if r<= 0.25:
if i==0:
action1=2
if 0.25 UNIT:
pre_displacement1 = np.array([0, -40])
elif action1 == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
pre_displacement1 = np.array([0, 40])
elif action1 == 2: # right
if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:
pre_displacement1 = np.array([40, 0])
elif action1 == 3: # left
if self.centre1[0] > UNIT:
pre_displacement1 = np.array([-40, 0])
else:
if action1 == 0: # up
if self.centre1[1] > UNIT:
pre_displacement1 = np.array([0, -40])
elif action1 == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
pre_displacement1 = np.array([0, 40])
elif action1 == 2: # right
if self.centre1[0] < (MAZE_H - 1) * UNIT:
pre_displacement1 = np.array([40, 0])
elif action1 == 3: # left
if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:
pre_displacement1 = np.array([-40, 0])
# true next state
displacement1 = np.array([0, 0])
if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:
if true_action1 == 0: # up
if self.centre1[1] > UNIT:
displacement1= np.array([0, -40])
elif true_action1 == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
displacement1= np.array([0,40])
elif true_action1 == 2: # right
if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:
displacement1= np.array([40,0])
elif true_action1 == 3: # left
if self.centre1[0] > UNIT:
displacement1= np.array([-40,0])
else:
if true_action1 == 0: # up
if self.centre1[1] > UNIT:
displacement1= np.array([0,-40])
elif true_action1 == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
displacement1= np.array([0,40])
elif true_action1 == 2: # right
if self.centre1[0] < (MAZE_H - 1) * UNIT:
displacement1= np.array([40,0])
elif true_action1 == 3: # left
if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:
displacement1= np.array([-40,0])
self.canvas.move(self.agent1, displacement1[0], displacement1[1])
s1_ = self.canvas.coords(self.agent1)
sss = [(s1_[4] + s1_[8]) / 2, (s1_[5] + s1_[9]) / 2]
return displacement1, pre_displacement1,s1_,sss
def is_agent1_in_danger(self):
"""
whether in danger(hell1~hell4)
:return: True/False
"""
s1 = self.canvas.coords(self.agent1)
print([(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2])
print(f'hell1_center: {self.hell1_center}, hell2_center: {self.hell2_center}, hell3_center: {self.hell3_center}, hell4_center: {self.hell4_center}')
agent1_pos = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2]
danger_centers = [self.hell1_center, self.hell2_center, self.hell3_center, self.hell4_center]
for center in danger_centers:
if all(np.isclose(agent1_pos, center)):
return True
return False
def step2(self, action):
s = self.canvas.coords(self.agent2)
s=[(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]
if all(s == self.oval_center)and self.empathy_emotion==-1:
done_oval=1
else:
done_oval=0
base_action = np.array([0, 0])
if action == 0: # up
if s[1] > UNIT:
base_action[1] -= UNIT
elif action == 1: # down
if s[1] < (MAZE_W - 1) * UNIT:
base_action[1] += UNIT
elif action == 2: # right
if s[0] < (MAZE_H - 1) * UNIT:
base_action[0] += UNIT
elif action == 3: # left
if s[0] > ( (MAZE_H - 1)/2+2) * UNIT:
base_action[0] -= UNIT
self.canvas.move(self.agent2, base_action[0], base_action[1]) # move agent
s_ = self.canvas.coords(self.agent2) # next state
self.centre2= [(s_[4] + s_[8]) / 2, (s_[5] + s_[9]) / 2]
if all(self.centre2 == self.oval_center) and self.empathy_emotion==-1:
self.help_signal=1
if all(self.centre2 == self.goal_centre):
done = True
else:
done = False
return s_, done,done_oval
def reward2(self):
s_ = self.canvas.coords(self.agent2)
self.centre2= [(s_[4] + s_[8]) / 2, (s_[5] + s_[9]) / 2]
if (self.empathy_emotion - self.empathy_emotion_t_1)==-1:
reward1=0
elif (self.empathy_emotion - self.empathy_emotion_t_1)==1:
reward1=10
else:
reward1=0
if all(self.centre2 == self.goal_centre):
reward2 = 10
else:
reward2 = -1
return reward1,reward2
def _set_wall(self):
self.oval_center = np.array([(MAZE_H * UNIT)-20, ((MAZE_W)*UNIT-20)])# [(MAZE_H * UNIT)/2+80, UNIT/2+UNIT]
self.oval = self.canvas.create_oval(
self.oval_center[0] - 15, self.oval_center[1] - 15,
self.oval_center[0] + 15, self.oval_center[1] + 15,
fill='yellow')
self.help = self.canvas.coords(self.oval)
wall_center = []
self.wall = []
for i in range(MAZE_W):
wall_center.append([])
self.wall.append([])
for i in range(MAZE_W):
wall_center[i] = np.array([(MAZE_H * UNIT) / 2, ((i) * UNIT) + UNIT / 2])# wall
self.wall[i] = self.canvas.create_rectangle(
wall_center[i][0] - 20, wall_center[i][1] - 20,
wall_center[i][0] + 20, wall_center[i][1] + 20,
fill='grey')
self.canvas.pack()
def generate_expression1(self,emotion):
if emotion==-1:
self.canvas.itemconfig(self.agent1, fill="red", outline='black')
self.canvas.pack()
if emotion==0:
self.canvas.itemconfig(self.agent1, fill="blue", outline='black')
self.canvas.pack()
def generate_expression2(self,emotion):
if emotion==-1:
self.canvas.itemconfig(self.agent2, fill="red")
self.canvas.pack()
if emotion==0:
self.canvas.itemconfig(self.agent2, fill="green")
self.canvas.pack()
def render(self):
time.sleep(0.000001)
self.update()
# def getter(self,widget):
# widget.update()
# x = tk.Tk.winfo_rootx(self) + widget.winfo_x()
# y = tk.Tk.winfo_rooty(self) + widget.winfo_y()
# x1 = x + widget.winfo_width()
# y1 = y + widget.winfo_height()
# ImageGrab.grab().crop((x, y, x1, y1)).save("first.jpg")
# return ImageGrab.grab().crop((x, y, x1, y1))
================================================
FILE: examples/Social_Cognition/affective_empathy/BEEAD-SNN/env_poly_SNN.py
================================================
import numpy as np
np.random.seed(1)
import tkinter as tk
import time
from PIL import ImageGrab
UNIT = 40 # pixels
MAZE_H = 9 # grid height
MAZE_W = 4 # grid width
class Maze(tk.Tk, object):
def __init__(self):
super(Maze, self).__init__()
self.action_space = ['u', 'd', 'l', 'r']
self.n_actions = len(self.action_space)
self.title('self-pain')
self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_W * UNIT))
self._build_maze()
self.danger=0
self.action_hurt=0
self.sensory_hurt = 0
self.open_door = 0
self.pain_state=0
# create environment
def _build_maze(self):
self.canvas = tk.Canvas(self, bg='white',
height=MAZE_W * UNIT,
width=MAZE_H * UNIT)
# create grids
for c in range(0, MAZE_H * UNIT, UNIT):
x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
self.canvas.create_line(x0, y0, x1, y1)
for r in range(0, MAZE_H * UNIT, UNIT):
x0, y0, x1, y1 = 0, r, MAZE_H * UNIT, r
self.canvas.create_line(x0, y0, x1, y1)
self.orgin=[20,20]
# create agent
self.points1 = [
self.orgin[0]-15,self.orgin[1]-15,
self.orgin[0]+15,self.orgin[1]-15,
self.orgin[0]+15,self.orgin[1],
self.orgin[0],self.orgin[1]+15,
self.orgin[0]-15,self.orgin[1],
]
self.rect = self.canvas.create_polygon(self.points1, fill="green")
self.canvas.pack()
#reset agent location
def reset(self):
self.open_door = 0
self.update()
time.sleep(0.5)
self.canvas.delete(self.rect)
self.orgin = [20, 20]
# 下
self.points1 = [
self.orgin[0] - 15,self.orgin[1] - 15,
self.orgin[0] + 15,self.orgin[1] - 15,
self.orgin[0] + 15,self.orgin[1],
self.orgin[0],self.orgin[1]+15,
self.orgin[0] - 15,self.orgin[1],
]
self.rect = self.canvas.create_polygon(self.points1, fill="green")
return self.canvas.coords(self.rect)
def step(self, s, action, pain):
s = self.canvas.coords(self.rect)
self.centre = [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]
# danger or switch
if self.danger==1:
if all(self.centre == self.oval_center):
s_color = 'yellow'
self.canvas.delete(self.wall[3])
self.render()
self.open_door = 1
move = np.array([80, 0])
self.canvas.move(self.rect, move[0], move[1])
s = self.canvas.coords(self.rect)
self.render()
elif all(self.centre == self.hell1_center):
s_color = 'black'
self.action_hurt = 1
self.render()
else:
s_color = 'white'
# modify current state
self.canvas.delete(self.rect)
self.centre = [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]
if action==0:
self.points0 = [
self.centre[0] + 15,self.centre[1] + 15,
self.centre[0] - 15,self.centre[1] + 15,
self.centre[0] - 15,self.centre[1],
self.centre[0],self.centre[1] - 15,
self.centre[0] + 15,self.centre[1],
]
if pain==0:
color="green"
if pain == 1:
color = "red"
self.rect = self.canvas.create_polygon(self.points0, fill=color)
if action==1:
self.points1 = [
self.centre[0] - 15,self.centre[1] - 15,
self.centre[0] + 15,self.centre[1] - 15,
self.centre[0] + 15,self.centre[1],
self.centre[0],self.centre[1] + 15,
self.centre[0] - 15,self.centre[1],
]
if pain==0:
color="green"
if pain == 1:
color = "red"
self.rect = self.canvas.create_polygon(self.points1, fill=color)
if action==2:
self.points2 = [
self.centre[0] - 15,self.centre[1] + 15,
self.centre[0] - 15,self.centre[1] - 15,
self.centre[0],self.centre[1] - 15,
self.centre[0] + 15,self.centre[1],
self.centre[0],self.centre[1] + 15,
]
if pain==0:
color="green"
if pain == 1:
color = "red"
self.rect = self.canvas.create_polygon(self.points2, fill=color)
if action==3:
self.points3 = [
self.centre[0] + 15,
self.centre[1] - 15,
self.centre[0] + 15,self.centre[1] + 15,
self.centre[0],self.centre[1] + 15,
self.centre[0] - 15,self.centre[1],
self.centre[0],self.centre[1] - 15,
]
if pain==0:
color="green"
if pain == 1:
color = "red"
self.rect = self.canvas.create_polygon(self.points3, fill=color)
s = self.canvas.coords(self.rect)
self.render()
if s[0] > (9 / 2) * 40:
self.action_hurt = 0
base_action = np.array([0, 0])
if self.action_hurt == 0:
true_action = action
else:
if action == 0:
true_action = 1
if action == 1:
true_action = 0
if action == 2:
true_action = 3
if action == 3:
true_action = 2
# predict next state
self.centre1 = [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]
pre_displacement1 = np.array([0, 0])
if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT: # 120
if action == 0: # up
if self.centre1[1] > UNIT:
pre_displacement1 = np.array([0, -40])
elif action == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
pre_displacement1 = np.array([0, 40])
elif action == 2: # right
if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:
pre_displacement1 = np.array([40, 0])
elif action == 3: # left
if self.centre1[0] > UNIT:
pre_displacement1 = np.array([-40, 0])
else:
if action == 0: # up
if self.centre1[1] > UNIT:
pre_displacement1 = np.array([0, -40])
elif action == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
pre_displacement1 = np.array([0, 40])
elif action == 2: # right
if self.centre1[0] < (MAZE_H - 1) * UNIT:
pre_displacement1 = np.array([40, 0])
elif action == 3: # left
if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:
pre_displacement1 = np.array([-40, 0])
# true next state
displacement1 = np.array([0, 0])
if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:
if true_action == 0: # up
if self.centre1[1] > UNIT:
displacement1=np.array([0,-40])
elif true_action == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
displacement1=np.array([0,40])
elif true_action == 2: # right
if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:
displacement1=np.array([40,0])
elif true_action == 3: # left
if self.centre1[0] > UNIT:
displacement1=np.array([-40,0])
else:
if true_action == 0: # up
if self.centre1[1] > UNIT:
displacement1=np.array([0,-40])
elif true_action == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
displacement1=np.array([0,40])
elif true_action == 2: # right
if self.centre1[0] < (MAZE_H - 1) * UNIT:
displacement1=np.array([40,0])
elif true_action == 3: # left
if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:
displacement1=np.array([-40,0])
self.canvas.move(self.rect, displacement1[0], displacement1[1])
s1_ = self.canvas.coords(self.rect)
sss = [(s1_[4] + s1_[8]) / 2, (s1_[5] + s1_[9]) / 2]
return displacement1, pre_displacement1,s1_,sss
def _set_danger(self):
self.hell1_center = np.array([60, 60])
self.hell1 = self.canvas.create_oval(
self.hell1_center[0] - 15, self.hell1_center[1] - 15,
self.hell1_center[0] + 15, self.hell1_center[1] + 15,
fill='black')
# self.canvas.create_bitmap((40 , 40), bitmap='error')
self.hell = self.canvas.coords(self.hell1)
self.canvas.pack()
self.danger=1
def _set_switch(self):
self.oval_center = np.array([(MAZE_H * UNIT) / 2 - UNIT, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])
self.oval = self.canvas.create_oval(
self.oval_center[0] - 15, self.oval_center[1] - 15,
self.oval_center[0] + 15, self.oval_center[1] + 15,
fill='yellow')
self.switch = self.canvas.coords(self.oval)
self.canvas.pack()
def _set_wall(self):
wall_center=[]
self.wall=[]
for a in range(MAZE_W):
wall_center.append([0,0])
self.wall.append([])
for b in range(MAZE_W):
wall_center[b]=np.array([(MAZE_H*UNIT)/2,((b)*UNIT)+UNIT/2])
self.wall[b] = self.canvas.create_rectangle(
wall_center[b][0] - 20, wall_center[b][1] - 20,
wall_center[b][0] + 20, wall_center[b][1] + 20,
fill='grey')
self.wall0 = self.canvas.coords(self.wall[0])
self.wall1 = self.canvas.coords(self.wall[1])
self.wall2 = self.canvas.coords(self.wall[2])
self.wall3 = self.canvas.coords(self.wall[3])
def generate_expression(self,pain):
if pain==1:
self.canvas.itemconfig(self.rect, fill="red", outline='red')
if pain == 0:
self.canvas.itemconfig(self.rect, fill="green", outline='green')
def render(self):
time.sleep(0.1)
self.update()
def is_agent_in_danger(self):
"""
Check if the agent is in a danger zone (hell1).
Returns: True/False
"""
s = self.canvas.coords(self.rect)
agent_pos = [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]
if hasattr(self, 'hell1_center'):
if all(np.isclose(agent_pos, self.hell1_center)):
return True
return False
# def getter(self, widget):
# widget.update()
# x = tk.Tk.winfo_rootx(self) + widget.winfo_x()
# y = tk.Tk.winfo_rooty(self) + widget.winfo_y()
# x1 = x + widget.winfo_width()
# y1 = y + widget.winfo_height()
# ImageGrab.grab().crop((x, y, x1, y1)).save("first.jpg")
# return ImageGrab.grab().crop((x, y, x1, y1))
================================================
FILE: examples/Social_Cognition/affective_empathy/BEEAD-SNN/rsnn.py
================================================
import torch
from torch import nn
from braincog.base.node.node import IFNode
from braincog.base.learningrule.STDP import STDP,MutliInputSTDP
from braincog.base.connection.CustomLinear import CustomLinear
from collections import deque
from random import randint
class RSNN(nn.Module):
def __init__(self,num_state,num_action):
super().__init__()
# parameters
rsnn_mask=[]
rsnn_con=[]
con_matrix1 = torch.ones((num_state,num_action), dtype=torch.float)
rsnn_mask.append(con_matrix1)
# rsnn_con.append(CustomLinear(torch.randint(2,size=(num_state,num_action))*0.1))
rsnn_con.append(CustomLinear(torch.ones(num_state,num_action)*0.1))
self.num_subR=2
self.connection = rsnn_con
self.mask=rsnn_mask
self.node = [IFNode() for i in range(self.num_subR)]
self.learning_rule = []
self.learning_rule.append(MutliInputSTDP(self.node[1], [self.connection[0]]))
self.weight_trace = torch.zeros(con_matrix1.shape, dtype=torch.float)
self.out_in = torch.zeros((num_state), dtype=torch.float)
self.out = torch.zeros((self.connection[0].weight.size()[1]), dtype=torch.float)
self.dw = torch.zeros((self.connection[0].weight.size()), dtype=torch.float)
def forward(self, input):
input=torch.tensor(input, dtype=torch.float)
self.out_in=self.node[0](input)
self.out,self.dw = self.learning_rule[0](self.out_in)
return self.out,self.dw
def UpdateWeight(self,reward,a,C,n):
self.weight_trace[0:n,:]=0
self.weight_trace[n+1:, :] = 0
self.weight_trace[:, :a * C] = 0
self.weight_trace[:, (a + 1) * C:] = 0
self.weight_trace[self.weight_trace>0]=self.weight_trace[self.weight_trace>0]*reward
self.weight_trace[self.weight_trace < 0] = -1*self.weight_trace[self.weight_trace < 0] * reward
self.connection[0].update(self.weight_trace)
# self.connection[0].weight.data = torch.clamp(self.connection[0].weight.data, -1, 10)
# for i in range(self.connection[0].weight.size()[1]):
# self.connection[0].weight.data[:, i] = (self.connection[0].weight.data[:, i] - torch.min(self.connection[0].weight.data[:, i])) / (torch.max(self.connection[0].weight.data[:, i]) - torch.min(self.connection[0].weight.data[:, i]))
# self.connection[0].weight.data= self.connection[0].weight.data * 0.5
self.weight_trace = torch.zeros((64,5*C), dtype=torch.float)
def reset(self):
for i in range(self.num_subR):
self.node[i].n_reset()
for i in range(len(self.learning_rule)):
self.learning_rule[i].reset()
def getweight(self):
return self.connection
================================================
FILE: examples/Social_Cognition/affective_empathy/BEEAD-SNN/sd_env.py
================================================
import numpy as np
np.random.seed(1)
import tkinter as tk
import time
from PIL import ImageGrab
UNIT = 40 # pixels
MAZE_H = 11 # grid horizontal
MAZE_W = 5 # grid vertical
class Snowdrift(tk.Tk, object):
def __init__(self, n_agents=3, n_snowdrifts=4):
super(Snowdrift, self).__init__()
self.action_space = ['up', 'down', 'left', 'right', 'clean']
self.n_actions = len(self.action_space)
self.n_agents = n_agents
self.n_snowdrifts = n_snowdrifts
self.UNIT = 40
self.MAZE_H = 8
self.MAZE_W = 8
self.title('Snowdrift Game')
self.geometry('{0}x{1}'.format(self.MAZE_H * self.UNIT, self.MAZE_W * self.UNIT)) # canvas
self.agents = []
self.agents_pos = []
self.agents_emotion = [-1] * n_agents
self.snowdrifts = []
self.snowdrifts_pos = []
self.cleaned = []
self.empathy_emotion = 0
self.empathy_emotion_t_1 = 0
self.help_signal = 0
self.lamda = 1.0
self._build_maze()
def _build_maze(self):
self.canvas = tk.Canvas(self, bg='white',
height=self.MAZE_H * self.UNIT,
width=self.MAZE_W * self.UNIT)
# Create agents
colors = ['red', 'blue', 'green']
for i in range(self.n_agents):
pos = np.array([np.random.randint(0, self.MAZE_W) * self.UNIT + self.UNIT/2,
np.random.randint(0, self.MAZE_H) * self.UNIT + self.UNIT/2])
agent = self.canvas.create_oval(
pos[0] - 15, pos[1] - 15,
pos[0] + 15, pos[1] + 15,
fill=colors[i])
self.agents.append(agent)
self.agents_pos.append(pos)
if self.agents_emotion[i] == -1:
self.canvas.itemconfig(agent, fill='gray')
# Create snowdrifts
for _ in range(self.n_snowdrifts):
pos = np.array([np.random.randint(0, self.MAZE_W) * self.UNIT + self.UNIT/2,
np.random.randint(0, self.MAZE_H) * self.UNIT + self.UNIT/2])
points = [
pos[0], pos[1] - 15,
pos[0] - 15, pos[1] + 15,
pos[0] + 15, pos[1] + 15
]
snowdrift = self.canvas.create_polygon(points, fill='black')
self.snowdrifts.append(snowdrift)
self.snowdrifts_pos.append(pos)
self.cleaned.append(False)
self.canvas.pack()
def reset(self, agent_id):
"""Reset environment and return initial state index"""
self.update()
time.sleep(0.001)
# Reset all states
self.empathy_emotion = 0
self.empathy_emotion_t_1 = 0
self.help_signal = 0
# Reset agents
for i in range(self.n_agents):
self.canvas.delete(self.agents[i])
pos = np.array([np.random.randint(0, self.MAZE_W) * self.UNIT + self.UNIT/2,
np.random.randint(0, self.MAZE_H) * self.UNIT + self.UNIT/2])
self.agents_pos[i] = pos
self.agents[i] = self.canvas.create_oval(
pos[0] - 15, pos[1] - 15,
pos[0] + 15, pos[1] + 15,
fill='gray') # ['red', 'blue', 'green'][i]
self.agents_emotion[i] = -1
# Reset snowdrifts
for i in range(self.n_snowdrifts):
if hasattr(self, 'snowdrifts') and len(self.snowdrifts) > i:
self.canvas.delete(self.snowdrifts[i])
pos = self.snowdrifts_pos[i]
points = [
pos[0], pos[1] - 15,
pos[0] - 15, pos[1] + 15,
pos[0] + 15, pos[1] + 15
]
snowdrift = self.canvas.create_polygon(points, fill='black')
if not hasattr(self, 'snowdrifts') or len(self.snowdrifts) <= i:
self.snowdrifts.append(snowdrift)
else:
self.snowdrifts[i] = snowdrift
self.cleaned = [False] * self.n_snowdrifts
# Calculate initial state index
init_state = self._get_state_index(agent_id)
return init_state
def step_all(self, actions):
"""Multi-agent environment step
Args:
actions: List[int] - List of actions for each agent
Returns:
next_states: List[int] - Next state index for each agent
rewards: List[float] - Rewards obtained by each agent
done: bool - Whether the episode is finished
info: dict - Additional information
"""
rewards = [0] * self.n_agents
empathtrewards_t = [0] * self.n_agents
next_states = []
cleaned_this_step = [] # Record snowdrifts cleaned in this step
# 1. Move phase - all agents move simultaneously
for agent_id, action in enumerate(actions):
s = self.agents_pos[agent_id]
base_action = np.array([0, 0])
if action < 4: # Move actions
if action == 0: # up
if s[1] > self.UNIT:
base_action[1] -= self.UNIT
elif action == 1: # down
if s[1] < (self.MAZE_H - 1) * self.UNIT:
base_action[1] += self.UNIT
elif action == 2: # right
if s[0] < (self.MAZE_W - 1) * self.UNIT:
base_action[0] += self.UNIT
elif action == 3: # left
if s[0] > self.UNIT:
base_action[0] -= self.UNIT
self.canvas.move(self.agents[agent_id], base_action[0], base_action[1])
self.agents_pos[agent_id] = self.agents_pos[agent_id] + base_action
# 2. Cleaning phase - handle all cleaning actions
for agent_id, action in enumerate(actions):
if action == 4:
s = self.agents_pos[agent_id]
for i, pos in enumerate(self.snowdrifts_pos):
if all(s == pos) and not self.cleaned[i] and i not in cleaned_this_step:
self.canvas.itemconfig(self.snowdrifts[i], fill='')
self.cleaned[i] = True
cleaned_this_step.append(agent_id)
rewards[agent_id] += 2
self.agents_emotion[agent_id] = -1
self.canvas.itemconfig(self.agents[agent_id], fill='gray')
for j in range(self.n_agents):
if j != agent_id:
rewards[j] += 6
if self.agents_emotion[j] == -1:
self.agents_emotion[j] = 0
self.canvas.itemconfig(self.agents[j], fill=['red', 'blue', 'green'][j])
empathtrewards_t[agent_id] += 6
# 3. Calculate next state for each agent
for agent_id in range(self.n_agents):
next_state = self._get_state_index(agent_id)
next_states.append(next_state)
empathtrewards_t[agent_id]
# 4. Check if finished
done = all(self.cleaned)
info = {
'cleaned_positions': cleaned_this_step,
'agent_emotions': self.agents_emotion.copy()
}
return next_states, rewards, empathtrewards_t, done, info
def _get_state_index(self, agent_id):
"""Convert state to index value"""
state_index = 0
pos = self.agents_pos[agent_id]
x = int(pos[0] / (self.MAZE_W * self.UNIT) * 8)
y = int(pos[1] / (self.MAZE_H * self.UNIT) * 8)
state_index += x + y * 8
return state_index
def render(self):
"""Render environment"""
time.sleep(0.00001)
self.update()
================================================
FILE: examples/Social_Cognition/affective_empathy/BEEAD-SNN/snowdrift_main.py
================================================
import time
import datetime
import os
import random
import numpy as np
import torch
from sd_env import Snowdrift
from rsnn import RSNN
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
# Global parameters
N_action = 5 # up, down, left, right, clean
N_state = 64 # 8*8 grid
C = 50
runtime = 100
trace_decay = 0.8
torch.manual_seed(42)
np.random.seed(42)
def encode(n, e):
z = torch.zeros(N_state, 100)
z[n, :] = 1
z = z * 0.51
return z
def aoencode(n, e, env, agent_id):
z = torch.zeros(N_state, 100)
z[n, :] = 1
for i in range(len(env.agents_pos)):
if i != agent_id:
other_pos = env.agents_pos[i]
x = int(other_pos[0] / (env.MAZE_W * env.UNIT) * 8)
y = int(other_pos[1] / (env.MAZE_H * env.UNIT) * 8)
other_state_idx = x + y * 8
z[other_state_idx, :] += 0.3
for i, snow_pos in enumerate(env.snowdrifts_pos):
if not env.cleaned[i]:
x = int(snow_pos[0] / (env.MAZE_W * env.UNIT) * 8)
y = int(snow_pos[1] / (env.MAZE_H * env.UNIT) * 8)
snow_state_idx = x + y * 8
z[snow_state_idx, :] += 0.6
z = z * 0.51
return z
def poencode(n, e, env, agent_id):
"""
Encode state as partially observable representation.
Args:
n: state index of current agent position
e: emotion state
env: environment object
agent_id: agent ID
"""
z = torch.zeros(N_state, 100)
agent_pos = env.agents_pos[agent_id]
cur_x = int(agent_pos[0] / (env.MAZE_W * env.UNIT) * 8)
cur_y = int(agent_pos[1] / (env.MAZE_H * env.UNIT) * 8)
obs_range = 1 # observable grid range
z[n, :] = 1
for i in range(len(env.agents_pos)):
if i != agent_id:
other_pos = env.agents_pos[i]
x = int(other_pos[0] / (env.MAZE_W * env.UNIT) * 8)
y = int(other_pos[1] / (env.MAZE_H * env.UNIT) * 8)
if abs(x - cur_x) <= obs_range and abs(y - cur_y) <= obs_range:
other_state_idx = x + y * 8
z[other_state_idx, :] += 0.3
for i, snow_pos in enumerate(env.snowdrifts_pos):
if not env.cleaned[i]:
x = int(snow_pos[0] / (env.MAZE_W * env.UNIT) * 8)
y = int(snow_pos[1] / (env.MAZE_H * env.UNIT) * 8)
if abs(x - cur_x) <= obs_range and abs(y - cur_y) <= obs_range:
snow_state_idx = x + y * 8
z[snow_state_idx, :] += 0.6
z = z * 0.51
return z
def chooseAct(Net, input, explore, n, env, agent_id):
count_group = np.zeros(N_action)
count_output = np.zeros(N_action * C)
for i_train in range(runtime):
out, dw = Net(input[:, i_train])
Net.weight_trace *= trace_decay
Net.weight_trace += dw[0]
count_output = count_output + np.array(out)
for i in range(N_action):
count_group[i] = count_output[i*C:(i+1)*C].sum()
agent_pos = env.agents_pos[agent_id]
at_snowdrift = False
for i, snow_pos in enumerate(env.snowdrifts_pos):
if not env.cleaned[i] and all(agent_pos == snow_pos):
at_snowdrift = True
break
if not at_snowdrift:
count_group[4] = float('-inf')
if np.random.uniform() < explore:
if not at_snowdrift:
action = np.random.randint(0, 4)
else:
if count_group.max() > float('-inf'):
action = count_group.argmax()
else:
action = np.random.randint(0, N_action)
else:
if not at_snowdrift:
action = np.random.randint(0, 4)
else:
action = np.random.randint(0, N_action)
return action, Net, dw[0], 0
def train_model(n_agents, lamdas, episodes):
# TensorBoard writer
current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
log_dir = os.path.join('run33obs', f'sd_partobs_a{n_agents}_l{lamdas[0]}{lamdas[1]}{lamdas[2]}_e{episodes}_{current_time}', f'')
writer = SummaryWriter(log_dir)
nets = [RSNN(N_state, N_action*C) for _ in range(n_agents)]
learn_steps = [[] for _ in range(n_agents)]
weight_marks = [np.zeros((N_state, N_action)) for _ in range(n_agents)]
update_stops = [0 for _ in range(n_agents)]
empathy_rewards_t = [0] * n_agents
total_rewards = [0] * n_agents
env = Snowdrift(n_agents=n_agents, n_snowdrifts=10)
episode_cleaned_counts = []
agent_cleaned_counts = [0] * n_agents
for episode in range(episodes):
print(f'Episode: {episode}, Lambda: {lamdas}')
cleaned_count = 0
episode_agent_cleaned_count = [0] * n_agents
states = []
emotion_t = [-1] * n_agents
for i in range(n_agents):
state = env.reset(i)
states.append(state)
episode_rewards = [0 for _ in range(n_agents)]
episode_total_rewards = [0 for _ in range(n_agents)]
if episode < 100:
e_greedy = 0.2
elif episode < 300:
e_greedy = 0.5
elif episode < 900:
e_greedy = 0.9
else:
for i in range(n_agents):
if update_stops[i] == 0:
update_stops[i] = 1
e_greedy = 1
for t in range(100):
emotion_tt = emotion_t.copy()
emotion_t = env.agents_emotion.copy()
actions = []
for i in range(n_agents):
input_state = poencode(states[i], env.agents_emotion[i], env, i)
action, nets[i], dw, _ = chooseAct(nets[i], input_state, e_greedy, states[i], env, i)
actions.append(action)
next_states, rewards, empathy_rewards, done, info = env.step_all(actions)
cleaned_count += len(info['cleaned_positions'])
if 'cleaned_by_agent' in info:
for snow_idx, agent_idx in info['cleaned_by_agent'].items():
episode_agent_cleaned_count[agent_idx] += 1
agent_cleaned_counts[agent_idx] += 1
print(f'intereaction {t} :')
for i in range(n_agents):
if env.agents_emotion[i] == emotion_t[i] and env.agents_emotion[i] == emotion_tt[i] and (emotion_tt[0]==0 or emotion_tt[1]==0 or emotion_tt[2]==0):
total_rewards[i] = lamdas[i] * (empathy_rewards[i] - empathy_rewards_t[i]) + rewards[i]
elif env.agents_emotion[0]==-1 and env.agents_emotion[1]==-1 and env.agents_emotion[2]==-1:
total_rewards[i] = lamdas[i] * (empathy_rewards[i] - empathy_rewards_t[i]) + rewards[i]
else:
total_rewards[i] = lamdas[i] * empathy_rewards[i] + rewards[i]
print(f'Actions: {actions}, Rewards: {rewards}, Empathy Rewards: {empathy_rewards} Total Rewards: {total_rewards} emotion: {env.agents_emotion}')
empathy_rewards_t = empathy_rewards
env.render()
for i in range(n_agents):
if update_stops[i] == 0:
nets[i].UpdateWeight(total_rewards[i], actions[i], C, states[i])
states = next_states
for i in range(n_agents):
episode_rewards[i] += rewards[i]
episode_total_rewards[i] += total_rewards[i]
if done:
break
for i in range(n_agents):
writer.add_scalar(f'ERewards/Agent_{i+1}', episode_rewards[i], episode)
writer.add_scalar(f'totalRewards/Agent_{i+1}', episode_total_rewards[i], episode)
writer.add_scalar(f'Cleaned/Agent_{i+1}', episode_agent_cleaned_count[i], episode)
for i in range(n_agents):
learn_steps[i].append(episode_rewards[i])
# Save weights
if episode == episodes-1:
for i in range(n_agents):
torch.save(nets[i].connection[0].weight.data, f'weight_agent{i}_lambda{lamdas}_episode{episode}.pth')
episode_cleaned_counts.append(cleaned_count)
writer.add_scalar('Performance/Cleaned_Snowdrifts', cleaned_count, episode)
print(f'Cleaned Count: {cleaned_count}')
writer.close()
return learn_steps, episode_cleaned_counts
if __name__ == "__main__":
n_agents = 3
n_snowdrifts = 10
self_factors = [[1.51, 1.51, 1.51]]
all_learn_steps = []
all_cleaned_counts = []
for ii in range(len(self_factors)):
all_learn_steps.append([[] for _ in range(n_agents)])
all_cleaned_counts.append([])
for iii, lamdas in enumerate(self_factors):
steps, cleaned_counts = train_model(n_agents, lamdas, 1000)
for i in range(n_agents):
all_learn_steps[iii][i].extend(steps[i])
all_cleaned_counts[iii] = cleaned_counts
# Save plot images
save_dir = 'results'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
plt.figure(figsize=[24, 8])
# Plot reward curves for each agent
plt.subplot(1, 3, 1)
colors = ['red', 'blue', 'green']
labels = ['Agent 1', 'Agent 2', 'Agent 3']
for i in range(n_agents):
plt.plot(all_learn_steps[0][i], label=labels[i], color=colors[i])
plt.legend(loc='lower right')
plt.title('Rewards per Agent')
plt.xlabel('Episode')
plt.ylabel('Reward')
# Plot number of cleaned snowdrifts
plt.subplot(1, 3, 2)
plt.plot(all_cleaned_counts[0], label='Cleaned Snowdrifts', color='black')
plt.axhline(y=n_snowdrifts, color='r', linestyle='--', label='Total Snowdrifts')
plt.legend(loc='lower right')
plt.title('Number of Cleaned Snowdrifts per Episode')
plt.xlabel('Episode')
plt.ylabel('Count')
plt.ylim([0, n_snowdrifts + 1])
# Add total reward curve
plt.subplot(1, 3, 3)
total_rewards_per_episode = np.sum(all_learn_steps[0], axis=0) # Calculate total reward for each episode
plt.plot(total_rewards_per_episode, label='Total Rewards', color='purple')
plt.legend(loc='lower right')
plt.title('Total Rewards of All Agents')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.tight_layout()
# Save image
save_path = os.path.join(save_dir, f'training_results_{timestamp}.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close() # Close the figure to free memory
print(f'Image saved to: {save_path}')
================================================
FILE: examples/Social_Cognition/affective_empathy/BRP-SNN/BRP-SNN.py
================================================
import os
import sys
# 把当前文件所在文件夹的父文件夹路径加入到PYTHONPATH
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import imageio
from env_poly_SNN import Maze
from env_two_poly_SNN import Maze2
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import torch, os, sys
from torch import nn
from torch.nn import Parameter
import abc
import math
from abc import ABC
import torch.nn.functional as F
from braincog.base.node.node import *
from braincog.base.learningrule.STDP import *
from braincog.base.connection.CustomLinear import *
X=np.array([[0],
[1],
[2],
[3]])
Y=np.array([[0,-40],
[0,40],
[40,0],
[-40,0]
])
class BrainArea(nn.Module, abc.ABC):
"""
脑区基类
"""
@abc.abstractmethod
def __init__(self):
"""
"""
super().__init__()
@abc.abstractmethod
def forward(self, x):
"""
计算前向传播过程
:return:x是脉冲
"""
return x
def reset(self):
"""
计算前向传播过程
:return:x是脉冲
"""
pass
class BNESNN(BrainArea):
"""
负面情绪网络
"""
def __init__(self,):
super().__init__()
self.node = [IFNode() for i in range(5)]
self.connection = []
con_matrix0 = torch.eye(12, 12)*6
self.connection.append(CustomLinear(con_matrix0))#input-state
con_matrix1 = torch.zeros((12, 24), dtype=torch.float)
self.connection.append(CustomLinear(con_matrix1))#state-prediction
con_matrix2 = torch.eye(24, 24)*6
self.connection.append(CustomLinear(con_matrix2))#input-prediction
con_matrix3 = torch.eye(24, 24)*6
self.connection.append(CustomLinear(con_matrix3))#input-sensory
con_matrix4 = torch.eye(24, 24)*6
self.connection.append(CustomLinear(con_matrix4))#sensory-error
con_matrix5 = torch.eye(24, 24)*(-6)
self.connection.append(CustomLinear(con_matrix5))#prediction-error
con_matrix6 = torch.zeros((24, 24), dtype=torch.float)
p=0.5
if p==0.25:
con_matrix6[:,0:3]=1
if p==0.5:
con_matrix6[:,0:6]=1
if p==0.75:
con_matrix6[:,0:9]=1
if p==1:
con_matrix6[:,0:12]=1
self.connection.append(CustomLinear(con_matrix6))#error-pain
self.stdp = []
self.stdp.append(STDP(self.node[0], self.connection[0]))#node0-state,stdp0
self.stdp.append(MutliInputSTDP(self.node[1], [self.connection[1], self.connection[2]]))#node1-prediction,stdp1
self.stdp.append(STDP(self.node[3], self.connection[3]))#node3-sensory,stdp2
self.stdp.append(MutliInputSTDP(self.node[2], [self.connection[4], self.connection[5]]))#node2-error,stdp3
self.stdp.append(STDP(self.node[1], self.connection[1]))#node1-prediction,stdp4
self.stdp.append(STDP(self.node[4], self.connection[6]))#node4-pain,stdp5
def forward(self, x1,x2):
"""
计算前向传播过程,训练过程
"""
out__s, dw0 = self.stdp[0](x1)#node0
out__p,dw = self.stdp[1](out__s,x2)#node1
return dw,out__s,out__p
def calculate_error(self, x1,x2):
"""
测试过程
"""
out__s,dw = self.stdp[0](x1)#node0-state,stdp0
out__pre,dw= self.stdp[4](out__s)#node1-prediction,stdp1
out__sensory,dw = self.stdp[2](x2)#node3-sensory,stdp2
out__error,dw = self.stdp[3](out__sensory,out__pre)#node2-error,stdp3
out__pain,dw = self.stdp[5](out__error)#node4-pain,stdp5
return out__s,out__pre,out__sensory,out__error,out__pain
def UpdateWeight(self, i, dw, delta):
"""
更新第i组连接的权重 根据传入的dw值
:param i: 要更新的连接的索引
:param dw: 更新的量
:return: None
"""
self.connection[i].update(dw*delta)
self.connection[i].weight.data= torch.clamp(self.connection[i].weight.data,0,6)
def reset(self):
"""
reset神经元或学习法则的中间量
:return: None
"""
for i in range(5):
self.node[i].n_reset()
for i in range(len(self.stdp)):
self.stdp[i].reset()
def GRF(X,N):
gauss_neuron = 12
center = np.ones((gauss_neuron, 1))
width = 1 / 15
for i in range(len(center)):
center[i] = (2 * i - 3) / 20
x = np.arange(0, 1, 0.0001)
num_features = N
gauss_recpt_field = np.zeros((gauss_neuron, len(x)))
for i in range(gauss_neuron):
gauss_recpt_field[i, :] = np.exp(-(x - center[i]) ** 2 / (2 * width * width))
def gauss_response(inputs,num_features):
spike_time = np.zeros((gauss_neuron, num_features))
# input: shape [1, features]
# output: shape [gaussian neurons*features] spiking time
for i in range(num_features):
for j in range(gauss_neuron):
spike_time[j, i] = gauss_recpt_field[j, inputs[i]] #entry gauss function
spikes = []
for i in range(spike_time.shape[1]):
spikes.extend(spike_time[:, i])
return np.array(spikes)
gauss_neurons = gauss_neuron * N
scaler = MinMaxScaler()
X = scaler.fit_transform(X)
X = (X * 10000).astype(int) #10000
X[X == 10000] = 9999
input_spike = np.zeros((X.shape[0], gauss_neurons))
for i in range(X.shape[0]):
input_spike[i, :] = gauss_response(X[i, :],num_features)
input_spike[input_spike < 0.1] = 0
input_spike = np.around(100 * (1 - input_spike))
input_spike[input_spike == 0] = 1
input_spike[input_spike == 100] = 0
state=[]
for i in range(len(X)):
aa=[]
for j in range(gauss_neurons):
if input_spike[i][j] != 0:
number=input_spike[i][j]
aa.append((int(number),j))
state.append(aa)
return state
def encode(input,n_neuron):
a=len(input)
input_encode = []
for i in range(n_neuron):
temp = np.zeros([100, ])
input_encode.append(temp)
for j in range(a):
s=input[j][0]
n=input[j][1]
input_encode[n][s]=1
return input_encode
class BAESNN(BrainArea):
"""
情感共情网络
"""
def __init__(self,):
"""
"""
super().__init__()
self.node = [IFNode() for i in range(5)]
self.connection = []
con_matrix0 = torch.eye(24, 24)*6
self.connection.append(CustomLinear(con_matrix0))#input-emotion
con_matrix1 = torch.zeros((24, 50), dtype=torch.float)
for j in range(50):
if j in np.arange(0,25,1):
for i in np.arange(0, 12, 1):
con_matrix1[i,j] =2
if j in np.arange(25,50,1):
for i in np.arange(12, 24, 1):
con_matrix1[i,j] =2
self.connection.append(CustomLinear(con_matrix1))#emotion-ifg
con_matrix2 = torch.zeros((24, 50), dtype=torch.float)
self.connection.append(CustomLinear(con_matrix2))#perception-ifg
con_matrix3 = torch.eye(24, 24)*6
self.connection.append(CustomLinear(con_matrix3))#input-perception
con_matrix4=torch.zeros((24,10), dtype=torch.float)
for j in range(10):
if j in np.arange(0,5,1):
for i in np.arange(0, 12, 1):
con_matrix4[i,j] =2
if j in np.arange(5,10,1):
for i in np.arange(12, 24, 1):
con_matrix4[i,j] =2
self.connection.append(CustomLinear(con_matrix4))#emotion-sma
con_matrix5=torch.zeros((24,10), dtype=torch.float)
self.connection.append(CustomLinear(con_matrix5))#perception-m1
con_matrix6 = torch.eye(10, 10)*6
self.connection.append(CustomLinear(con_matrix6))#sma-m1
self.stdp = []
self.stdp.append(STDP(self.node[0], self.connection[0]))#0
self.stdp.append(STDP(self.node[2], self.connection[3]))#1
self.stdp.append(MutliInputSTDP(self.node[1], [self.connection[1], self.connection[2]]))#2
self.stdp.append(MutliInputSTDP(self.node[3], [self.connection[4], self.connection[5]]))#3
self.stdp.append(STDP(self.node[4], self.connection[6]))#4
self.stdp.append(STDP(self.node[1],self.connection[2]))#5
self.stdp.append(STDP(self.node[3],self.connection[5]))#6
def forward(self, x1,x2):
"""
计算前向传播过程
:return:x是脉冲
"""
out__m, dw0 = self.stdp[0](x1)#node0
out__p, dw3 = self.stdp[1](x2)#node2
out__ifg,dw_p_i=self.stdp[2](out__m,out__p)#node1
out__sma,dw_p_s=self.stdp[3](out__m,out__p)#node3
out__m1,dw1=self.stdp[4](out__sma)#node4
return dw_p_i,dw_p_s,out__ifg,out__sma,out__m1
def empathy(self,x3):
out_p,dw2=self.stdp[1](x3)#node2
out_ifg,dw4=self.stdp[5](out_p)#node1
out_sma,dw5=self.stdp[6](out_p)#node3
out_m1,dw6=self.stdp[4](out_sma)#node4
return out_ifg,out_sma,out_m1
def UpdateWeight(self, i, dw, delta):
"""
更新第i组连接的权重 根据传入的dw值
:param i: 要更新的连接的索引
:param dw: 更新的量
:return: None
"""
self.connection[i].update(dw*delta)
self.connection[i].weight.data= torch.clamp(self.connection[i].weight.data,-1,4)
def reset(self):
"""
reset神经元或学习法则的中间量
:return: None
"""
for i in range(5):
self.node[i].n_reset()
for i in range(len(self.stdp)):
self.stdp[i].reset()
def BNESNN_train():
state=GRF(X,1)
prediction=GRF(Y,2)
T=100
epoch=10
for k in range(epoch):
print('epoch:',k)
for n in range(4):
snn1.reset()
train_state = np.array(encode(state[n], 12))
train_state=torch.tensor(train_state,dtype=torch.float32)
train_prediction = np.array(encode(prediction[n], 24))
train_prediction=torch.tensor(train_prediction,dtype=torch.float32)
for i in range(T):
OUTPUT = snn1(train_state[:,i],train_prediction[:,i])
snn1.UpdateWeight(1,OUTPUT[0][0],1)
def BAESNN_train():
s = env.reset()
env._set_danger()
env._set_wall()
pain=0
i=0
set_pain=0
env._set_switch()
for i in range(100):
snn1.reset()
T=100
pain=0
print('**************step:',i)
env.render()
action = np.random.choice(list(range(env.n_actions)))
print('action:',action)
d,d_pre,s_,sss = env.step(s, action, pain)
print('d:',d,'d_pre:',d_pre,'sss:',sss)
env.render()
while (d==np.array([0,0])).all():
action = np.random.choice(list(range(env.n_actions)))
print('action:',action)
d,d_pre,s_,sss = env.step(s, action, pain)
print('d:',d,'d_pre:',d_pre,'sss:',sss)
env.render()
aa=np.argwhere(X==action)[0][0]
for i in range(4):
if (Y[i]==d).all():
b=i
print('aa:',aa,'b:',b)
state=GRF(X,1)
prediction=GRF(Y,2)
x=encode(state[aa],12)
y=encode(prediction[b],24)
train_state = np.array(x)
train_state=torch.tensor(train_state,dtype=torch.float32)
train_prediction = np.array(y)
train_prediction=torch.tensor(train_prediction,dtype=torch.float32)
OUT_PAIN=torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.]])
spike_pain=[]
spike_error=[]
for i in range(T):
OUTPUT_TEST = snn1.calculate_error(train_state[:,i],train_prediction[:,i])
spike_pain.append(OUTPUT_TEST[4])
spike_error.append(OUTPUT_TEST[3])
if OUTPUT_TEST[3].sum() != 0:
print('OUTPUT_TEST3:',i,OUTPUT_TEST[3])
if OUTPUT_TEST[4].sum() != 0:#pain brain area
print('OUTPUT_TEST4:',i,OUTPUT_TEST[4])
OUT_PAIN=OUTPUT_TEST[4]
pain=1
set_pain=1
spike_pain = torch.stack(spike_pain)
spike_error=torch.stack(spike_error)
if pain==1:
spike_rate_vis_1d(spike_error)
spike_rate_vis_1d(spike_pain)
print('pain:',pain)
snn2.reset()
T2=20
X1= OUT_PAIN.view(1, -1)
X2=torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.]])
print('X1,X2:',X1,X2)
for i in range(T2):
if i>=2:
X2=X1
OUTPUT = snn2(X1,X2)
snn2.UpdateWeight(2,OUTPUT[0][1],0.01)
snn2.UpdateWeight(5,OUTPUT[1][1],-0.1)
if OUTPUT[2][0][0]==1:
env.canvas.itemconfig(env.rect, fill="red", outline='red')
if OUTPUT[2][0][0]==0:
env.canvas.itemconfig(env.rect, fill="green", outline='green')
env.render()
print('out_ifg:',OUTPUT[2])
print('out_sma:',OUTPUT[3])
print('out_m1:',OUTPUT[4])
print('con2:',snn2.connection[2].weight.data)
print('con5:',snn2.connection[5].weight.data)
s = s_
if set_pain==1 and pain==0:
env.render()
break
env.destroy()
def BAESNN_test():
s1,s=env2.reset()
pain=0
pain1 = 0
i=0
set_pain=0
for i in range(100):
snn1.reset()
T=100
pain=0
print('**************test_step:',i)
env2.render()
action1 = np.random.choice(list(range(env.n_actions)))
print('action1:',action1)
d,d_pre,s1_,sss = env2.step(s1, action1, pain1)
print('d:',d,'d_pre:',d_pre,'sss:',sss)
env2.render()
while (d==np.array([0,0])).all():
action1 = np.random.choice(list(range(env.n_actions)))
print('action1:',action1)
d,d_pre,s1_,sss = env2.step(s, action1, pain1)
print('d:',d,'d_pre:',d_pre,'sss:',sss)
env2.render()
aa=np.argwhere(X==action1)[0][0]
for i in range(4):
if (Y[i]==d).all():
b=i
# print('aa:',aa,'b:',b)
state=GRF(X,1)
prediction=GRF(Y,2)
x=encode(state[aa],12)
y=encode(prediction[b],24)
train_state = np.array(x)
train_state=torch.tensor(train_state,dtype=torch.float32)
train_prediction = np.array(y)
train_prediction=torch.tensor(train_prediction,dtype=torch.float32)
OUT_PAIN=torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.]])
for i in range(T):
OUTPUT_TEST = snn1.calculate_error(train_state[:,i],train_prediction[:,i])
if OUTPUT_TEST[3].sum() != 0:
print('OUTPUT_TEST3:',i,OUTPUT_TEST[3])
if OUTPUT_TEST[4].sum() != 0:#pain brain area
print('OUTPUT_TEST4:',i,OUTPUT_TEST[4])
OUT_PAIN=OUTPUT_TEST[4]
pain1=1
set_pain=1
print('pain1:',pain1)
env2.generate_expression1(pain1)
snn2.reset()
T2=20
X3= OUT_PAIN.view(1, -1)
for i in range(T2):
OUT=snn2.empathy(X3)
print('out_ifg:',OUT[0])
if OUT[0][0][0]==1:
pain=1
if OUT[0][0][0]==0:
pain=0
if pain==1:
env2.agent_help()
s1 = s1_
env2.render()
if pain==0 and set_pain==1:
env2.render()
break
env2.destroy()
if __name__ == "__main__":
env = Maze()
snn1 = BNESNN()
snn2 = BAESNN()
BNESNN_train()
BAESNN_train()
env.mainloop()
env2 = Maze2()
BAESNN_test()
env2.mainloop()
================================================
FILE: examples/Social_Cognition/affective_empathy/BRP-SNN/README.md
================================================
================================================
FILE: examples/Social_Cognition/affective_empathy/BRP-SNN/env_poly_SNN.py
================================================
import numpy as np
np.random.seed(1)
import tkinter as tk
import time
from PIL import ImageGrab
UNIT = 40 # pixels
MAZE_H = 9 # grid height
MAZE_W = 4 # grid width
class Maze(tk.Tk, object):
def __init__(self):
super(Maze, self).__init__()
self.action_space = ['u', 'd', 'l', 'r']
self.n_actions = len(self.action_space)
self.title('self-pain')
self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_W * UNIT))
self._build_maze()
self.danger=0
self.action_hurt=0
self.sensory_hurt = 0
self.open_door = 0
self.pain_state=0
# create environment
def _build_maze(self):
self.canvas = tk.Canvas(self, bg='white',
height=MAZE_W * UNIT,
width=MAZE_H * UNIT)
# create grids
for c in range(0, MAZE_H * UNIT, UNIT):
x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
self.canvas.create_line(x0, y0, x1, y1)
for r in range(0, MAZE_H * UNIT, UNIT):
x0, y0, x1, y1 = 0, r, MAZE_H * UNIT, r
self.canvas.create_line(x0, y0, x1, y1)
self.orgin=[20,20]
# create agent
# 下
self.points1 = [
# 左上
self.orgin[0]-15,#5
self.orgin[1]-15,#5
# 右上
self.orgin[0]+15,#35
self.orgin[1]-15,#5
# 右下+
self.orgin[0]+15,#35
self.orgin[1],#20
# 顶点
self.orgin[0],#20
self.orgin[1]+15,#35
# 左下+
self.orgin[0]-15,#5
self.orgin[1],#20
]
self.rect = self.canvas.create_polygon(self.points1, fill="green")
self.canvas.pack()
#reset agent location
def reset(self):
self.open_door = 0
self.update()
time.sleep(0.5)
self.canvas.delete(self.rect)
self.orgin = [20, 20]
# 下
self.points1 = [
# 左上
self.orgin[0] - 15, # 5
self.orgin[1] - 15, # 5
# 右上
self.orgin[0] + 15, # 35
self.orgin[1] - 15, # 5
# 右下+
self.orgin[0] + 15, # 35
self.orgin[1], # 20
# 顶点
self.orgin[0], # 20
self.orgin[1] + 15, # 35
# 左下+
self.orgin[0] - 15, # 5
self.orgin[1], # 20
]
self.rect = self.canvas.create_polygon(self.points1, fill="green")
return self.canvas.coords(self.rect)
def step(self, s, action, pain):
s = self.canvas.coords(self.rect)
self.centre = [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]
# danger or switch
if self.danger==1:
if all(self.centre == self.oval_center):
s_color = 'yellow'
self.canvas.delete(self.wall[3])
self.render()
self.open_door = 1
move = np.array([80, 0])
self.canvas.move(self.rect, move[0], move[1])
s = self.canvas.coords(self.rect)
self.render()
elif all(self.centre == self.hell1_center):
s_color = 'black'
self.action_hurt = 1
self.render()
else:
s_color = 'white'
# modify current state
self.canvas.delete(self.rect)# 主要为开关那几步考虑,所以重复写了
self.centre = [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]
if action==0:
self.points0 = [
# 右下
self.centre[0] + 15, # 35
self.centre[1] + 15, # 35
# 左下
self.centre[0] - 15, # 5
self.centre[1] + 15, # 35
# 左上+
self.centre[0] - 15, # 5
self.centre[1], # 20
# 顶点
self.centre[0], # 20
self.centre[1] - 15, # 5
# 右上+
self.centre[0] + 15, # 35
self.centre[1], # 20
]
if pain==0:
color="green"
if pain == 1:
color = "red"
self.rect = self.canvas.create_polygon(self.points0, fill=color)
if action==1:
self.points1 = [
# 左上
self.centre[0] - 15, # 5
self.centre[1] - 15, # 5
# 右上
self.centre[0] + 15, # 35
self.centre[1] - 15, # 5
# 右下+
self.centre[0] + 15, # 35
self.centre[1], # 20
# 顶点
self.centre[0], # 20
self.centre[1] + 15, # 35
# 左下+
self.centre[0] - 15, # 5
self.centre[1], # 20
]
if pain==0:
color="green"
if pain == 1:
color = "red"
self.rect = self.canvas.create_polygon(self.points1, fill=color)
if action==2:
self.points2 = [
# 左下
self.centre[0] - 15, # 5
self.centre[1] + 15, # 35
# 左上
self.centre[0] - 15, # 5
self.centre[1] - 15, # 5
# 右上+
self.centre[0], # 20
self.centre[1] - 15, # 5
# 顶点
self.centre[0] + 15, # 35
self.centre[1], # 20
# 右下+
self.centre[0], # 20
self.centre[1] + 15, # 35
]
if pain==0:
color="green"
if pain == 1:
color = "red"
self.rect = self.canvas.create_polygon(self.points2, fill=color)
if action==3:
self.points3 = [
# 右上
self.centre[0] + 15, # 20+15
self.centre[1] - 15, # 20-15
# 右下
self.centre[0] + 15, # 20+15
self.centre[1] + 15, # 20+15
# 左下+
self.centre[0], # 20
self.centre[1] + 15, # 20+15
# 顶点
self.centre[0] - 15, # 20-15
self.centre[1], # 20
# 左上+
self.centre[0], # 20
self.centre[1] - 15, # 20-15
]
if pain==0:
color="green"
if pain == 1:
color = "red"
self.rect = self.canvas.create_polygon(self.points3, fill=color)
s = self.canvas.coords(self.rect)
self.render()#显示当前的动作指令是什么
if s[0] > (9 / 2) * 40:
self.action_hurt = 0
base_action = np.array([0, 0])
if self.action_hurt == 0:
true_action = action
else:
if action == 0:
true_action = 1
if action == 1:
true_action = 0
if action == 2:
true_action = 3
if action == 3:
true_action = 2
# predict next state
# predict next state
self.centre1 = [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]
pre_displacement1 = np.array([0, 0])
if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT: # 120
if action == 0: # up
if self.centre1[1] > UNIT:
pre_displacement1 = np.array([0, -40])
elif action == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
pre_displacement1 = np.array([0, 40])
elif action == 2: # right
if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:
pre_displacement1 = np.array([40, 0])
elif action == 3: # left
if self.centre1[0] > UNIT:
pre_displacement1 = np.array([-40, 0])
else:
if action == 0: # up
if self.centre1[1] > UNIT:
pre_displacement1 = np.array([0, -40])
elif action == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
pre_displacement1 = np.array([0, 40])
elif action == 2: # right
if self.centre1[0] < (MAZE_H - 1) * UNIT:
pre_displacement1 = np.array([40, 0])
elif action == 3: # left
if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:
pre_displacement1 = np.array([-40, 0])
# true next state
displacement1 = np.array([0, 0])
if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:
if true_action == 0: # up
if self.centre1[1] > UNIT:
displacement1=np.array([0,-40])
elif true_action == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
displacement1=np.array([0,40])
elif true_action == 2: # right
if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:
displacement1=np.array([40,0])
elif true_action == 3: # left
if self.centre1[0] > UNIT:
displacement1=np.array([-40,0])
else:
if true_action == 0: # up
if self.centre1[1] > UNIT:
displacement1=np.array([0,-40])
elif true_action == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
displacement1=np.array([0,40])
elif true_action == 2: # right
if self.centre1[0] < (MAZE_H - 1) * UNIT:
displacement1=np.array([40,0])
elif true_action == 3: # left
if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:
displacement1=np.array([-40,0])
self.canvas.move(self.rect, displacement1[0], displacement1[1])
s1_ = self.canvas.coords(self.rect)
sss = [(s1_[4] + s1_[8]) / 2, (s1_[5] + s1_[9]) / 2]
return displacement1, pre_displacement1,s1_,sss
def _set_danger(self):
self.hell1_center = np.array([60, 60])
self.hell1 = self.canvas.create_oval(
self.hell1_center[0] - 15, self.hell1_center[1] - 15,
self.hell1_center[0] + 15, self.hell1_center[1] + 15,
fill='black')
# self.canvas.create_bitmap((40 , 40), bitmap='error')
self.hell = self.canvas.coords(self.hell1)
self.canvas.pack()
self.danger=1
def _set_switch(self):
self.oval_center = np.array([(MAZE_H * UNIT) / 2 - UNIT, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])
self.oval = self.canvas.create_oval(
self.oval_center[0] - 15, self.oval_center[1] - 15,
self.oval_center[0] + 15, self.oval_center[1] + 15,
fill='yellow')
self.switch = self.canvas.coords(self.oval)
self.canvas.pack()
def _set_wall(self):
wall_center=[]
self.wall=[]
for a in range(MAZE_W):
wall_center.append([0,0])
self.wall.append([])
for b in range(MAZE_W):
wall_center[b]=np.array([(MAZE_H*UNIT)/2,((b)*UNIT)+UNIT/2])
self.wall[b] = self.canvas.create_rectangle(
wall_center[b][0] - 20, wall_center[b][1] - 20,
wall_center[b][0] + 20, wall_center[b][1] + 20,
fill='grey')
self.wall0 = self.canvas.coords(self.wall[0])
self.wall1 = self.canvas.coords(self.wall[1])
self.wall2 = self.canvas.coords(self.wall[2])
self.wall3 = self.canvas.coords(self.wall[3])
# self.canvas.pack()
def generate_expression(self,pain):
if pain==1:
self.canvas.itemconfig(self.rect, fill="red", outline='red')
# self.canvas.pack()
if pain == 0:
self.canvas.itemconfig(self.rect, fill="green", outline='green')
# self.canvas.pack()
def render(self):
time.sleep(0.2)
self.update()
# def getter(self, widget):
# widget.update()
# x = tk.Tk.winfo_rootx(self) + widget.winfo_x()
# y = tk.Tk.winfo_rooty(self) + widget.winfo_y()
# x1 = x + widget.winfo_width()
# y1 = y + widget.winfo_height()
# ImageGrab.grab().crop((x, y, x1, y1)).save("first.jpg")
# return ImageGrab.grab().crop((x, y, x1, y1))
================================================
FILE: examples/Social_Cognition/affective_empathy/BRP-SNN/env_two_poly_SNN.py
================================================
import numpy as np
np.random.seed(1)
import tkinter as tk
import time
from PIL import ImageGrab
UNIT = 40 # pixels
MAZE_H = 9 # grid height
MAZE_W = 4 # grid width
class Maze2(tk.Tk, object):
def __init__(self):
super(Maze2, self).__init__()
self.action_space = ['u', 'd', 'l', 'r']
self.action_space1 = ['u', 'd', 'l', 'r']
self.n_actions = len(self.action_space)
self.n_actions1 = len(self.action_space1)
self.title('two_agent_empathy')
self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_W * UNIT))
self._build_maze()
self.danger=0
self.action_hurt=0
self.sensory_hurt = 0
self.action_hurt1 = 0
self.sensory_hurt1 = 0
self.open_door=0
# create environment
def _build_maze(self):
self.canvas = tk.Canvas(self, bg='white',
height=MAZE_W * UNIT,
width=MAZE_H * UNIT)
# create grids
for c in range(0, MAZE_H * UNIT, UNIT):
x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
self.canvas.create_line(x0, y0, x1, y1)
for r in range(0, MAZE_H * UNIT, UNIT):
x0, y0, x1, y1 = 0, r, MAZE_H * UNIT, r
self.canvas.create_line(x0, y0, x1, y1)
# create switch
self.oval_center = np.array([(MAZE_H * UNIT)/2-UNIT+80, ((MAZE_W+4) * UNIT)/2-UNIT/2-80])
self.oval = self.canvas.create_oval(
self.oval_center[0] - 15, self.oval_center[1] - 15,
self.oval_center[0] + 15, self.oval_center[1] + 15,
fill='yellow')
self.switch = self.canvas.coords(self.oval)
self.orgin1 = np.array([20, 20])
# 下
self.points1 = [
# 左上
self.orgin1[0] - 15, # 5
self.orgin1[1] - 15, # 5
# 右上
self.orgin1[0] + 15, # 35
self.orgin1[1] - 15, # 5
# 右下+
self.orgin1[0] + 15, # 35
self.orgin1[1], # 20
# 顶点
self.orgin1[0], # 20
self.orgin1[1] + 15, # 35
# 左下+
self.orgin1[0] - 15, # 5
self.orgin1[1], # 20
]
self.agent1 = self.canvas.create_polygon(self.points1, outline='black',fill="blue")
self.orgin = np.array([MAZE_H * UNIT - UNIT / 2, 20])
# 下
self.points = [
# 左上
self.orgin[0] - 15, # 5
self.orgin[1] - 15, # 5
# 右上
self.orgin[0] + 15, # 35
self.orgin[1] - 15, # 5
# 右下+
self.orgin[0] + 15, # 35
self.orgin[1], # 20
# 顶点
self.orgin[0], # 20
self.orgin[1] + 15, # 35
# 左下+
self.orgin[0] - 15, # 5
self.orgin[1], # 20
]
self.agent = self.canvas.create_polygon(self.points, fill="green")
wall_center = []
self.wall = []
for i in range(MAZE_W):
wall_center.append([])
self.wall.append([])
for i in range(MAZE_W):
wall_center[i] = np.array([(MAZE_H * UNIT) / 2, ((i) * UNIT) + UNIT / 2])
self.wall[i] = self.canvas.create_rectangle(
wall_center[i][0] - 20, wall_center[i][1] - 20,
wall_center[i][0] + 20, wall_center[i][1] + 20,
fill='grey')
self.hell1_center = np.array([100, 20])
self.hell1 = self.canvas.create_oval(
self.hell1_center[0] - 15, self.hell1_center[1] - 15,
self.hell1_center[0] + 15, self.hell1_center[1] + 15,
fill='black')
self.hell2_center = np.array([60, 100])
self.hell2 = self.canvas.create_oval(
self.hell2_center[0] - 15, self.hell2_center[1] - 15,
self.hell2_center[0] + 15, self.hell2_center[1] + 15,
fill='black')
# self.canvas.create_bitmap((40 , 40), bitmap='error')
self.danger = 1
self.canvas.pack()
#reset agent location
def reset(self):
self.update()
time.sleep(0.5)
self.canvas.delete(self.agent1)
self.canvas.delete(self.agent)
self.orgin1 = np.array([20, 20])
# 下
self.points1 = [
# 左上
self.orgin1[0] - 15, # 5
self.orgin1[1] - 15, # 5
# 右上
self.orgin1[0] + 15, # 35
self.orgin1[1] - 15, # 5
# 右下+
self.orgin1[0] + 15, # 35
self.orgin1[1], # 20
# 顶点
self.orgin1[0], # 20
self.orgin1[1] + 15, # 35
# 左下+
self.orgin1[0] - 15, # 5
self.orgin1[1], # 20
]
self.agent1 = self.canvas.create_polygon(self.points1, outline='black',fill="blue")
self.orgin = np.array([MAZE_H * UNIT - UNIT / 2, 20])
# 下
self.points = [
# 左上
self.orgin[0] - 15, # 5
self.orgin[1] - 15, # 5
# 右上
self.orgin[0] + 15, # 35
self.orgin[1] - 15, # 5
# 右下+
self.orgin[0] + 15, # 35
self.orgin[1], # 20
# 顶点
self.orgin[0], # 20
self.orgin[1] + 15, # 35
# 左下+
self.orgin[0] - 15, # 5
self.orgin[1], # 20
]
self.agent = self.canvas.create_polygon(self.points, fill="green")
return self.canvas.coords(self.agent1),self.canvas.coords(self.agent)
def step(self, s, action, pain):
s1 = self.canvas.coords(self.agent1)
self.centre1 = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2]
if all(self.centre1 == self.hell1_center):
self.action_hurt1 = 1
if all(self.centre1 == self.hell2_center):
self.action_hurt1 = 1
self.oval_center111 = np.array([(MAZE_H * UNIT) / 2, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])
if all(self.centre1 ==self.oval_center111):
move = np.array([80, 0])
self.canvas.move(self.agent1, move[0], move[1])
s1 = self.canvas.coords(self.agent1)
self.render()
self.oval_center111 = np.array([(MAZE_H * UNIT) / 2 - UNIT, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])
if all(self.centre1 ==self.oval_center111):
move = np.array([80, 0])
self.canvas.move(self.agent1, move[0], move[1])
s1 = self.canvas.coords(self.agent1)
self.render()
self.oval_center111 = np.array([(MAZE_H * UNIT) / 2 - UNIT*2, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])
if all(self.centre1 == self.oval_center111):
move = np.array([80, 0])
self.canvas.move(self.agent1, move[0], move[1])
s1 = self.canvas.coords(self.agent1)
self.render()
self.oval_center111 = np.array([(MAZE_H * UNIT) / 2 - UNIT*3, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])
if all(self.centre1 == self.oval_center111):
move = np.array([80, 0])
self.canvas.move(self.agent1, move[0], move[1])
s1 = self.canvas.coords(self.agent1)
self.render()
self.oval_center111 = np.array([(MAZE_H * UNIT) / 2 - UNIT*4, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])
if all(self.centre1 == self.oval_center111):
move = np.array([80, 0])
self.canvas.move(self.agent1, move[0], move[1])
s1 = self.canvas.coords(self.agent1)
self.render()
#显示当前的动作指令是什么
self.canvas.delete(self.agent1)
self.centre = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2]
if action==0:
self.points0 = [
# 右下
self.centre[0] + 15, # 35
self.centre[1] + 15, # 35
# 左下
self.centre[0] - 15, # 5
self.centre[1] + 15, # 35
# 左上+
self.centre[0] - 15, # 5
self.centre[1], # 20
# 顶点
self.centre[0], # 20
self.centre[1] - 15, # 5
# 右上+
self.centre[0] + 15, # 35
self.centre[1], # 20
]
if pain==0:
color="blue"
if pain == 1:
color = "red"
self.agent1 = self.canvas.create_polygon(self.points0, fill=color)
if action==1:
self.points1 = [
# 左上
self.centre[0] - 15, # 5
self.centre[1] - 15, # 5
# 右上
self.centre[0] + 15, # 35
self.centre[1] - 15, # 5
# 右下+
self.centre[0] + 15, # 35
self.centre[1], # 20
# 顶点
self.centre[0], # 20
self.centre[1] + 15, # 35
# 左下+
self.centre[0] - 15, # 5
self.centre[1], # 20
]
if pain==0:
color="blue"
if pain == 1:
color = "red"
self.agent1 = self.canvas.create_polygon(self.points1, fill=color)
if action==2:
self.points2 = [
# 左下
self.centre[0] - 15, # 5
self.centre[1] + 15, # 35
# 左上
self.centre[0] - 15, # 5
self.centre[1] - 15, # 5
# 右上+
self.centre[0], # 20
self.centre[1] - 15, # 5
# 顶点
self.centre[0] + 15, # 35
self.centre[1], # 20
# 右下+
self.centre[0], # 20
self.centre[1] + 15, # 35
]
if pain==0:
color="blue"
if pain == 1:
color = "red"
self.agent1 = self.canvas.create_polygon(self.points2, fill=color)
if action==3:
self.points3 = [
# 右上
self.centre[0] + 15, # 20+15
self.centre[1] - 15, # 20-15
# 右下
self.centre[0] + 15, # 20+15
self.centre[1] + 15, # 20+15
# 左下+
self.centre[0], # 20
self.centre[1] + 15, # 20+15
# 顶点
self.centre[0] - 15, # 20-15
self.centre[1], # 20
# 左上+
self.centre[0], # 20
self.centre[1] - 15, # 20-15
]
if pain==0:
color="blue"
if pain == 1:
color = "red"
self.agent1 = self.canvas.create_polygon(self.points3, fill=color)
s1 = self.canvas.coords(self.agent1)
self.render()#显示当前的动作指令是什么
self.centre1 = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2]
if self.centre1[0] > (9 / 2) * 40:
self.action_hurt1 = 0
# whether hurt
if self.action_hurt1 == 0:
true_action = action
else:
if action == 0:
true_action = 1
if action == 1:
true_action = 0
if action == 2:
true_action = 3
if action == 3:
true_action = 2
base_action = np.array([0, 0])
# predict next state
self.centre1 = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2]
pre_displacement1 = np.array([0, 0])
if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT: # 120
if action == 0: # up
if self.centre1[1] > UNIT:
pre_displacement1 = np.array([0, -40])
elif action == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
pre_displacement1 = np.array([0, 40])
elif action == 2: # right
if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:
pre_displacement1 = np.array([40, 0])
elif action == 3: # left
if self.centre1[0] > UNIT:
pre_displacement1 = np.array([-40, 0])
else:
if action == 0: # up
if self.centre1[1] > UNIT:
pre_displacement1 = np.array([0, -40])
elif action == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
pre_displacement1 = np.array([0, 40])
elif action == 2: # right
if self.centre1[0] < (MAZE_H - 1) * UNIT:
pre_displacement1 = np.array([40, 0])
elif action == 3: # left
if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:
pre_displacement1 = np.array([-40, 0])
# true next state
displacement1 = np.array([0, 0])
if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:
if true_action == 0: # up
if self.centre1[1] > UNIT:
displacement1=np.array([0,-40])
elif true_action == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
displacement1=np.array([0,40])
elif true_action == 2: # right
if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:
displacement1=np.array([40,0])
elif true_action == 3: # left
if self.centre1[0] > UNIT:
displacement1=np.array([-40,0])
else:
if true_action == 0: # up
if self.centre1[1] > UNIT:
displacement1=np.array([0,-40])
elif true_action == 1: # down
if self.centre1[1] < (MAZE_W - 1) * UNIT:
displacement1=np.array([0,40])
elif true_action == 2: # right
if self.centre1[0] < (MAZE_H - 1) * UNIT:
displacement1=np.array([40,0])
elif true_action == 3: # left
if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:
displacement1=np.array([-40,0])
self.canvas.move(self.agent1, displacement1[0], displacement1[1])
s1_ = self.canvas.coords(self.agent1)
sss = [(s1_[4] + s1_[8]) / 2, (s1_[5] + s1_[9]) / 2]
return displacement1, pre_displacement1,s1_,sss
def agent_help(self):
s = self.canvas.coords(self.agent)
self.centre2= [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]
if all(self.centre2 == self.oval_center):
self.canvas.delete(self.wall[3])
self.render()
self.open_door=1
else:
self.canvas.move(self.agent, -40, 0) # move agent
self.render()
self.canvas.move(self.agent, -40, 0)
self.render()
self.canvas.move(self.agent, -40, 0)
self.render()
self.canvas.move(self.agent, 0, 40)
self.render()
s_ = self.canvas.coords(self.agent) # next state
return s_
def _set_danger(self):
hell1_center = np.array([140, 60])
self.hell1 = self.canvas.create_agent1angle(
hell1_center[0] - 15, hell1_center[1] - 15,
hell1_center[0] + 15, hell1_center[1] + 15,
fill='black')
hell2_center = np.array([100, 140])
self.hell2 = self.canvas.create_agent1angle(
hell2_center[0] - 15, hell2_center[1] - 15,
hell2_center[0] + 15, hell2_center[1] + 15,
fill='black')
# self.canvas.create_bitmap((40 , 40), bitmap='error')
self.canvas.pack()
self.danger=1
def _set_wall(self):
wall_center=[]
self.wall=[]
for i in range(MAZE_W):
wall_center.append([])
self.wall.append([])
for i in range(MAZE_W):
wall_center[i]=np.array([(MAZE_H*UNIT)/2,((i)*UNIT)+UNIT/2])
self.wall[i] = self.canvas.create_agent1angle(
wall_center[i][0] - 20, wall_center[i][1] - 20,
wall_center[i][0] + 20, wall_center[i][1] + 20,
fill='grey')
self.canvas.pack()
def generate_expression1(self,pain1):
if pain1==1:
self.canvas.itemconfig(self.agent1, fill="red", outline='black')
self.canvas.pack()
if pain1 ==0:
self.canvas.itemconfig(self.agent1, fill="blue", outline='black')
self.canvas.pack()
def render(self):
time.sleep(0.2)
self.update()
================================================
FILE: examples/Social_Cognition/mirror_test/README.md
================================================
# Mirror Test
The mirror_test.py implements the core code of the Multi-Robots Mirror Self-Recognition Test in "Toward Robot Self-Consciousness (II): Brain-Inspired Robot Bodily Self Model for Self-Recognition".
The experiment is: three robots with identical appearance move their arms randomly in front of the mirror at the same time.
In the training stage, according to the spiking time difference of neurons in IPLM and IPLV, the robot learns the correlations between self-generated actions and visual feedbacks in motion by learning with spike timing dependent plasticity (STDP) mechanism.
In the test stage, the robot can predicts the visual feedback generated by its arm movement according to the training results. With the InsulaNet, the robot can identify which mirror image belongs to it.
In the result, Motion Detection shows the results of visual detection, and Motion Prediction shows the visual feedback generated by itself. The red line in the figure indicates that the robot determines that the corresponding mirror belongs to itself.
Differences from the original article:
Since there is no motion error under the simulation conditions, the theta_threshold is set to zero.
### Citation
If you find this package helpful, please consider citing the following papers:
```BibTex
@article{zeng2018toward,
title={Toward robot self-consciousness (ii): brain-inspired robot bodily self model for self-recognition},
author={Zeng, Yi and Zhao, Yuxuan and Bai, Jun and Xu, Bo},
journal={Cognitive Computation},
volume={10},
number={2},
pages={307--320},
year={2018},
publisher={Springer}
}
@misc{https://doi.org/10.48550/arxiv.2207.08533,
doi = {10.48550/ARXIV.2207.08533},
url = {https://arxiv.org/abs/2207.08533},
author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},
title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},
publisher = {arXiv},
year = {2022},
}
```
================================================
FILE: examples/Social_Cognition/mirror_test/mirror_test.py
================================================
from braincog.base.brainarea.Insula import *
from braincog.base.brainarea.IPL import *
from braincog.base.learningrule.STDP import *
from braincog.base.node.node import *
from braincog.base.connection.CustomLinear import *
import random
import numpy as np
import torch
import os
import sys
from torch import nn
from torch.nn import Parameter
import abc
import math
from abc import ABC
import numpy as np
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
import matplotlib.pyplot as plt
from braincog.base.strategy.surrogate import *
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
if __name__ == "__main__":
"""
Set the number of neurons, and each neuron represents unique motion information (such as angle)
"""
# number of neurons
num_neuron = 5
num_vPMC = num_neuron
num_STS = num_neuron
num_IPLM = num_neuron
num_IPLV = num_neuron
num_Insula = num_neuron
"""
Setting the network structure and the initial weight of IPL
"""
# IPLNet
# connection
connection = []
# vPMC-IPLM
con_matrix0 = torch.eye(num_IPLM, dtype=torch.float) * 2.5
connection.append(CustomLinear(con_matrix0))
# STS-IPLV
con_matrix1 = torch.eye(num_IPLV, dtype=torch.float) * 2.5
connection.append(CustomLinear(con_matrix1))
# IPLM-IPLV
con_matrix2 = torch.zeros([num_IPLM, num_IPLV], dtype=torch.float)
connection.append(CustomLinear(con_matrix2))
IPL = IPLNet(connection)
print("IPL Connection (Before training):", connection[2].weight)
"""
Setting the network structure and the initial weight of Insula
"""
# InsulaNet
# connection
Insula_connection = []
# IPLV-Insula
con_matrix0 = torch.eye(num_IPLM, dtype=torch.float) * 2
Insula_connection.append(CustomLinear(con_matrix0))
# STS-Insula
con_matrix1 = torch.eye(num_IPLV, dtype=torch.float) * 2
Insula_connection.append(CustomLinear(con_matrix1))
Insula = InsulaNet(Insula_connection)
"""
Training process
:param train_num: number of movements during training
"""
# Train
for vPMC_Angel in range(1, num_vPMC + 1):
# vPMC Angle
vPMC_Angel_v = torch.zeros([1, num_vPMC], dtype=torch.float)
vPMC_Angel_v[0, vPMC_Angel - 1] = 20
dwIPL_temp = torch.zeros([num_IPLM, num_IPLV], dtype=torch.float)
train_num = 10
for i_train in range(train_num):
# STS 1
STS_Angel_1 = vPMC_Angel
for t in range(2):
vPMC_input = vPMC_Angel_v
STS_Angel_v = torch.zeros([1, num_STS], dtype=torch.float)
STS_Angel_v[0, STS_Angel_1 - 1] = 20
STS_input = STS_Angel_v
IPLV_out, dwIPL = IPL(vPMC_input, STS_input)
dwIPL_temp = dwIPL_temp + dwIPL
IPL.reset()
# STS 2
STS_Angel_2 = random.randint(1, num_neuron)
for t in range(2):
vPMC_input = vPMC_Angel_v
STS_Angel_v = torch.zeros([1, num_STS], dtype=torch.float)
STS_Angel_v[0, STS_Angel_2 - 1] = 20
STS_input = STS_Angel_v
IPLV_out, dwIPL = IPL(vPMC_input, STS_input)
dwIPL_temp = dwIPL_temp + dwIPL
IPL.reset()
# STS 3
STS_Angel_3 = random.randint(1, num_neuron)
for t in range(2):
vPMC_input = vPMC_Angel_v
STS_Angel_v = torch.zeros([1, num_STS], dtype=torch.float)
STS_Angel_v[0, STS_Angel_3 - 1] = 20
STS_input = STS_Angel_v
IPLV_out, dwIPL = IPL(vPMC_input, STS_input)
dwIPL_temp = dwIPL_temp + dwIPL
IPL.reset()
IPL.UpdateWeight(2, dwIPL_temp)
print("IPL Connection (After training):", connection[2].weight)
"""
Test process
:param move_count: number of movements during test
"""
# Test
move_count = 10
TestList_vPMC_Angel = np.random.randint(1, num_vPMC, move_count)
TestList_STS_Angel_1 = TestList_vPMC_Angel
TestList_STS_Angel_2 = np.random.randint(1, num_STS, move_count)
TestList_STS_Angel_3 = np.random.randint(1, num_STS, move_count)
TestMat_STS_Angle = np.vstack((TestList_STS_Angel_1, TestList_STS_Angel_2, TestList_STS_Angel_3))
np.random.shuffle(TestMat_STS_Angle)
TestList_IPLV_out = []
for i_test in range(move_count):
Test_vPMC_Angel = TestList_vPMC_Angel[i_test]
Test_vPMC_Angel_v = torch.zeros([1, num_vPMC], dtype=torch.float)
Test_vPMC_Angel_v[0, Test_vPMC_Angel - 1] = 20
Test_STS_Angel_v = torch.zeros([1, num_STS], dtype=torch.float)
for t in range(2):
IPL(Test_vPMC_Angel_v, Test_STS_Angel_v)
IPLV_out_f = torch.argmax(IPL.node[1].u) + 1
IPL.reset()
TestList_IPLV_out.append(IPLV_out_f.numpy().item())
confidence = [0, 0, 0]
for i in range(move_count):
theta_predict = TestList_IPLV_out[i]
theta_visual_1 = TestMat_STS_Angle[0][i]
theta_visual_2 = TestMat_STS_Angle[1][i]
theta_visual_3 = TestMat_STS_Angle[2][i]
Test_IPL_v = torch.zeros([1, num_IPLV], dtype=torch.float)
Test_IPL_v[0, theta_predict - 1] = 20
Test_STS1_v = torch.zeros([1, num_STS], dtype=torch.float)
Test_STS1_v[0, theta_visual_1 - 1] = 20
for t in range(2):
Insula(Test_IPL_v, Test_STS1_v)
if sum(sum(Insula.out_Insula)) > 0:
confidence[0] = confidence[0] + 1
Insula.reset()
Test_STS2_v = torch.zeros([1, num_STS], dtype=torch.float)
Test_STS2_v[0, theta_visual_2 - 1] = 20
for t in range(2):
Insula(Test_IPL_v, Test_STS2_v)
if sum(sum(Insula.out_Insula)) > 0:
confidence[1] = confidence[1] + 1
Insula.reset()
Test_STS3_v = torch.zeros([1, num_STS], dtype=torch.float)
Test_STS3_v[0, theta_visual_3 - 1] = 20
for t in range(2):
Insula(Test_IPL_v, Test_STS3_v)
if sum(sum(Insula.out_Insula)) > 0:
confidence[2] = confidence[2] + 1
Insula.reset()
x_0 = torch.arange(0, move_count)
x_1 = torch.arange(move_count * 1, move_count * 2)
x_2 = torch.arange(move_count * 2, move_count * 3)
color_list = ['k', 'k', 'k']
color_list[confidence.index(max(confidence))] = 'r'
plt.subplot(211)
plt.figure(1)
plt.plot(x_0, TestMat_STS_Angle[0], color=color_list[0])
plt.plot(x_1, TestMat_STS_Angle[1], color=color_list[1])
plt.plot(x_2, TestMat_STS_Angle[2], color=color_list[2])
plt.title("Motion Detection")
plt.subplot(212)
plt.plot(x_0, TestList_IPLV_out, color='r')
plt.title("Motion Prediction")
plt.tight_layout()
plt.show()
================================================
FILE: examples/Spiking-Transformers/LIFNode.py
================================================
from timm.models.layers import to_2tuple, trunc_normal_, DropPath
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
import torch.nn.functional as F
from braincog.model_zoo.base_module import BaseModule
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.strategy.surrogate import *
class MyBaseNode(BaseNode):
def __init__(self, threshold=0.5, step=4, layer_by_layer=False, mem_detach=False):
super().__init__(threshold=threshold, step=step, layer_by_layer=layer_by_layer, mem_detach=mem_detach)
def rearrange2node(self, inputs):
if self.groups != 1:
if len(inputs.shape) == 4:
outputs = rearrange(inputs, 'b (c t) w h -> t b c w h', t=self.step)
elif len(inputs.shape) == 2:
outputs = rearrange(inputs, 'b (c t) -> t b c', t=self.step)
else:
raise NotImplementedError
elif self.layer_by_layer:
if len(inputs.shape) == 4:
outputs = rearrange(inputs, '(t b) c w h -> t b c w h', t=self.step)
# 加入适配Transformer T B N C的rearange2node分支
elif len(inputs.shape) == 3:
outputs = rearrange(inputs, '(t b) n c -> t b n c', t=self.step)
elif len(inputs.shape) == 2:
outputs = rearrange(inputs, '(t b) c -> t b c', t=self.step)
else:
raise NotImplementedError
else:
outputs = inputs
return outputs
def rearrange2op(self, inputs):
if self.groups != 1:
if len(inputs.shape) == 5:
outputs = rearrange(inputs, 't b c w h -> b (c t) w h')
elif len(inputs.shape) == 3:
outputs = rearrange(inputs, ' t b c -> b (c t)')
else:
raise NotImplementedError
elif self.layer_by_layer:
if len(inputs.shape) == 5:
outputs = rearrange(inputs, 't b c w h -> (t b) c w h')
# 加入适配Transformer T B N C的rearange2op分支
elif len(inputs.shape) == 4:
outputs = rearrange(inputs, ' t b n c -> (t b) n c')
elif len(inputs.shape) == 3:
outputs = rearrange(inputs, ' t b c -> (t b) c')
else:
raise NotImplementedError
else:
outputs = inputs
return outputs
class MyGrad(SurrogateFunctionBase):
def __init__(self, alpha=4., requires_grad=False):
super().__init__(alpha, requires_grad)
@staticmethod
def act_fun(x, alpha):
return sigmoid.apply(x, alpha)
class MyNode(MyBaseNode):
def __init__(self, threshold=1., step=4, layer_by_layer=True, tau=2., act_fun=MyGrad, mem_detach=True, *args,
**kwargs):
super().__init__(threshold=threshold, step=step, layer_by_layer=layer_by_layer, mem_detach=mem_detach)
self.tau = tau
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.act_fun = act_fun(alpha=4., requires_grad=False)
def integral(self, inputs):
self.mem = self.mem + (inputs - self.mem) / self.tau
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.threshold)
self.mem = self.mem * (1 - self.spike.detach())
================================================
FILE: examples/Spiking-Transformers/README.md
================================================
# Spiking Transformers Reproduced With Braincog
Here is the current Spiking Transformer code reproduced using [BrainCog](http://www.brain-cog.network/). Welcome to follow the work of BrainCog and utilize the [BrainCog framework](https://github.com/BrainCog-X/Brain-Cog) to create relevant brain-inspired AI endeavors. The works implemented here will also be merged into BrainCog Repo.
### Models
**Spikformer(ICLR 2023)**
[Zhou, Z., Zhu, Y., He, C., Wang, Y., Yan, S., Tian, Y., & Yuan, L. (2022). Spikformer: When spiking neural network meets transformer. arXiv preprint arXiv:2209.15425.](https://openreview.net/forum?id=frE4fUwz_h)

**Spike-driven Transformer(Nips 2023)**
[Yao, M., Hu, J., Zhou, Z., Yuan, L., Tian, Y., Xu, B., & Li, G. (2024). Spike-driven transformer. Advances in Neural Information Processing Systems, 36.](https://proceedings.neurips.cc/paper_files/paper/2023/hash/ca0f5358dbadda74b3049711887e9ead-Abstract-Conference.html)

**Spike-driven Transformer V2(ICLR 2024)**
[Yao, M., Hu, J., Hu, T., Xu, Y., Zhou, Z., Tian, Y., ... & Li, G. (2023, October). Spike-driven Transformer V2: Meta Spiking Neural Network Architecture Inspiring the Design of Next-generation Neuromorphic Chips. In The Twelfth International Conference on Learning Representations.](https://openreview.net/forum?id=1SIBN5Xyw7)

## Models in comming soon
**SpikingResFormer**
[Shi, X., Hao, Z., & Yu, Z. (2024). SpikingResformer: Bridging ResNet and Vision Transformer in Spiking Neural Networks. arXiv preprint arXiv:2403.14302.](https://arxiv.org/abs/2403.14302)
**TIM**
[Shen, S., Zhao, D., Shen, G., & Zeng, Y. (2024). TIM: An Efficient Temporal Interaction Module for Spiking Transformer. arXiv preprint arXiv:2401.11687.](https://arxiv.org/abs/2401.11687)
**SGLFormer(Frontiers in Neuroscience)**
[Zhang, H., Zhou, C., Yu, L., Huang, L., Ma, Z., Fan, X., ... & Tian, Y. (2024). SGLFormer: Spiking Global-Local-Fusion Transformer with High Performance. Frontiers in Neuroscience, 18, 1371290.](https://www.frontiersin.org/journals/neuroscience/articles/10.3389/fnins.2024.1371290/full)
**QKFormer(CVPR2024)**
[Zhou, C., Zhang, H., Zhou, Z., Yu, L., Huang, L., Fan, X., ... & Tian, Y. (2024). QKFormer: Hierarchical Spiking Transformer using QK Attention. arXiv preprint arXiv:2403.16552.](https://arxiv.org/abs/2403.16552)
## Requirments
- Braincog
- einops >= 0.4.1
- timm >= 0.5.4
## Training Examples
### Training on CIFAR10-DVS
python main.py --dataset dvsc10 --epochs 500 --batch-size 16 --seed 42 --event-size 64 --model spikformer_dvs
### Training on ImageNet
python main.py --dataset imnet --epochs 500 --batch-size 16 --seed 42 --model spikformer
================================================
FILE: examples/Spiking-Transformers/datasets.py
================================================
import os
import warnings
import random
import torchvision.datasets
import braincog.datasets.ucf101_dvs
try:
import tonic
from tonic import DiskCachedDataset
except:
warnings.warn("tonic should be installed, 'pip install git+https://github.com/FloyedShen/tonic.git'")
import torch
import torch.nn.functional as F
import torch.utils
import torchvision.datasets as datasets
from timm.data import ImageDataset, create_loader, Mixup, FastCollateMixup, AugMixDataset
from timm.data import create_transform, distributed_sampler
from timm.data.loader import PrefetchLoader
from tonic import DiskCachedDataset
from torchvision import transforms
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from braincog.datasets.NOmniglot.nomniglot_full import NOmniglotfull
from braincog.datasets.NOmniglot.nomniglot_nw_ks import NOmniglotNWayKShot
from braincog.datasets.NOmniglot.nomniglot_pair import NOmniglotTrainSet, NOmniglotTestSet
# from braincog.base.conversion.conversion import CIFAR10Policy, Cutout
# from .cut_mix import CutMix, EventMix, MixUp
# from .rand_aug import *
# from .event_drop import event_drop
# from .utils import dvs_channel_check_expend, rescale
DVSCIFAR10_MEAN_16 = [0.3290, 0.4507]
DVSCIFAR10_STD_16 = [1.8398, 1.6549]
DATA_DIR = '/data/datasets'
DEFAULT_CROP_PCT = 0.875
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
CIFAR10_DEFAULT_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_DEFAULT_STD = (0.2023, 0.1994, 0.2010)
CIFAR100_DEFAULT_MEAN = (0.5071, 0.4867, 0.4408)
CIFAR100_DEFAULT_STD = (0.2675, 0.2565, 0.2761)
def unpack_mix_param(args):
mix_up = args['mix_up'] if 'mix_up' in args else False
cut_mix = args['cut_mix'] if 'cut_mix' in args else False
event_mix = args['event_mix'] if 'event_mix' in args else False
beta = args['beta'] if 'beta' in args else 1.
prob = args['prob'] if 'prob' in args else .5
num = args['num'] if 'num' in args else 1
num_classes = args['num_classes'] if 'num_classes' in args else 10
noise = args['noise'] if 'noise' in args else 0.
gaussian_n = args['gaussian_n'] if 'gaussian_n' in args else None
return mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n
def build_transform(is_train, img_size):
"""
构建数据增强, 适用于static data
:param is_train: 是否训练集
:param img_size: 输出的图像尺寸
:return: 数据增强策略
"""
resize_im = img_size > 32
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=img_size,
is_training=True,
color_jitter=0.4,
# auto_augment='rand-m9-mstd0.5-inc1',
interpolation='bicubic',
# re_prob=0.25,
# re_mode='pixel',
# re_count=1,
)
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(
img_size, padding=4)
return transform
t = []
if resize_im:
size = int((256 / 224) * img_size)
t.append(
# to maintain same ratio w.r.t. 224 images
transforms.Resize(size, interpolation=3),
)
t.append(transforms.CenterCrop(img_size))
t.append(transforms.ToTensor())
if img_size > 32:
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
else:
t.append(transforms.Normalize(CIFAR10_DEFAULT_MEAN, CIFAR10_DEFAULT_STD))
return transforms.Compose(t)
def build_dataset(is_train, img_size, dataset, path, same_da=False):
"""
构建带有增强策略的数据集
:param is_train: 是否训练集
:param img_size: 输出图像尺寸
:param dataset: 数据集名称
:param path: 数据集路径
:param same_da: 为训练集使用测试集的增广方法
:return: 增强后的数据集
"""
transform = build_transform(False, img_size) if same_da else build_transform(is_train, img_size)
if dataset == 'CIFAR10':
dataset = datasets.CIFAR10(
path, train=is_train, transform=transform, download=True)
nb_classes = 10
elif dataset == 'CIFAR100':
dataset = datasets.CIFAR100(
path, train=is_train, transform=transform, download=True)
nb_classes = 100
else:
raise NotImplementedError
return dataset, nb_classes
class MNISTData(object):
"""
Load MNIST datesets.
"""
def __init__(self,
data_path: str,
batch_size: int,
train_trans: Sequence[torch.nn.Module] = None,
test_trans: Sequence[torch.nn.Module] = None,
pin_memory: bool = True,
drop_last: bool = True,
shuffle: bool = True,
) -> None:
self._data_path = data_path
self._batch_size = batch_size
self._pin_memory = pin_memory
self._drop_last = drop_last
self._shuffle = shuffle
self._train_transform = transforms.Compose(train_trans) if train_trans else None
self._test_transform = transforms.Compose(test_trans) if test_trans else None
def get_data_loaders(self):
print('Batch size: ', self._batch_size)
train_datasets = datasets.MNIST(root=self._data_path, train=True, transform=self._train_transform, download=True)
test_datasets = datasets.MNIST(root=self._data_path, train=False, transform=self._test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=self._batch_size,
pin_memory=self._pin_memory, drop_last=self._drop_last, shuffle=self._shuffle
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=self._batch_size,
pin_memory=self._pin_memory, drop_last=False
)
return train_loader, test_loader
def get_standard_data(self):
MNIST_MEAN = 0.1307
MNIST_STD = 0.3081
self._train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),
transforms.ToTensor(),
transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])
self._test_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])
return self.get_data_loaders()
def get_mnist_data(batch_size, num_workers=8, same_da=False, **kwargs):
"""
获取MNIST数据
http://data.pymvpa.org/datasets/mnist/
:param batch_size: batch size
:param same_da: 为训练集使用测试集的增广方法
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
MNIST_MEAN = 0.1307
MNIST_STD = 0.3081
if 'skip_norm' in kwargs and kwargs['skip_norm'] is True:
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(rescale)
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(rescale)
])
else:
train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),
# transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])
test_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])
train_datasets = datasets.MNIST(
root=DATA_DIR, train=True, transform=test_transform if same_da else train_transform, download=True)
test_datasets = datasets.MNIST(
root=DATA_DIR, train=False, transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_fashion_data(batch_size, num_workers=8, same_da=False, **kwargs):
"""
获取fashion MNIST数据
http://arxiv.org/abs/1708.07747
:param batch_size: batch size
:param same_da: 为训练集使用测试集的增广方法
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor()])
test_transform = transforms.Compose([transforms.ToTensor()])
train_datasets = datasets.FashionMNIST(
root=DATA_DIR, train=True, transform=test_transform if same_da else train_transform, download=True)
test_datasets = datasets.FashionMNIST(
root=DATA_DIR, train=False, transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_cifar10_data(batch_size, num_workers=8, same_da=False, **kwargs):
# """
# 获取CIFAR10数据
# https://www.cs.toronto.edu/~kriz/cifar.html
# :param batch_size: batch size
# :param kwargs:
# :return: (train loader, test loader, mixup_active, mixup_fn)
# """
# train_datasets, _ = build_dataset(True, 32, 'CIFAR10', DATA_DIR, same_da)
# test_datasets, _ = build_dataset(False, 32, 'CIFAR10', DATA_DIR, same_da)
#
# train_loader = torch.utils.data.DataLoader(
# train_datasets, batch_size=batch_size,
# pin_memory=True, drop_last=True, shuffle=True,
# num_workers=num_workers
# )
#
# test_loader = torch.utils.data.DataLoader(
# test_datasets, batch_size=batch_size,
# pin_memory=True, drop_last=False,
# num_workers=num_workers
# )
normalize = transforms.Normalize(CIFAR10_DEFAULT_MEAN, CIFAR10_DEFAULT_STD)
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
CIFAR10Policy(),
transforms.ToTensor(),
Cutout(n_holes=1, length=16),
normalize])
transform_test = transforms.Compose([transforms.ToTensor(), normalize])
train_dataset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transform_test)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
shuffle=True, num_workers=num_workers,
pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
shuffle=False, num_workers=num_workers,
pin_memory=True
)
return train_loader, test_loader, None, None
def get_cifar100_data(batch_size, num_workers=8, same_data=False, *args, **kwargs):
# """
# 获取CIFAR100数据
# https://www.cs.toronto.edu/~kriz/cifar.html
# :param batch_size: batch size
# :param kwargs:
# :return: (train loader, test loader, mixup_active, mixup_fn)
# """
# train_datasets, _ = build_dataset(True, 32, 'CIFAR100', DATA_DIR, same_data)
# test_datasets, _ = build_dataset(False, 32, 'CIFAR100', DATA_DIR, same_data)
#
# train_loader = torch.utils.data.DataLoader(
# train_datasets, batch_size=batch_size,
# pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
# )
#
# test_loader = torch.utils.data.DataLoader(
# test_datasets, batch_size=batch_size,
# pin_memory=True, drop_last=False, num_workers=num_workers
# )
# return train_loader, test_loader, False, None
normalize = transforms.Normalize(CIFAR100_DEFAULT_MEAN, CIFAR100_DEFAULT_STD)
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
CIFAR10Policy(),
transforms.ToTensor(),
Cutout(n_holes=1, length=16),
normalize])
transform_test = transforms.Compose([transforms.ToTensor(), normalize])
train_dataset = datasets.CIFAR100(root=DATA_DIR, train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR100(root=DATA_DIR, train=False, download=True, transform=transform_test)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
shuffle=True, num_workers=num_workers,
pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
shuffle=False, num_workers=num_workers,
pin_memory=True
)
return train_loader, test_loader, None, None
def get_imnet_data(args, _logger, data_config, num_aug_splits, **kwargs):
"""
获取ImageNet数据集
http://arxiv.org/abs/1409.0575
:param args: 其他的参数
:param _logger: 日志路径
:param data_config: 增强策略
:param num_aug_splits: 不同增强策略的数量
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
train_dir = os.path.join(DATA_DIR, 'ILSVRC2012/train')
if not os.path.exists(train_dir):
_logger.error(
'Training folder does not exist at: {}'.format(train_dir))
exit(1)
dataset_train = ImageDataset(train_dir)
# collate_fn = None
# mixup_fn = None
# mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
# if mixup_active:
# mixup_args = dict(
# mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
# prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
# label_smoothing=args.smoothing, num_classes=args.num_classes)
# if args.prefetcher:
# # collate conflict (need to support deinterleaving in collate mixup)
# assert not num_aug_splits
# collate_fn = FastCollateMixup(**mixup_args)
# else:
# mixup_fn = Mixup(**mixup_args)
# if num_aug_splits > 1:
# dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation']
loader_train = create_loader(
dataset_train,
input_size=data_config['input_size'],
batch_size=args.batch_size,
is_training=True,
use_prefetcher=args.prefetcher,
no_aug=args.no_aug,
# re_prob=args.reprob,
# re_mode=args.remode,
# re_count=args.recount,
# re_split=args.resplit,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
# vflip=args.vflip,
# color_jitter=args.color_jitter,
# auto_augment=args.aa,
# num_aug_splits=num_aug_splits,
interpolation=train_interpolation,
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
# collate_fn=collate_fn,
pin_memory=args.pin_mem,
# use_multi_epochs_loader=args.use_multi_epochs_loader
)
eval_dir = os.path.join(DATA_DIR, 'ILSVRC2012/val')
if not os.path.isdir(eval_dir):
eval_dir = os.path.join(DATA_DIR, 'ILSVRC2012/validation')
if not os.path.isdir(eval_dir):
_logger.error(
'Validation folder does not exist at: {}'.format(eval_dir))
exit(1)
dataset_eval = ImageDataset(eval_dir)
loader_eval = create_loader(
dataset_eval,
input_size=data_config['input_size'],
batch_size=args.validation_batch_size_multiplier * args.batch_size,
is_training=False,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
)
return loader_train, loader_eval, None, None
def get_dvsg_data(batch_size, step, **kwargs):
"""
获取DVS Gesture数据
DOI: 10.1109/CVPR.2017.781
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
sensor_size = tonic.datasets.DVSGesture.sensor_size
size = kwargs['size'] if 'size' in kwargs else 48
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),
])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),
])
train_dataset = tonic.datasets.DVSGesture(os.path.join(DATA_DIR, 'DVS/DVSGesture'),
transform=train_transform, train=True)
test_dataset = tonic.datasets.DVSGesture(os.path.join(DATA_DIR, 'DVS/DVSGesture'),
transform=test_transform, train=False)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
lambda x: dvs_channel_check_expend(x),
transforms.RandomCrop(size, padding=size // 12),
# lambda x: event_drop(x),
# transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
lambda x: dvs_channel_check_expend(x),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/DVSGesture/train_cache_{}'.format(step)),
transform=train_transform, num_copies=3)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/DVSGesture/test_cache_{}'.format(step)),
transform=test_transform, num_copies=3)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
pin_memory=True, drop_last=True, num_workers=8,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=2,
shuffle=False,
)
return train_loader, test_loader, mixup_active, None
def get_dvsc10_data(batch_size, step, **kwargs):
"""
获取DVS CIFAR10数据
http://journal.frontiersin.org/article/10.3389/fnins.2017.00309/full
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
size = kwargs['size'] if 'size' in kwargs else 48
sensor_size = tonic.datasets.CIFAR10DVS.sensor_size
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=train_transform)
test_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=test_transform)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# lambda x: TemporalShift(x, .01),
# lambda x: drop(x, 0.15),
# lambda x: ShearX(x, 15),
# lambda x: ShearY(x, 15),
# lambda x: TranslateX(x, 0.225),
# lambda x: TranslateY(x, 0.225),
# lambda x: Rotate(x, 15),
# lambda x: CutoutAbs(x, 0.25),
# lambda x: CutoutTemporal(x, 0.25),
# lambda x: GaussianBlur(x, 0.5),
# lambda x: SaltAndPepperNoise(x, 0.1),
# transforms.Normalize(DVSCIFAR10_MEAN_16, DVSCIFAR10_STD_16),
transforms.RandomCrop(size, padding=size // 12),
transforms.RandomHorizontalFlip(),
# lambda x: event_drop(x),
# transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
# print('randaug', m, n)
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/DVS_Cifar10/train_cache_{}'.format(step)),
transform=train_transform)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/DVS_Cifar10/test_cache_{}'.format(step)),
transform=test_transform)
num_train = len(train_dataset)
num_per_cls = num_train // 10
indices_train, indices_test = [], []
portion = kwargs['portion'] if 'portion' in kwargs else .9
for i in range(10):
indices_train.extend(
list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))))
indices_test.extend(
list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
# print('cut_mix', beta, prob, num, num_classes)
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_train),
pin_memory=True, drop_last=True, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_test),
pin_memory=True, drop_last=False, num_workers=2
)
return train_loader, test_loader, mixup_active, None
def get_nmnist_data(batch_size, step, **kwargs):
"""
获取DVS CIFAR10数据
http://journal.frontiersin.org/article/10.3389/fnins.2017.00309/full
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
size = kwargs['size'] if 'size' in kwargs else 48
sensor_size = tonic.datasets.NMNIST.sensor_size
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/NMNIST'), transform=train_transform)
test_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/NMNIST'), transform=test_transform)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# lambda x: TemporalShift(x, .01),
# lambda x: drop(x, 0.15),
# lambda x: ShearX(x, 15),
# lambda x: ShearY(x, 15),
# lambda x: TranslateX(x, 0.225),
# lambda x: TranslateY(x, 0.225),
# lambda x: Rotate(x, 15),
# lambda x: CutoutAbs(x, 0.25),
# lambda x: CutoutTemporal(x, 0.25),
# lambda x: GaussianBlur(x, 0.5),
# lambda x: SaltAndPepperNoise(x, 0.1),
# transforms.Normalize(DVSCIFAR10_MEAN_16, DVSCIFAR10_STD_16),
transforms.RandomCrop(size, padding=size // 12),
transforms.RandomHorizontalFlip(),
# lambda x: event_drop(x),
# transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
# print('randaug', m, n)
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/NMNIST/train_cache_{}'.format(step)),
transform=train_transform)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/NMNIST/test_cache_{}'.format(step)),
transform=test_transform)
num_train = len(train_dataset)
num_per_cls = num_train // 10
indices_train, indices_test = [], []
portion = kwargs['portion'] if 'portion' in kwargs else .9
for i in range(10):
indices_train.extend(
list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))))
indices_test.extend(
list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
# print('cut_mix', beta, prob, num, num_classes)
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_train),
pin_memory=True, drop_last=True, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_test),
pin_memory=True, drop_last=False, num_workers=2
)
return train_loader, test_loader, mixup_active, None
def get_NCALTECH101_data(batch_size, step, **kwargs):
"""
获取NCaltech101数据
http://journal.frontiersin.org/Article/10.3389/fnins.2015.00437/abstract
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
sensor_size = braincog.datasets.ncaltech101.NCALTECH101.sensor_size
cls_count = braincog.datasets.ncaltech101.NCALTECH101.cls_count
dataset_length = braincog.datasets.ncaltech101.NCALTECH101.length
portion = kwargs['portion'] if 'portion' in kwargs else .9
size = kwargs['size'] if 'size' in kwargs else 48
# print('portion', portion)
train_sample_weight = []
train_sample_index = []
train_count = 0
test_sample_index = []
idx_begin = 0
for count in cls_count:
sample_weight = dataset_length / count
train_sample = round(portion * count)
test_sample = count - train_sample
train_count += train_sample
train_sample_weight.extend(
[sample_weight] * train_sample
)
train_sample_weight.extend(
[0.] * test_sample
)
train_sample_index.extend(
list((range(idx_begin, idx_begin + train_sample)))
)
test_sample_index.extend(
list(range(idx_begin + train_sample, idx_begin + train_sample + test_sample))
)
idx_begin += count
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weight, train_count)
test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_sample_index)
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = braincog.datasets.ncaltech101.NCALTECH101(os.path.join(DATA_DIR, 'DVS/NCALTECH101'), transform=train_transform)
test_dataset = braincog.datasets.ncaltech101.NCALTECH101(os.path.join(DATA_DIR, 'DVS/NCALTECH101'), transform=test_transform)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: print(x.shape),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
transforms.RandomCrop(size, padding=size // 12),
# transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# lambda x: temporal_flatten(x),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/NCALTECH101/train_cache_{}'.format(step)),
transform=train_transform, num_copies=3)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/NCALTECH101/test_cache_{}'.format(step)),
transform=test_transform, num_copies=3)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
sampler=train_sampler,
pin_memory=True, drop_last=True, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
sampler=test_sampler,
pin_memory=True, drop_last=False, num_workers=2
)
return train_loader, test_loader, mixup_active, None
def get_UCF101DVS_data(batch_size, step, **kwargs):
"""
获取DVS CIFAR10数据
http://journal.frontiersin.org/article/10.3389/fnins.2017.00309/full
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
size = kwargs['size'] if 'size' in kwargs else 48
sensor_size = braincog.datasets.ucf101_dvs.UCF101DVS.sensor_size
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = braincog.datasets.ucf101_dvs.UCF101DVS(os.path.join(DATA_DIR, 'DVS/UCF101DVS'), train=True, transform=train_transform)
test_dataset = braincog.datasets.ucf101_dvs.UCF101DVS(os.path.join(DATA_DIR, 'DVS/UCF101DVS'), train=False, transform=test_transform)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# lambda x: TemporalShift(x, .01),
# lambda x: drop(x, 0.15),
# lambda x: ShearX(x, 15),
# lambda x: ShearY(x, 15),
# lambda x: TranslateX(x, 0.225),
# lambda x: TranslateY(x, 0.225),
# lambda x: Rotate(x, 15),
# lambda x: CutoutAbs(x, 0.25),
# lambda x: CutoutTemporal(x, 0.25),
# lambda x: GaussianBlur(x, 0.5),
# lambda x: SaltAndPepperNoise(x, 0.1),
# transforms.Normalize(DVSCIFAR10_MEAN_16, DVSCIFAR10_STD_16),
# transforms.RandomCrop(size, padding=size // 12),
transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
# print('randaug', m, n)
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'UCF101DVS/train_cache_{}'.format(step)),
transform=train_transform)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'UCF101DVS/test_cache_{}'.format(step)),
transform=test_transform)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
# print('cut_mix', beta, prob, num, num_classes)
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
pin_memory=True, drop_last=True, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False,
pin_memory=True, drop_last=False, num_workers=2
)
return train_loader, test_loader, mixup_active, None
def get_HMDBDVS_data(batch_size, step, **kwargs):
sensor_size = braincog.datasets.hmdb_dvs.HMDBDVS.sensor_size
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = braincog.datasets.hmdb_dvs.HMDBDVS(os.path.join(DATA_DIR, 'HMDBDVS'), transform=train_transform)
test_dataset = braincog.datasets.hmdb_dvs.HMDBDVS(os.path.join(DATA_DIR, 'HMDBDVS'), transform=test_transform)
cls_count = train_dataset.cls_count
dataset_length = train_dataset.length
portion = .5
# portion = kwargs['portion'] if 'portion' in kwargs else .9
size = kwargs['size'] if 'size' in kwargs else 48
# print('portion', portion)
train_sample_weight = []
train_sample_index = []
train_count = 0
test_sample_index = []
idx_begin = 0
for count in cls_count:
sample_weight = dataset_length / count
train_sample = round(portion * count)
test_sample = count - train_sample
train_count += train_sample
train_sample_weight.extend(
[sample_weight] * train_sample
)
train_sample_weight.extend(
[0.] * test_sample
)
lst = list(range(idx_begin, idx_begin + train_sample + test_sample))
random.seed(0)
random.shuffle(lst)
train_sample_index.extend(
lst[:train_sample]
# list((range(idx_begin, idx_begin + train_sample)))
)
test_sample_index.extend(
lst[train_sample:train_sample + test_sample]
# list(range(idx_begin + train_sample, idx_begin + train_sample + test_sample))
)
idx_begin += count
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weight, train_count)
test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_sample_index)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: print(x.shape),
# lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# transforms.RandomCrop(size, padding=size // 12),
# transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# lambda x: temporal_flatten(x),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'HMDBDVS/train_cache_{}'.format(step)),
transform=train_transform, num_copies=3)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'HMDBDVS/test_cache_{}'.format(step)),
transform=test_transform, num_copies=3)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
sampler=train_sampler,
pin_memory=True, drop_last=True, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
sampler=test_sampler,
pin_memory=True, drop_last=False, num_workers=2
)
return train_loader, test_loader, mixup_active, None
# def get_NCARS_data(batch_size, step, **kwargs):
# """
# 获取N-Cars数据
# https://ieeexplore.ieee.org/document/8578284/
# :param batch_size: batch size
# :param step: 仿真步长
# :param kwargs:
# :return: (train loader, test loader, mixup_active, mixup_fn)
# """
# sensor_size = tonic.datasets.NCARS.sensor_size
# size = kwargs['size'] if 'size' in kwargs else 48
#
# train_transform = transforms.Compose([
# # tonic.transforms.Denoise(filter_time=10000),
# # tonic.transforms.DropEvent(p=0.1),
# tonic.transforms.ToFrame(sensor_size=None, n_time_bins=step),
# ])
# test_transform = transforms.Compose([
# # tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.ToFrame(sensor_size=None, n_time_bins=step),
# ])
#
# train_dataset = tonic.datasets.NCARS(os.path.join(DATA_DIR, 'DVS/NCARS'), transform=train_transform, train=True)
# test_dataset = tonic.datasets.NCARS(os.path.join(DATA_DIR, 'DVS/NCARS'), transform=test_transform, train=False)
#
# train_transform = transforms.Compose([
# lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# lambda x: dvs_channel_check_expend(x),
# transforms.RandomCrop(size, padding=size // 12),
# transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(15)
# ])
# test_transform = transforms.Compose([
# lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# lambda x: dvs_channel_check_expend(x),
# ])
# if 'rand_aug' in kwargs.keys():
# if kwargs['rand_aug'] is True:
# n = kwargs['randaug_n']
# m = kwargs['randaug_m']
# train_transform.transforms.insert(2, RandAugment(m=m, n=n))
#
# # if 'temporal_flatten' in kwargs.keys():
# # if kwargs['temporal_flatten'] is True:
# # train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# # test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
#
# train_dataset = DiskCachedDataset(train_dataset,
# cache_path=os.path.join(DATA_DIR, 'DVS/NCARS/train_cache_{}'.format(step)),
# transform=train_transform, num_copies=3)
# test_dataset = DiskCachedDataset(test_dataset,
# cache_path=os.path.join(DATA_DIR, 'DVS/NCARS/test_cache_{}'.format(step)),
# transform=test_transform, num_copies=3)
#
# mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
# mixup_active = cut_mix | event_mix | mix_up
#
# if cut_mix:
# train_dataset = CutMix(train_dataset,
# beta=beta,
# prob=prob,
# num_mix=num,
# num_class=num_classes,
# noise=noise)
#
# if event_mix:
# train_dataset = EventMix(train_dataset,
# beta=beta,
# prob=prob,
# num_mix=num,
# num_class=num_classes,
# noise=noise,
# gaussian_n=gaussian_n)
# if mix_up:
# train_dataset = MixUp(train_dataset,
# beta=beta,
# prob=prob,
# num_mix=num,
# num_class=num_classes,
# noise=noise)
#
# train_loader = torch.utils.data.DataLoader(
# train_dataset, batch_size=batch_size,
# pin_memory=True, drop_last=True, num_workers=8,
# shuffle=True,
# )
#
# test_loader = torch.utils.data.DataLoader(
# test_dataset, batch_size=batch_size,
# pin_memory=True, drop_last=False, num_workers=2,
# shuffle=False,
# )
#
# return train_loader, test_loader, mixup_active, None
def get_nomni_data(batch_size, train_portion=1., **kwargs):
"""
获取N-Omniglot数据
:param batch_size:batch的大小
:param data_mode:一共full nkks pair三种模式
:param frames_num:一个样本帧的个数
:param data_type:event frequency两种模式
"""
data_mode = kwargs["data_mode"] if "data_mode" in kwargs else "full"
frames_num = kwargs["frames_num"] if "frames_num" in kwargs else 10
data_type = kwargs["data_type"] if "data_type" in kwargs else "event"
train_transform = transforms.Compose([
transforms.Resize((64, 64))])
test_transform = transforms.Compose([
transforms.Resize((64, 64))])
if data_mode == "full":
train_datasets = NOmniglotfull(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), train=True, frames_num=frames_num,
data_type=data_type,
transform=train_transform)
test_datasets = NOmniglotfull(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), train=False, frames_num=frames_num,
data_type=data_type,
transform=test_transform)
elif data_mode == "nkks":
train_datasets = NOmniglotNWayKShot(os.path.join(DATA_DIR, 'DVS/NOmniglot'),
n_way=kwargs["n_way"],
k_shot=kwargs["k_shot"],
k_query=kwargs["k_query"],
train=True,
frames_num=frames_num,
data_type=data_type,
transform=train_transform)
test_datasets = NOmniglotNWayKShot(os.path.join(DATA_DIR, 'DVS/NOmniglot'),
n_way=kwargs["n_way"],
k_shot=kwargs["k_shot"],
k_query=kwargs["k_query"],
train=False,
frames_num=frames_num,
data_type=data_type,
transform=test_transform)
elif data_mode == "pair":
train_datasets = NOmniglotTrainSet(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), use_frame=True,
frames_num=frames_num, data_type=data_type,
use_npz=False, resize=105)
test_datasets = NOmniglotTestSet(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), time=2000, way=kwargs["n_way"],
shot=kwargs["k_shot"], use_frame=True,
frames_num=frames_num, data_type=data_type, use_npz=False, resize=105)
else:
pass
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size, num_workers=12,
pin_memory=True, drop_last=True, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size, num_workers=12,
pin_memory=True, drop_last=False
)
return train_loader, test_loader, None, None
================================================
FILE: examples/Spiking-Transformers/main.py
================================================
import argparse
import time
import timm.models
import yaml
import os
import random as buildin_random
import logging
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from braincog.base.node.node import *
from braincog.utils import *
from braincog.base.utils.criterions import *
# from braincog.datasets.datasets import *
from datasets import *
from braincog.model_zoo.resnet import *
from braincog.model_zoo.convnet import *
from braincog.model_zoo.vgg_snn import VGG_SNN, SNN5
# from braincog.model_zoo.fc_snn import SHD_SNN
from braincog.model_zoo.resnet19_snn import resnet19
#from braincog.model_zoo.sew_resnet import sew_resnet18, sew_resnet34, sew_resnet50
from braincog.utils import save_feature_map, setup_seed
from braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix, plot_mem_distribution
import torch
import torch.nn as nn
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model, register_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
from torch.utils.tensorboard import SummaryWriter
# load spiking transformer models
from models.spikformer import spikformer
from models.spikformer_dvs import spikformer_dvs
from models.spike_driven_transformer import sd_transformer
from models.spike_driven_transformer_dvs import sd_transformer_dvs
from models.spike_driven_transformer_v2 import sd_transformer_v2
from models.spike_driven_transformer_v2_dvs import sd_transformer_v2_dvs
# choose ur device here
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser = argparse.ArgumentParser(description='SNN Training and Evaluating')
# Model parameters
parser.add_argument('--dataset', default='dvsc10', type=str)
parser.add_argument('--model', default='spikformer', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',
help='path to eval checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
parser.add_argument('--num-classes', type=int, default=10, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
# Dataset parameters for static datasets
parser.add_argument('--img-size', type=int, default=224, metavar='N',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='inputs image center crop percent (for validation only)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
# Dataloader parameters
parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='inputs batch size for training (default: 128)')
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
help='ratio of validation batch size to training batch size (default: 1)')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=1e-4,
help='weight decay (default: 0.01 for adamw)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=5e-3, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--epochs', type=int, default=400, metavar='N',
help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochss to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--power', type=int, default=1, help='power')
# Augmentation & regularization parameters ONLY FOR IMAGE NET
parser.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
parser.add_argument('--jsd', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "const")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix', type=float, default=0.,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=0.,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.0)')
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
parser.add_argument('--newton-maxiter', default=20, type=int,
help='max iterration in newton method')
parser.add_argument('--reset-drop', action='store_true', default=False,
help='whether to reset drop')
parser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],
help='The implementation way of gaussian kernel method, choose from "cuda" and "torch"')
# Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
parser.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
parser.add_argument('--dist-bn', type=str, default='',
help='Distribute BatchNorm stats between node after each epoch ("broadcast", "reduce", or "")')
parser.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.99996,
help='decay factor for model weights moving average (default: 0.9998)')
# Misc
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=8, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--save-images', action='store_true', default=False,
help='save images of inputs bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--output', default='/home/shensicheng/code/SpikingTransformers', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
parser.add_argument('--tensorboard-dir', default='./runs', type=str)
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--device', type=int, default=0)
# Spike parameters
parser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')
parser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')
parser.add_argument('--temporal-flatten', action='store_true',
help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')
parser.add_argument('--adaptive-node', action='store_true')
parser.add_argument('--critical-loss', action='store_true')
parser.add_argument('--conv-type', type=str, default='normal')
parser.add_argument('--sew-cnf', type=str, default='ADD')
parser.add_argument('--rand-step', action='store_true')
# neuron type
parser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')
parser.add_argument('--act-fun', type=str, default='QGateGrad',
help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')
parser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')
parser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')
parser.add_argument('--requires-thres-grad', action='store_true')
parser.add_argument('--sigmoid-thres', action='store_true')
parser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')
parser.add_argument('--noisy-grad', type=float, default=0.,
help='Add noise to backward, sometime will make higher accuracy (default: 0.)')
parser.add_argument('--spike-output', action='store_true', default=False,
help='Using mem output or spike output (default: False)')
parser.add_argument('--n_groups', type=int, default=1)
parser.add_argument('--n-encode-type', type=str, default='linear')
parser.add_argument('--n-preact', action='store_true')
parser.add_argument('--layer-by-layer', action='store_true',
help='forward step-by-step or layer-by-layer. '
'Larger Model with layer-by-layer will be faster (default: False)')
parser.add_argument('--tet-loss', action='store_true')
# EventData Augmentation
parser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')
parser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')
parser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')
parser.add_argument('--cutmix_beta', type=float, default=2.0, help='cutmix_beta (default: 1.)')
parser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')
parser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')
parser.add_argument('--cutmix_noise', type=float, default=0.,
help='Add Pepper noise after mix, sometimes work (default: 0.)')
parser.add_argument('--gaussian-n', type=int, default=3)
parser.add_argument('--rand-aug', action='store_true',
help='Rand Augment for Event data (default: False)')
parser.add_argument('--randaug_n', type=int, default=3,
help='Rand Augment times n (default: 3)')
parser.add_argument('--randaug_m', type=int, default=15,
help='Rand Augment times n (default: 15) (0-30)')
parser.add_argument('--train-portion', type=float, default=0.9,
help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')
parser.add_argument('--event-size', default=48, type=int,
help='Event size. Resize event data before process (default: 48)')
parser.add_argument('--node-resume', type=str, default='',
help='resume weights in node for adaptive node. (default: False)')
# visualize
parser.add_argument('--visualize', action='store_true',
help='Visualize spiking map for each layer, only for validate (default: False)')
parser.add_argument('--spike-rate', action='store_true',
help='Print spiking rate for each layer, only for validate(default: False)')
parser.add_argument('--tsne', action='store_true')
parser.add_argument('--conf-mat', action='store_true')
parser.add_argument('--mem-dist', action='store_true')
parser.add_argument('--adaptation-info', action='store_true')
parser.add_argument('--suffix', type=str, default='',
help='Add an additional suffix to the save path (default: \'\')')
try:
from apex import amp
from apex.parallel import DistributedDataParallel as ApexDDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
if args_config.config:
with open(args_config.config, 'r') as f:
cfg = yaml.safe_load(f)
parser.set_defaults(**cfg)
# The main arg parser parses the rest of the args, the usual
# defaults will have been overridden if config file specified.
args = parser.parse_args(remaining)
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
def main():
args, args_text = _parse_args()
# args.no_spike_output = args.no_spike_output | args.cut_mix
args.no_spike_output = True
output_dir = ''
if args.local_rank == 0:
output_base = args.output if args.output else './output'
exp_name = '-'.join([
args.model,
args.dataset,
args.node_type,
str(args.step),
args.suffix,
datetime.now().strftime("%Y%m%d-%H%M%S"),
# str(args.img_size)
])
output_dir = get_outdir(output_base, 'train', exp_name)
args.output_dir = output_dir
setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))
summary_writer = SummaryWriter(log_dir=os.path.join(args.tensorboard_dir, exp_name))
args.tensorboard_prefix = os.path.join(args.dataset, args.model)
else:
summary_writer = None
setup_default_logging()
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1:
_logger.warning(
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
args.num_gpu = 1
# args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
if args.distributed:
args.num_gpu = 1
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
else:
torch.cuda.set_device('cuda:%d' % args.device)
assert args.rank >= 0
if args.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
_logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
# torch.manual_seed(args.seed + args.rank)
setup_seed(args.seed + args.rank)
model = create_model(
args.model,
# pretrained=args.pretrained,
# num_classes=args.num_classes,
# dataset=args.dataset,
# step=args.step,
# encode_type=args.encode,
# node_type=eval(args.node_type),
# threshold=args.threshold,
# tau=args.tau,
# sigmoid_thres=args.sigmoid_thres,
# requires_thres_grad=args.requires_thres_grad,
# spike_output=not args.no_spike_output,
# act_fun=args.act_fun,
# temporal_flatten=args.temporal_flatten,
# layer_by_layer=args.layer_by_layer,
# n_groups=args.n_groups,
# n_encode_type=args.n_encode_type,
# n_preact=args.n_preact,
# tet_loss=args.tet_loss,
# sew_cnf=args.sew_cnf,
# conv_type=args.conv_type,
)
_logger.info('[MODEL ARCH]\n{}'.format(model))
if 'dvs' in args.dataset:
args.channels = 2
elif 'mnist' in args.dataset:
args.channels = 1
else:
args.channels = 3
# flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)
# _logger.info('flops = %fM', flops / 1e6)
# _logger.info('param size = %fM', params / 1e6)
linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0
args.lr = linear_scaled_lr
_logger.info("learning rate is %f" % linear_scaled_lr)
if args.local_rank == 0:
_logger.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
num_aug_splits = 0
if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
use_amp = None
if args.amp:
# for backwards compat, `--amp` arg tries apex before native amp
if has_apex:
args.apex_amp = True
elif has_native_amp:
args.native_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
if args.num_gpu > 1:
if use_amp == 'apex':
_logger.warning(
'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
use_amp = None
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
assert not args.channels_last, "Channels last not supported with DP, use DDP."
else:
model = model.cuda()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
optimizer = create_optimizer(args, model)
_logger.info('[OPTIMIZER]\n{}'.format(optimizer))
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
loss_scaler = ApexScaler()
if args.local_rank == 0:
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
amp_autocast = torch.cuda.amp.autocast
loss_scaler = NativeScaler()
if args.local_rank == 0:
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if args.local_rank == 0:
_logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint
resume_epoch = None
if args.resume and args.eval_checkpoint == '':
args.eval_checkpoint = args.resume
if args.resume:
args.eval = True
# checkpoint = torch.load(args.resume, map_location='cpu')
# model.load_state_dict(checkpoint['state_dict'], False)
resume_epoch = resume_checkpoint(
model, args.resume,
optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0)
# print(model.get_attr('mu'))
# print(model.get_attr('sigma'))
if hasattr(model, 'set_threshold'):
model.set_threshold(args.threshold)
if args.critical_loss or args.spike_rate:
model.set_requires_fp(True)
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume=args.resume)
if args.node_resume:
ckpt = torch.load(args.node_resume, map_location='cpu')
model.load_node_weight(ckpt, args.node_trainable)
model_without_ddp = model
if args.distributed:
if args.sync_bn:
assert not args.split_bn
try:
if has_apex and use_amp != 'native':
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
except Exception as e:
_logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
if has_apex and use_amp != 'native':
# Apex DDP preferred unless native amp is activated
if args.local_rank == 0:
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model.cuda(), device_ids=[args.local_rank],
find_unused_parameters=True) # can use device str in Torch >= 1.1
model_without_ddp = model.module
# NOTE: EMA model does not need to be wrapped by DDP
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
_logger.info('Scheduled epochs: {}'.format(num_epochs))
# now config only for imnet
data_config = resolve_data_config(vars(args), model=model, verbose=False)
loader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.dataset)(
batch_size=args.batch_size,
step=args.step,
args=args,
_logge=_logger,
data_config=data_config,
num_aug_splits=num_aug_splits,
size=args.event_size,
mix_up=args.mix_up,
cut_mix=args.cut_mix,
event_mix=args.event_mix,
beta=args.cutmix_beta,
prob=args.cutmix_prob,
gaussian_n=args.gaussian_n,
num=args.cutmix_num,
noise=args.cutmix_noise,
num_classes=args.num_classes,
rand_aug=args.rand_aug,
randaug_n=args.randaug_n,
randaug_m=args.randaug_m,
portion=args.train_portion,
_logger=_logger,
)
# _logger.info('train_loader:\n{}\nval_loader:\n{}'.format(loader_train, loader_eval))
if args.loss_fn == 'mse':
train_loss_fn = UnilateralMse(1.)
validate_loss_fn = UnilateralMse(1.)
elif args.loss_fn == 'onehot-mse':
train_loss_fn = OnehotMse(args.num_classes)
validate_loss_fn = OnehotMse(args.num_classes)
else:
if args.jsd:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
elif mixup_active:
# smoothing is handled with mixup target transform
train_loss_fn = SoftTargetCrossEntropy().cuda()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
if args.loss_fn == 'mix':
train_loss_fn = MixLoss(train_loss_fn)
validate_loss_fn = MixLoss(validate_loss_fn)
if args.tet_loss:
train_loss_fn = TetLoss(train_loss_fn)
validate_loss_fn = TetLoss(validate_loss_fn)
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
if args.eval: # evaluate the model
# if args.distributed:
# raise NotImplementedError('eval not has not been verified for distributed')
# else:
# load_checkpoint(model, args.eval_checkpoint, args.model_ema)
model.eval()
for t in range(1, args.step * 3):
# for t in range(args.step, args.step + 1):
model.set_attr('step', t)
val_metrics = validate(start_epoch, model, loader_eval, validate_loss_fn, args,
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat, summary_writer=summary_writer)
print(f"[STEP:{t}], Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%")
return
saver = None
if args.local_rank == 0:
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=3)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
try: # train the model
if args.reset_drop:
model_without_ddp.reset_drop_path(0.0)
for epoch in range(start_epoch, args.epochs):
if epoch == 0 and args.reset_drop:
model_without_ddp.reset_drop_path(args.drop_path)
if args.distributed:
loader_train.sampler.set_epoch(epoch)
train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler,
model_ema=model_ema, mixup_fn=mixup_fn, summary_writer=summary_writer
)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
_logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(epoch, model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast,
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat, summary_writer=summary_writer)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
epoch, model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)',
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat, summary_writer=summary_writer
)
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None)
# if saver is not None and epoch >= args.n_warm_up:
if saver is not None:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
except KeyboardInterrupt:
pass
if best_metric is not None:
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def train_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None, summary_writer=None):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
# closses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.train()
# t, k = adjust_surrogate_coeff(100, args.epochs)
# model.set_attr('t', t)
# model.set_attr('k', k)
end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
iters_per_epoch = len(loader)
for batch_idx, (inputs, target) in enumerate(loader):
last_batch = batch_idx == last_idx
if args.rand_step:
step = buildin_random.randint(1, args.step + 2)
model.set_attr('step', step)
data_time_m.update(time.time() - end)
if not args.prefetcher or args.dataset != 'imnet':
inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()
if mixup_fn is not None:
inputs, target = mixup_fn(inputs, target)
if args.channels_last:
inputs = inputs.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(inputs)
loss = loss_fn(output, target)
if args.tet_loss:
output = output.mean(0)
if not (args.cut_mix | args.mix_up | args.event_mix | (args.cutmix != 0.) | (args.mixup != 0.)):
# print(output.shape, target.shape)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
# acc1, = accuracy(output, target)
else:
acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)
else:
loss.backward(create_graph=second_order)
if args.noisy_grad != 0.:
random_gradient(model, args.noisy_grad)
if args.clip_grad is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
# if args.opt == 'lamb':
# optimizer.step(epoch=epoch)
# else:
optimizer.step()
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
num_updates += 1
batch_time_m.update(time.time() - end)
if args.local_rank == 0:
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/train/top1'), acc1.item(), epoch * iters_per_epoch + batch_idx)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/train/top5'), acc5.item(), epoch * iters_per_epoch + batch_idx)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/train/loss'), loss.item(), epoch * iters_per_epoch + batch_idx)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
if args.distributed:
loss = reduce_tensor(loss.data, args.world_size)
acc1 = reduce_tensor(acc1, args.world_size)
acc5 = reduce_tensor(acc5, args.world_size)
losses_m.update(loss.item(), inputs.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
# closses_m.update(reduced_loss.item(), inputs.size(0))
if args.local_rank == 0:
# if args.distributed:
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
top1=top1_m,
top5=top5_m,
batch_time=batch_time_m,
rate=inputs.size(0) * args.world_size / batch_time_m.val,
rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m
))
if args.save_images and output_dir:
torchvision.utils.save_image(
inputs,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True)
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
end = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
if args.local_rank == 0:
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/train/top1'), top1_m.avg, epoch)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/train/top5'), top5_m.avg, epoch)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/train/loss'), losses_m.avg, epoch)
if args.rand_step:
model.set_attr('step', args.step)
return OrderedDict([('loss', losses_m.avg)])
def validate(epoch, model, loader, loss_fn, args, amp_autocast=suppress,
log_suffix='', visualize=False, spike_rate=False, tsne=False, conf_mat=False, summary_writer=None):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
# closses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
spike_m = AverageMeter()
model.eval()
feature_vec = []
feature_cls = []
logits_vec = []
labels_vec = []
mem_vec = []
end = time.time()
last_idx = len(loader) - 1
iters_per_epoch = len(loader)
with torch.no_grad():
for batch_idx, (inputs, target) in enumerate(loader):
# inputs = inputs.type(torch.float64)
last_batch = batch_idx == last_idx
if not args.prefetcher or args.dataset != 'imnet':
inputs = inputs.type(torch.FloatTensor).cuda()
target = target.cuda()
if args.channels_last:
inputs = inputs.contiguous(memory_format=torch.channels_last)
if not args.distributed:
if (visualize or spike_rate or tsne or conf_mat or args.mem_dist) and not args.critical_loss:
model.set_requires_fp(True)
with amp_autocast():
output = model(inputs)
if isinstance(output, (tuple, list)):
output = output[0]
# augmentation reduction
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor]
# print(args.rank, output.shape, target.shape, max(target))
loss = loss_fn(output, target)
if args.tet_loss:
output = output.mean(0)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
acc1 = reduce_tensor(acc1, args.world_size)
acc5 = reduce_tensor(acc5, args.world_size)
else:
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), inputs.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
# closses_m.update(closs, inputs.size(0))
batch_time_m.update(time.time() - end)
end = time.time()
if args.local_rank == 0:
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/val/top1'), acc1.item(), epoch * iters_per_epoch + batch_idx)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/val/top5'), acc5.item(), epoch * iters_per_epoch + batch_idx)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/val/loss'), loss.item(), epoch * iters_per_epoch + batch_idx)
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
if not args.distributed and spike_rate:
spike_m.update(model.get_tot_spike() / output.size(0), output.size(0))
if not args.distributed and spike_rate:
_logger.info(
'[Spike Info]: {spike.val} ({spike.avg})'.format(
spike=spike_m
)
)
if last_batch or batch_idx % args.log_interval == 0:
_logger.info(
'Eval : {} '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
epoch,
batch_idx,
last_idx,
batch_time=batch_time_m,
loss=losses_m,
top1=top1_m,
top5=top5_m,
))
# metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg)])
if args.local_rank == 0:
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/val/top1'), top1_m.avg, epoch)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/val/top5'), top5_m.avg, epoch)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/val/loss'), losses_m.avg, epoch)
return metrics
if __name__ == '__main__':
main()
================================================
FILE: examples/Spiking-Transformers/models/spike_driven_transformer.py
================================================
import torch
import torch.nn as nn
from timm.models.layers import to_2tuple, trunc_normal_, DropPath
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
import torch.nn.functional as F
from braincog.model_zoo.base_module import BaseModule
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.strategy.surrogate import *
from LIFNode import MyNode # LIFNode setting for Spiking Tranformers
from functools import partial
__all__ = ['spikformer']
'''The input shape of neuromorphic datasets in Spiking Transformer when using Braincog
are used to set to 64*64 '''
class MLP(BaseModule):
#Linear here is subsituted by convs
def __init__(self, in_features, step=10, encode_type='direct', hidden_features=None, out_features=None, drop=0.):
super().__init__(step=10, encode_type='direct')
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1)
self.fc1_bn = nn.BatchNorm1d(hidden_features)
self.fc1_lif = MyNode(step=step, tau=2.0)
self.fc2_conv = nn.Conv1d(hidden_features, out_features, kernel_size=1, stride=1)
self.fc2_bn = nn.BatchNorm1d(out_features)
self.fc2_lif = MyNode(step=step, tau=2.0)
self.c_hidden = hidden_features
self.c_output = out_features
def forward(self, x):
self.reset()
T, B, C, N = x.shape
x = self.fc1_lif(x.flatten(0, 1)).reshape(T, B, C, N).contiguous()
x = self.fc1_conv(x.flatten(0, 1))
x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N).contiguous() # T B C N
x = self.fc2_lif(x.flatten(0, 1)).reshape(T, B, self.c_hidden, N).contiguous()
x = self.fc2_conv(x.flatten(0, 1))
x = self.fc2_bn(x).reshape(T, B, C, N).contiguous()
return x
class SSA(BaseModule):
def __init__(self, dim, step=10, encode_type='direct', num_heads=16, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., sr_ratio=1):
super().__init__(step=10, encode_type='direct')
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
# for shortcut
self.head_lif = MyNode(step=step, tau=2.0)
self.num_heads = num_heads
# scale
self.scale = 0.25
self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
self.q_bn = nn.BatchNorm1d(dim)
self.q_lif = MyNode(step=step, tau=2.0)
self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
self.k_bn = nn.BatchNorm1d(dim)
self.k_lif = MyNode(step=step, tau=2.0)
self.v_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
self.v_bn = nn.BatchNorm1d(dim)
self.v_lif = MyNode(step=step, tau=2.0)
self.attn_drop = nn.Dropout(0.2)
self.res_lif = MyNode(step=step, tau=2.0)
self.attn_lif = MyNode(step=step, tau=2.0, v_threshold=0.5, )
self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
self.proj_bn = nn.BatchNorm1d(dim)
self.proj_lif = MyNode(step=step, tau=2.0, )
self.sd_lif = PLIFNode(step=step, threshold=0.5, tau=2)
def forward(self, x):
self.reset()
T, B, C, N = x.shape
x_for_qkv = x.flatten(0, 1) # TB, C N
x_for_qkv = self.head_lif(x_for_qkv)
q_conv_out = self.q_conv(x_for_qkv) # [TB] C N
q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N).contiguous() # T B C N
q_conv_out = self.q_lif(q_conv_out.flatten(0, 1)).reshape(T, B, C, N) # TB C N
q = q_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
k_conv_out = self.k_conv(x_for_qkv)
k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N).contiguous()
k_conv_out = self.k_lif(k_conv_out.flatten(0, 1)).reshape(T, B, C, N) # TB C N
k = k_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
v_conv_out = self.v_conv(x_for_qkv)
v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, N).contiguous()
v_conv_out = self.v_lif(v_conv_out.flatten(0, 1)).reshape(T, B, C, N) # TB C N
v = v_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
# Spike-driven Transformer attention
kv = k.mul(v)
kv = kv.sum(dim=-2, keepdim=True)
kv = self.sd_lif(kv)
x = q.mul(kv)
x = x.transpose(3,4).reshape(T, B, C, N).contiguous() # T B C N
# ignore following lines for membrane shortcut
# x = self.attn_lif(x.flatten(0,1)) #[TB] C N
# x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T, B, C, N) #T B C N
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, step=10, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = SSA(dim, step=step, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(step=step, in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
def forward(self, x):
# residual connection
x = x + self.attn(x)
x = x + self.mlp(x)
return x
# embed_dims = 256
class SPS(BaseModule):
def __init__(self, step=10, encode_type='direct', img_size_h=128, img_size_w=128, patch_size=4, in_channels=2,
embed_dims=256):
super().__init__(step=10, encode_type='direct')
self.image_size = [img_size_h, img_size_w]
patch_size = to_2tuple(patch_size) # 4->(4,4)
self.patch_size = patch_size # patch_size
self.C = in_channels # image_channel
self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1]
self.num_patches = self.H * self.W
# DVS with 2 more Maxpooling
self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn = nn.BatchNorm2d(embed_dims // 8)
self.proj_lif = MyNode(step=step, tau=2.0)
self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.proj_conv1 = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn1 = nn.BatchNorm2d(embed_dims // 4)
self.proj_lif1 = MyNode(step=step, tau=2.0)
self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.proj_conv2 = nn.Conv2d(embed_dims // 4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn2 = nn.BatchNorm2d(embed_dims // 2)
self.proj_lif2 = MyNode(step=step, tau=2.0)
self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.proj_conv3 = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn3 = nn.BatchNorm2d(embed_dims)
# self.proj_lif3 = MyNode(step=step, tau=2.0)
self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.rpe_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
self.rpe_bn = nn.BatchNorm2d(embed_dims)
self.rpe_lif = MyNode(step=step, tau=2.0)
def forward(self, x):
self.reset()
T, B, C, H, W = x.shape
x = self.proj_conv(x.flatten(0, 1)) # have some fire value
x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous()
x = self.proj_lif(x.flatten(0, 1)).contiguous()
x = self.maxpool(x)
x = self.proj_conv1(x)
x = self.proj_bn1(x).reshape(T, B, -1, H // 2, W // 2).contiguous()
x = self.proj_lif1(x.flatten(0, 1)).contiguous()
x = self.maxpool1(x)
x = self.proj_conv2(x)
x = self.proj_bn2(x).reshape(T, B, -1, H // 4, W // 4).contiguous()
x = self.proj_lif2(x.flatten(0, 1)).contiguous()
x = self.maxpool2(x)
x = self.proj_conv3(x)
x = self.proj_bn3(x).reshape(T, B, -1, H // 8, W // 8)
# abandon the LIF here to leverage membrane shortcut
# x = self.proj_lif3(x.flatten(0, 1)).contiguous()
x = self.maxpool3(x.flatten(0,1)).reshape(T, B, -1, H // 16, W // 16)
# The order here is different from spikformer for using membrain shortcut
x_rpe = self.rpe_lif(x.flatten(0, 1)).contiguous()
x_rpe = self.rpe_bn(self.rpe_conv(x_rpe)).reshape(T, B, -1, H // 16, W // 16).contiguous()
x = x + x_rpe # membrane shortcut
x = x.reshape(T, B, -1, (H // 16) * (H // 16)).contiguous()
return x # T B C N
class Spikformer(BaseModule):
def __init__(self, step=10, encode_type='direct',
img_size_h=224, img_size_w=224, patch_size=16, in_channels=3, num_classes=1000,
embed_dims=512, num_heads=12, mlp_ratios=4, qkv_bias=False, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=8, sr_ratios=4,
):
super().__init__(step=10, encode_type='direct')
self.step = step # time step
self.num_classes = num_classes
self.depths = depths
# for membrane shortcut
self.final_lif = MyNode(step=step,tau=2.0)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)] # stochastic depth decay rule
patch_embed = SPS(step=step,
img_size_h=img_size_h,
img_size_w=img_size_w,
patch_size=patch_size,
in_channels=in_channels,
embed_dims=embed_dims)
block = nn.ModuleList([Block(step=step,
dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],
norm_layer=norm_layer, sr_ratio=sr_ratios)
for j in range(depths)])
setattr(self, f"patch_embed", patch_embed)
setattr(self, f"block", block)
# classification head
self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
@torch.jit.ignore
def _get_pos_embed(self, pos_embed, patch_embed, H, W):
if H * W == self.patch_embed1.num_patches:
return pos_embed
else:
return F.interpolate(
pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
block = getattr(self, f"block")
patch_embed = getattr(self, f"patch_embed")
x = patch_embed(x)
for blk in block:
x = blk(x)
# for membrane shortcut
T, B , C, N = x.shape
x = self.final_lif(x.flatten(0,1)).reshape(T, B, C, N).contiguous()
return x.mean(3)
def forward(self, x):
self.reset()
x = self.encoder(x)
x = self.forward_features(x)
x = self.head(x.mean(0))
return x
# Adjust ur hyperparams here
@register_model
def sd_transformer(pretrained=False, **kwargs):
model = Spikformer(step = 4,
img_size_h=224, img_size_w=224,
patch_size=16, embed_dims=512, num_heads=16, mlp_ratios=4,
in_channels=3, num_classes=1000, qkv_bias=False,
depths=8, sr_ratios=1,
**kwargs
)
model.default_cfg = _cfg()
return model
================================================
FILE: examples/Spiking-Transformers/models/spike_driven_transformer_dvs.py
================================================
import torch
import torch.nn as nn
from timm.models.layers import to_2tuple, trunc_normal_, DropPath
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
import torch.nn.functional as F
from braincog.model_zoo.base_module import BaseModule
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.strategy.surrogate import *
from LIFNode import MyNode # LIFNode setting for Spiking Tranformers
from functools import partial
__all__ = ['spikformer']
'''The input shape of neuromorphic datasets in Spiking Transformer when using Braincog
are used to set to 64*64 '''
class MLP(BaseModule):
#Linear here is subsituted by convs
def __init__(self, in_features, step=10, encode_type='direct', hidden_features=None, out_features=None, drop=0.):
super().__init__(step=10, encode_type='direct')
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1)
self.fc1_bn = nn.BatchNorm1d(hidden_features)
self.fc1_lif = MyNode(step=step, tau=2.0)
self.fc2_conv = nn.Conv1d(hidden_features, out_features, kernel_size=1, stride=1)
self.fc2_bn = nn.BatchNorm1d(out_features)
self.fc2_lif = MyNode(step=step, tau=2.0)
self.c_hidden = hidden_features
self.c_output = out_features
def forward(self, x):
self.reset()
T, B, C, N = x.shape
x = self.fc1_lif(x.flatten(0, 1)).reshape(T, B, C, N).contiguous()
x = self.fc1_conv(x.flatten(0, 1))
x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N).contiguous() # T B C N
x = self.fc2_lif(x.flatten(0, 1)).reshape(T, B, self.c_hidden, N).contiguous()
x = self.fc2_conv(x.flatten(0, 1))
x = self.fc2_bn(x).reshape(T, B, C, N).contiguous()
return x
class SSA(BaseModule):
def __init__(self, dim, step=10, encode_type='direct', num_heads=16, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., sr_ratio=1):
super().__init__(step=10, encode_type='direct')
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
# for shortcut
self.head_lif = MyNode(step=step, tau=2.0)
self.num_heads = num_heads
# scale
self.scale = 0.25
self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
self.q_bn = nn.BatchNorm1d(dim)
self.q_lif = MyNode(step=step, tau=2.0)
self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
self.k_bn = nn.BatchNorm1d(dim)
self.k_lif = MyNode(step=step, tau=2.0)
self.v_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
self.v_bn = nn.BatchNorm1d(dim)
self.v_lif = MyNode(step=step, tau=2.0)
self.attn_drop = nn.Dropout(0.2)
self.res_lif = MyNode(step=step, tau=2.0)
self.attn_lif = MyNode(step=step, tau=2.0, v_threshold=0.5, )
self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
self.proj_bn = nn.BatchNorm1d(dim)
self.proj_lif = MyNode(step=step, tau=2.0, )
self.sd_lif = PLIFNode(step=step, threshold=0.5, tau=2)
def forward(self, x):
self.reset()
T, B, C, N = x.shape
x_for_qkv = x.flatten(0, 1) # TB, C N
x_for_qkv = self.head_lif(x_for_qkv)
q_conv_out = self.q_conv(x_for_qkv) # [TB] C N
q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N).contiguous() # T B C N
q_conv_out = self.q_lif(q_conv_out.flatten(0, 1)).reshape(T, B, C, N) # TB C N
q = q_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
k_conv_out = self.k_conv(x_for_qkv)
k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N).contiguous()
k_conv_out = self.k_lif(k_conv_out.flatten(0, 1)).reshape(T, B, C, N) # TB C N
k = k_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
v_conv_out = self.v_conv(x_for_qkv)
v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, N).contiguous()
v_conv_out = self.v_lif(v_conv_out.flatten(0, 1)).reshape(T, B, C, N) # TB C N
v = v_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
# Spike-driven Transformer attention
kv = k.mul(v)
kv = kv.sum(dim=-2, keepdim=True)
kv = self.sd_lif(kv)
x = q.mul(kv)
x = x.transpose(3,4).reshape(T, B, C, N).contiguous() # T B C N
# ignore following lines for membrane shortcut
# x = self.attn_lif(x.flatten(0,1)) #[TB] C N
# x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T, B, C, N) #T B C N
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, step=10, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = SSA(dim, step=step, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(step=step, in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
def forward(self, x):
# residual connection
x = x + self.attn(x)
x = x + self.mlp(x)
return x
# embed_dims = 256
class SPS(BaseModule):
def __init__(self, step=10, encode_type='direct', img_size_h=128, img_size_w=128, patch_size=4, in_channels=2,
embed_dims=256):
super().__init__(step=10, encode_type='direct')
self.image_size = [img_size_h, img_size_w]
patch_size = to_2tuple(patch_size) # 4->(4,4)
self.patch_size = patch_size # patch_size
self.C = in_channels # image_channel
self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1]
self.num_patches = self.H * self.W
# DVS with 2 more Maxpooling
self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn = nn.BatchNorm2d(embed_dims // 8)
self.proj_lif = MyNode(step=step, tau=2.0)
self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.proj_conv1 = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn1 = nn.BatchNorm2d(embed_dims // 4)
self.proj_lif1 = MyNode(step=step, tau=2.0)
self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.proj_conv2 = nn.Conv2d(embed_dims // 4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn2 = nn.BatchNorm2d(embed_dims // 2)
self.proj_lif2 = MyNode(step=step, tau=2.0)
self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.proj_conv3 = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn3 = nn.BatchNorm2d(embed_dims)
# self.proj_lif3 = MyNode(step=step, tau=2.0)
self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.rpe_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
self.rpe_bn = nn.BatchNorm2d(embed_dims)
self.rpe_lif = MyNode(step=step, tau=2.0)
def forward(self, x):
self.reset()
T, B, C, H, W = x.shape
x = self.proj_conv(x.flatten(0, 1)) # have some fire value
x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous()
x = self.proj_lif(x.flatten(0, 1)).contiguous()
x = self.maxpool(x)
x = self.proj_conv1(x)
x = self.proj_bn1(x).reshape(T, B, -1, H // 2, W // 2).contiguous()
x = self.proj_lif1(x.flatten(0, 1)).contiguous()
x = self.maxpool1(x)
x = self.proj_conv2(x)
x = self.proj_bn2(x).reshape(T, B, -1, H // 4, W // 4).contiguous()
x = self.proj_lif2(x.flatten(0, 1)).contiguous()
x = self.maxpool2(x)
x = self.proj_conv3(x)
x = self.proj_bn3(x).reshape(T, B, -1, H // 8, W // 8)
# abandon the LIF here to leverage membrane shortcut
# x = self.proj_lif3(x.flatten(0, 1)).contiguous()
x = self.maxpool3(x.flatten(0,1)).reshape(T, B, -1, H // 16, W // 16)
# The order here is different from spikformer for using membrain shortcut
x_rpe = self.rpe_lif(x.flatten(0, 1)).contiguous()
x_rpe = self.rpe_bn(self.rpe_conv(x_rpe)).reshape(T, B, -1, H // 16, W // 16).contiguous()
x = x + x_rpe # membrane shortcut
x = x.reshape(T, B, -1, (H // 16) * (H // 16)).contiguous()
return x # T B C N
class Spikformer(BaseModule):
def __init__(self, step=10, encode_type='direct',
img_size_h=64, img_size_w=64, patch_size=4, in_channels=2, num_classes=10,
embed_dims=256, num_heads=16, mlp_ratios=4, qkv_bias=False, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=2, sr_ratios=4,
):
super().__init__(step=10, encode_type='direct')
self.step = step # time step
self.num_classes = num_classes
self.depths = depths
# for membrane shortcut
self.final_lif = MyNode(step=step,tau=2.0)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)] # stochastic depth decay rule
patch_embed = SPS(step=step,
img_size_h=img_size_h,
img_size_w=img_size_w,
patch_size=patch_size,
in_channels=in_channels,
embed_dims=embed_dims)
block = nn.ModuleList([Block(step=step,
dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],
norm_layer=norm_layer, sr_ratio=sr_ratios)
for j in range(depths)])
setattr(self, f"patch_embed", patch_embed)
setattr(self, f"block", block)
# classification head
self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
@torch.jit.ignore
def _get_pos_embed(self, pos_embed, patch_embed, H, W):
if H * W == self.patch_embed1.num_patches:
return pos_embed
else:
return F.interpolate(
pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
block = getattr(self, f"block")
patch_embed = getattr(self, f"patch_embed")
x = patch_embed(x)
for blk in block:
x = blk(x)
# for membrane shortcut
T, B , C, N = x.shape
x = self.final_lif(x.flatten(0,1)).reshape(T, B, C, N).contiguous()
return x.mean(3)
def forward(self, x):
self.reset()
x = x.permute(1, 0, 2, 3, 4) # [T, N, 2, *, *]
x = self.forward_features(x)
x = self.head(x.mean(0))
return x
# Adjust ur hyperparams here
@register_model
def sd_transformer_dvs(pretrained=False, **kwargs):
model = Spikformer(step = 8,
img_size_h=64, img_size_w=64,
patch_size=4, embed_dims=256, num_heads=16, mlp_ratios=4,
in_channels=2, num_classes=10, qkv_bias=False,
depths=2, sr_ratios=1,
**kwargs
)
model.default_cfg = _cfg()
return model
================================================
FILE: examples/Spiking-Transformers/models/spike_driven_transformer_v2.py
================================================
import torch
import torch.nn as nn
from timm.models.layers import to_2tuple, trunc_normal_, DropPath
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
import torch.nn.functional as F
from braincog.model_zoo.base_module import BaseModule
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.strategy.surrogate import *
from LIFNode import MyNode # LIFNode setting for Spiking Tranformers
from functools import partial
__all__ = ['spikformer']
'''The input shape of neuromorphic datasets in Spiking Transformer when using Braincog
are used to set to 64*64 '''
'''Here the second version of Spike-driven Transformer only open sourced the
code for img cla '''
# Modified Operators
class BNAndPadLayer(nn.Module):
def __init__(
self,
pad_pixels,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
):
super(BNAndPadLayer, self).__init__()
self.bn = nn.BatchNorm2d(
num_features, eps, momentum, affine, track_running_stats
)
self.pad_pixels = pad_pixels
def forward(self, input):
output = self.bn(input)
if self.pad_pixels > 0:
if self.bn.affine:
pad_values = (
self.bn.bias.detach()
- self.bn.running_mean
* self.bn.weight.detach()
/ torch.sqrt(self.bn.running_var + self.bn.eps)
)
else:
pad_values = -self.bn.running_mean / torch.sqrt(
self.bn.running_var + self.bn.eps
)
output = F.pad(output, [self.pad_pixels] * 4)
pad_values = pad_values.view(1, -1, 1, 1)
output[:, :, 0 : self.pad_pixels, :] = pad_values
output[:, :, -self.pad_pixels :, :] = pad_values
output[:, :, :, 0 : self.pad_pixels] = pad_values
output[:, :, :, -self.pad_pixels :] = pad_values
return output
@property
def weight(self):
return self.bn.weight
@property
def bias(self):
return self.bn.bias
@property
def running_mean(self):
return self.bn.running_mean
@property
def running_var(self):
return self.bn.running_var
@property
def eps(self):
return self.bn.eps
class RepConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
bias=False,
):
super().__init__()
# hidden_channel = in_channel
conv1x1 = nn.Conv2d(in_channels, in_channels, 1, 1, 0, bias=False, groups=1)
bn = BNAndPadLayer(pad_pixels=1, num_features=in_channels)
conv3x3 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 3, 1, 0, groups=in_channels, bias=False),
nn.Conv2d(in_channels, out_channels, 1, 1, 0, groups=1, bias=False),
nn.BatchNorm2d(out_channels),
)
self.body = nn.Sequential(conv1x1, bn, conv3x3)
def forward(self, x):
return self.body(x)
class SepConv(BaseModule):
r"""
Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381.
"""
def __init__(
self,
dim,
step=8,
encode_type='direct',
expansion_ratio=2,
act2_layer=nn.Identity,
bias=False,
kernel_size=7,
padding=3,
):
super().__init__(step=step,encode_type=encode_type,)
med_channels = int(expansion_ratio * dim)
self.lif1 = MyNode(step=step,tau=2.0)
self.pwconv1 = nn.Conv2d(dim, med_channels, kernel_size=1, stride=1, bias=bias)
self.bn1 = nn.BatchNorm2d(med_channels)
self.lif2 =MyNode(step=step,tau=2.0)
self.dwconv = nn.Conv2d(
med_channels,
med_channels,
kernel_size=kernel_size,
padding=padding,
groups=med_channels,
bias=bias,
) # depthwise conv
self.pwconv2 = nn.Conv2d(med_channels, dim, kernel_size=1, stride=1, bias=bias)
self.bn2 = nn.BatchNorm2d(dim)
def forward(self, x):
self.reset()
T, B, C, H, W = x.shape
x = self.lif1(x.flatten(0,1)).reshape(T,B,C,H,W).contiguous()
x = self.bn1(self.pwconv1(x.flatten(0, 1))).reshape(T, B, -1, H, W)
x = self.lif2(x.flatten(0,1)).reshape(T,B,-1,H,W).contiguous()
x = self.dwconv(x.flatten(0, 1))
x = self.bn2(self.pwconv2(x)).reshape(T, B, -1, H, W)
return x # T B C H W
class MLP(BaseModule):
#Linear here is subsituted by convs
def __init__(self, in_features, step=10, encode_type='direct', hidden_features=None, out_features=None, drop=0.):
super().__init__(step=10, encode_type='direct')
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1)
self.fc1_bn = nn.BatchNorm1d(hidden_features)
self.fc1_lif = MyNode(step=step, tau=2.0)
self.fc2_conv = nn.Conv1d(hidden_features, out_features, kernel_size=1, stride=1)
self.fc2_bn = nn.BatchNorm1d(out_features)
self.fc2_lif = MyNode(step=step, tau=2.0)
self.c_hidden = hidden_features
self.c_output = out_features
def forward(self, x):
self.reset()
T, B, C, H, W = x.shape
x = x.flatten(3) # T B C N
_, _, _, N = x.shape
x = self.fc1_lif(x.flatten(0, 1)).reshape(T, B, C, N).contiguous()
x = self.fc1_conv(x.flatten(0, 1))
x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N).contiguous() # T B C N
x = self.fc2_lif(x.flatten(0, 1)).reshape(T, B, self.c_hidden, N).contiguous()
x = self.fc2_conv(x.flatten(0, 1))
x = self.fc2_bn(x).reshape(T, B, C, H, W).contiguous()
return x # T B C H W
# convs in SDSA V3/V4 should be substituted
class SDSA(BaseModule):
def __init__(self, dim, step=10, encode_type='direct', num_heads=16, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., sr_ratio=1):
super().__init__(step=10, encode_type='direct')
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
# scale
self.scale = 0.125
self.head_lif = MyNode(step=step, tau=2.0) # for spike-drivens
self.q_conv = RepConv(dim, dim, bias=False)
self.q_bn = nn.BatchNorm2d(dim)
self.q_lif = MyNode(step=step, tau=2.0)
self.k_conv = RepConv(dim, dim, bias=False)
self.k_bn = nn.BatchNorm2d(dim)
self.k_lif = MyNode(step=step, tau=2.0)
self.v_conv = RepConv(dim, dim, bias=False)
self.v_bn = nn.BatchNorm2d(dim)
self.v_lif = MyNode(step=step, tau=2.0)
self.attn_drop = nn.Dropout(0.2)
self.res_lif = MyNode(step=step, tau=2.0)
self.attn_lif = MyNode(step=step, tau=2.0, v_threshold=0.5, )
self.proj_conv = RepConv(dim, dim, bias=False)
self.proj_bn = nn.BatchNorm2d(dim)
def forward(self, x):
self.reset()
#different here
T, B, C, H, W = x.shape
N = H * W
x = self.head_lif(x.flatten(0,1)).reshape(T, B, C, H, W).contiguous()
x_for_qkv = x.flatten(0, 1) # TB C H W
q_conv_out = self.q_conv(x_for_qkv) # [TB] C H W
q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, H, W).contiguous() # T B C H W
q_conv_out = self.q_lif(q_conv_out.flatten(0, 1)).reshape(T, B, C, N).transpose(-1,-2) # T B N C
q = q_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
k_conv_out = self.k_conv(x_for_qkv)
k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, H, W).contiguous()
k_conv_out = self.k_lif(k_conv_out.flatten(0, 1)).reshape(T, B, C, N).transpose(-1,-2) # T B N C
k = k_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
v_conv_out = self.v_conv(x_for_qkv)
v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, H, W).contiguous()
v_conv_out = self.v_lif(v_conv_out.flatten(0, 1)).reshape(T, B, C, N).transpose(-1,-2) # T B N C
v = v_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
x = k.transpose(-2, -1) @ v
x = (q @ x) * self.scale
x = x.transpose(3, 4).reshape(T, B, C, N).contiguous()
x = self.attn_lif(x).reshape(T, B, C, H, W)
x = x.reshape(T, B, C, H, W)
x = x.flatten(0, 1)
x = self.proj_conv(x)
x = self.proj_bn(x).reshape(T, B, C, H, W)
return x # T B C H W
class Block(nn.Module):
def __init__(self, dim, num_heads, step=10, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = SDSA(dim, step=step, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(step=step, in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
def forward(self, x):
# residual connection
x = x + self.attn(x)
x = x + self.mlp(x)
return x
class DownSampling(BaseModule):
def __init__(
self,
step=10,
encode_type='direct',
in_channels=2,
embed_dims=512,
kernel_size=3,
stride=2,
padding=1,
first_layer=True,
):
super().__init__(step=step,
encode_type=encode_type,)
self.encode_conv = nn.Conv2d(
in_channels,
embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
self.encode_bn = nn.BatchNorm2d(embed_dims)
if not first_layer:
self.encode_lif = MyNode(
tau=2.0,step=step
)
def forward(self, x):
self.reset()
T, B, C, H, W = x.shape
if hasattr(self, "encode_lif"):
x = self.encode_lif(x.flatten(0,1)).reshape(T,B,C,H,W).contiguous()
x = self.encode_conv(x.flatten(0, 1))
_, _, H, W = x.shape
x = self.encode_bn(x).reshape(T, B, -1, H, W).contiguous()
return x
class ConvBlock(BaseModule):
def __init__(
self,
dim,
step=10,
encode_type='direct',
mlp_ratio=4.0,
):
super().__init__(step=step,
encode_type=encode_type,)
self.Conv = SepConv(step=step,dim=dim)
# self.Conv = MHMC(dim=dim)
self.lif1 = MyNode(step=step,tau=2.0)
self.conv1 = nn.Conv2d(
dim, dim * mlp_ratio, kernel_size=3, padding=1, groups=1, bias=False
)
# self.conv1 = RepConv(dim, dim*mlp_ratio)
self.bn1 = nn.BatchNorm2d(dim * mlp_ratio)
self.lif2 = MyNode(step=step,tau=2.0)
self.conv2 = nn.Conv2d(
dim * mlp_ratio, dim, kernel_size=3, padding=1, groups=1, bias=False
)
# self.conv2 = RepConv(dim*mlp_ratio, dim)
self.bn2 = nn.BatchNorm2d(dim)
def forward(self, x):
self.reset()
T, B, C, H, W = x.shape
x = self.Conv(x) + x
x_feat = x
x = self.bn1(self.conv1(self.lif1(x.flatten(0,1)))).reshape(T, B, 4 * C, H, W)
x = self.bn2(self.conv2(self.lif2(x.flatten(0, 1)))).reshape(T, B, C, H, W)
x = x_feat + x
return x
class Spikformer(BaseModule):
def __init__(self, step=4, encode_type='direct',
img_size_h=64, img_size_w=64, patch_size=4, in_channels=2, num_classes=1000,
embed_dims=512, num_heads=16, mlp_ratios=4, qkv_bias=False, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=8, sr_ratios=4,kd=False,
):
super().__init__(step=10, encode_type='direct')
self.step = step # time step
self.num_classes = num_classes
self.depths = depths
self.block3_depths = 6
# for membrane shortcut
self.final_lif = MyNode(step=step,tau=2.0)
# channel for dvs
# 16 32 64 128 256
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)] # stochastic depth decay rule
self.downsample1_1 = DownSampling(
step=step,
in_channels=in_channels,
embed_dims=embed_dims // 16,
kernel_size=7,
stride=2,
padding=3,
first_layer=True,
)
self.ConvBlock1_1 = nn.ModuleList(
[ConvBlock(step=step,dim= embed_dims // 16, mlp_ratio=mlp_ratios)]
)
self.downsample1_2 = DownSampling(
step=step,
in_channels = embed_dims // 16,
embed_dims= embed_dims // 8,
kernel_size=3,
stride=2,
padding=1,
first_layer=False,
)
self.ConvBlock1_2 = nn.ModuleList(
[ConvBlock(step=step,dim=embed_dims // 8, mlp_ratio=mlp_ratios)]
)
self.downsample2 = DownSampling(
step=step,
in_channels=embed_dims // 8,
embed_dims=embed_dims // 4,
kernel_size=3,
stride=2,
padding=1,
first_layer=False,
)
self.ConvBlock2_1 = nn.ModuleList(
[ConvBlock(step=step,dim=embed_dims // 4, mlp_ratio=mlp_ratios)]
)
self.ConvBlock2_2 = nn.ModuleList(
[ConvBlock(step=step,dim=embed_dims // 4, mlp_ratio=mlp_ratios)]
)
self.downsample3 = DownSampling(
step=step,
in_channels=embed_dims // 4,
embed_dims=embed_dims // 2,
kernel_size=3,
stride=2,
padding=1,
first_layer=False,
)
self.block3 = nn.ModuleList(
[
Block(
step=step,
dim=embed_dims // 2,
num_heads=num_heads,
mlp_ratio=mlp_ratios,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
# drop_path=dpr[j],
norm_layer=norm_layer,
sr_ratio=sr_ratios,
)
for j in range(self.block3_depths)
]
)
self.downsample4 = DownSampling(
step=step,
in_channels=embed_dims // 2,
embed_dims=embed_dims,
kernel_size=3,
stride=1,
padding=1,
first_layer=False,
)
self.block4 = nn.ModuleList(
[
Block(
step=step,
dim=embed_dims,
num_heads=num_heads,
mlp_ratio=mlp_ratios,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[j],
norm_layer=norm_layer,
sr_ratio=sr_ratios,
)
for j in range(self.depths-self.block3_depths)
]
)
# classification head
self.lif = MyNode(step=step,tau=2.0,)
self.head = (
nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()
)
self.kd = kd
if self.kd:
self.head_kd = (
nn.Linear(embed_dims, num_classes)
if num_classes > 0
else nn.Identity()
)
self.apply(self._init_weights)
# setattr(self, f"patch_embed", patch_embed)
# setattr(self, f"block", block)
@torch.jit.ignore
def _get_pos_embed(self, pos_embed, patch_embed, H, W):
if H * W == self.patch_embed1.num_patches:
return pos_embed
else:
return F.interpolate(
pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
x = self.downsample1_1(x)
for blk in self.ConvBlock1_1:
x = blk(x)
x = self.downsample1_2(x)
for blk in self.ConvBlock1_2:
x = blk(x)
x = self.downsample2(x)
for blk in self.ConvBlock2_1:
x = blk(x)
for blk in self.ConvBlock2_2:
x = blk(x)
x = self.downsample3(x)
for blk in self.block3: # attention here
x = blk(x)
x = self.downsample4(x) # attention here
for blk in self.block4:
x = blk(x)
return x # T,B,C,H,W
def forward(self, x):
self.reset()
x = self.encoder(x) # [T, N, 2, *, *]
x = self.forward_features(x)
x = x.flatten(3).mean(3)
T,B,_ = x.shape
x_lif = self.lif(x.flatten(0,1)).reshape(T,B,-1)
x = self.head(x_lif).mean(0)
if self.kd:
x_kd = self.head_kd(x_lif).mean(0)
if self.training:
return x, x_kd
else:
return (x + x_kd) / 2
return x
# Adjust ur hyperparams here
@register_model
def sd_transformer_v2(pretrained=False, **kwargs):
model = Spikformer(step = 4,
img_size_h=224, img_size_w=224,
patch_size=16, embed_dims=512, num_heads=12, mlp_ratios=4,
in_channels=3, num_classes=1000, qkv_bias=False,
depths=2, sr_ratios=1,
**kwargs
)
model.default_cfg = _cfg()
return model
================================================
FILE: examples/Spiking-Transformers/models/spike_driven_transformer_v2_dvs.py
================================================
import torch
import torch.nn as nn
from timm.models.layers import to_2tuple, trunc_normal_, DropPath
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
import torch.nn.functional as F
from braincog.model_zoo.base_module import BaseModule
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.strategy.surrogate import *
from LIFNode import MyNode # LIFNode setting for Spiking Tranformers
from functools import partial
__all__ = ['spikformer']
'''The input shape of neuromorphic datasets in Spiking Transformer when using Braincog
are used to set to 64*64 '''
'''Here the second version of Spike-driven Transformer only open sourced the
code for img cla '''
# Modified Operators
class BNAndPadLayer(nn.Module):
def __init__(
self,
pad_pixels,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
):
super(BNAndPadLayer, self).__init__()
self.bn = nn.BatchNorm2d(
num_features, eps, momentum, affine, track_running_stats
)
self.pad_pixels = pad_pixels
def forward(self, input):
output = self.bn(input)
if self.pad_pixels > 0:
if self.bn.affine:
pad_values = (
self.bn.bias.detach()
- self.bn.running_mean
* self.bn.weight.detach()
/ torch.sqrt(self.bn.running_var + self.bn.eps)
)
else:
pad_values = -self.bn.running_mean / torch.sqrt(
self.bn.running_var + self.bn.eps
)
output = F.pad(output, [self.pad_pixels] * 4)
pad_values = pad_values.view(1, -1, 1, 1)
output[:, :, 0 : self.pad_pixels, :] = pad_values
output[:, :, -self.pad_pixels :, :] = pad_values
output[:, :, :, 0 : self.pad_pixels] = pad_values
output[:, :, :, -self.pad_pixels :] = pad_values
return output
@property
def weight(self):
return self.bn.weight
@property
def bias(self):
return self.bn.bias
@property
def running_mean(self):
return self.bn.running_mean
@property
def running_var(self):
return self.bn.running_var
@property
def eps(self):
return self.bn.eps
class RepConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
bias=False,
):
super().__init__()
# hidden_channel = in_channel
conv1x1 = nn.Conv2d(in_channels, in_channels, 1, 1, 0, bias=False, groups=1)
bn = BNAndPadLayer(pad_pixels=1, num_features=in_channels)
conv3x3 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 3, 1, 0, groups=in_channels, bias=False),
nn.Conv2d(in_channels, out_channels, 1, 1, 0, groups=1, bias=False),
nn.BatchNorm2d(out_channels),
)
self.body = nn.Sequential(conv1x1, bn, conv3x3)
def forward(self, x):
return self.body(x)
class SepConv(BaseModule):
r"""
Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381.
"""
def __init__(
self,
dim,
step=8,
encode_type='direct',
expansion_ratio=2,
act2_layer=nn.Identity,
bias=False,
kernel_size=7,
padding=3,
):
super().__init__(step=step,encode_type=encode_type,)
med_channels = int(expansion_ratio * dim)
self.lif1 = MyNode(step=step,tau=2.0)
self.pwconv1 = nn.Conv2d(dim, med_channels, kernel_size=1, stride=1, bias=bias)
self.bn1 = nn.BatchNorm2d(med_channels)
self.lif2 =MyNode(step=step,tau=2.0)
self.dwconv = nn.Conv2d(
med_channels,
med_channels,
kernel_size=kernel_size,
padding=padding,
groups=med_channels,
bias=bias,
) # depthwise conv
self.pwconv2 = nn.Conv2d(med_channels, dim, kernel_size=1, stride=1, bias=bias)
self.bn2 = nn.BatchNorm2d(dim)
def forward(self, x):
self.reset()
T, B, C, H, W = x.shape
x = self.lif1(x.flatten(0,1)).reshape(T,B,C,H,W).contiguous()
x = self.bn1(self.pwconv1(x.flatten(0, 1))).reshape(T, B, -1, H, W)
x = self.lif2(x.flatten(0,1)).reshape(T,B,-1,H,W).contiguous()
x = self.dwconv(x.flatten(0, 1))
x = self.bn2(self.pwconv2(x)).reshape(T, B, -1, H, W)
return x # T B C H W
class MLP(BaseModule):
#Linear here is subsituted by convs
def __init__(self, in_features, step=10, encode_type='direct', hidden_features=None, out_features=None, drop=0.):
super().__init__(step=10, encode_type='direct')
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1)
self.fc1_bn = nn.BatchNorm1d(hidden_features)
self.fc1_lif = MyNode(step=step, tau=2.0)
self.fc2_conv = nn.Conv1d(hidden_features, out_features, kernel_size=1, stride=1)
self.fc2_bn = nn.BatchNorm1d(out_features)
self.fc2_lif = MyNode(step=step, tau=2.0)
self.c_hidden = hidden_features
self.c_output = out_features
def forward(self, x):
self.reset()
T, B, C, H, W = x.shape
x = x.flatten(3) # T B C N
_, _, _, N = x.shape
x = self.fc1_lif(x.flatten(0, 1)).reshape(T, B, C, N).contiguous()
x = self.fc1_conv(x.flatten(0, 1))
x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N).contiguous() # T B C N
x = self.fc2_lif(x.flatten(0, 1)).reshape(T, B, self.c_hidden, N).contiguous()
x = self.fc2_conv(x.flatten(0, 1))
x = self.fc2_bn(x).reshape(T, B, C, H, W).contiguous()
return x # T B C H W
# convs in SDSA V3/V4 should be substituted
class SDSA(BaseModule):
def __init__(self, dim, step=10, encode_type='direct', num_heads=16, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., sr_ratio=1):
super().__init__(step=10, encode_type='direct')
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
# scale
self.scale = 0.125
self.head_lif = MyNode(step=step, tau=2.0) # for spike-drivens
self.q_conv = RepConv(dim, dim, bias=False)
self.q_bn = nn.BatchNorm2d(dim)
self.q_lif = MyNode(step=step, tau=2.0)
self.k_conv = RepConv(dim, dim, bias=False)
self.k_bn = nn.BatchNorm2d(dim)
self.k_lif = MyNode(step=step, tau=2.0)
self.v_conv = RepConv(dim, dim, bias=False)
self.v_bn = nn.BatchNorm2d(dim)
self.v_lif = MyNode(step=step, tau=2.0)
self.attn_drop = nn.Dropout(0.2)
self.res_lif = MyNode(step=step, tau=2.0)
self.attn_lif = MyNode(step=step, tau=2.0, v_threshold=0.5, )
self.proj_conv = RepConv(dim, dim, bias=False)
self.proj_bn = nn.BatchNorm2d(dim)
def forward(self, x):
self.reset()
#different here
T, B, C, H, W = x.shape
N = H * W
x = self.head_lif(x.flatten(0,1)).reshape(T, B, C, H, W).contiguous()
x_for_qkv = x.flatten(0, 1) # TB C H W
q_conv_out = self.q_conv(x_for_qkv) # [TB] C H W
q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, H, W).contiguous() # T B C H W
q_conv_out = self.q_lif(q_conv_out.flatten(0, 1)).reshape(T, B, C, N).transpose(-1,-2) # T B N C
q = q_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
k_conv_out = self.k_conv(x_for_qkv)
k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, H, W).contiguous()
k_conv_out = self.k_lif(k_conv_out.flatten(0, 1)).reshape(T, B, C, N).transpose(-1,-2) # T B N C
k = k_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
v_conv_out = self.v_conv(x_for_qkv)
v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, H, W).contiguous()
v_conv_out = self.v_lif(v_conv_out.flatten(0, 1)).reshape(T, B, C, N).transpose(-1,-2) # T B N C
v = v_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
x = k.transpose(-2, -1) @ v
x = (q @ x) * self.scale
x = x.transpose(3, 4).reshape(T, B, C, N).contiguous()
x = self.attn_lif(x).reshape(T, B, C, H, W)
x = x.reshape(T, B, C, H, W)
x = x.flatten(0, 1)
x = self.proj_conv(x)
x = self.proj_bn(x).reshape(T, B, C, H, W)
return x # T B C H W
class Block(nn.Module):
def __init__(self, dim, num_heads, step=10, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = SDSA(dim, step=step, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(step=step, in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
def forward(self, x):
# residual connection
x = x + self.attn(x)
x = x + self.mlp(x)
return x
class DownSampling(BaseModule):
def __init__(
self,
step=10,
encode_type='direct',
in_channels=2,
embed_dims=256,
kernel_size=3,
stride=2,
padding=1,
first_layer=True,
):
super().__init__(step=step,
encode_type=encode_type,)
self.encode_conv = nn.Conv2d(
in_channels,
embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
self.encode_bn = nn.BatchNorm2d(embed_dims)
if not first_layer:
self.encode_lif = MyNode(
tau=2.0,step=step
)
def forward(self, x):
self.reset()
T, B, C, H, W = x.shape
if hasattr(self, "encode_lif"):
x = self.encode_lif(x.flatten(0,1)).reshape(T,B,C,H,W).contiguous()
x = self.encode_conv(x.flatten(0, 1))
_, _, H, W = x.shape
x = self.encode_bn(x).reshape(T, B, -1, H, W).contiguous()
return x
class ConvBlock(BaseModule):
def __init__(
self,
dim,
step=10,
encode_type='direct',
mlp_ratio=4.0,
):
super().__init__(step=step,
encode_type=encode_type,)
self.Conv = SepConv(step=step,dim=dim)
# self.Conv = MHMC(dim=dim)
self.lif1 = MyNode(step=step,tau=2.0)
self.conv1 = nn.Conv2d(
dim, dim * mlp_ratio, kernel_size=3, padding=1, groups=1, bias=False
)
# self.conv1 = RepConv(dim, dim*mlp_ratio)
self.bn1 = nn.BatchNorm2d(dim * mlp_ratio)
self.lif2 = MyNode(step=step,tau=2.0)
self.conv2 = nn.Conv2d(
dim * mlp_ratio, dim, kernel_size=3, padding=1, groups=1, bias=False
)
# self.conv2 = RepConv(dim*mlp_ratio, dim)
self.bn2 = nn.BatchNorm2d(dim)
def forward(self, x):
self.reset()
T, B, C, H, W = x.shape
x = self.Conv(x) + x
x_feat = x
x = self.bn1(self.conv1(self.lif1(x.flatten(0,1)))).reshape(T, B, 4 * C, H, W)
x = self.bn2(self.conv2(self.lif2(x.flatten(0, 1)))).reshape(T, B, C, H, W)
x = x_feat + x
return x
class Spikformer(BaseModule):
def __init__(self, step=10, encode_type='direct',
img_size_h=64, img_size_w=64, patch_size=4, in_channels=2, num_classes=10,
embed_dims=256, num_heads=16, mlp_ratios=4, qkv_bias=False, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=2, sr_ratios=4,kd=False,
):
super().__init__(step=10, encode_type='direct')
self.step = step # time step
self.num_classes = num_classes
self.depths = depths
self.block3_depths = 1
# for membrane shortcut
self.final_lif = MyNode(step=step,tau=2.0)
# channel for dvs
# 16 32 64 128 256
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)] # stochastic depth decay rule
self.downsample1_1 = DownSampling(
step=step,
in_channels=in_channels,
embed_dims=embed_dims // 16,
kernel_size=7,
stride=2,
padding=3,
first_layer=True,
)
self.ConvBlock1_1 = nn.ModuleList(
[ConvBlock(step=step,dim= embed_dims // 16, mlp_ratio=mlp_ratios)]
)
self.downsample1_2 = DownSampling(
step=step,
in_channels = embed_dims // 16,
embed_dims= embed_dims // 8,
kernel_size=3,
stride=2,
padding=1,
first_layer=False,
)
self.ConvBlock1_2 = nn.ModuleList(
[ConvBlock(step=step,dim=embed_dims // 8, mlp_ratio=mlp_ratios)]
)
self.downsample2 = DownSampling(
step=step,
in_channels=embed_dims // 8,
embed_dims=embed_dims // 4,
kernel_size=3,
stride=2,
padding=1,
first_layer=False,
)
self.ConvBlock2_1 = nn.ModuleList(
[ConvBlock(step=step,dim=embed_dims // 4, mlp_ratio=mlp_ratios)]
)
self.ConvBlock2_2 = nn.ModuleList(
[ConvBlock(step=step,dim=embed_dims // 4, mlp_ratio=mlp_ratios)]
)
self.downsample3 = DownSampling(
step=step,
in_channels=embed_dims // 4,
embed_dims=embed_dims // 2,
kernel_size=3,
stride=2,
padding=1,
first_layer=False,
)
self.block3 = nn.ModuleList(
[
Block(
step=step,
dim=embed_dims // 2,
num_heads=num_heads,
mlp_ratio=mlp_ratios,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
# drop_path=dpr[j],
norm_layer=norm_layer,
sr_ratio=sr_ratios,
)
for j in range(self.block3_depths)
]
)
self.downsample4 = DownSampling(
step=step,
in_channels=embed_dims // 2,
embed_dims=embed_dims,
kernel_size=3,
stride=1,
padding=1,
first_layer=False,
)
self.block4 = nn.ModuleList(
[
Block(
step=step,
dim=embed_dims,
num_heads=num_heads,
mlp_ratio=mlp_ratios,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[j],
norm_layer=norm_layer,
sr_ratio=sr_ratios,
)
for j in range(self.depths-self.block3_depths)
]
)
# classification head
self.lif = MyNode(step=step,tau=2.0,)
self.head = (
nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()
)
self.kd = kd
if self.kd:
self.head_kd = (
nn.Linear(embed_dims, num_classes)
if num_classes > 0
else nn.Identity()
)
self.apply(self._init_weights)
# setattr(self, f"patch_embed", patch_embed)
# setattr(self, f"block", block)
@torch.jit.ignore
def _get_pos_embed(self, pos_embed, patch_embed, H, W):
if H * W == self.patch_embed1.num_patches:
return pos_embed
else:
return F.interpolate(
pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
x = self.downsample1_1(x)
for blk in self.ConvBlock1_1:
x = blk(x)
x = self.downsample1_2(x)
for blk in self.ConvBlock1_2:
x = blk(x)
x = self.downsample2(x)
for blk in self.ConvBlock2_1:
x = blk(x)
for blk in self.ConvBlock2_2:
x = blk(x)
x = self.downsample3(x)
for blk in self.block3: # attention here
x = blk(x)
x = self.downsample4(x) # attention here
for blk in self.block4:
x = blk(x)
return x # T,B,C,H,W
def forward(self, x):
self.reset()
x = x.permute(1, 0, 2, 3, 4) # [T, N, 2, *, *]
x = self.forward_features(x)
x = x.flatten(3).mean(3)
T,B,_ = x.shape
x_lif = self.lif(x.flatten(0,1)).reshape(T,B,-1)
x = self.head(x_lif).mean(0)
if self.kd:
x_kd = self.head_kd(x_lif).mean(0)
if self.training:
return x, x_kd
else:
return (x + x_kd) / 2
return x
# Adjust ur hyperparams here
@register_model
def sd_transformer_v2_dvs(pretrained=False, **kwargs):
model = Spikformer(step = 8,
img_size_h=64, img_size_w=64,
patch_size=4, embed_dims=256, num_heads=16, mlp_ratios=4,
in_channels=2, num_classes=10, qkv_bias=False,
depths=2, sr_ratios=1,
**kwargs
)
model.default_cfg = _cfg()
return model
================================================
FILE: examples/Spiking-Transformers/models/spikformer.py
================================================
import torch
import torch.nn as nn
from timm.models.layers import to_2tuple, trunc_normal_, DropPath
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
import torch.nn.functional as F
from braincog.model_zoo.base_module import BaseModule
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.strategy.surrogate import *
from LIFNode import MyNode # LIFNode setting for Spiking Tranformers
from functools import partial
__all__ = ['spikformer']
class MLP(BaseModule):
# Linear -> BN -> LIF -> Linear -> BN -> LIF
def __init__(self, in_features, step=4, encode_type='direct', hidden_features=None, out_features=None, drop=0.):
super().__init__(step=step, encode_type=encode_type)
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1_linear = nn.Linear(in_features, hidden_features)
self.fc1_bn = nn.BatchNorm1d(hidden_features)
self.fc1_lif = MyNode(step=step,tau=2.0)
self.fc2_linear = nn.Linear(hidden_features, out_features)
self.fc2_bn = nn.BatchNorm1d(out_features)
self.fc2_lif = MyNode(step=step,tau=2.0)
self.c_hidden = hidden_features
self.c_output = out_features
def forward(self, x):
self.reset()
T, B, N, C = x.shape
x_ = x.flatten(0, 1) # TB N C
x = self.fc1_linear(x_)
x = self.fc1_bn(x.transpose(-1, -2)).transpose(-1, -2).reshape(T, B, N, self.c_hidden).contiguous() # T B N C
x = self.fc1_lif(x.flatten(0, 1)).reshape(T, B, N, self.c_hidden)
x = self.fc2_linear(x.flatten(0, 1))
x = self.fc2_bn(x.transpose(-1, -2)).transpose(-1, -2).reshape(T, B, N, C).contiguous()
x = self.fc2_lif(x.flatten(0, 1)).reshape(T, B, N, self.c_output)
return x
class SSA(BaseModule):
def __init__(self, dim, step=4, encode_type='rate', num_heads=12, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., sr_ratio=1):
super().__init__(step=step, encode_type=encode_type)
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
# 多头注意力 # of heads
self.num_heads = num_heads
# scale参数,用于防止KQ乘积结果过大
self.scale = 0.125
self.q_linear = nn.Linear(dim, dim)
self.q_bn = nn.BatchNorm1d(dim)
self.q_lif = MyNode(step=step,tau=2.0)
self.k_linear = nn.Linear(dim, dim)
self.k_bn = nn.BatchNorm1d(dim)
self.k_lif = MyNode(step=step,tau=2.0)
self.v_linear = nn.Linear(dim, dim)
self.v_bn = nn.BatchNorm1d(dim)
self.v_lif = MyNode(step=step,tau=2.0)
self.attn_lif = MyNode(step=step, tau=2.0, v_threshold=0.5, )
self.proj_linear = nn.Linear(dim, dim)
self.proj_bn = nn.BatchNorm1d(dim)
self.proj_lif = MyNode(step=step, tau=2.0, )
def forward(self, x):
self.reset()
T, B, N, C = x.shape
x_for_qkv = x.flatten(0, 1) # TB, N, C
q_linear_out = self.q_linear(x_for_qkv) # [TB, N, C]
q_linear_out = self.q_bn(q_linear_out.transpose(-1, -2)).transpose(-1, -2).reshape(T, B, N,
C).contiguous() # T B N C
q_linear_out = self.q_lif(q_linear_out.flatten(0, 1)).reshape(T, B, N, C) # TB N C
q = q_linear_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
k_linear_out = self.k_linear(x_for_qkv)
k_linear_out = self.k_bn(k_linear_out.transpose(-1, -2)).transpose(-1, -2).reshape(T, B, N, C).contiguous()
k_linear_out = self.k_lif(k_linear_out.flatten(0, 1)).reshape(T, B, N, C) # TB N C
k = k_linear_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
v_linear_out = self.v_linear(x_for_qkv)
v_linear_out = self.v_bn(v_linear_out.transpose(-1, -2)).transpose(-1, -2).reshape(T, B, N, C).contiguous()
v_linear_out = self.v_lif(v_linear_out.flatten(0, 1)).reshape(T, B, N, C) # TB N C
v = v_linear_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
# @表示矩阵乘法,与matmul等价
# K,QV -> attention -> scale -> LIF -> Linear -> BN
attn = (q @ k.transpose(-2, -1)) * self.scale
x = attn @ v
x = x.transpose(2, 3).reshape(T, B, N, C).contiguous()
x = self.attn_lif(x.flatten(0, 1)).reshape(T, B, N, C) # T B N C
x = x.flatten(0, 1) # TB N C
x = self.proj_lif(self.proj_bn(self.proj_linear(x).transpose(-1, -2)).transpose(-1, -2)).reshape(T, B, N, C)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, step =4, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
super().__init__()
self.step = 4
self.norm1 = norm_layer(dim)
self.attn = SSA(dim, step=self.step, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(step=self.step, in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
def forward(self, x):
# residual
x = x + self.attn(x)
x = x + self.mlp(x)
return x
# SPS for dimension adjustment
# embed_dims = 256
class SPS(BaseModule):
def __init__(self, step=4, encode_type='direct', img_size_h=32, img_size_w=32, patch_size=4, in_channels=3,
embed_dims=384):
super().__init__(step=step, encode_type=encode_type)
self.image_size = [img_size_h, img_size_w]
# timm内置to_2tuple把整形转换成2元元组
patch_size = to_2tuple(patch_size) # 4->(4,4)
self.patch_size = patch_size # patch_size
self.C = in_channels # image_channel
self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1] # 重新计算patch之后的图片大小
self.num_patches = self.H * self.W # patch数量
self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn = nn.BatchNorm2d(embed_dims // 8)
self.proj_lif = MyNode(step=step,tau=2.0)
self.proj_conv1 = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn1 = nn.BatchNorm2d(embed_dims // 4)
self.proj_lif1 = MyNode(step=step,tau=2.0)
self.proj_conv2 = nn.Conv2d(embed_dims // 4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn2 = nn.BatchNorm2d(embed_dims // 2)
self.proj_lif2 = MyNode(step=step,tau=2.0)
self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.proj_conv3 = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn3 = nn.BatchNorm2d(embed_dims)
self.proj_lif3 = MyNode(step=step,tau=2.0)
self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.rpe_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
self.rpe_bn = nn.BatchNorm2d(embed_dims)
self.rpe_lif = MyNode(step=step,tau=2.0)
def forward(self, x):
self.reset()
T, B, C, H, W = x.shape
x = self.proj_conv(x.flatten(0, 1)) # have some fire value
x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous()
x = self.proj_lif(x.flatten(0, 1)).contiguous()
x = self.proj_conv1(x)
x = self.proj_bn1(x).reshape(T, B, -1, H, W).contiguous()
x = self.proj_lif1(x.flatten(0, 1)).contiguous()
x = self.proj_conv2(x)
x = self.proj_bn2(x).reshape(T, B, -1, H, W).contiguous()
x = self.proj_lif2(x.flatten(0, 1)).contiguous()
x = self.maxpool2(x)
x = self.proj_conv3(x)
x = self.proj_bn3(x).reshape(T, B, -1, H // 2, W // 2).contiguous()
x = self.proj_lif3(x.flatten(0, 1)).contiguous()
x = self.maxpool3(x)
x_feat = x.reshape(T, B, -1, H // 4, W // 4).contiguous()
x = self.rpe_conv(x)
x = self.rpe_bn(x).reshape(T, B, -1, H // 4, W // 4).contiguous()
x = self.rpe_lif(x.flatten(0, 1)).reshape(T, B, -1, H // 4, W // 4)
x = x + x_feat
x = x.flatten(-2).transpose(-1, -2) # T,B,N,C
return x
class Spikformer(BaseModule):
def __init__(self, step=4, encode_type='direct',
img_size_h=224, img_size_w=224, patch_size=16, in_channels=3, num_classes=1000,
embed_dims=384, num_heads=12, mlp_ratios=4, qkv_bias=False, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=4, sr_ratios=4,
):
super().__init__(step=step, encode_type=encode_type)
self.step = step # time step
self.num_classes = num_classes
self.depths = depths
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)] # stochastic depth decay rule
patch_embed = SPS(step = self.step,
img_size_h=img_size_h,
img_size_w=img_size_w,
patch_size=patch_size,
in_channels=in_channels,
embed_dims=embed_dims)
block = nn.ModuleList([Block(step=self.step,
dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],
norm_layer=norm_layer, sr_ratio=sr_ratios)
for j in range(depths)])
setattr(self, f"patch_embed", patch_embed)
setattr(self, f"block", block)
# classification head
self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
@torch.jit.ignore
def _get_pos_embed(self, pos_embed, patch_embed, H, W):
if H * W == self.patch_embed1.num_patches:
return pos_embed
else:
return F.interpolate(
pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
block = getattr(self, f"block")
patch_embed = getattr(self, f"patch_embed")
x = patch_embed(x)
for blk in block:
x = blk(x)
return x.mean(2)
def forward(self, x):
x = self.encoder(x)
x = self.forward_features(x)
x = self.head(x.mean(0))
return x
@register_model
def spikformer(pretrained=False, **kwargs):
model = Spikformer(
step=4,
img_size_h=224, img_size_w=224,
patch_size=16, embed_dims=512, num_heads=8, mlp_ratios=4,
in_channels=3, num_classes=1000, qkv_bias=False,
depths=8, sr_ratios=1,
**kwargs
)
model.default_cfg = _cfg()
return model
================================================
FILE: examples/Spiking-Transformers/models/spikformer_dvs.py
================================================
import torch
import torch.nn as nn
from timm.models.layers import to_2tuple, trunc_normal_, DropPath
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
import torch.nn.functional as F
from braincog.model_zoo.base_module import BaseModule
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.strategy.surrogate import *
from LIFNode import MyNode # LIFNode setting for Spiking Tranformers
from functools import partial
__all__ = ['spikformer']
'''The input shape of neuromorphic datasets in Spiking Transformer when using Braincog
are used to set to 64*64 '''
class MLP(BaseModule):
#Linear here is subsituted by convs
def __init__(self, in_features, step=10, encode_type='direct', hidden_features=None, out_features=None, drop=0.):
super().__init__(step=step, encode_type=encode_type)
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1)
self.fc1_bn = nn.BatchNorm1d(hidden_features)
self.fc1_lif = MyNode(step=step, tau=2.0)
self.fc2_conv = nn.Conv1d(hidden_features, out_features, kernel_size=1, stride=1)
self.fc2_bn = nn.BatchNorm1d(out_features)
self.fc2_lif = MyNode(step=step, tau=2.0)
self.c_hidden = hidden_features
self.c_output = out_features
def forward(self, x):
self.reset()
T, B, C, N = x.shape
x = self.fc1_conv(x.flatten(0, 1))
x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N).contiguous() # T B C N
x = self.fc1_lif(x.flatten(0, 1)).reshape(T, B, self.c_hidden, N).contiguous()
x = self.fc2_conv(x.flatten(0, 1))
x = self.fc2_bn(x).reshape(T, B, C, N).contiguous()
x = self.fc2_lif(x.flatten(0, 1)).reshape(T, B, C, N).contiguous()
return x
class SSA(BaseModule):
def __init__(self, dim, step=10, encode_type='direct', num_heads=16, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., sr_ratio=1):
super().__init__(step=step, encode_type=encode_type)
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
# scale
self.scale = 0.25
self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
self.q_bn = nn.BatchNorm1d(dim)
self.q_lif = MyNode(step=step, tau=2.0)
self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
self.k_bn = nn.BatchNorm1d(dim)
self.k_lif = MyNode(step=step, tau=2.0)
self.v_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
self.v_bn = nn.BatchNorm1d(dim)
self.v_lif = MyNode(step=step, tau=2.0)
self.attn_drop = nn.Dropout(0.2)
self.res_lif = MyNode(step=step, tau=2.0)
self.attn_lif = MyNode(step=step, tau=2.0, v_threshold=0.5, )
self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
self.proj_bn = nn.BatchNorm1d(dim)
self.proj_lif = MyNode(step=step, tau=2.0, )
def forward(self, x):
self.reset()
T, B, C, N = x.shape
x_for_qkv = x.flatten(0, 1) # TB, C N
q_conv_out = self.q_conv(x_for_qkv) # [TB] C N
q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N).contiguous() # T B C N
q_conv_out = self.q_lif(q_conv_out.flatten(0, 1)).reshape(T, B, C, N) # TB C N
q = q_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
k_conv_out = self.k_conv(x_for_qkv)
k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N).contiguous()
k_conv_out = self.k_lif(k_conv_out.flatten(0, 1)).reshape(T, B, C, N) # TB C N
k = k_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
v_conv_out = self.v_conv(x_for_qkv)
v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, N).contiguous()
v_conv_out = self.v_lif(v_conv_out.flatten(0, 1)).reshape(T, B, C, N) # TB C N
v = v_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
# @表示矩阵乘法,与matmul等价
# K,QV -> attention -> scale -> LIF -> Linear -> BN
attn = (q @ k.transpose(-2, -1))
x = (attn @ v) * self.scale
x = x.transpose(3, 4).reshape(T, B, C, N).contiguous() # T B C N
x = self.attn_lif(x.flatten(0, 1)) # [TB] C N
x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T, B, C, N) # T B C N
return x
# 整个encoder block,要在SSA和MLP的基础上加入残差
class Block(nn.Module):
def __init__(self, dim, num_heads, step=10, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = SSA(dim, step=step, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(step=step, in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
def forward(self, x):
# residual connection
x = x + self.attn(x)
x = x + self.mlp(x)
return x
# embed_dims = 256
class SPS(BaseModule):
def __init__(self, step=10, encode_type='direct', img_size_h=128, img_size_w=128, patch_size=4, in_channels=2,
embed_dims=256):
super().__init__(step=step, encode_type=encode_type)
self.image_size = [img_size_h, img_size_w]
# timm内置to_2tuple把整形转换成2元元组
patch_size = to_2tuple(patch_size) # 4->(4,4)
self.patch_size = patch_size # patch_size
self.C = in_channels # image_channel
self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1]
self.num_patches = self.H * self.W
# DVS with 2 more Maxpooling
self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn = nn.BatchNorm2d(embed_dims // 8)
self.proj_lif = MyNode(step=step, tau=2.0)
self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.proj_conv1 = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn1 = nn.BatchNorm2d(embed_dims // 4)
self.proj_lif1 = MyNode(step=step, tau=2.0)
self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.proj_conv2 = nn.Conv2d(embed_dims // 4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn2 = nn.BatchNorm2d(embed_dims // 2)
self.proj_lif2 = MyNode(step=step, tau=2.0)
self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.proj_conv3 = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn3 = nn.BatchNorm2d(embed_dims)
self.proj_lif3 = MyNode(step=step, tau=2.0)
self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.rpe_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
self.rpe_bn = nn.BatchNorm2d(embed_dims)
self.rpe_lif = MyNode(step=step, tau=2.0)
def forward(self, x):
self.reset()
T, B, C, H, W = x.shape
x = self.proj_conv(x.flatten(0, 1)) # have some fire value
x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous()
x = self.proj_lif(x.flatten(0, 1)).contiguous()
x = self.maxpool(x)
x = self.proj_conv1(x)
x = self.proj_bn1(x).reshape(T, B, -1, H // 2, W // 2).contiguous()
x = self.proj_lif1(x.flatten(0, 1)).contiguous()
x = self.maxpool1(x)
x = self.proj_conv2(x)
x = self.proj_bn2(x).reshape(T, B, -1, H // 4, W // 4).contiguous()
x = self.proj_lif2(x.flatten(0, 1)).contiguous()
x = self.maxpool2(x)
x = self.proj_conv3(x)
x = self.proj_bn3(x).reshape(T, B, -1, H // 8, W // 8).contiguous()
x = self.proj_lif3(x.flatten(0, 1)).contiguous()
x = self.maxpool3(x)
x_rpe = self.rpe_bn(self.rpe_conv(x)).reshape(T, B, -1, H // 16, W // 16).contiguous()
x_rpe = self.rpe_lif(x_rpe.flatten(0, 1)).contiguous()
x = x + x_rpe
x = x.reshape(T, B, -1, (H // 16) * (H // 16)).contiguous()
return x # T B C N
class Spikformer(nn.Module):
def __init__(self, step=10,
img_size_h=64, img_size_w=64, patch_size=4, in_channels=2, num_classes=10,
embed_dims=256, num_heads=16, mlp_ratios=4, qkv_bias=False, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=2, sr_ratios=4,
):
super().__init__()
self.step = step # time step
self.num_classes = num_classes
self.depths = depths
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)] # stochastic depth decay rule
patch_embed = SPS(step=step,
img_size_h=img_size_h,
img_size_w=img_size_w,
patch_size=patch_size,
in_channels=in_channels,
embed_dims=embed_dims)
block = nn.ModuleList([Block(step=step,
dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],
norm_layer=norm_layer, sr_ratio=sr_ratios)
for j in range(depths)])
setattr(self, f"patch_embed", patch_embed)
setattr(self, f"block", block)
# classification head
self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
@torch.jit.ignore
def _get_pos_embed(self, pos_embed, patch_embed, H, W):
if H * W == self.patch_embed1.num_patches:
return pos_embed
else:
return F.interpolate(
pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
block = getattr(self, f"block")
patch_embed = getattr(self, f"patch_embed")
x = patch_embed(x)
for blk in block:
x = blk(x)
return x.mean(3)
def forward(self, x):
x = x.permute(1, 0, 2, 3, 4) # [T, N, 2, *, *]
x = self.forward_features(x)
x = self.head(x.mean(0))
return x
# Adjust ur hyperparams here
@register_model
def spikformer_dvs(pretrained=False, **kwargs):
model = Spikformer(step = 8,
img_size_h=64, img_size_w=64,
patch_size=4, embed_dims=256, num_heads=16, mlp_ratios=4,
in_channels=2, num_classes=10, qkv_bias=False,
depths=2, sr_ratios=1,
**kwargs
)
model.default_cfg = _cfg()
return model
================================================
FILE: examples/Structural_Development/DPAP/README.md
================================================
# Developmental Plasticity-inspired Adaptive Pruning for Deep Spiking and Artificial Neural Networks #
## Requirments ##
* matplotlib==3.5.1
* numpy==1.22.4
* Pillow==9.3.0
* scipy==1.9.3
* tensorboardX==2.5.1
* torch==1.8.1+cu111
* torchvision==0.9.1+cu111
## Run ##
``` CUDA_VISIBLE_DEVICES=0 python prun_ main.py```
## Citation ##
If you find the code and dataset useful in your research, please consider citing:
```
@article{han2024similarity,
title={Developmental Plasticity-inspired Adaptive Pruning for Deep Spiking and Artificial Neural Networks},
author={Han, Bing and Zhao, Feifei and Zeng, Yi and Shen Guobin},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
year={2024}
}
@article{zeng2023braincog,
title={Braincog: A spiking neural network based, brain-inspired cognitive intelligence engine for brain-inspired ai and brain simulation},
author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},
journal={Patterns},
volume={4},
number={8},
year={2023},
publisher={Elsevier},
}
```
Enjoy!
================================================
FILE: examples/Structural_Development/DPAP/mask_model.py
================================================
import abc
from functools import partial
from torch.nn import functional as F
import torchvision
from timm.models import register_model
from braincog.base.node.node import *
from braincog.base.encoder.encoder import *
from braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule
from utils import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
convlayer = [-1,0, 1, 3, 4, 6, 7]
fclayer=[8,9]
imgsize = [32,32, 32, 16,16, 16, 8,8, 8]
size = [3,128, 128, 256, 256, 512,512]
size_pool = [3,128, 128,128, 256, 256,256, 512,512]
fcsize=[512*8*8,512]
class my_cifar_model(BaseModule):
def __init__(self,
num_classes=10,
step=8,
node_type=LIFNode,
encode_type='direct',
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.num_classes = num_classes
# self.node = node_type
# if issubclass(self.node, BaseNode):
# self.node = partial(self.node, **kwargs, step=step)
self.feature = nn.Sequential(
BaseConvModule(size[0], size[1], kernel_size=(3, 3), padding=(1, 1)),
BaseConvModule(size[1],size[2], kernel_size=(3, 3), padding=(1, 1)),
nn.MaxPool2d(2),
BaseConvModule(size[2], size[3], kernel_size=(3, 3), padding=(1, 1)),
BaseConvModule(size[3], size[4], kernel_size=(3, 3), padding=(1, 1)),
nn.MaxPool2d(2),
BaseConvModule(size[4], size[5], kernel_size=(3, 3), padding=(1, 1)),
BaseConvModule(size[5], size[6], kernel_size=(3, 3), padding=(1, 1)),
)
self.cfla=self._cflatten()
self.fc_prun = self._create_fc_prun()
self.fc = self._create_fc()
def _cflatten(self):
fc = nn.Sequential(
nn.Flatten(),
)
return fc
def _create_fc_prun(self):
fc = nn.Sequential(
BaseLinearModule(fcsize[0], fcsize[1])
)
return fc
def _create_fc(self):
fc = nn.Sequential(
BaseLinearModule(fcsize[1], self.num_classes)
)
return fc
def forward(self, inputs):
inputs = self.encoder(inputs)
self.reset()
if not self.training:
self.fire_rate.clear()
outputs = []
spikes=[]
for t in range(self.step):
spikest=[]
x = inputs[t]
if x.shape[-1] > 32:
x = F.interpolate(x, size=[64, 64])
spikest.append(x.detach())
for i in range(len(self.feature)):
spikei=self.feature[i](x)
x=spikei
spikest.append(spikei.detach())
x=self.cfla(x)
spikest.append(x.detach())
x=self.fc_prun(x)
spikest.append(x.detach())
x = self.fc(x)
spikes.append(spikest)
outputs.append(x)
return sum(outputs) / len(outputs),spikes
class Mask:
def __init__(self, model,batch,step):
self.model = model
self.fullbook={}
self.mat = {}
self.feature=model.feature
self.fc={}
self.fc[1]=model.fc_prun[0]
self.fc[2]=model.fc[0]
self.n_delta={}
self.ww_delta={}
self.reduce={}
self.reduceww={}
self.batch=batch
self.step=step
def init_length(self):
for i in range(1,len(convlayer)):
index=convlayer[i]
self.fullbook[index] =torch.ones((size[i],size[i-1],3,3),device=device)
self.n_delta[index]=torch.zeros(size[i],device=device)
self.reduce[index] = 10*torch.ones(size[i],device=device)
for i in range(1,len(fclayer)):
index=fclayer[i]
self.fullbook[index] = torch.ones((fcsize[i],fcsize[i-1]),device=device)
self.n_delta[index]=torch.zeros(fcsize[i],device=device)
self.ww_delta[index]=torch.zeros(fcsize[i]*fcsize[i-1],device=device)
self.reduce[index] = 10*torch.ones(fcsize[i],device=device)
self.reduceww[index] = 10*torch.ones(fcsize[i]*fcsize[i-1],device=device)
def get_filter_codebook(self,ww,dendrite,ii,index,epoch):
if ii == 4:
wconv= dendrite#.cpu().numpy()
self.n_delta[index]=(unit(wconv)*2-0.65)
pos=torch.nonzero(self.n_delta[index]>0)
self.n_delta[index][pos]=self.n_delta[index][pos]+5
print(wconv.mean(),wconv.max(), wconv.min())
self.reduce[index]=self.reduce[index]*0.999+self.n_delta[index]*math.exp(-int((epoch-5)/12))
filter_ind = torch.nonzero(self.reduce[index] <0)
print(self.reduce[index].mean(),self.reduce[index].max(),self.reduce[index].min(),len(filter_ind))
for x in range(0, len(filter_ind)):
self.fullbook[index][filter_ind[x]] = 0
if ii == 2:
length=ww.size()[0]*ww.size()[1]
book=torch.ones(length,device=device)
filter_ww = ww.view(-1)#.cpu().numpy()
self.ww_delta[index]=(unit(filter_ww)*2-1.5)
pos=torch.nonzero(self.ww_delta[index]>0)
self.ww_delta[index][pos]=self.ww_delta[index][pos]+2
self.reduceww[index]= self.reduceww[index]*0.999+self.ww_delta[index]*math.exp(-int((epoch-5)/13))
filter_indww =torch.nonzero(self.reduceww[index] < 0)
book[filter_indww]=0
book=book.reshape((ww.size()[0],-1))
self.fullbook[index]=self.fullbook[index]*book
print(self.reduceww[index].mean(),self.reduceww[index].max(),self.reduceww[index].min(),len(filter_indww))
wconv= dendrite#.cpu().numpy()
self.n_delta[index]=(unit(wconv)*2-1.5)
pos=torch.nonzero(self.n_delta[index]>0)
self.n_delta[index][pos]=self.n_delta[index][pos]+2
self.reduce[index]=self.reduce[index]*0.999+self.n_delta[index]*math.exp(-int((epoch-5)/13))
filter_ind = torch.nonzero(self.reduce[index] <0)
print(self.reduce[index].mean(),self.reduce[index].max(),self.reduce[index].min(),len(filter_ind))
for x in range(0, len(filter_ind)):
self.fullbook[index][filter_ind[x]] = 0
return self.fullbook[index]
def convert2tensor(self, x):
x = torch.FloatTensor(x)
return x
def init_mask(self, wwfc,convtra,epoch):
for i in range(1,len(convlayer)):
index=convlayer[i]
ww = wwfc[index]
dendrite=convtra[index]
self.mat[index]=self.get_filter_codebook(ww, dendrite,4,index,epoch)
#self.mat[index] = self.convert2tensor(self.mat[index]).cuda()
for i in range(1,len(fclayer)):
index=fclayer[i]
ww=wwfc[index]
dendrite=convtra[index]
self.mat[index]=self.get_filter_codebook(ww,dendrite,2,index,epoch)
#self.mat[index] = self.convert2tensor(self.mat[index]).cuda()
def do_mask(self):
for i in range(1,len(convlayer)):
index=convlayer[i]
ww = self.feature[index].conv.weight
maskww=ww*self.mat[index]
self.feature[index].conv.weight.data=maskww
for i in range(1,len(fclayer)):
ind=fclayer[i]
ww = self.fc[i].fc.weight
maskww=ww*self.mat[ind]
self.fc[i].fc.weight.data=maskww
def if_zero(self):
cc=[]
for i in range(1,len(convlayer)):
ww=self.feature[convlayer[i]].conv.weight
b = ww.data.view(-1).cpu().numpy()
print("number of weight is %d, zero is %.3f" %(len(b),100*(len(b)- np.count_nonzero(b))/len(b)))
cc.append(100*(len(b)- np.count_nonzero(b))/len(b))
for i in range(1,len(fcsize)):
ww=self.fc[i].fc.weight
b = ww.data.view(-1).cpu().numpy()
print("number of weight is %d, zero is %.3f" %(len(b),100*(len(b)- np.count_nonzero(b))/len(b)))
cc.append(100*(len(b)- np.count_nonzero(b))/len(b))
return cc
class Trace:
def __init__(self, model,batch,step):
self.model = model
self.feature=model.feature
self.ctrace={}
self.fctrace={}
self.csum={}
self.fcsum={}
self.delta = 0.5
self.batch=batch
self.step=step
def computing_trace(self,spikes):
for i in range(len(imgsize)):
index=i-1
self.ctrace[index]=torch.zeros((self.batch,size_pool[i],imgsize[i],imgsize[i]),device=device)
for i in range(len(fclayer)):
index=fclayer[i]
self.fctrace[index]=torch.zeros((self.batch,fcsize[i]),device=device)
for t in range(self.step):
for i in range(len(imgsize)):
index=i-1
sp=spikes[t][index+1].detach()
#print(sp.size(),self.ctrace[index].size())
self.ctrace[index]=self.delta*self.ctrace[index].cuda()+sp.cuda()
for i in range(len(fclayer)):
index=fclayer[i]
sp=spikes[t][index+1].detach()
self.fctrace[index]=self.delta*self.fctrace[index].cuda()+sp.cuda()
for i in range(len(imgsize)):
index=i-1
self.csum[index]=self.ctrace[index]/(self.step)
self.csum[index]=torch.sum(torch.sum(self.csum[index],dim=2),dim=2)
for i in range(len(fclayer)):
index=fclayer[i]
self.fcsum[index]=self.fctrace[index]/(self.step)
return self.csum,self.fcsum
================================================
FILE: examples/Structural_Development/DPAP/prun_main.py
================================================
import argparse
import time
import os
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
import sys
sys.path.append('..')
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as NativeDDP
import logging
from timm.utils import *
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from braincog.base.node.node import *
from braincog.base.encoder.encoder import *
from braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule
from braincog.base.utils.criterions import *
from braincog.datasets.datasets import *
from mask_model import *
from utils import *
_logger = logging.getLogger('train')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
exp_name = '-'.join([datetime.now().strftime("%Y%m%d-%H%M%S"),'c10'])
output_dir = get_outdir('./', 'train', exp_name)
setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))
_logger.info(exp_name)
config_parser = cfg = argparse.ArgumentParser(description='Training Config', add_help=False)
dataset='cifar10'
num_classes=10
step=8
encode='direct'
node_type='PLIFNode'
thresh=0.5
tau=2.0
torch.backends.cudnn.benchmark = True
devicee=0
seed=42
channels = 2
batch_size=50
epochs=300
lr=5e-3
linear_scaled_lr = lr * batch_size/ 1024.0
cfg.opt='adamw'
cfg.lr=linear_scaled_lr
cfg.weight_decay=0.01
cfg.momentum=0.9
cfg.epochs=epochs
cfg.sched='cosine'
cfg.min_lr=1e-5
cfg.warmup_lr=1e-6
cfg.warmup_epochs=5
cfg.cooldown_epochs=10
cfg.decay_rate=0.1
eval_metric='top1'
best_test = 0
best_testepoch = 0
best_testprun = 0
best_testepochprun = 0
epoch_prune = 1
rate_decay_epoch=30
NUM = 0
torch.cuda.set_device('cuda:%d' % devicee)
torch.manual_seed(seed)
model = my_cifar_model(step=step,encode_type=encode,node_type=node_type,num_classes=num_classes)
model = model.cuda()
print(model)
optimizer = create_optimizer(cfg, model)
lr_scheduler, num_epochs = create_scheduler(cfg, optimizer)
loader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % dataset)(batch_size=batch_size, step=step)
train_loss_fn = UnilateralMse(1.)
validate_loss_fn = UnilateralMse(1.)
m = Mask(model,batch_size,step)
m.init_length()
trace=Trace(model,batch_size,step)
neuron_th,spines,bcm,epoch_trace = init(batch_size,convlayer,fclayer,size,fcsize)
def BCM_and_trace(NUM,trace,spikes,neuron_th,bcm,epoch_trace):
NUM = NUM + 1
csum,fcsum= trace.computing_trace(spikes)
for i in range(1,len(convlayer)):
index=convlayer[i]
post1 = (csum[index] * (csum[index] - neuron_th[index]))
hebb1 = torch.mm(post1.T, csum[index-1])
bcm[index] = bcm[index] + hebb1
neuron_th[index] = torch.div(neuron_th[index] * (NUM - 1) + csum[index], NUM)
cs=torch.sum(csum[index],dim=0)
epoch_trace[index] = epoch_trace[index] + cs
for i in range(1,len(fclayer)):
index = fclayer[i]
post1 = (fcsum[index] * (fcsum[index] - neuron_th[index]))
hebb1 = torch.mm(post1.T, fcsum[fclayer[i - 1]])
bcm[index] = bcm[index] + hebb1
neuron_th[index] = torch.div(neuron_th[index] * (NUM - 1) + fcsum[index], NUM)
cs=torch.sum(fcsum[index],dim=0)
epoch_trace[index] = epoch_trace[index] + cs
return epoch_trace,bcm,NUM
def train_epoch(
epoch, model, loader, optimizer, loss_fn,trace,NUM,bcm,neuron_th,epoch_trace,
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None):
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
top1_m = AverageMeter()
model.train()
end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
for batch_idx, (inputs, target) in enumerate(loader):
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()
output,spikes = model(inputs)
epoch_trace,bcm,NUM = BCM_and_trace(NUM,trace,spikes,neuron_th,bcm,epoch_trace)
loss = loss_fn(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses_m.update(loss.item(), inputs.size(0))
top1_m.update(acc1.item(), inputs.size(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_time_m.update(time.time() - end)
if last_batch or batch_idx %100 == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
print("Train: epoch:",epoch,batch_idx,"/",len(loader),"loss:",losses_m.avg,"acc1:", top1_m.avg,"lr:",lr)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
end = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.avg)]),epoch_trace,bcm,NUM
def validate(model, loader, loss_fn, amp_autocast=suppress, log_suffix=''):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.eval()
end = time.time()
last_idx = len(loader) - 1
with torch.no_grad():
for batch_idx, (inputs, target) in enumerate(loader):
last_batch = batch_idx == last_idx
inputs = inputs.type(torch.FloatTensor).cuda()
target = target.cuda()
output,spikes = model(inputs)
if isinstance(output, (tuple, list)):
output = output[0]
loss = loss_fn(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), inputs.size(0))
top1_m.update(acc1.item(), output.size(0))
if last_batch or batch_idx %100 == 0:
print("Test: loss:",losses_m.avg,"acc1:", top1_m.avg)
batch_time_m.update(time.time() - end)
end = time.time()
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
return metrics
for epoch in range(epochs):
train_metrics, epoch_trace, bcm, N = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn,trace,NUM,bcm,neuron_th,epoch_trace,
lr_scheduler=lr_scheduler)
NUM = N
for i in range(1,len(convlayer)):
index=convlayer[i]
bcmconv = torch.sum(bcm[index], dim=1)
bcmconv=unit_tensor(bcmconv)
traconv=unit_tensor(epoch_trace[index])
spines[index]=bcmconv*traconv
for i in range(1, len(fclayer)):
index=fclayer[i]
bcmfc = torch.sum(bcm[index], dim=1)
bcmfc=unit_tensor(bcmfc)
trafc=unit_tensor(epoch_trace[index])
spines[index]=bcmfc*trafc
if epoch>4:
m.model = model
m.init_mask(bcm,spines,epoch)
m.do_mask()
print("Done pruning")
cc=m.if_zero()
model = m.model
eval_metrics = validate(model, loader_eval, validate_loss_fn)
top1=eval_metrics['top1']
if top1 > best_testprun:
best_testprun = top1
best_testepochprun =epoch
if epoch%40==0:
print('best acc:',best_testprun,'best epoch:',best_testepochprun)
if epoch>4:
_logger.info('*** epoch: {0} (pruning rate {1},acc:{2})'.format(epoch, cc,top1))
if lr_scheduler is not None:
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
================================================
FILE: examples/Structural_Development/DPAP/utils.py
================================================
import torch
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def print_log(print_string, log):
#print("{}".format(print_string))
log.write('{}\n'.format(print_string))
log.flush()
def unit(x):
if x.size()[0]>0:
xnp=x.cpu().numpy()
maxx=np.percentile(xnp, 75)
minx=torch.min(x)
marge=maxx-minx
if marge!=0:
xx=(x-minx)/marge
xx=torch.clip(xx, 0,1)
else:
xx=0.5*torch.ones_like(x)
return xx
else:
return x
def unit_tensor(x):
if x.size()[0]>0:
maxx=torch.max(x)
minx=torch.min(x)
marge=maxx-minx
if marge!=0:
xx=(x-minx)/marge
else:
xx=0.5*torch.ones_like(x)
return xx
else:
return x
def init(batch,convlayer,fclayer,size,fcsize):
neuron_th={}
convtra = {}
bcm={}
epoch_trace = {}
for i in range(1,len(convlayer)):
index=convlayer[i]
neuron_th[index]=torch.zeros((batch,size[i]),device=device)
convtra[index] = torch.zeros(size[i],device=device)
bcm[index]=torch.zeros(size[i],size[i-1],device=device)
epoch_trace[index] = torch.zeros((size[i]),device=device)
for i in range(1,len(fclayer)):
index=fclayer[i]
neuron_th[index]=torch.zeros((batch,fcsize[i]),device=device)
convtra[index]=torch.zeros(fcsize[i],device=device)
bcm[index]=torch.zeros(fcsize[i],fcsize[i-1],device=device)
epoch_trace[index] = torch.zeros(fcsize[i],device=device)
return neuron_th,convtra,bcm,epoch_trace
================================================
FILE: examples/Structural_Development/DSD-SNN/README.md
================================================
# Enhancing Efficient Continual Learning with Dynamic Structure Development of Spiking Neural Networks #
## Requirments ##
* numpy
* timm
* pytorch >= 1.7.0
* collections
* argparse
## Introduction ##
Dynamic Structure Development of Spiking Neural Networks (DSD-SNN) for efficient and adaptive continual learning:
grow new neurons and prune redundant neurons, increasing memory capacity and reducing computational overhead.
verlap shared structure to leverage acquired knowledge to new tasks, empowering a single network to support multiple incremental tasks.
We validate the effectiveness of the DSD-SNN multiple TIL and CIL benchmarks.
## Run ##
```CUDA_VISIBLE_DEVICES=0 python main_simplified.py```
## Citation ##
If you find the code and dataset useful in your research, please consider citing:
```
@article{han2022developmental,
title={Enhancing Efficient Continual Learning with Dynamic Structure Development of Spiking Neural Networks},
author={Han, Bing and Zhao, Feifei and Zeng, Yi and Wenxuan, Pan and Shen, Guobin},
booktitle = {Proceedings of the Thirty-First International Joint Conference on
Artificial Intelligence, {IJCAI-23}},
publisher = {International Joint Conferences on Artificial Intelligence Organization},
year={2023}
}
@article{zeng2023braincog,
title={Braincog: A spiking neural network based, brain-inspired cognitive intelligence engine for brain-inspired ai and brain simulation},
author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},
journal={Patterns},
volume={4},
number={8},
year={2023},
publisher={Elsevier},
}
```
Enjoy!
================================================
FILE: examples/Structural_Development/DSD-SNN/cifar100/available.py
================================================
from torchvision import datasets, transforms
from manipulate import UnNormalize
# specify available data-sets.
AVAILABLE_DATASETS = {
'MNIST': datasets.MNIST,
'CIFAR100': datasets.CIFAR100,
'CIFAR10': datasets.CIFAR10,
}
# specify available transforms.
AVAILABLE_TRANSFORMS = {
'MNIST': [
transforms.ToTensor(),
],
'MNIST32': [
transforms.Pad(2),
transforms.ToTensor(),
],
'CIFAR10': [
transforms.ToTensor(),
],
'CIFAR100': [
transforms.ToTensor(),
],
'CIFAR10_norm': [
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
],
'CIFAR100_norm': [
transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761])
],
'CIFAR10_denorm': UnNormalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]),
'CIFAR100_denorm': UnNormalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761]),
'augment_from_tensor': [
transforms.ToPILImage(),
transforms.RandomCrop(32, padding=4, padding_mode='symmetric'),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
],
'augment': [
transforms.RandomCrop(32, padding=4, padding_mode='symmetric'),
transforms.RandomHorizontalFlip(),
],
}
# specify configurations of available data-sets.
DATASET_CONFIGS = {
'MNIST': {'size': 28, 'channels': 1, 'classes': 10},
'MNIST32': {'size': 32, 'channels': 1, 'classes': 10},
'CIFAR10': {'size': 32, 'channels': 3, 'classes': 10},
'CIFAR100': {'size': 32, 'channels': 3, 'classes': 100},
}
================================================
FILE: examples/Structural_Development/DSD-SNN/cifar100/main_simplified.py
================================================
import argparse
import time
import timm.models
import yaml
import os
import logging
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from braincog.base.node.node import *
from braincog.utils import *
from braincog.base.utils.criterions import *
from braincog.datasets.datasets import *
from braincog.model_zoo.resnet import *
from braincog.model_zoo.convnet import *
from braincog.utils import save_feature_map
import torch
import torch.nn as nn
import torchvision.utils
from torchvision import transforms
from timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
from maskcl2 import *
# from ptflops import get_model_complexity_info
from thop import profile, clever_format
from manipulate import SubDataset
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
from available import AVAILABLE_DATASETS, AVAILABLE_TRANSFORMS, DATASET_CONFIGS
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import ConcatDataset
import copy
from vgg_snn import SNN
# torch.cuda.set_device(9)
# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser = argparse.ArgumentParser(description='SNN Training and Evaluating')
# Model parameters
parser.add_argument('--dataset', default='cifar100', type=str)
parser.add_argument('--model', default='cifar_convnet', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--num-classes', type=int, default=100, metavar='N',
help='number of label classes (default: 100)')
parser.add_argument('--task_num', type=int, default=10, metavar='N',
help='number of label classes (default: 10)')
# Dataloader parameters
parser.add_argument('-b', '--batch-size', type=int, default=50, metavar='N',
help='inputs batch size for training (default: 128)')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.01,
help='weight decay (default: 0.01 for adamw)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=1e-2, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit')
parser.add_argument('--warmup-lr', type=float, default=1e-4, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--epochs', type=int, default=600, metavar='N',
help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--power', type=int, default=1, help='power')
# Misc
parser.add_argument('--seed', type=int, default=0, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=25, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--output', default='/home/hanbing/brain/bp2/', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
# Spike parameters
parser.add_argument('--step', type=int, default=4, help='Simulation time step (default: 10)')
parser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')
# neuron type
parser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')
parser.add_argument('--act-fun', type=str, default='QGateGrad',
help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')
parser.add_argument('--thresh', type=float, default=.5, help='Firing threshold (default: 0.5)')
parser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')
parser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')
parser.add_argument('--noisy-grad', type=float, default=0.,
help='Add noise to backward, sometime will make higher accuracy (default: 0.)')
parser.add_argument('--n_warm_up', type=int, default=0,
help='Warm up epoch, replace all node to ReLU to warm up weights in network before (default: 0)')
parser.add_argument('--spike-output', action='store_true', default=False,
help='Using mem output or spike output (default: False)')
# EventData Augmentation
parser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')
parser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')
parser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')
parser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')
parser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')
parser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')
parser.add_argument('--cutmix_noise', type=float, default=0.,
help='Add Pepper noise after mix, sometimes work (default: 0.)')
parser.add_argument('--rand-aug', action='store_true',
help='Rand Augment for Event data (default: False)')
parser.add_argument('--randaug_n', type=int, default=3,
help='Rand Augment times n (default: 3)')
parser.add_argument('--randaug_m', type=int, default=15,
help='Rand Augment times n (default: 15) (0-30)')
parser.add_argument('--temporal-flatten', action='store_true',
help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')
parser.add_argument('--train-portion', type=float, default=0.9,
help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')
parser.add_argument('--event-size', default=48, type=int,
help='Event size. Resize event data before process (default: 48)')
parser.add_argument('--layer-by-layer', action='store_true',
help='forward step-by-step or layer-by-layer. '
'Larger Model with layer-by-layer will be faster (default: False)')
parser.add_argument('--node-resume', type=str, default='',
help='resume weights in node for adaptive node. (default: False)')
parser.add_argument('--node-trainable', action='store_true')
# visualize
parser.add_argument('--visualize', action='store_true',
help='Visualize spiking map for each layer, only for validate (default: False)')
parser.add_argument('--spike-rate', action='store_true',
help='Print spiking rate for each layer, only for validate(default: False)')
parser.add_argument('--suffix', type=str, default='',
help='Add an additional suffix to the save path (default: \'\')')
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
if args_config.config:
with open(args_config.config, 'r') as f:
cfg = yaml.safe_load(f)
parser.set_defaults(**cfg)
# The main arg parser parses the rest of the args, the usual
# defaults will have been overridden if config file specified.
args = parser.parse_args(remaining)
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
def get_dataset(name, type='train', download=True, capacity=None, permutation=None, dir='./store/datasets',
verbose=False, augment=False, normalize=False, target_transform=None):
'''Create [train|valid|test]-dataset.'''
data_name = 'MNIST' if name in ('MNIST28', 'MNIST32') else name
dataset_class = AVAILABLE_DATASETS[data_name]
# specify image-transformations to be applied
transforms_list = [*AVAILABLE_TRANSFORMS['augment']] if augment else []
transforms_list += [*AVAILABLE_TRANSFORMS[name]]
if normalize:
transforms_list += [*AVAILABLE_TRANSFORMS[name+"_norm"]]
# if permutation is not None:
# transforms_list.append(transforms.Lambda(lambda x, p=permutation: permutate_image_pixels(x, p)))
dataset_transform = transforms.Compose(transforms_list)
# load data-set
dataset = dataset_class('{dir}/{name}'.format(dir=dir, name=data_name), train=False if type=='test' else True,
download=download, transform=dataset_transform, target_transform=target_transform)
# print information about dataset on the screen
if verbose:
print(" --> {}: '{}'-dataset consisting of {} samples".format(name, type, len(dataset)))
# if dataset is (possibly) not large enough, create copies until it is.
if capacity is not None and len(dataset) < capacity:
dataset = ConcatDataset([copy.deepcopy(dataset) for _ in range(int(np.ceil(capacity / len(dataset))))])
return dataset
def main():
args, args_text = _parse_args()
# args.no_spike_output = args.no_spike_output | args.cut_mix
args.no_spike_output = True
output_dir = ''
output_base = args.output if args.output else './output'
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
'SNN',
args.dataset,
str(args.seed),
'gwu',
'xin-epochs'
])
output_dir = get_outdir(output_base, 'train', exp_name)
print(output_dir)
args.output_dir = output_dir
setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))
torch.cuda.set_device('cuda:%d' % args.device)
torch.manual_seed(args.seed)
model = SNN(
num_classes=args.num_classes,
dataset=args.dataset,
step=args.step,
encode_type=args.encode,
node_type=eval(args.node_type),
threshold=args.thresh,
tau=args.tau,
spike_output=not args.no_spike_output,
act_fun=args.act_fun,
temporal_flatten=args.temporal_flatten,
layer_by_layer=args.layer_by_layer,
batch_size=args.batch_size,
task_num=args.task_num
)
print(model)
# for n,p in enumerate(model.parameters()):
# print(n,p.size())
if 'dvs' in args.dataset:
args.channels = 2
elif 'mnist' in args.dataset:
args.channels = 1
else:
args.channels = 3
# flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.img_size, args.img_size),), verbose=False)
# _logger.info('flops = %fM', flops / 1e6)
# _logger.info('param size = %fM', params / 1e6)
linear_scaled_lr = args.lr * args.batch_size / 1024.0
args.lr = linear_scaled_lr
model = model.cuda()
optimizer = create_optimizer(args, model)
# optionally resume from a checkpoint
resume_epoch = None
if args.resume:
# checkpoint = torch.load(args.resume, map_location='cpu')
# model.load_state_dict(checkpoint['state_dict'], False)
resume_epoch = resume_checkpoint(
model, args.resume,
optimizer=None if args.no_resume_opt else optimizer)
if args.node_resume:
ckpt = torch.load(args.node_resume, map_location='cpu')
model.load_node_weight(ckpt, args.node_trainable)
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
m = Mask(model)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch)
_logger.info('Scheduled epochs: {}'.format(num_epochs))
batch_size=args.batch_size
data_dir = '/data0/datasets/'
trainset = get_dataset('CIFAR100', type="train", dir=data_dir)
testset = get_dataset('CIFAR100', type="test", dir=data_dir)
out_num=int(args.num_classes/args.task_num)
labels_per_dataset_train = [list(np.array(range(out_num))+out_num*context_id) for context_id in range(args.task_num)]
labels_per_dataset_test = [list(np.array(range(out_num))+out_num*context_id) for context_id in range(args.task_num)]
train_datasets = []
for labels in labels_per_dataset_train:
target_transform = transforms.Lambda(lambda y, x=labels[0]: y-x)
train_datasets.append(SubDataset(trainset, labels, target_transform=target_transform))
test_datasets = []
for labels in labels_per_dataset_test:
target_transform = transforms.Lambda(lambda y, x=labels[0]: y-x)
test_datasets.append(SubDataset(testset, labels, target_transform=target_transform))
train_data = []
test_data = []
t_data=[]
for task in range(len(train_datasets)):
train_data.append(DataLoader(train_datasets[task], batch_size=batch_size, shuffle=True, drop_last=True, **({'num_workers': 4, 'pin_memory': True})))
test_data.append(DataLoader(test_datasets[task], batch_size=batch_size, shuffle=True, drop_last=True, **({'num_workers': 4, 'pin_memory': True})))
if args.loss_fn == 'mse':
train_loss_fn = UnilateralMse(1.)
validate_loss_fn = UnilateralMse(1.)
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
if args.loss_fn == 'mix':
train_loss_fn = MixLoss(train_loss_fn)
validate_loss_fn = MixLoss(validate_loss_fn)
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
saver = CheckpointSaver(
model=model, optimizer=optimizer, args=args,
checkpoint_dir=output_dir, recovery_dir=output_dir)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
loader_his=[]
task_ready={}
for index, item in enumerate(model.parameters()):
if len(item.size()) > 1 and index<=40:
task_ready[index]=torch.zeros(item.size(),device=device)
try: # train the model
task_count=0
regularization_terms= {}
for task in range(len(train_datasets)):
print("Task:",task)
if task==0:
m.model = model
mat=m.init_length()
model = m.model
epochs=50
else:
m.model = model
mat,task_ready,taskmaskk,taskww=m.init_grow(task)
model = m.model
epochs=30
ta_his=[i for i in range(task+1)]
for epoch in range(epochs):
loader_train = iter(train_data[task])
if task==0:
train_epoch(epoch, task, model, loader_train, optimizer, train_loss_fn, args,mat,task_ready,taskww=None,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,regularization_terms=regularization_terms)
else:
train_epoch(epoch, task, model, loader_train, optimizer, train_loss_fn, args,mat,task_ready,taskww=taskww,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,regularization_terms=regularization_terms)
print(epoch)
if epoch>0:
m.model = model
m.init_mask(task,epoch)
mat=m.do_mask(task)
model = m.model
ta_his=[i for i in range(task+1)]
for t in ta_his:
loader_his=iter(test_data[t])
validate(t, model, loader_his, validate_loss_fn, args,mat)
cc=m.if_zero()
_logger.info('*** epoch: {0}, task: {1}, pruning: {2}'.format(epoch,task, cc))
p_index=m.record()
except KeyboardInterrupt:
pass
# if best_metric is not None:
# _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def train_epoch(
epoch, task,model, loader, optimizer, loss_fn, args,mat,task_ready,taskww=None,
lr_scheduler=None, saver=None, output_dir='',regularization_terms={}):
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.train()
end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
for batch_idx, (inputs, target) in enumerate(loader):
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()
output = model(inputs, mat)
t_preds = output[task].cuda()
loss = loss_fn(t_preds, target)
# if len(regularization_terms)>0:
# reg_loss = 0
# for i,reg_term in regularization_terms.items():
# task_reg_loss = 0
# importance = reg_term['importance']
# task_param = reg_term['task_param']
# for n, p in enumerate(model.parameters()):
# if len(p.size())>=1:
# task_reg_loss += (importance[n] * (p - task_param[n]) ** 2).sum()
# reg_loss += task_reg_loss
# loss += 10000 * reg_loss
acc1, acc5 = accuracy(t_preds, target, topk=(1, 5))
losses_m.update(loss.item(), inputs.size(0))
top1_m.update(acc1.item(), inputs.size(0))
top5_m.update(acc5.item(), inputs.size(0))
optimizer.zero_grad()
loss.backward()
# for index, item in enumerate(model.parameters()):
# if len(item.size()) > 1 and index<=40:
# gradmask=torch.where(task_ready[index]>0,0.0,1.0)
# item.grad=item.grad*gradmask
optimizer.step()
for index, item in enumerate(model.parameters()):
if len(item.size()) > 1 and index<=40:
if index<40:
ready=task_ready[index].view(task_ready[index].size()[0],-1)
ready=torch.sum(ready,dim=1)
else:
ready=torch.sum(task_ready[index],dim=1)
windex=torch.nonzero(ready>0)
for i in range(len(windex)):
item.data[windex[i]]=taskww[index][windex[i]]
num_updates += 1
batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
# lrl = [param_group['lr'] for param_group in optimizer.param_groups]
# lr = sum(lrl) / len(lrl)
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
top1=top1_m, top5=top5_m,
batch_time=batch_time_m,
rate=inputs.size(0) / batch_time_m.val,
rate_avg=inputs.size(0) / batch_time_m.avg,
data_time=data_time_m))
# if saver is not None and args.recovery_interval and (
# last_batch or (batch_idx + 1) % args.recovery_interval == 0):
# saver.save_recovery(epoch, batch_idx=batch_idx)
# if lr_scheduler is not None:
# lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
# end = time.time()
# # end for
# if hasattr(optimizer, 'sync_lookahead'):
# optimizer.sync_lookahead()
# return OrderedDict([('loss', losses_m.avg)])
def validate(task, model, loader, loss_fn, args, mat,log_suffix='', visualize=False, spike_rate=False):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.eval()
end = time.time()
with torch.no_grad():
last_idx = len(loader) - 1
for batch_idx, (inputs, target) in enumerate(loader):
last_batch = batch_idx == last_idx
inputs = inputs.type(torch.FloatTensor).cuda()
target = target.cuda()
output = model(inputs,mat)
t_preds = output[task]
loss = loss_fn(t_preds, target)
acc1, acc5 = accuracy(t_preds, target, topk=(1, 5))
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), inputs.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
batch_time_m.update(time.time() - end)
end = time.time()
log_name = 'Test'+str(task) + log_suffix
_logger.info(
'{0}: [{1:>4d}/{2}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
log_name, batch_idx, last_idx, batch_time=batch_time_m,
loss=losses_m, top1=top1_m, top5=top5_m))
# metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
# return metrics
if __name__ == '__main__':
main()
================================================
FILE: examples/Structural_Development/DSD-SNN/cifar100/manipulate.py
================================================
import torch
from torch.utils.data import Dataset
def permutate_image_pixels(image, permutation):
'''Permutate the pixels of an image according to [permutation].
[image] 3D-tensor containing the image
[permutation] of pixel-indeces in their new order'''
if permutation is None:
return image
else:
c, h, w = image.size()
image = image.view(c, -1)
image = image[:, permutation] #--> same permutation for each channel
image = image.view(c, h, w)
return image
#----------------------------------------------------------------------------------------------------------#
class SubDataset(Dataset):
'''To sub-sample a dataset, taking only those samples with label in [sub_labels].
After this selection of samples has been made, it is possible to transform the target-labels,
which can be useful when doing continual learning with fixed number of output units.'''
def __init__(self, original_dataset, sub_labels, target_transform=None):
super().__init__()
self.dataset = original_dataset
self.sub_indeces = []
for index in range(len(self.dataset)):
if hasattr(original_dataset, "train_labels"):
if self.dataset.target_transform is None:
label = self.dataset.train_labels[index]
else:
label = self.dataset.target_transform(self.dataset.train_labels[index])
elif hasattr(self.dataset, "test_labels"):
if self.dataset.target_transform is None:
label = self.dataset.test_labels[index]
else:
label = self.dataset.target_transform(self.dataset.test_labels[index])
else:
label = self.dataset[index][1]
if label in sub_labels:
self.sub_indeces.append(index)
self.target_transform = target_transform
def __len__(self):
return len(self.sub_indeces)
def __getitem__(self, index):
sample = self.dataset[self.sub_indeces[index]]
if self.target_transform:
target = self.target_transform(sample[1])
sample = (sample[0], target)
return sample
class MemorySetDataset(Dataset):
'''Create dataset from list of with shape (N, C, H, W) (i.e., with N images each).
The images at the i-th entry of [memory_sets] belong to class [i], unless a [target_transform] is specified'''
def __init__(self, memory_sets, target_transform=None):
super().__init__()
self.memory_sets = memory_sets
self.target_transform = target_transform
def __len__(self):
total = 0
for class_id in range(len(self.memory_sets)):
total += len(self.memory_sets[class_id])
return total
def __getitem__(self, index):
total = 0
for class_id in range(len(self.memory_sets)):
examples_in_this_class = len(self.memory_sets[class_id])
if index < (total + examples_in_this_class):
class_id_to_return = class_id if self.target_transform is None else self.target_transform(class_id)
example_id = index - total
break
else:
total += examples_in_this_class
image = torch.from_numpy(self.memory_sets[class_id][example_id])
return (image, class_id_to_return)
class TransformedDataset(Dataset):
'''To modify an existing dataset with a transform.
This is useful for creating different permutations of MNIST without loading the data multiple times.'''
def __init__(self, original_dataset, transform=None, target_transform=None):
super().__init__()
self.dataset = original_dataset
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
(input, target) = self.dataset[index]
if self.transform:
input = self.transform(input)
if self.target_transform:
target = self.target_transform(target)
return (input, target)
# ----------------------------------------------------------------------------------------------------------#
class UnNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
"""Denormalize image, either single image (C,H,W) or image batch (N,C,H,W)"""
batch = (len(tensor.size()) == 4)
for t, m, s in zip(tensor.permute(1, 0, 2, 3) if batch else tensor, self.mean, self.std):
t.mul_(s).add_(m)
# The normalize code -> t.sub_(m).div_(s)
return tensor
================================================
FILE: examples/Structural_Development/DSD-SNN/cifar100/maskcl2.py
================================================
import numpy as np
import torch
import math
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import random
def unit(x):
if x.size()[0]>0:
xnp=x.cpu().numpy()
maxx=torch.max(x)
#maxx=np.percentile(xnp, 99.5)
minx=torch.min(x)
marge=maxx-minx
if marge!=0:
xx=(x-minx)/marge
xx=torch.clip(xx, 0,1)
else:
xx=0.5*torch.ones_like(x)
return xx
else:
return x
class Mask:
def __init__(self, model):
self.model = model
self.mat = {}
self.p_index={}
self.p_num={}
self.k=15
self.task_ready={}
self.taskmask={}
self.init_rate=0.3
self.grow_rate=0.125
self.prunconv_init=0.8
self.prunfc_init=1.3
self.prunconv_grow=0.5
self.prunfc_grow=1
self.n_delta={}
self.reduce={}
self.taskww={}
def init_length(self):
for index, item in enumerate(self.model.parameters()):
if len(item.size()) > 1:
print(index,item.size())
self.mat[index]=torch.ones(item.size(),device=device)
for index, item in enumerate(self.model.parameters()):
if len(item.size()) > 1:
if index<=40:
self.p_index[index]=torch.tensor([])
self.task_ready[index]=torch.zeros(item.size(),device=device)
self.reduce[index] = 5*torch.ones(item.size()[0],device=device)
if len(item.size()) == 4:
self.p_num[index]=torch.zeros(item.size()[0],device=device)
self.mat[index][int(self.init_rate*item.size()[0]):]=0.0
if index+5<40:
self.mat[index+5][:,int(self.init_rate*item.size()[0]):]=0.0
if index+5==40:
self.mat[index+5][:,int(self.init_rate*item.size()[0])*16:]=0.0
if len(item.size()) == 2:
self.p_num[index]=torch.zeros(item.size()[0]*item.size()[1],device=device)
self.mat[index][int(self.init_rate*item.size()[0]):]=0.0
if index>40:
self.mat[index]=torch.ones(item.size(),device=device)
if index==44:
self.mat[index]=torch.zeros(item.size(),device=device)
self.mat[index][:,:int(self.init_rate*item.size()[1])]=1.0
return self.mat
def get_filter_codebook(self,index,ww,task,epoch):
if task==0:
pruncon=self.prunconv_init
prunfc=self.prunfc_init
else:
pruncon=self.prunconv_grow
prunfc=self.prunfc_grow
if len(ww.size()) == 4:
p_ww=ww.view(ww.size()[0],-1)
p_ww=torch.sum(p_ww,dim=1)
task_use=self.mat[index]-torch.sign(self.task_ready[index])
nouse=torch.sum(task_use.view(ww.size()[0],-1),dim=1)
no=torch.nonzero(nouse<0.1)
p_ww[no]=p_ww.max()
self.n_delta[index]=(unit(p_ww)*2-pruncon)
pos=torch.nonzero(self.n_delta[index]>0)
self.n_delta[index][pos]=self.n_delta[index][pos]+3
self.reduce[index]=self.reduce[index]*0.999+self.n_delta[index]*math.exp(-int((epoch-1)/13))
p_ind = torch.nonzero(self.reduce[index] <0)
print(self.reduce[index].mean(),self.reduce[index].max(),self.reduce[index].min(),len(p_ind))
for x in range(0, len(p_ind)):
self.mat[index][p_ind[x]] = 0
if index+5<40:
self.mat[index+5][:,p_ind[x]]=0
if index+5==40:
self.mat[index+5][:,p_ind[x]*16:(p_ind[x]+1)*16]=0
self.mat[index]=torch.sign(self.mat[index]+ self.task_ready[index])
if len(ww.size()) == 2:
p_ww=torch.sum(ww,dim=1)
task_use=self.mat[index]-torch.sign(self.task_ready[index])
nouse=torch.sum(task_use,dim=1)
no=torch.nonzero(nouse<0.1)
p_ww[no]=p_ww.max()
self.n_delta[index]=(unit(p_ww)*2-prunfc)
print(self.n_delta[index].mean(),self.n_delta[index].max(),self.n_delta[index].min())
pos=torch.nonzero(self.n_delta[index]>0)
self.n_delta[index][pos]=self.n_delta[index][pos]+3
self.reduce[index]=self.reduce[index]*0.999+self.n_delta[index]*math.exp(-int((epoch-1)/13))
p_ind = torch.nonzero(self.reduce[index] <0)
print(self.reduce[index].mean(),self.reduce[index].max(),self.reduce[index].min(),len(p_ind))
index_ta=44+task*2
for x in range(0, len(p_ind)):
self.mat[index][p_ind[x]] = 0
self.mat[index_ta][:,p_ind[x]]=0
self.mat[index]=torch.sign(self.mat[index]+self.task_ready[index])
def convert2tensor(self, x):
x = torch.FloatTensor(x)
return x
def init_mask(self,task,epoch):
for index, item in enumerate(self.model.parameters()):
if len(item.size()) > 1 and index<=40:
self.get_filter_codebook(index,abs(item.data),task,epoch)
def do_mask(self,task):
for index, item in enumerate(self.model.parameters()):
if len(item.size()) > 1 and index<=40:
ww=item.data
item.data=ww*self.mat[index].cuda()
return self.mat
def init_grow(self,task):
self.taskmask[task]={}
index_ta=44+task*2
self.mat[index_ta]=torch.zeros(self.mat[index_ta].size(),device=device)
for index, item in enumerate(self.model.parameters()):
if len(item.size()) > 1 and index<=40:
self.task_ready[index]=self.task_ready[index]+self.mat[index]
self.taskmask[task][index]=self.mat[index]
self.taskww[index]=item.data.clone()
ind_all=[x for x in range(item.size()[0])]
pp=list(np.array(self.p_index[index]))
ind_empty=set(ind_all)-set(pp)
ind_empty=list(ind_empty)
# random.shuffle(ind_empty)
ind_grow=ind_empty[:int(item.size()[0]*self.grow_rate)]
ind_grow=torch.tensor(ind_grow)
ww_g=torch.empty(item.size(),device=device)
if index<40:
torch.nn.init.kaiming_uniform_(ww_g, a=math.sqrt(5))
if index==40:
kk=1/math.sqrt(item.size()[1])
torch.nn.init.uniform_(ww_g,a=-kk,b=kk)
for x in range(0, len(ind_grow)):
self.mat[index][ind_grow[x]] = 1.0
item.data[ind_grow[x]]=ww_g[ind_grow[x]]
if index==40:
self.mat[index_ta][:,ind_grow[x]]=1.0
self.mat[index]=torch.sign(self.mat[index]+self.task_ready[index])
self.p_num[index]=torch.zeros(item.size()[0],device=device)
self.reduce[index] = 5*torch.ones(item.size()[0],device=device)
#self.mat[index_ta]=self.mat[index_ta]+self.mat[index_ta-2]
# nn=torch.sum(self.task_ready[24],dim=1)
# use_nn=torch.nonzero(nn>1)
# for x in range(0, len(use_nn)):
# self.mat[index_ta][:,use_nn[x]] = 1.0
for index, item in enumerate(self.model.parameters()):
if len(item.size()) > 1 and index<40:
nsum=self.mat[index].view(item.size()[0],-1)
nn=torch.sum(abs(nsum),dim=1)
empy_nn=torch.nonzero(nn<0.00001)
for x in range(0, len(empy_nn)):
if index+5<40:
self.mat[index+5][:,empy_nn[x]]=0
if index+5==40:
self.mat[index+5][:,empy_nn[x]*16:(empy_nn[x]+1)*16]=0
return self.mat,self.task_ready,self.taskmask,self.taskww
def if_zero(self):
cc=[]
for index, item in enumerate(self.model.parameters()):
if len(item.size()) > 1 and index<=40:
b = item.data.view(-1).cpu().numpy()
print("number of weight is %d, zero is %.3f" %(len(b),100*(len(b)- np.count_nonzero(b))/len(b)))
cc.append(100*(len(b)- np.count_nonzero(b))/len(b))
if index==40:
nouse=torch.sum(self.mat[index],dim=1)
no=torch.nonzero(nouse<0.1)
print(len(no))
cc.append(len(no))
return cc
def record(self):
for index, item in enumerate(self.model.parameters()):
if len(item.size()) > 1 and index<=40:
nsum=self.mat[index].view(item.size()[0],-1)
nn=torch.sum(abs(nsum),dim=1)
epoch_select=torch.nonzero(nn>0.00001)
select=set(epoch_select)|set(self.p_index[index])
self.p_index[index]= torch.tensor(list(select))
return self.p_index
================================================
FILE: examples/Structural_Development/DSD-SNN/cifar100/vgg_snn.py
================================================
# encoding: utf-8
# Author : Floyed
# Datetime : 2022/7/26 18:56
# User : Floyed
# Product : PyCharm
# Project : BrainCog
# File : vgg_snn.py
# explain :
from functools import partial
from torch.nn import functional as F
import torchvision
from timm.models import register_model
from braincog.datasets import is_dvs_data
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.encoder.encoder import *
from braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule
@register_model
class SNN(BaseModule):
def __init__(self,
num_classes=100,
step=8,
node_type=LIFNode,
encode_type='direct',
batch_size=100,
task_num=10,
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.n_preact = kwargs['n_preact'] if 'n_preact' in kwargs else False
self.batch_size=batch_size
self.num_classes = num_classes
self.task_num=task_num
self.out_num=int(self.num_classes/self.task_num)
self.node = node_type
if issubclass(self.node, BaseNode):
self.node = partial(self.node, **kwargs, step=step)
self.dataset = kwargs['dataset']
if not is_dvs_data(self.dataset):
init_channel = 3
output_size = 2
else:
init_channel = 2
output_size = 3
#self.channel_number=[256,512,1024]
self.channel_number=[512,1024,2048]
self.feature = nn.Sequential(
BaseConvModule(init_channel, self.channel_number[0], kernel_size=(3, 3), padding=(1, 1), node=self.node),
BaseConvModule(self.channel_number[0], self.channel_number[0], kernel_size=(3, 3), padding=(1, 1), node=self.node),
nn.AvgPool2d(2),
BaseConvModule(self.channel_number[0], self.channel_number[0], kernel_size=(3, 3), padding=(1, 1), node=self.node),
BaseConvModule(self.channel_number[0], self.channel_number[0], kernel_size=(3, 3), padding=(1, 1), node=self.node),
nn.AvgPool2d(2),
BaseConvModule(self.channel_number[0], self.channel_number[1], kernel_size=(3, 3), padding=(1, 1), node=self.node),
BaseConvModule(self.channel_number[1], self.channel_number[1], kernel_size=(3, 3), padding=(1, 1), node=self.node),
nn.AvgPool2d(2),
BaseConvModule(self.channel_number[1], self.channel_number[2], kernel_size=(3, 3), padding=(1, 1), node=self.node),
BaseConvModule(self.channel_number[2], self.channel_number[2], kernel_size=(3, 3), padding=(1, 1), node=self.node),
)
self.fc = nn.Sequential(
nn.Flatten(),
BaseLinearModule(
self.channel_number[2]*4*4, self.channel_number[2], node=self.node),
)
self.dec = nn.ModuleDict()
for task in range(self.task_num):
ta=str(task)
self.dec[ta] = self._create_decision()
def logits(self, x):
outputs =torch.zeros((self.task_num,self.batch_size,self.out_num),device='cuda')
for task, func in self.dec.items():
ta=int(task)
outputs[ta]=func(x)
return outputs
def _create_decision(self):
fc = nn.Linear(self.channel_number[2], self.out_num)
# fc = BaseLinearModule(1024, 10, node=self.node)
return fc
def forward(self, inputs, mat):
inputs = self.encoder(inputs)
self.reset()
step = self.step
outputs = []
for index, item in enumerate(self.parameters()):
if len(item.size()) > 1:
ww=item.data
item.data=ww*mat[index].cuda()
for t in range(step):
x = inputs[t]
x = self.feature(x)
x = self.fc(x)
x = self.logits(x)
outputs.append(x)
out=sum(outputs).cuda()
return out / step
# class MaskConvModule(nn.Module):
# """
# SNN卷积模块
# :param in_channels: 输入通道数
# :param out_channels: 输出通道数
# :param kernel_size: kernel size
# :param stride: stride
# :param padding: padding
# :param bias: Bias
# :param node: 神经元类型
# :param kwargs:
# """
# def __init__(self,
# in_channels: int,
# out_channels: int,
# kernel_size=(3, 3),
# stride=(1, 1),
# padding=(1, 1),
# bias=False,
# node=PLIFNode,
# **kwargs):
# super().__init__()
# if node is None:
# raise TypeError
# self.groups = kwargs['groups'] if 'groups' in kwargs else 1
# self.conv = MConv2d(in_channels=in_channels * self.groups,
# out_channels=out_channels * self.groups,
# kernel_size=kernel_size,
# padding=padding,
# stride=stride,
# bias=bias)
# self.bn = nn.BatchNorm2d(out_channels * self.groups)
# self.node = partial(node, **kwargs)()
# self.activation = nn.Identity()
# def forward(self, x, mat):
# x = self.conv(x,mat)
# x = self.bn(x)
# x = self.node(x)
# return x
# class MConv2d(nn.Conv2d):
# def __init__(self, in_channels, out_channels, kernel_size, stride=1,
# padding=0, dilation=1, groups=1, bias=True, gain=True):
# super(MConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
# padding, dilation, groups, bias)
# self.gain = 1.
# def forward(self, x, mat):
# weight = self.weight
# weight = weight*mat
# return F.conv2d(x, weight, self.bias, self.stride,
# self.padding, self.dilation, self.groups)
# class MaskLinearModule(nn.Module):
# """
# 线性模块
# :param in_features: 输入尺寸
# :param out_features: 输出尺寸
# :param bias: 是否有Bias, 默认 ``False``
# :param node: 神经元类型, 默认 ``LIFNode``
# :param args:
# :param kwargs:
# """
# def __init__(self,
# in_features: int,
# out_features: int,
# bias=True,
# node=LIFNode,
# *args,
# **kwargs):
# super().__init__()
# if node is None:
# raise TypeError
# self.fc = MLinear(in_features=in_features,
# out_features=out_features, bias=bias)
# self.node = partial(node, **kwargs)()
# def forward(self, x,mat):
# outputs = self.fc(x,mat)
# return self.node(outputs)
# class MLinear(nn.Linear):
# def __init__(self, in_features: int, out_features: int, bias: bool = True):
# super(MLinear, self).__init__(in_features, out_features, bias)
# self.gain = 1.
# def forward(self, input, mat):
# weight = self.weight
# weight = weight*mat
# return F.linear(input, weight, self.bias)
================================================
FILE: examples/Structural_Development/ELSM/evolve.py
================================================
import time
import threading
from threading import Thread
import os
import networkx as nx
import numpy as np
from population import *
import nsganet as engine
from pymop.problem import Problem
from pymoo.optimize import minimize
from pymoo.operators.sampling.random_sampling import RandomSampling
from pymoo.operators.mutation.bitflip_mutation import BinaryBitflipMutation
import logging
from model import *
from spikes import calc_f2
from mul import mul_f1
_logger = logging.getLogger('')
config_parser = parser = argparse.ArgumentParser(description='Evolution Config', add_help=False)
parser = argparse.ArgumentParser(description='ELSM')
parser.add_argument('--device', type=int, default=2)
parser.add_argument('--seed', type=int, default=68, metavar='S')
parser.add_argument('--datapath', default='', type=str, metavar='PATH')
parser.add_argument('--output', default='', type=str, metavar='PATH')
parser.add_argument('--liquid-size', type=int, default=8000)
parser.add_argument('--pop-size', type=int, default=80)
parser.add_argument('--up', type=int, default=32000000)
parser.add_argument('--low', type=int, default=320000)
parser.add_argument('--n_offspring', type=int, default=100)
parser.add_argument('--n_gens', type=int, default=10000)
parser.add_argument('--arand', type=float, default=285)
parser.add_argument('--brand', type=float, default=1.8)
def _parse_args():
args_config, remaining = config_parser.parse_known_args()
args = parser.parse_args(remaining)
return args
class Evolve(Problem):
# first define the NAS problem (inherit from pymop)
def __init__(self, args,n_var=20, n_obj=1, n_constr=0, lb=None, ub=None):
super().__init__(n_var=n_var, n_obj=n_obj, n_constr=n_constr, type_var=np.int64)
self.xl = lb
self.xu = ub
self._n_evaluated = 0 # keep track of how many architectures are sampled
self.args=args
def _evaluate(self, x, out, *args, **kwargs):
objs = np.full((x.shape[0], self.n_obj), np.nan)
g1 = np.full((x.shape[0]), np.nan)
g2 = np.full((x.shape[0]), np.nan)
gen_dir=os.path.join(self.args.output,'generaion'+str(kwargs['algorithm'].n_gen))
os.makedirs(gen_dir,exist_ok = True)
# np.save(os.path.join(gen_dir,"x.npy"),x)
lsms = x.reshape(x.shape[0],self.args.liquid_size,self.args.liquid_size)
for i in range(x.shape[0]):
temp_G = nx.Graph(lsms[i])
nx.write_gpickle(temp_G, os.path.join(gen_dir,str(i)+".pkl"))
self.ob1=mul_f1(pop=x.shape[0],steps=10,rootdir=gen_dir)
for i in range(x.shape[0]):
arch_id = self._n_evaluated + 1
print('\n')
_logger.info('Network= {}'.format(arch_id))
genome = x[i, :]
g1[i]= genome.sum()-self.args.up
g2[i]= self.args.low-genome.sum()
lsmm = genome.reshape(self.args.liquid_size,self.args.liquid_size)
small_coe_a,small_coe_b=self.ob1[i]
lsmm=torch.tensor(lsmm,device='cuda:%d' % self.args.device).float()
crit = calc_f2(lsmm,'cuda:%d' % self.args.device)
objs[i, 1] = abs(crit-1)
# all objectives assume to be MINIMIZED !!!!!
objs[i, 0] = -(small_coe_a/self.args.arand)/(small_coe_b/self.args.brand)
_logger.info('small word= {}'.format(objs[i, 0]))
_logger.info('criticality= {}'.format(objs[i, 1]))
self._n_evaluated += 1
out["F"] = objs
out["G"] = np.column_stack([g1,g2])
# if your NAS problem has constraints, use the following line to set constraints
# out["G"] = np.column_stack([g1, g2, g3, g4, g5, g6]) in case 6 constraints
# ---------------------------------------------------------------------------------------------------------
# Define what statistics to print or save for each generation
# ---------------------------------------------------------------------------------------------------------
def do_every_generations(algorithm):
# this function will be call every generation
# it has access to the whole algorithm class
gen = algorithm.n_gen
pop_var = algorithm.pop.get("X")
pop_obj = algorithm.pop.get("F")
# report generation info to files
_logger.info("generation = {}".format(gen))
_logger.info("population error1: best = {}, mean = {}, "
"median1 = {}, worst1 = {}".format(np.min(pop_obj[:, 0]), np.mean(pop_obj[:, 0]),
np.median(pop_obj[:, 0]), np.max(pop_obj[:, 0])))
_logger.info('Best1 Genome id= {}'.format(np.argmin(pop_obj[:, 0])))
_logger.info("population error2: best = {}, mean = {}, "
"median2 = {}, worst2 = {}".format(np.min(pop_obj[:, 1]), np.mean(pop_obj[:, 1]),
np.median(pop_obj[:, 1]), np.max(pop_obj[:, 1])))
_logger.info('Best2 Genome id= {}'.format(np.argmin(pop_obj[:, 1])))
if gen%20==0:
best_sid=np.argmin(pop_obj[:, 0])
best_sname='-'.join([
'gen'+str(gen),
's'+str(float('%.4f' % pop_obj[best_sid, 0])),
'c'+str(float('%.4f' % pop_obj[best_sid, 1])),
])
best_cid=np.argmin(pop_obj[:, 1])
best_cname='-'.join([
'gen'+str(gen),
's'+str(float('%.4f' % pop_obj[best_cid, 0])),
'c'+str(float('%.4f' % pop_obj[best_cid, 1])),
])
np.save(os.path.join('',best_sname+datetime.now().strftime("%Y%m%d-%H%M%S")),pop_var[np.argmin(pop_obj[:, 0])])
np.save(os.path.join('',best_cname+datetime.now().strftime("%Y%m%d-%H%M%S")),pop_var[np.argmin(pop_obj[:, 1])])
if __name__ == '__main__':
args = _parse_args()
out_base_dir= os.path.join(args.output, datetime.now().strftime("%Y%m%d-%H%M%S"))
os.makedirs(out_base_dir,exist_ok = True)
args.output=out_base_dir
setup_default_logging(log_path=os.path.join(out_base_dir, 'log.txt'))
kkk = Evolve(args,n_var=args.liquid_size*args.liquid_size,
n_obj=2, n_constr=2)
method = engine.nsganet(pop_size=args.pop_size,
sampling=RandomSampling(var_type='custom'),
mutation=BinaryBitflipMutation(),
n_offsprings=args.n_offspring,
eliminate_duplicates=True)
kres=minimize(kkk,
method,
callback=do_every_generations,
termination=('n_gen', args.n_gens))
================================================
FILE: examples/Structural_Development/ELSM/lsm.py
================================================
from __future__ import print_function
import torchvision
import torchvision.transforms as transforms
import os
import time
import numpy as np
import torch
from torch import nn as nn
from mnistmodel import *
from tqdm import tqdm
import argparse
from datetime import datetime
import logging
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy
from braincog.base.utils import UnilateralMse, MixLoss
from braincog.base.learningrule.STDP import *
device='cuda:7'
def lr_scheduler(optimizer, epoch, init_lr=0.1, lr_decay_epoch=50):
"""Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
if epoch % lr_decay_epoch == 0 and epoch > 1:
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] * 0.1
return optimizer
batch_size=100
liquid_size=8000
learning_rate = 1e-3
num_epochs = 100 # max epoch
data_path = '/data'
load_path=''
train_dataset = torchvision.datasets.MNIST(root=data_path, train=True, download=False, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
test_set = torchvision.datasets.MNIST(root=data_path, train=False, download=False, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)
snn = SNN(ins=784,
batchsize=batch_size,
device=device,
liquid_size=liquid_size,
lsm_tau=lsm_tau,
lsm_th=lsm_th)
snn.load_state_dict(torch.load(load_path)['fc'])
snn.learning_rule=[]
snn.con[0].load_state_dict(torch.load(load_path)['lsm0'])
w2tmp=nn.Linear(liquid_size,liquid_size,bias=False,device=device)
snn.connectivity_matrix=torch.load(load_path)['connectivity_matrix'].to(device)
w2tmp.weight.data=(torch.load(load_path)['liquid_weight'].to(device))*snn.connectivity_matrix
snn.learning_rule.append(MutliInputSTDP(snn.node_lsm(), [snn.con[0], w2tmp])) # pm
snn.eval()
snn.to(device)
ls = 'mse'
if ls == 'ce':
criterion = nn.CrossEntropyLoss()
elif ls == 'bce':
criterion = nn.BCEWithLogitsLoss()
elif ls == 'mse':
criterion = UnilateralMse(1.)
elif ls == 'sce':
criterion = LabelSmoothingCrossEntropy()
elif ls == 'sbce':
criterion = LabelSmoothingBCEWithLogitsLoss()
elif ls == 'umse':
criterion = UnilateralMse(.5)
optimizer = torch.optim.AdamW(snn.fc.parameters(),lr=0.001, weight_decay=1e-4)
l=[]
best_acc=0
for epoch in range(num_epochs):
running_loss = 0
start_time = time.time()
for i, (images, labels) in enumerate(tqdm(train_loader)):
snn.zero_grad()
optimizer.zero_grad()
images = images.float().to(device)
outputs = snn(images)
labels=labels.to(device)
loss = criterion(outputs, labels)
running_loss += loss.item()
loss.backward()
optimizer.step()
snn.reset()
if (i + 1) % 100 == 0:
running_loss = 0
correct = 0
total = 0
optimizer = lr_scheduler(optimizer, epoch, learning_rate, 40)
for batch_idx, (inputs, targets) in enumerate(test_loader):
inputs = inputs.float().to(device)
snn.zero_grad()
optimizer.zero_grad()
outputs = snn(inputs)
targets=targets.to(device)
loss = criterion(outputs, targets)
_, predicted = outputs.max(1)
total += float(targets.size(0))
correct += float(predicted.eq(targets).sum().item())
snn.reset()
if batch_idx % 100 == 0:
acc = 100. * float(correct) / float(total)
print(batch_idx, len(test_loader), ' Acc: %.5f' % acc)
print('Test Accuracy: %.3f' % (100 * correct / total))
acc = 100. * float(correct) / float(total)
if best_acc < acc:
best_acc = acc
print(best_acc)
l.append(best_acc)
================================================
FILE: examples/Structural_Development/ELSM/model.py
================================================
from functools import partial
from torch.nn import functional as F
from torch import nn as nn
import torchvision, pprint
from copy import deepcopy
from timm.models import register_model
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.encoder.encoder import *
from braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule
from braincog.base.brainarea.BrainArea import BrainArea
from braincog.base.connection.CustomLinear import *
from braincog.base.learningrule.STDP import *
import matplotlib.pyplot as plt
@register_model
class nSNN(BaseModule):
def __init__(self,
batchsize,
liquid_size,
device,
connectivity_matrix,
num_classes=10,
step=1,
node_type=LIFNode,
encode_type='direct',
lsm_th=0.3,
fc_th=0.3,
lsm_tau=3,
fc_tau=3,
ins=1156,
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.batchsize=batchsize
self.ins=ins
self.node_lsm=partial(node_type, **kwargs, step=step,tau=lsm_tau,threshold=lsm_th)
self.node_fc = partial(node_type, **kwargs, step=step,tau=fc_tau,threshold=fc_th)
self.liquid_size=liquid_size
self.device=device
self.con=[]
self.learning_rule=[]
self.connectivity_matrix=connectivity_matrix
w1tmp=nn.Linear(ins,liquid_size,bias=False).to(device)
self.con.append(w1tmp)
w2tmp=nn.Linear(liquid_size,liquid_size,bias=False).to(device)
self.liquid_weight=w2tmp.weight.data
w2tmp.weight.data=w2tmp.weight.data*self.connectivity_matrix
self.con.append(w2tmp)
self.learning_rule.append(MutliInputSTDP(self.node_lsm(), [self.con[0], self.con[1]])) # pm
self.fc = nn.Sequential(
nn.Linear(liquid_size,num_classes),
self.node_fc()
)
def forward(self, x):
sum_spike=0
self.out = torch.zeros(x.shape[0], self.liquid_size).to(self.device)
tw=x.shape[1]
self.tw=tw
self.firing_tw=torch.zeros(tw, self.batchsize, self.liquid_size).to(self.device)
for t in range(tw):
self.out, self.dw = self.learning_rule[0](x[:,t,:], self.out)
out_liquid=self.out[:,0:self.liquid_size]
xout = self.fc(out_liquid)
sum_spike=sum_spike+xout
self.firing_tw[t]=out_liquid
outputs = sum_spike / tw
return outputs
@register_model
class mSNN(BaseModule):
def __init__(self,
batchsize,
liquid_size,
device,
connectivity_matrix,
num_classes=10,
step=1,
node_type=LIFNode,
encode_type='direct',
lsm_th=0.3,
fc_th=0.3,
lsm_tau=3,
fc_tau=3,
tw=100,
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.batchsize=batchsize
self.node_lsm=partial(node_type, **kwargs, step=step,tau=lsm_tau,threshold=lsm_th)
self.node_fc = partial(node_type, **kwargs, step=step,tau=fc_tau,threshold=fc_th)
self.liquid_size=liquid_size
self.out = torch.zeros(self.batchsize, liquid_size).to(device)
self.device=device
self.con=[]
self.learning_rule=[]
self.connectivity_matrix=connectivity_matrix
w1tmp=nn.Linear(784,liquid_size,bias=False).to(device)
self.con.append(w1tmp)
w2tmp=nn.Linear(liquid_size,liquid_size,bias=False).to(device)
self.liquid_weight=w2tmp.weight.data
w2tmp.weight.data=w2tmp.weight.data*self.connectivity_matrix
self.con.append(w2tmp)
self.learning_rule.append(MutliInputSTDP(self.node_lsm(), [self.con[0], self.con[1]])) # pm
self.fc = nn.Sequential(
nn.Linear(liquid_size,num_classes),
self.node_fc()
)
def forward(self, x):
x = x.reshape(x.shape[0], -1)
sum_spike=0
time_window=20
self.tw=time_window
self.firing_tw=torch.zeros(time_window, self.batchsize, self.liquid_size).to(self.device)
self.out = torch.zeros(self.batchsize, self.liquid_size).to(self.device)
for t in range(time_window):
self.out, self.dw = self.learning_rule[0](x, self.out)
out_liquid=self.out[:,0:self.liquid_size]
xout = self.fc(out_liquid)
sum_spike=sum_spike+xout
self.firing_tw[t]=out_liquid
# print(out_liquid.sum())
# print(xout.sum())
outputs = sum_spike / time_window
return outputs
================================================
FILE: examples/Structural_Development/ELSM/nsganet.py
================================================
import numpy as np
from pymoo.algorithms.genetic_algorithm import GeneticAlgorithm
from pymoo.docs import parse_doc_string
from pymoo.model.individual import Individual
from pymoo.model.survival import Survival
from pymoo.operators.crossover.point_crossover import PointCrossover
from pymoo.operators.mutation.polynomial_mutation import PolynomialMutation
from pymoo.operators.mutation.bitflip_mutation import BinaryBitflipMutation
from pymoo.operators.sampling.random_sampling import RandomSampling
from pymoo.operators.selection.tournament_selection import compare, TournamentSelection
from pymoo.util.display import disp_multi_objective
from pymoo.util.dominator import Dominator
from pymoo.util.non_dominated_sorting import NonDominatedSorting
from pymoo.util.randomized_argsort import randomized_argsort
# =========================================================================================================
# Implementation
# based on nsga2 from https://github.com/msu-coinlab/pymoo
# =========================================================================================================
class NSGANet(GeneticAlgorithm):
def __init__(self, **kwargs):
kwargs['individual'] = Individual(rank=np.inf, crowding=-1)
super().__init__(**kwargs)
self.tournament_type = 'comp_by_dom_and_crowding'
self.func_display_attrs = disp_multi_objective
# ---------------------------------------------------------------------------------------------------------
# Binary Tournament Selection Function
# ---------------------------------------------------------------------------------------------------------
def binary_tournament(pop, P, algorithm, **kwargs):
if P.shape[1] != 2:
raise ValueError("Only implemented for binary tournament!")
tournament_type = algorithm.tournament_type
S = np.full(P.shape[0], np.nan)
for i in range(P.shape[0]):
a, b = P[i, 0], P[i, 1]
# if at least one solution is infeasible
if pop[a].CV > 0.0 or pop[b].CV > 0.0:
S[i] = compare(a, pop[a].CV, b, pop[b].CV, method='smaller_is_better', return_random_if_equal=True)
# both solutions are feasible
else:
if tournament_type == 'comp_by_dom_and_crowding':
rel = Dominator.get_relation(pop[a].F, pop[b].F)
if rel == 1:
S[i] = a
elif rel == -1:
S[i] = b
elif tournament_type == 'comp_by_rank_and_crowding':
S[i] = compare(a, pop[a].rank, b, pop[b].rank,
method='smaller_is_better')
else:
raise Exception("Unknown tournament type.")
# if rank or domination relation didn't make a decision compare by crowding
if np.isnan(S[i]):
S[i] = compare(a, pop[a].get("crowding"), b, pop[b].get("crowding"),
method='larger_is_better', return_random_if_equal=True)
return S[:, None].astype(np.int)
# ---------------------------------------------------------------------------------------------------------
# Survival Selection
# ---------------------------------------------------------------------------------------------------------
class RankAndCrowdingSurvival(Survival):
def __init__(self) -> None:
super().__init__(True)
def _do(self, pop, n_survive, D=None, **kwargs):
# get the objective space values and objects
F = pop.get("F")
# the final indices of surviving individuals
survivors = []
# do the non-dominated sorting until splitting front
fronts = NonDominatedSorting().do(F, n_stop_if_ranked=n_survive)
for k, front in enumerate(fronts):
# calculate the crowding distance of the front
crowding_of_front = calc_crowding_distance(F[front, :])
# save rank and crowding in the individual class
for j, i in enumerate(front):
pop[i].set("rank", k)
pop[i].set("crowding", crowding_of_front[j])
# current front sorted by crowding distance if splitting
if len(survivors) + len(front) > n_survive:
I = randomized_argsort(crowding_of_front, order='descending', method='numpy')
I = I[:(n_survive - len(survivors))]
# otherwise take the whole front unsorted
else:
I = np.arange(len(front))
# extend the survivors by all or selected individuals
survivors.extend(front[I])
return pop[survivors]
def calc_crowding_distance(F):
infinity = 1e+14
n_points = F.shape[0]
n_obj = F.shape[1]
if n_points <= 2:
return np.full(n_points, infinity)
else:
# sort each column and get index
I = np.argsort(F, axis=0, kind='mergesort')
# now really sort the whole array
F = F[I, np.arange(n_obj)]
# get the distance to the last element in sorted list and replace zeros with actual values
dist = np.concatenate([F, np.full((1, n_obj), np.inf)]) \
- np.concatenate([np.full((1, n_obj), -np.inf), F])
index_dist_is_zero = np.where(dist == 0)
dist_to_last = np.copy(dist)
for i, j in zip(*index_dist_is_zero):
dist_to_last[i, j] = dist_to_last[i - 1, j]
dist_to_next = np.copy(dist)
for i, j in reversed(list(zip(*index_dist_is_zero))):
dist_to_next[i, j] = dist_to_next[i + 1, j]
# normalize all the distances
norm = np.max(F, axis=0) - np.min(F, axis=0)
norm[norm == 0] = np.nan
dist_to_last, dist_to_next = dist_to_last[:-1] / norm, dist_to_next[1:] / norm
# if we divided by zero because all values in one columns are equal replace by none
dist_to_last[np.isnan(dist_to_last)] = 0.0
dist_to_next[np.isnan(dist_to_next)] = 0.0
# sum up the distance to next and last and norm by objectives - also reorder from sorted list
J = np.argsort(I, axis=0)
crowding = np.sum(dist_to_last[J, np.arange(n_obj)] + dist_to_next[J, np.arange(n_obj)], axis=1) / n_obj
# replace infinity with a large number
crowding[np.isinf(crowding)] = infinity
return crowding
# =========================================================================================================
# Interface
# =========================================================================================================
def nsganet(
pop_size=100,
sampling=RandomSampling(var_type=np.int),
selection=TournamentSelection(func_comp=binary_tournament),
crossover=PointCrossover(n_points=2),
mutation=PolynomialMutation(eta=3, var_type=np.int),
eliminate_duplicates=True,
n_offsprings=None,
**kwargs):
"""
Parameters
----------
pop_size : {pop_size}
sampling : {sampling}
selection : {selection}
crossover : {crossover}
mutation : {mutation}
eliminate_duplicates : {eliminate_duplicates}
n_offsprings : {n_offsprings}
Returns
-------
nsganet : :class:`~pymoo.model.algorithm.Algorithm`
Returns an NSGANet algorithm object.
"""
return NSGANet(pop_size=pop_size,
sampling=sampling,
selection=selection,
crossover=crossover,
mutation=mutation,
survival=RankAndCrowdingSurvival(),
eliminate_duplicates=eliminate_duplicates,
n_offsprings=n_offsprings,
**kwargs)
parse_doc_string(nsganet)
================================================
FILE: examples/Structural_Development/ELSM/spikes.py
================================================
from __future__ import print_function
import torchvision
import torchvision.transforms as transforms
import os
import numpy as np
import torch
from torch import nn as nn
from model import *
from tqdm import tqdm
import argparse
from datetime import datetime
import logging
from timm.utils import *
from spikingjelly.datasets.n_mnist import NMNIST
from timm.loss import LabelSmoothingCrossEntropy
from braincog.base.utils.criterions import *
import networkx as nx
import time
from braincog.base.learningrule.STDP import *
def randbool(size, p=0.5):
return torch.rand(*size) < p
def calc_f2(con,device):
batch_size=1
liquid_size=8000
images=torch.load('/1000images.pt')
labels=torch.load('/1000labels.pt')
load_path='970.t7'
snn = nSNN(ins=2312,
batchsize=batch_size,
device=device,
liquid_size=liquid_size,
lsm_tau=2.0,
lsm_th=0.20,
connectivity_matrix=randbool([liquid_size, liquid_size],p=0.01).to(device).int())
snn.load_state_dict(torch.load(load_path,map_location={'cuda:2':device})['fc'])
snn.con[0].load_state_dict(torch.load(load_path,map_location={'cuda:2':device})['lsm0'])
snn.to(device)
criterion = UnilateralMse(1.)
optimizer = torch.optim.AdamW(snn.fc.parameters(),lr=0.001, weight_decay=1e-4)
k=0
sbr=0
snn.connectivity_matrix=con
snn.learning_rule=[]
w2tmp=nn.Linear(liquid_size,liquid_size,bias=False,device=device)
w2tmp.weight.data=(torch.load(load_path,map_location={'cuda:2':device})['liquid_weight'])*snn.connectivity_matrix
snn.learning_rule.append(MutliInputSTDP(snn.node_lsm(), [snn.con[0], w2tmp]))
snn.eval()
for label,data in zip(labels,images):
running_loss = 0
snn.zero_grad()
optimizer.zero_grad()
data = data.to(device)
label = label.to(device)
data=data.reshape(batch_size,data.shape[0],-1)
output = snn(data)
# print(torch.argmax(output)==label)
out_liquid=snn.firing_tw.squeeze(-2)
mupost=torch.matmul(con,out_liquid.unsqueeze(-1))
mupre=torch.matmul(con.t(),out_liquid.unsqueeze(-1))
for t in range(snn.tw):
if t>5 and t0:
xnp=x.cpu().numpy()
maxx=torch.max(x)
#maxx=np.percentile(xnp, 99.5)
minx=torch.min(x)
marge=maxx-minx
if marge!=0:
xx=(x-minx)/marge
xx=torch.clip(xx, 0,1)
else:
xx=0.5*torch.ones_like(x)
return xx
else:
return x
class Mask:
def __init__(self, model):
self.model = model
self.mat = {}
self.p_index={}
self.p_num={}
self.k=15
self.task_ready={}
self.regutask_ready={}
self.taskmask={}
self.init_rate=0.3
self.grow_rate=0.125
self.prunconv_init=0.8
self.prunfc_init=1.3
self.prunconv_grow=0.5
self.prunfc_grow=1
self.n_delta={}
self.ren_delta={}
self.reduce={}
self.rereduce={}
self.taskww={}
self.tasknore={}
def init_length(self,task=0,task_nn=None):
for index, item in enumerate(self.model.parameters()):
if len(item.size()) > 1 and item.size()[-1]!=1:
print(index,item.size())
self.mat[index]=torch.ones(item.size(),device=device)
for t in range(task):
self.rereduce[t]={}
for index, item in enumerate(self.model.parameters()):
if True:
if index<=20:
c_index=0
elif index<=45:
c_index=1
elif index<=70:
c_index=2
else:
c_index=3
taskindb=task_nn[t-1][c_index]
taskinda=task_nn[t][c_index]
lenre=taskinda-taskindb
self.rereduce[t][index] = 1*torch.ones(lenre,device=device)
return self.mat
def get_filter_reuse(self,index,ww,task,epoch,c_index,cdim_before=None,task_nn=None,all_dist=None,bias=0):
lenre=cdim_before[1]-cdim_before[0]
similar=1-all_dist+bias
if similar<0.2:
similar=0.2
if similar>0.9:
similar=0.9
revalue=similar*torch.ones(lenre).cuda() #1/8,1/4,1/2,1,1.5
if len(ww.size()) == 4:
# p_www=ww*self.mat[index]
p_ww=torch.sum(torch.sum(torch.sum(ww,dim=3),dim=2),dim=1)
# p_ww=p_ww[cdim_before[0]:cdim_before[1]]
ren_delta=-(2*unit(p_ww)-revalue)#revalue0.8
#print(self.ren_delta[index])
pos=torch.nonzero(ren_delta>0)
ren_delta[pos]=ren_delta[pos]+3
self.rereduce[task][index]=self.rereduce[task][index]*0.999+ren_delta*math.exp(-int((epoch-1)/2))
p_ind = torch.nonzero(self.rereduce[task][index] <0)
matkey=self.mat.keys()
matkey=torch.tensor(list(matkey))
matindex=torch.nonzero(matkey==index)
next_index=matkey[matindex+1]
for x in range(0, len(p_ind)):
self.mat[next_index.item()][:,p_ind[x]+cdim_before[0]]=0
b = self.mat[next_index.item()][:,cdim_before[0]:cdim_before[1]].reshape(-1).cpu().numpy()
pruning=100*(len(b)- np.count_nonzero(b))/len(b)
#print(index,self.rereduce[task][index].mean(),self.rereduce[task][index].max(),self.rereduce[task][index].min(),len(b)-np.count_nonzero(b),pruning)
def convert2tensor(self, x):
x = torch.FloatTensor(x)
return x
def init_mask(self,task,epoch,dim_cur=None,task_nn=None,all_dist=None,all_model=None):
for t in range(task):
similart=all_dist[t]
for index, item in enumerate(all_model[t].parameters()):
if len(item.size()) > 2 and item.size()[-1]!=1 and index<95:
if index<=20:
c_index=0
elif index<=45:
c_index=1
elif index<=70:
c_index=2
else:
c_index=3
if index<=25:
bias=0.2
elif index<=50:
bias=0.1
elif index<=74:
bias=-0.1
else:
bias=0.2
taskindb=task_nn[t-1][c_index]
taskinda=task_nn[t][c_index]
cdim_before=[taskindb,taskinda]
self.get_filter_reuse(index,abs(item.grad),t,epoch,c_index,cdim_before,task_nn=task_nn,all_dist=similart,bias=bias)
def do_mask(self,task):
for index, item in enumerate(self.model.parameters()):
if len(item.size()) > 1 and item.size()[-1]!=1:
ww=item.data
item.data=ww*self.mat[index].cuda()
return self.mat
def if_zero(self):
cc=[]
for index, item in enumerate(self.model.parameters()):
if len(item.size()) > 1 and item.size()[-1]!=1 and index>0:
b = item.data.view(-1).cpu().numpy()
print("number of weight is %d, zero is %.3f" %(len(b),100*(len(b)- np.count_nonzero(b))/len(b)))
cc.append(100*(len(b)- np.count_nonzero(b))/len(b))
return cc
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/convnet/network.py
================================================
import copy
# import pdb
import torch
from torch import nn
import torch.nn.functional as F
from inclearn.tools import factory
from inclearn.convnet.imbalance import CR, All_av,BiC
from inclearn.convnet.classifier import CosineClassifier
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.encoder.encoder import *
from braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule
class BasicNet(nn.Module):
def __init__(
self,
convnet_type,
cfg,
nf=64,
use_bias=False,
init="kaiming",
device=None,
dataset="cifar100",
):
super(BasicNet, self).__init__()
self.nf = nf
self.init = init
self.convnet_type = convnet_type
self.dataset = dataset
self.start_class = cfg['start_class']
self.weight_normalization = cfg['weight_normalization']
self.remove_last_relu = True if self.weight_normalization else False
self.use_bias = use_bias if not self.weight_normalization else False
self.dea = cfg['dea']
self.ft_type = cfg.get('feature_type', 'normal')
self.at_res = cfg.get('attention_use_residual', False)
self.div_type = cfg['div_type']
self.reuse_oldfc = cfg['reuse_oldfc']
self.prune = cfg.get('prune', False)
self.reset = cfg.get('reset_se', True)
self.torc=cfg['distillation']
self.node =LIFNode
self.encoder = Encoder(4, 'direct', temporal_flatten=False, layer_by_layer=False, **cfg)
# if self.dea:
# print("Enable dynamical reprensetation expansion!")
# self.convnets = nn.ModuleList()
# self.convnets.append(
# factory.get_convnet(convnet_type,
# nf=nf,
# dataset=dataset,
# start_class=self.start_class,
# remove_last_relu=self.remove_last_relu))
# self.out_dim = self.convnets[0].out_dim
# self.c_dim=self.convnets[0].channel_dim
# else:
# self.convnet = factory.get_convnet(convnet_type,
# nf=nf,
# dataset=dataset,
# remove_last_relu=self.remove_last_relu)
# self.out_dim = self.convnet.out_dim
self.channel_number=[32,64,128,256]#[32,64,128,256] # [24,48,72,96] #[24,48,96,192][32,64,128,256]
self.channel_dim=[48,96,192,384]
self.c_number1=np.array(self.channel_number)
if self.dea:
print("Enable dynamical reprensetation expansion!")
self.convnets = nn.ModuleList()
self.convnets.append(
factory.get_convnet(convnet_type,c_dim=self.channel_number,cdim_cur=self.channel_number)
)
self.out_dim = self.channel_number[-1]
self.out_dim_cc = self.channel_number[-1]
else:
self.convnet = factory.get_convnet(convnet_type,
nf=nf,
dataset=dataset,
remove_last_relu=self.remove_last_relu)
self.out_dim = self.convnet.out_dim
self.classifier = None
self.se = None
self.aux_classifier = None
self.n_classes = 0
self.ntask = 0
self.device = device
if cfg['postprocessor']['enable']:
if cfg['postprocessor']['type'].lower() == "cr":
self.postprocessor = CR()
elif cfg['postprocessor']['type'].lower() == "aver":
self.postprocessor = All_av()
else:
self.postprocessor = BiC(cfg['postprocessor']["lr"], cfg['postprocessor']["scheduling"],
cfg['postprocessor']["lr_decay_factor"], cfg['postprocessor']["weight_decay"],
cfg['postprocessor']["batch_size"], cfg['postprocessor']["epochs"])
self.task_nn={}
self.task_nn[-1]=np.array([0,0,0,0])
self.task_nn[0]=np.array(self.channel_number)
self.to(self.device)
def forward(self, task,inputs,mask=None,classify=True):
inputs = self.encoder(inputs)
self.resetsnn()
step = 4
outputs = []
if self.classifier is None:
raise Exception("Add some classes before training.")
if mask is not None:
mat=mask.mat
if self.torc:
ttc=task
else:
ttc=-1
for index, item in enumerate(self.convnets[ttc].parameters()):
if len(item.size()) > 1 and item.size()[-1]!=1:
ww=item.data
item.data=ww*mat[index].cuda()
if self.dea:
# feature = [convnet(x) for convnet in self.convnets]
for time in range(step):
x_init = inputs[time]
task_feature={}
for t in range(task):
task_feature[t]={}
x=self.convnets[t].forward_init(x_init)
task_feature[t][0]=x
for l in range(len(self.convnets[t].layer_convnets)):
for lc in range(len(self.convnets[t].layer_convnets[l].conv)):
if lc==0 or lc==3:
identity = x
if lc==2:
if l>0:
identity=self.convnets[t].layer_convnets[l].conv[lc](identity)
x=x+identity
task_feature[t][5*l+lc+1]=x
else:
old_tfeature=[]
for old_t in range(t):
old_tfeature.append(task_feature[old_t][5*l+lc])
old_tfeature.append(x)
x= torch.cat(old_tfeature, 1)
x=self.convnets[t].layer_convnets[l].conv[lc](x)
if lc==4:
x=x+identity
task_feature[t][5*l+lc+1]=x
x=self.convnets[task].forward_init(x_init)
for l in range(len(self.convnets[task].layer_convnets)):
for lc in range(len(self.convnets[task].layer_convnets[l].conv)):
if lc==0 or lc==3:
identity = x
if lc==2:
if l>0:
identity=self.convnets[task].layer_convnets[l].conv[lc](identity)
x=x+identity
else:
mid_feature=[]
for t in range(task):
mid_feature.append(task_feature[t][5*l+lc])
mid_feature.append(x)
x= torch.cat(mid_feature, 1)
x=self.convnets[task].layer_convnets[l].conv[lc](x)
if lc==4:
x=x+identity
if self.torc:
last_feature=[]
for t in range(task):
last=len(task_feature[t])-1
last_feature.append(task_feature[t][last])
last_feature.append(x)
outputs.append(torch.cat(last_feature, 1))
else:
x=self.convnets[task].avgpool(x)
x=x.view(x.size()[0],-1)
outputs.append(x)
feature=sum(outputs).cuda()/ step
last_dim =x.size(1)
width = feature.size(1)
if self.torc:
if self.reset:
se = factory.get_attention(width, self.ft_type, self.at_res).to(self.device)
features = se(feature)
else:
features = self.se(feature)
else:
features=feature
else:
features = self.convnet(x)
if self.torc:
if classify==True:
logits = self.convnets[-1].classifer(features)
div_logits = self.convnets[-1].aux_classifier(features[:, -last_dim:]) if self.ntask > 1 else None
else:
logits=None
div_logits=None
else:
if classify==True:
logits = self.convnets[task].classifer(features)
else:
logits=None
div_logits=None
return {'feature': features, 'logit': logits, 'div_logit': div_logits, 'features': feature}
def caculate_dim(self, x):
feature = [convnet(x) for convnet in self.convnets]
features = torch.cat(feature, 1)
width = features.size(1)
# se = factory.get_attention(width, self.ft_type, self.at_res).to(self.device)
se = factory.get_attention(width, "ce", self.at_res).cuda()
features = se(features)
# import pdb
# pdb.set_trace()
return features.size(1), feature[-1].size(1)
@property
def features_dim(self,ntask):
if self.dea:
return self.out_dim#+ntask*self.channel_number1[-1]
else:
return self.out_dim
def freeze(self):
for param in self.parameters():
param.requires_grad = False
self.eval()
return self
def copy(self):
return copy.deepcopy(self)
def add_classes(self, n_classes,min_dist):
self.ntask += 1
if self.dea:
self._add_classes_multi_fc(n_classes,min_dist)
else:
self._add_classes_single_fc(n_classes)
self.n_classes += n_classes
def _add_classes_multi_fc(self, n_classes,min_dist):
self.classifier=self.convnets[-1].classifer
if self.ntask > 1:
if min_dist<0.1:
min_dist=0.1
self.channel_number1=np.array([32,64,96,128])*1*(1-math.exp(-5*min_dist)) #0.5,1,1.5,2,3,4 [24,48,72,96][16,32,48,64]
self.channel_number1=self.channel_number1.astype(np.int64)
self.channel_dim=self.channel_number1
self.c_number1=self.c_number1+np.array(self.channel_number1)
self.task_nn[self.ntask-1]=self.c_number1
new_clf = factory.get_convnet("resnet18",c_dim=self.c_number1,cdim_cur=self.channel_number1).to(self.device)
self.out_dim=self.out_dim+self.channel_number1[-1]
self.out_dim_cc=self.channel_number1[-1]
self.convnets.append(new_clf)
if self.torc:
if not self.reset:
self.se = factory.get_attention(512*len(self.convnets), self.ft_type, self.at_res)
self.se.to(self.device)
if self.classifier is not None:
weight = copy.deepcopy(self.classifier.weight.data)
fc = self._gen_classifier(self.out_dim, self.n_classes + n_classes)
if self.classifier is not None and self.reuse_oldfc:
fc.weight.data[:self.n_classes, :(self.out_dim - self.out_dim_cc)] = weight
del self.classifier
self.classifier = fc
self.convnets[-1].classifer=self.classifier
else:
fc = self._gen_classifier(self.out_dim_cc, n_classes)
del self.classifier
self.classifier = fc
self.convnets[-1].classifer=fc
if self.torc:
if self.div_type == "n+1":
div_fc = self._gen_classifier(self.out_dim_cc, n_classes + 1)
elif self.div_type == "1+1":
div_fc = self._gen_classifier(self.out_dim_cc, 2)
elif self.div_type == "n+t":
div_fc = self._gen_classifier(self.out_dim_cc, self.ntask + n_classes)
else:
div_fc = self._gen_classifier(self.out_dim_cc, self.n_classes + n_classes)
del self.aux_classifier
self.aux_classifier = div_fc
self.convnets[-1].aux_classifier=self.aux_classifier
def _add_classes_single_fc(self, n_classes):
if self.classifier is not None:
weight = copy.deepcopy(self.classifier.weight.data)
if self.use_bias:
bias = copy.deepcopy(self.classifier.bias.data)
classifier = self._gen_classifier(self.features_dim, self.n_classes + n_classes)
if self.classifier is not None and self.reuse_oldfc:
classifier.weight.data[:self.n_classes] = weight
if self.use_bias:
classifier.bias.data[:self.n_classes] = bias
del self.classifier
self.classifier = classifier
def _gen_classifier(self, in_features, n_classes):
if self.weight_normalization:
classifier = CosineClassifier(in_features, n_classes).to(self.device)
# classifier = CosineClassifier(in_features, n_classes).cuda()
else:
classifier = nn.Linear(in_features, n_classes, bias=self.use_bias).to(self.device)
# classifier = nn.Linear(in_features, n_classes, bias=self.use_bias).cuda()
if self.init == "kaiming":
nn.init.kaiming_normal_(classifier.weight, nonlinearity="linear")
if self.use_bias:
nn.init.constant_(classifier.bias, 0.0)
return classifier
def resetsnn(self):
"""
重置所有神经元的膜电位
:return:
"""
for mod in self.convnets.modules():
if hasattr(mod, 'n_reset'):
mod.n_reset()
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/convnet/resnet.py
================================================
"""Taken & slightly modified from:
* https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
"""
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from torch.nn import functional as F
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, remove_last_relu=False):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
self.remove_last_relu = remove_last_relu
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
if not self.remove_last_relu:
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = conv1x1(inplanes, planes)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = conv1x1(planes, planes * self.expansion)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
# 共享权重的MLP
self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
self.relu1 = nn.ReLU()
self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
class SEFeatureAt(nn.Module):
def __init__(self, inplanes, type, at_res):
super(SEFeatureAt, self).__init__()
self.se = nn.Sequential(
nn.AdaptiveAvgPool2d((1,1)),
nn.Conv2d(inplanes,inplanes//16,kernel_size=1),
nn.ReLU(),
nn.Conv2d(inplanes//16,inplanes,kernel_size=1),
nn.Sigmoid()
)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.type = type
self.at_res = at_res
self.ca = ChannelAttention(inplanes)
self.sa = SpatialAttention()
def forward(self, x):
residual = x
if self.type == "se":
attention = self.se(x)
x = x * attention
elif self.type == "ffm":
x = self.ca(x) * x
x = self.sa(x) * x
if self.at_res:
x += residual
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
class ResNet(nn.Module):
def __init__(self,
block,
layers,
nf=64,
zero_init_residual=True,
dataset='cifar',
start_class=0,
remove_last_relu=False):
super(ResNet, self).__init__()
self.remove_last_relu = remove_last_relu
self.inplanes = nf
if 'cifar' in dataset:
self.conv1 = nn.Sequential(nn.Conv2d(3, nf, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(nf), nn.ReLU(inplace=True))
elif 'imagenet' in dataset:
if start_class == 0:
self.conv1 = nn.Sequential(
nn.Conv2d(3, nf, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(nf),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
)
else:
# Following PODNET implmentation
self.conv1 = nn.Sequential(
nn.Conv2d(3, nf, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(nf),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
)
self.layer1 = self._make_layer(block, 1 * nf, layers[0])
self.layer2 = self._make_layer(block, 2 * nf, layers[1], stride=2)
self.layer3 = self._make_layer(block, 4 * nf, layers[2], stride=2)
self.layer4 = self._make_layer(block, 8 * nf, layers[3], stride=2, remove_last_relu=remove_last_relu)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.out_dim = 8 * nf * block.expansion
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, remove_last_relu=False, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
if remove_last_relu:
for i in range(1, blocks - 1):
layers.append(block(self.inplanes, planes))
layers.append(block(self.inplanes, planes, remove_last_relu=True))
else:
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def reset_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.reset_running_stats()
def forward(self, x):
x = self.conv1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# x = self.avgpool(x)
# x = x.view(x.size(0), -1)
return x
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model
def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
return model
def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/convnet/sew_resnet.py
================================================
import torch
import torch.nn as nn
from copy import deepcopy
try:
from torchvision.models.utils import load_state_dict_from_url
except ImportError:
from torchvision._internally_replaced_utils import load_state_dict_from_url
from braincog.base.node import *
from braincog.model_zoo.base_module import *
from braincog.datasets import is_dvs_data
from timm.models import register_model
__all__ = ['SEWResNet', 'sew_resnet18', 'sew_resnet34', 'sew_resnet50', 'sew_resnet101',
'sew_resnet152', 'sew_resnext50_32x4d', 'sew_resnext101_32x8d',
'sew_wide_resnet50_2', 'sew_wide_resnet101_2']
model_urls = {
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
"wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
"wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
}
# modified by https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
def sew_function(x: torch.Tensor, y: torch.Tensor, cnf:str):
if cnf == 'ADD':
return x + y
elif cnf == 'AND':
return x * y
elif cnf == 'IAND':
return x * (1. - y)
else:
raise NotImplementedError
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, planes_cur,stride=1, downsample=None, groups=1,
dilation=1, norm_layer=None, cnf: str = None, node: callable = LIFNode, **kwargs):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
# if groups != 1 or base_width != 64:
# raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.conv=nn.Sequential(
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
BaseConvModule(inplanes, planes_cur, kernel_size=(3, 3), stride=stride,padding=(1, 1), node=node),
BaseConvModule(planes, planes_cur, kernel_size=(3, 3), padding=(1, 1), node=node),
downsample,
BaseConvModule(planes, planes_cur, kernel_size=(3, 3), padding=(1, 1), node=node),
BaseConvModule(planes, planes_cur, kernel_size=(3, 3), padding=(1, 1), node=node),)
self.cnf = cnf
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.conv2(out)
if self.downsample is not None:
identity = self.downsample_sn(self.downsample(x))
out = sew_function(identity, out, self.cnf)
return out
def extra_repr(self) -> str:
return super().extra_repr() + f'cnf={self.cnf}'
class SEWResNet(BaseModule):
def __init__(self, block, layers, c_dim=[64,128,256,512], cdim_cur=[],step=4,encode_type="direct",zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None, cnf: str = 'ADD', *args,**kwargs):
super().__init__(
step,
encode_type,
*args,
**kwargs
)
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.groups=groups
self.node = LIFNode
if issubclass(self.node, BaseNode):
self.node = partial(self.node, **kwargs, step=step)
self.c_dim=c_dim
if len(cdim_cur)>0:
self.cdim_cur=cdim_cur
else:
self.cdim_cur=self.c_dim
self.inplanes = c_dim[0]
self.inplanes_cur = cdim_cur[0]
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.conv1 = nn.Conv2d(3, self.cdim_cur[0], kernel_size=3, stride=1, padding=3,
bias=False)
self.bn1 = norm_layer(self.cdim_cur[0])
self.node1 = self.node()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer_convnets = nn.ModuleList()
self.layer_convnets.append(self._make_layer(block, c_dim[0], self.cdim_cur[0],layers[0], cnf=cnf, node=self.node, **kwargs))
self.layer_convnets.append(self._make_layer(block, c_dim[1], self.cdim_cur[1], layers[1], stride=2,
dilate=replace_stride_with_dilation[0], cnf=cnf, node=self.node, **kwargs))
self.layer_convnets.append(self._make_layer(block, c_dim[2], self.cdim_cur[2], layers[2], stride=2,
dilate=replace_stride_with_dilation[1], cnf=cnf, node=self.node, **kwargs))
self.layer_convnets.append(self._make_layer(block, c_dim[3], self.cdim_cur[3], layers[3], stride=2,
dilate=replace_stride_with_dilation[2], cnf=cnf, node=self.node, **kwargs))
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# self.fc = nn.Linear(512 * block.expansion, num_classes)
self.classifer=None
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, planes_cur,blocks, stride=1, dilate=False, cnf: str=None, node: callable = None, **kwargs):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes_cur, planes_cur * block.expansion, stride),
norm_layer(planes_cur * block.expansion),
node()
)
layers =block(self.inplanes, planes, planes_cur,stride, downsample, self.groups,
previous_dilation, norm_layer, cnf, node, **kwargs)
self.inplanes = planes * block.expansion
self.inplanes_cur= planes_cur * block.expansion
# for _ in range(1, blocks):
# layers.append(block(self.inplanes, planes, groups=self.groups,
# dilation=self.dilation,
# norm_layer=norm_layer, cnf=cnf, node=node, **kwargs))
return layers
def forward_init(self, inputs):
# See note [TorchScript super()]
x = self.conv1(inputs)
x = self.bn1(x)
x = self.node1(x)
# x = self.maxpool(x)
return x
def forward_impl(self, inputs):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def _sew_resnet(arch, block, layers, c_dim,cdim_cur,pretrained, progress, cnf, **kwargs):
model = SEWResNet(block, layers, c_dim=c_dim,cdim_cur=cdim_cur,cnf=cnf, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
@register_model
def sew_resnet18(c_dim=[64,128,256,512],cdim_cur=[],pretrained=False, progress=True, cnf: str = None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param cnf: the name of spike-element-wise function
:type cnf: str
:param node: a spiking neuron layer
:type node: callable
:param kwargs: kwargs for `node`
:type kwargs: dict
:return: Spiking ResNet-18
:rtype: torch.nn.Module
The spike-element-wise ResNet-18 `"Deep Residual Learning in Spiking Neural Networks" `_ modified by the ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_
"""
return _sew_resnet('resnet18', BasicBlock, [2, 2, 2, 2], c_dim,cdim_cur,pretrained, progress, 'ADD', **kwargs)
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/convnet/utils.py
================================================
import numpy as np
import torch
from torch import nn
from torch.optim import SGD
import torch.nn.functional as F
from inclearn.tools.metrics import ClassErrorMeter, AverageValueMeter
def finetune_last_layer(
logger,
network,
loader,
n_class,
nepoch=30,
lr=0.1,
scheduling=[15, 35],
lr_decay=0.1,
weight_decay=5e-4,
loss_type="ce",
temperature=5.0,
test_loader=None,
samples_per_cls = []
):
network.eval()
#if hasattr(network.module, "convnets"):
# for net in network.module.convnets:
# net.eval()
#else:
# network.module.convnet.eval()
optim = SGD(network.module.classifier.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, scheduling, gamma=lr_decay)
if loss_type == "ce":
criterion = nn.CrossEntropyLoss()
else:
criterion = nn.BCEWithLogitsLoss()
logger.info("Begin finetuning last layer")
for i in range(nepoch):
total_loss = 0.0
total_correct = 0.0
total_count = 0
# print(f"dataset loader length {len(loader.dataset)}")
for inputs, targets in loader:
inputs, targets = inputs.cuda(), targets.cuda()
if loss_type == "bce":
targets = to_onehot(targets, n_class)
outputs = network(inputs)['logit']
_, preds = outputs.max(1)
optim.zero_grad()
if loss_type == "cb":
loss = CB_loss(targets, outputs / temperature, samples_per_cls, n_class, "focal")
else:
loss = criterion(outputs / temperature, targets)
loss.backward()
optim.step()
total_loss += loss * inputs.size(0)
total_correct += (preds == targets).sum()
total_count += inputs.size(0)
if test_loader is not None:
test_correct = 0.0
test_count = 0.0
with torch.no_grad():
for inputs, targets in test_loader:
outputs = network(inputs.cuda())['logit']
_, preds = outputs.max(1)
test_correct += (preds.cpu() == targets).sum().item()
test_count += inputs.size(0)
scheduler.step()
if test_loader is not None:
logger.info(
"Epoch %d finetuning loss %.3f acc %.3f Eval %.3f" %
(i, total_loss.item() / total_count, total_correct.item() / total_count, test_correct / test_count))
else:
logger.info("Epoch %d finetuning loss %.3f acc %.3f" %
(i, total_loss.item() / total_count, total_correct.item() / total_count))
return network
def extract_features(task_i,model, loader,mask=None):
targets, features = [], []
model.eval()
with torch.no_grad():
for _inputs, _targets in loader:
_inputs = _inputs.cuda()
_targets = _targets.numpy()
_features = model(task_i,_inputs,mask)['feature'].detach().cpu().numpy()
features.append(_features)
targets.append(_targets)
return np.concatenate(features), np.concatenate(targets)
def calc_class_mean(network, loader, class_idx, metric):
EPSILON = 1e-8
features, targets = extract_features(network, loader)
# norm_feats = features/(np.linalg.norm(features, axis=1)[:,np.newaxis]+EPSILON)
# examplar_mean = norm_feats.mean(axis=0)
examplar_mean = features.mean(axis=0)
if metric == "cosine" or metric == "weight":
examplar_mean /= (np.linalg.norm(examplar_mean) + EPSILON)
return examplar_mean
def update_classes_mean(network, inc_dataset, n_classes, task_size, share_memory=None, metric="cosine", EPSILON=1e-8):
loader = inc_dataset._get_loader(inc_dataset.data_inc,
inc_dataset.targets_inc,
shuffle=False,
share_memory=share_memory,
mode="test")
class_means = np.zeros((n_classes, network.module.features_dim))
count = np.zeros(n_classes)
network.eval()
with torch.no_grad():
for x, y in loader:
feat = network(x.cuda())['feature']
for lbl in torch.unique(y):
class_means[lbl] += feat[y == lbl].sum(0).cpu().numpy()
count[lbl] += feat[y == lbl].shape[0]
for i in range(n_classes):
class_means[i] /= count[i]
if metric == "cosine" or metric == "weight":
class_means[i] /= (np.linalg.norm(class_means) + EPSILON)
return class_means
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/datasets/__init__.py
================================================
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/datasets/data.py
================================================
import random
import cv2
import numpy as np
import os.path as osp
from copy import deepcopy
from PIL import Image
import multiprocessing as mp
from multiprocessing import Pool
import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
warnings.filterwarnings("ignore", "Corrupt EXIF data", UserWarning)
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler, WeightedRandomSampler
from torchvision import datasets, transforms
from torchvision.datasets.folder import pil_loader
from .dataset import get_dataset
from inclearn.tools.data_utils import construct_balanced_subset
def get_data_folder(data_folder, dataset_name):
return osp.join(data_folder, dataset_name)
class IncrementalDataset:
def __init__(
self,
trial_i,
dataset_name,
random_order=False,
shuffle=True,
workers=10,
batch_size=128,
seed=1,
increment=10,
validation_split=0.0,
resampling=False,
data_folder="./data",
start_class=0,
torc = False
):
# The info about incremental split
self.torc=torc
self.trial_i = trial_i
self.start_class = start_class
#the number of classes for each step in incremental stage
self.task_size = increment
self.increments = []
self.random_order = random_order
self.validation_split = validation_split
#-------------------------------------
#Dataset Info
#-------------------------------------
self.data_folder = get_data_folder('/data0/datasets/', 'CIFAR100/')
self.dataset_name = dataset_name
# self.transform = transform
self.train_dataset = None
self.test_dataset = None
self.n_tot_cls = -1
datasets = get_dataset(dataset_name)
self._setup_data(datasets)
self._workers = workers
self._shuffle = shuffle
self._batch_size = batch_size
self._resampling = resampling
#Currently, don't support multiple datasets
self.train_transforms = datasets.train_transforms
self.test_transforms = datasets.test_transforms
#torchvision or albumentations
self.transform_type = datasets.transform_type
# memory Mt
self.data_memory = None
self.targets_memory = None
# Incoming data D_t
self.data_cur = None
self.targets_cur = None
# Available data \tilde{D}_t = D_t \cup M_t
self.data_inc = None # Cur task data + memory
self.targets_inc = None
# Available data stored in cpu memory.
self.shared_data_inc = None
self.shared_test_data = None
#Current states for Incremental Learning Stage.
self._current_task = 0
@property
def n_tasks(self):
return len(self.increments)
def new_task(self,task_i):
if self._current_task >= len(self.increments):
raise Exception("No more tasks.")
min_class, max_class, x_train, y_train, x_test, y_test,x_val, y_val = self._get_cur_step_data_for_raw_data(task_i)
self.data_cur, self.targets_cur = x_train, y_train
if self.torc:
if self.data_memory is not None:
print("Set memory of size: {}.".format(len(self.data_memory)))
if len(self.data_memory) != 0:
x_train = np.concatenate((x_train, self.data_memory))
y_train = np.concatenate((y_train, self.targets_memory))
self.data_inc, self.targets_inc = x_train, y_train
self.data_test_inc, self.targets_test_inc = x_test, y_test
train_loader = self._get_loader(x_train, y_train, mode="train")
if self.torc:
val_loader = self._get_loader(x_val, y_val, shuffle=False, mode="test")
test_loader = self._get_loader(x_test, y_test, shuffle=False, mode="test")
else:
val_loader=[]
test_loader=[]
for i in range(len(x_test)):
val_loader.append(self._get_loader(x_test[i], y_test[i], shuffle=False, mode="test"))
test_loader.append(self._get_loader(x_test[i], y_test[i], shuffle=False, mode="test"))
task_info = {
"min_class": min_class,
"max_class": max_class,
"increment": self.increments[self._current_task],
"task": self._current_task,
"max_task": len(self.increments),
"n_train_data": len(x_train),
"n_test_data": len(y_train),
}
self._current_task += 1
return task_info, train_loader, val_loader, test_loader
def _get_cur_step_data_for_raw_data(self, task_i):
min_class = sum(self.increments[:self._current_task])
max_class = sum(self.increments[:self._current_task + 1])
if self.torc:
x_train, y_train = self._select(task_i,self.data_train, self.targets_train, low_range=min_class, high_range=max_class)
x_test, y_test = self._select(task_i,self.data_test, self.targets_test, low_range=0, high_range=max_class)
x_val, y_val = self._select(task_i,self.data_test, self.targets_test, low_range=min_class, high_range=max_class)
else:
x_test=[]
y_test=[]
x_train, y_train = self._select(task_i,self.data_train, self.targets_train, low_range=min_class, high_range=max_class)
min_c=0
num_c=max_class-min_class
for taski in range(task_i+1):
x_test_i, y_test_i = self._select(taski,self.data_test, self.targets_test, low_range=min_c, high_range=min_c+num_c)
x_test.append(x_test_i)
y_test.append(y_test_i)
min_c+=num_c
x_val=x_test
y_val=y_test
return min_class, max_class, x_train, y_train, x_test, y_test,x_val, y_val
#--------------------------------
# Data Setup
#--------------------------------
def _setup_data(self, dataset):
# FIXME: handles online loading of images
self.data_train, self.targets_train = [], []
self.data_test, self.targets_test = [], []
self.data_val, self.targets_val = [], []
self.increments = []
self.class_order = []
current_class_idx = 0 # When using multiple datasets
train_dataset = dataset(self.data_folder, train=True)
test_dataset = dataset(self.data_folder, train=False)
self.train_dataset = train_dataset
self.test_datasets = test_dataset
self.n_tot_cls = self.train_dataset.n_cls #number of classes in whole dataset
self._setup_data_for_raw_data(dataset, train_dataset, test_dataset, current_class_idx)
# !list
self.data_train = np.concatenate(self.data_train)
self.targets_train = np.concatenate(self.targets_train)
self.data_val = np.concatenate(self.data_val)
self.targets_val = np.concatenate(self.targets_val)
self.data_test = np.concatenate(self.data_test)
self.targets_test = np.concatenate(self.targets_test)
def _setup_data_for_raw_data(self, dataset, train_dataset, test_dataset, current_class_idx=0):
increment = self.task_size
x_train, y_train = train_dataset.data, np.array(train_dataset.targets)
x_val, y_val, x_train, y_train = self._split_per_class(x_train, y_train, self.validation_split)
x_test, y_test = test_dataset.data, np.array(test_dataset.targets)
# Get Class Order
order = [i for i in range(len(np.unique(y_train)))]
if self.random_order:
random.seed(self._seed) # Ensure that following order is determined by seed:
random.shuffle(order)
elif dataset.class_order(self.trial_i) is not None:
order = dataset.class_order(self.trial_i)
self.class_order.append(order)
y_train = self._map_new_class_index(y_train, order)
y_val = self._map_new_class_index(y_val, order)
y_test = self._map_new_class_index(y_test, order)
y_train += current_class_idx
y_val += current_class_idx
y_test += current_class_idx
current_class_idx += len(order)
if self.start_class == 0:
# increment = 10, 那么 increments 就是 [10, 10, 10, 10, ...]
self.increments = [increment for _ in range(len(order) // increment)]
else:
self.increments.append(self.start_class)
for _ in range((len(order) - self.start_class) // increment):
self.increments.append(increment)
self.data_train.append(x_train)
self.targets_train.append(y_train)
self.data_val.append(x_val)
self.targets_val.append(y_val)
self.data_test.append(x_test)
self.targets_test.append(y_test)
@staticmethod
def _split_per_class(x, y, validation_split=0.0):
"""Splits train data for a subset of validation data.
Split is done so that each class has a much data.
"""
shuffled_indexes = np.random.permutation(x.shape[0])
x = x[shuffled_indexes]
y = y[shuffled_indexes]
x_val, y_val = [], []
x_train, y_train = [], []
for class_id in np.unique(y):
class_indexes = np.where(y == class_id)[0]
nb_val_elts = int(class_indexes.shape[0] * validation_split)
val_indexes = class_indexes[:nb_val_elts]
train_indexes = class_indexes[nb_val_elts:]
x_val.append(x[val_indexes])
y_val.append(y[val_indexes])
x_train.append(x[train_indexes])
y_train.append(y[train_indexes])
# !list
x_val, y_val = np.concatenate(x_val), np.concatenate(y_val)
x_train, y_train = np.concatenate(x_train), np.concatenate(y_train)
return x_val, y_val, x_train, y_train
@staticmethod
def _map_new_class_index(y, order):
"""Transforms targets for new class order."""
return np.array(list(map(lambda x: order.index(x), y)))
def _select(self, task_i,x, y, low_range=0, high_range=0):
idxes = sorted(np.where(np.logical_and(y >= low_range, y < high_range))[0])
if isinstance(x, list):
selected_x = [x[idx] for idx in idxes]
else:
selected_x = x[idxes]
if self.torc:
selected_y=y[idxes]
else:
selected_y=y[idxes]-low_range
return selected_x, selected_y
#--------------------------------
# Get Loader
#--------------------------------
def get_datainc_loader(self, mode='train'):
print(self.data_inc.shape)
train_loader = self._get_loader(self.data_inc, self.targets_inc, mode=mode)
return train_loader
def get_custom_loader_from_memory(self, class_indexes, mode="test"):
if not isinstance(class_indexes, list):
class_indexes = [class_indexes]
data, targets = [], []
for class_index in class_indexes:
class_data, class_targets = self._select(self.data_memory,
self.targets_memory,
low_range=class_index,
high_range=class_index + 1)
data.append(class_data)
targets.append(class_targets)
data = np.concatenate(data)
targets = np.concatenate(targets)
return data, targets, self._get_loader(data, targets, shuffle=False, mode=mode)
def _get_loader(self, x, y, share_memory=None, shuffle=True, mode="train", batch_size=None, resample=None):
if "balanced" in mode:
x, y = construct_balanced_subset(x, y)
batch_size = batch_size if batch_size is not None else self._batch_size
if "train" in mode:
trsf = self.train_transforms
resample_ = self._resampling if resample is None else True
if resample_ is False:
sampler = None
else:
sampler = get_weighted_random_sampler(y)
shuffle = False if resample_ is True else True
elif "test" in mode:
trsf = self.test_transforms
sampler = None
elif mode == "flip":
if "imagenet" in self.dataset_name:
trsf = A.Compose([A.HorizontalFlip(p=1.0), *self.test_transforms.transforms])
else:
trsf = transforms.Compose([transforms.RandomHorizontalFlip(p=1.0), *self.test_transforms.transforms])
sampler = None
else:
raise NotImplementedError("Unknown mode {}.".format(mode))
return DataLoader(DummyDataset(x,
y,
trsf,
trsf_type=self.transform_type,
share_memory_=share_memory,
dataset_name=self.dataset_name),
batch_size=batch_size,
shuffle=shuffle,
num_workers=self._workers,
sampler=sampler,
pin_memory=True)
def get_custom_loader(self, class_indexes, mode="test", data_source="train", imgs=None, tgts=None):
"""Returns a custom loader.
:param class_indexes: A list of class indexes that we want.
:param mode: Various mode for the transformations applied on it.
:param data_source: Whether to fetch from the train, val, or test set.
:return: The raw data and a loader.
"""
if not isinstance(class_indexes, list): # TODO: deprecated, should always give a list
class_indexes = [class_indexes]
if data_source == "train":
x, y = self.data_inc, self.targets_inc
elif data_source == "val":
x, y = self.data_val, self.targets_val
elif data_source == "test":
x, y = self.data_test, self.targets_test
elif data_source == 'specified' and imgs is not None and tgts is not None:
x, y = imgs, tgts
else:
raise ValueError("Unknown data source <{}>.".format(data_source))
data, targets = [], []
for class_index in class_indexes:
class_data, class_targets, = self._select(x, y, low_range=class_index, high_range=class_index + 1)
data.append(class_data)
targets.append(class_targets)
data = np.concatenate(data)
targets = np.concatenate(targets)
return data, targets, self._get_loader(data, targets, shuffle=False, mode=mode)
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, x, y, trsf, trsf_type, share_memory_=None, dataset_name=None):
self.dataset_name = dataset_name
self.x, self.y = x, y
self.trsf = trsf
self.trsf_type = trsf_type
self.manager = mp.Manager()
self.buffer_size = 4000000
if share_memory_ is None:
if self.x.shape[0] > self.buffer_size:
self.share_memory = self.manager.list([None for i in range(self.buffer_size)])
else:
self.share_memory = self.manager.list([None for i in range(len(x))])
else:
self.share_memory = share_memory_
def __len__(self):
if isinstance(self.x, list):
return len(self.x)
else:
return self.x.shape[0]
def __getitem__(self, idx):
x, y, = self.x[idx], self.y[idx]
if isinstance(x, np.ndarray):
# assume cifar
x = Image.fromarray(x)
else:
# Assume the dataset is ImageNet
if idx < len(self.share_memory):
if self.share_memory[idx] is not None:
x = self.share_memory[idx]
else:
x = cv2.imread(x)
x = x[:, :, ::-1]
self.share_memory[idx] = x
else:
x = cv2.imread(x)
x = x[:, :, ::-1]
if 'torch' in self.trsf_type:
x = self.trsf(x)
else:
x = self.trsf(image=x)['image']
return x, y
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/datasets/dataset.py
================================================
import os.path as osp
import numpy as np
import glob
from albumentations.pytorch import ToTensorV2
from torchvision import datasets, transforms
import torch
from inclearn.tools.cutout import Cutout
from inclearn.tools.autoaugment_extra import ImageNetPolicy
def get_datasets(dataset_names):
return [get_dataset(dataset_name) for dataset_name in dataset_names.split("-")]
def get_dataset(dataset_name):
if dataset_name == "cifar10":
return iCIFAR10
elif dataset_name == "cifar100":
return iCIFAR100
elif "imagenet100" in dataset_name:
return iImageNet100
else:
raise NotImplementedError("Unknown dataset {}.".format(dataset_name))
class DataHandler:
base_dataset = None
train_transforms = []
common_transforms = [ToTensorV2()]
class_order = None
class iCIFAR10(DataHandler):
base_dataset_cls = datasets.cifar.CIFAR10
transform_type = 'torchvision'
train_transforms = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
# transforms.ColorJitter(brightness=63 / 255),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
test_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
def __init__(self, data_folder, train, is_fine_label=False):
self.base_dataset = self.base_dataset_cls(data_folder, train=train, download=True)
self.data = self.base_dataset.data
self.targets = self.base_dataset.targets
self.n_cls = 10
@property
def is_proc_inc_data(self):
return False
@classmethod
def class_order(cls, trial_i):
return [4, 0, 2, 5, 8, 3, 1, 6, 9, 7]
class iCIFAR100(iCIFAR10):
label_list = [
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy',
'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish',
'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange',
'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk',
'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank',
'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale',
'willow_tree', 'wolf', 'woman', 'worm'
]
base_dataset_cls = datasets.cifar.CIFAR100
transform_type = 'torchvision'
train_transforms = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=63 / 255),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
test_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
def __init__(self, data_folder, train, is_fine_label=False):
self.base_dataset = self.base_dataset_cls(data_folder, train=train, download=True)
self.data = self.base_dataset.data
self.targets = self.base_dataset.targets
self.n_cls = 100
self.transform_type = 'torchvision'
@property
def is_proc_inc_data(self):
return False
@classmethod
def class_order(cls, trial_i):
return [
87, 0, 52, 58, 44, 91, 68, 97, 51, 15, 94, 92, 10, 72, 49, 78, 61, 14, 8, 86, 84, 96, 18, 24, 32, 45,
88, 11, 4, 67, 69, 66, 77, 47, 79, 93, 29, 50, 57, 83, 17, 81, 41, 12, 37, 59, 25, 20, 80, 73, 1, 28, 6,
46, 62, 82, 53, 9, 31, 75, 38, 63, 33, 74, 27, 22, 36, 3, 16, 21, 60, 19, 70, 90, 89, 43, 5, 42, 65, 76,
40, 30, 23, 85, 2, 95, 56, 48, 71, 64, 98, 13, 99, 7, 34, 55, 54, 26, 35, 39
]
class DataHandler:
base_dataset = None
train_transforms = []
common_transforms = [ToTensorV2()]
class_order = None
class iImageNet100(DataHandler):
base_dataset_cls = datasets.ImageFolder
transform_type = 'torchvision'
train_transforms = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(),
Cutout(n_holes=1, length=16),
transforms.ToPILImage(),
transforms.RandomResizedCrop(32),
transforms.RandomHorizontalFlip(),
ImageNetPolicy(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
test_transforms = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.Resize(32),
transforms.CenterCrop(32),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __init__(self, data_folder, train, is_fine_label=False):
if train is True:
self.base_dataset = self.base_dataset_cls(osp.join(data_folder, "train"))
else:
self.base_dataset = self.base_dataset_cls(osp.join(data_folder, "val"))
self.data, self.targets = zip(*self.base_dataset.samples)
self.data = np.array(self.data)
self.targets = np.array(self.targets)
self.n_cls = 200
@property
def is_proc_inc_data(self):
return False
@classmethod
def class_order(cls, trial_i):
return [
68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50,
28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96,
98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69,
36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33,
168, 156, 178, 108, 123, 184, 190, 165, 174, 176, 140, 189, 103,
192, 155, 109, 126, 180, 143, 138, 158, 170, 177, 101, 185, 119,
117, 150, 128, 153, 113, 181, 145, 182, 106, 159, 183, 116, 115,
144, 191, 141, 172, 160, 179, 152, 120, 110, 131, 154, 137, 195,
114, 171, 196, 198, 197, 102, 164, 166, 142, 122, 135, 186, 124,
134, 187, 121, 199, 100, 188, 127, 118, 194, 111, 112, 147, 125,
130, 146, 162, 169, 136, 161, 107, 163, 175, 105, 132, 104, 151,
148, 173, 193, 139, 167, 129, 149, 157, 133,
268, 256, 278, 208, 223, 284, 290, 265, 274, 276, 240, 289, 203,
292, 255, 209, 226, 280, 243, 238, 258, 270, 277, 201, 285, 219,
217, 250, 228, 253, 213, 281, 245, 282, 206, 259, 283, 216, 215,
244, 291, 241, 272, 260, 279, 252, 220, 210, 231, 254, 237, 295,
214, 271, 296, 298, 297, 202, 264, 266, 242, 222, 235, 286, 224,
234, 287, 221, 299, 200, 288, 227, 218, 294, 211, 212, 247, 225,
230, 246, 262, 269, 236, 261, 207, 263, 275, 205, 232, 204, 251,
248, 273, 293, 239, 267, 229, 249, 257, 233, 368, 356, 378, 308,
323, 384, 390, 365, 374, 376, 340, 389, 303, 392, 355, 309, 326,
380, 343, 338, 358, 370, 377, 301, 385, 319, 317, 350, 328, 353,
313, 381, 345, 382, 306, 359, 383, 316, 315, 344, 391, 341, 372,
360, 379, 352, 320, 310, 331, 354, 337, 395, 314, 371, 396, 398,
397, 302, 364, 366, 342, 322, 335, 386, 324, 334, 387, 321, 399,
300, 388, 327, 318, 394, 311, 312, 347, 325, 330, 346, 362, 369,
336, 361, 307, 363, 375, 305, 332, 304, 351, 348, 373, 393, 339,
367, 329, 349, 357, 333,
468, 456, 478, 408, 423, 484, 490, 465, 474, 476, 440, 489, 403,
492, 455, 409, 426, 480, 443, 438, 458, 470, 477, 401, 485, 419,
417, 450, 428, 453, 413, 481, 445, 482, 406, 459, 483, 416, 415,
444, 491, 441, 472, 460, 479, 452, 420, 410, 431, 454, 437, 495,
414, 471, 496, 498, 497, 402, 464, 466, 442, 422, 435, 486, 424,
434, 487, 421, 499, 400, 488, 427, 418, 494, 411, 412, 447, 425,
430, 446, 462, 469, 436, 461, 407, 463, 475, 405, 432, 404, 451,
448, 473, 493, 439, 467, 429, 449, 457, 433, 568, 556, 578, 508,
523, 584, 590, 565, 574, 576, 540, 589, 503, 592, 555, 509, 526,
580, 543, 538, 558, 570, 577, 501, 585, 519, 517, 550, 528, 553,
513, 581, 545, 582, 506, 559, 583, 516, 515, 544, 591, 541, 572,
560, 579, 552, 520, 510, 531, 554, 537, 595, 514, 571, 596, 598,
597, 502, 564, 566, 542, 522, 535, 586, 524, 534, 587, 521, 599,
500, 588, 527, 518, 594, 511, 512, 547, 525, 530, 546, 562, 569,
536, 561, 507, 563, 575, 505, 532, 504, 551, 548, 573, 593, 539,
567, 529, 549, 557, 533,
668, 656, 678, 608, 623, 684, 690, 665, 674, 676, 640, 689, 603,
692, 655, 609, 626, 680, 643, 638, 658, 670, 677, 601, 685, 619,
617, 650, 628, 653, 613, 681, 645, 682, 606, 659, 683, 616, 615,
644, 691, 641, 672, 660, 679, 652, 620, 610, 631, 654, 637, 695,
614, 671, 696, 698, 697, 602, 664, 666, 642, 622, 635, 686, 624,
634, 687, 621, 699, 600, 688, 627, 618, 694, 611, 612, 647, 625,
630, 646, 662, 669, 636, 661, 607, 663, 675, 605, 632, 604, 651,
648, 673, 693, 639, 667, 629, 649, 657, 633, 768, 756, 778, 708,
723, 784, 790, 765, 774, 776, 740, 789, 703, 792, 755, 709, 726,
780, 743, 738, 758, 770, 777, 701, 785, 719, 717, 750, 728, 753,
713, 781, 745, 782, 706, 759, 783, 716, 715, 744, 791, 741, 772,
760, 779, 752, 720, 710, 731, 754, 737, 795, 714, 771, 796, 798,
797, 702, 764, 766, 742, 722, 735, 786, 724, 734, 787, 721, 799,
700, 788, 727, 718, 794, 711, 712, 747, 725, 730, 746, 762, 769,
736, 761, 707, 763, 775, 705, 732, 704, 751, 748, 773, 793, 739,
767, 729, 749, 757, 733,
868, 856, 878, 808, 823, 884, 890, 865, 874, 876, 840, 889, 803,
892, 855, 809, 826, 880, 843, 838, 858, 870, 877, 801, 885, 819,
817, 850, 828, 853, 813, 881, 845, 882, 806, 859, 883, 816, 815,
844, 891, 841, 872, 860, 879, 852, 820, 810, 831, 854, 837, 895,
814, 871, 896, 898, 897, 802, 864, 866, 842, 822, 835, 886, 824,
834, 887, 821, 899, 800, 888, 827, 818, 894, 811, 812, 847, 825,
830, 846, 862, 869, 836, 861, 807, 863, 875, 805, 832, 804, 851,
848, 873, 893, 839, 867, 829, 849, 857, 833, 968, 956, 978, 908,
923, 984, 990, 965, 974, 976, 940, 989, 903, 992, 955, 909, 926,
980, 943, 938, 958, 970, 977, 901, 985, 919, 917, 950, 928, 953,
913, 981, 945, 982, 906, 959, 983, 916, 915, 944, 991, 941, 972,
960, 979, 952, 920, 910, 931, 954, 937, 995, 914, 971, 996, 998,
997, 902, 964, 966, 942, 922, 935, 986, 924, 934, 987, 921, 999,
900, 988, 927, 918, 994, 911, 912, 947, 925, 930, 946, 962, 969,
936, 961, 907, 963, 975, 905, 932, 904, 951, 948, 973, 993, 939,
967, 929, 949, 957, 933
]
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/models/__init__.py
================================================
from .incmodel import IncModel
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/models/base.py
================================================
import abc
import logging
import torch
import torch.nn.functional as F
import numpy as np
from inclearn.tools.metrics import ClassErrorMeter
LOGGER = logging.Logger("IncLearn", level="INFO")
class IncrementalLearner(abc.ABC):
"""Base incremental learner.
Methods are called in this order (& repeated for each new task):
1. set_task_info
2. before_task
3. train_task
4. after_task
5. eval_task
"""
def __init__(self, *args, **kwargs):
self._increments = []
self._seen_classes = []
def set_task_info(self, task, total_n_classes, increment, n_train_data, n_test_data, n_tasks):
self._task = task
self._task_size = increment
self._increments.append(self._task_size)
self._total_n_classes = total_n_classes
self._n_train_data = n_train_data
self._n_test_data = n_test_data
self._n_tasks = n_tasks
def before_task(self, taski, inc_dataset,mask,min_dist,all_dist):
LOGGER.info("Before task")
self.eval()
self._before_task(taski, inc_dataset,mask,min_dist,all_dist)
def train_task(self, task_i,train_loader, val_loader,mask,min_dist,all_dist):
LOGGER.info("train task")
self.train()
self._train_task(task_i,train_loader, val_loader,mask,min_dist,all_dist)
def after_task(self, taski, inc_dataset,mask):
LOGGER.info("after task")
self.eval()
self._after_task(taski, inc_dataset,mask)
def eval_task(self, task_i,data_loader,mask):
LOGGER.info("eval task")
self.eval()
return self._eval_task(task_i,data_loader,mask)
def get_memory(self):
return None
def eval(self):
raise NotImplementedError
def train(self):
raise NotImplementedError
def _before_task(self, data_loader):
pass
def _train_task(self, train_loader, val_loader):
raise NotImplementedError
def _after_task(self, data_loader):
pass
def _eval_task(self, data_loader):
raise NotImplementedError
@property
def _new_task_index(self):
return self._task * self._task_size
@property
def _memory_per_class(self):
"""Returns the number of examplars per class."""
return self._memory_size.mem_per_cls
def _after_epoch(self, epoch, avg_loss, train_new_accu, train_old_accu, accu):
self._run.log_scalar(f"train_loss_trial{self._trial_i}_task{self._task}", avg_loss, epoch + 1)
self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/train_loss", avg_loss, epoch + 1)
# self._run.log_scalar(f"train_new_accu_trial{self._trial_i}_task{self._task}",
# train_new_accu.value()[0], epoch + 1)
# self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/train_new_accu",
# train_new_accu.value()[0], epoch + 1)
# if self._task != 0:
# self._run.log_scalar(f"train_old_accu_trial{self._trial_i}_task{self._task}",
# train_old_accu.value()[0], epoch + 1)
# self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/train_old_accu",
# train_old_accu.value()[0], epoch + 1)
self._run.log_scalar(f"train_accu_trial{self._trial_i}_task{self._task}", accu.value()[0], epoch + 1)
self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/train_accu", accu.value()[0], epoch + 1)
# self._tensorboard.close()
self._tensorboard.flush()
def _validation(self, val_loader, epoch):
topk = 5 if self._n_classes >= 5 else self._n_classes
if self._val_per_n_epoch != -1 and epoch % self._val_per_n_epoch == 0:
_val_loss = 0
_val_accu = ClassErrorMeter(accuracy=True, topk=[1, topk])
_val_new_accu = ClassErrorMeter(accuracy=True)
_val_old_accu = ClassErrorMeter(accuracy=True)
self._parallel_network.eval()
with torch.no_grad():
for i, (inputs, targets) in enumerate(val_loader, 1):
old_classes = targets < (self._n_classes - self._task_size)
new_classes = targets >= (self._n_classes - self._task_size)
val_loss, _ = self._forward_loss(
inputs,
targets,
old_classes,
new_classes,
accu=_val_accu,
old_accu=_val_old_accu,
new_accu=_val_new_accu,
)
_val_loss += val_loss.item()
self._ex.logger.info(
f"epoch{epoch} val acc:{_val_accu.value()[0]:.2f}, val top5acc:{_val_accu.value()[1]:.2f}")
# Test accu
self._run.log_scalar(f"test_accu_trial{self._trial_i}_task{self._task}", _val_accu.value()[0], epoch + 1)
self._run.log_scalar(f"test_5accu_trial{self._trial_i}_task{self._task}", _val_accu.value()[1], epoch + 1)
self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/test_accu",
_val_accu.value()[0], epoch + 1)
self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/test_5accu",
_val_accu.value()[1], epoch + 1)
# Test new accu
self._run.log_scalar(f"test_new_accu_trial{self._trial_i}_task{self._task}",
_val_new_accu.value()[0], epoch + 1)
self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/test_new_accu",
_val_new_accu.value()[0], epoch + 1)
# Test old accu
if self._task != 0:
self._run.log_scalar(f"test_old_accu_trial{self._trial_i}_task{self._task}",
_val_old_accu.value()[0], epoch + 1)
self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/test_old_accu",
_val_old_accu.value()[0], epoch + 1)
# Test loss
self._run.log_scalar(f"test_loss_trial{self._trial_i}_task{self._task}", round(_val_loss / i, 3), epoch + 1)
self._tensorboard.add_scalar(f"trial{self._trial_i}_task{self._task}/test_loss", round(_val_loss / i, 3),
epoch + 1)
self._tensorboard.close()
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/models/incmodel.py
================================================
import numpy as np
import random
import time
import math
import os
from copy import deepcopy
from scipy.spatial.distance import cdist
from torchvision.utils import save_image
import torch
# import pdb
from torch.nn import DataParallel
from torch.nn import functional as F
from torch import nn
from inclearn.convnet import network
from inclearn.models.base import IncrementalLearner
from inclearn.tools import factory, utils
from inclearn.tools.metrics import ClassErrorMeter
from inclearn.tools.memory import MemorySize
from inclearn.tools.scheduler import GradualWarmupScheduler
from inclearn.convnet.utils import extract_features, update_classes_mean, finetune_last_layer
# Constants
EPSILON = 1e-8
class IncModel(IncrementalLearner):
def __init__(self, cfg, trial_i, _run, ex, tensorboard, inc_dataset):
super().__init__()
self._cfg = cfg
self._device = cfg['device']
self._ex = ex
self._run = _run # the sacred _run object.
# Data
self._inc_dataset = inc_dataset
self._n_classes = 0
self.classnum_list = []
self.sample_list = []
self._trial_i = trial_i # which class order is used
# Optimizer paras
self._opt_name = cfg["optimizer"]
self._warmup = cfg['warmup']
self._lr = cfg["lr"]
self._weight_decay = cfg["weight_decay"]
self._n_epochs = cfg["epochs"]
self._scheduling = cfg["scheduling"]
self._lr_decay = cfg["lr_decay"]
self.torc=cfg['distillation']
self.prune = cfg.get('prune', False)
# Logging
self._tensorboard = tensorboard
if f"trial{self._trial_i}" not in self._run.info:
self._run.info[f"trial{self._trial_i}"] = {}
self._val_per_n_epoch = cfg["val_per_n_epoch"]
# Model
self._dea = cfg['dea'] # Whether to expand the representation
self._network = network.BasicNet(
cfg["convnet"],
cfg=cfg,
nf=cfg["channel"],
device=self._device,
use_bias=cfg["use_bias"],
dataset=cfg["dataset"],
)
if self._cfg.get("caculate_params", False):
self._parallel_network = self._network
else:
# 并行计算
# gpus = [0, 1, 2, 3]
# self._parallel_network = DataParallel(self._network, device_ids=gpus, output_device=gpus[0])
self._parallel_network = DataParallel(self._network)
self._train_head = cfg["train_head"]
self._infer_head = cfg["infer_head"]
self._old_model = None
# Learning
self._temperature = cfg["temperature"]
self._distillation = cfg["distillation"]
self.lamb = cfg["distlamb"]
# Memory
self._memory_size = MemorySize(cfg["mem_size_mode"], inc_dataset, cfg["memory_size"],
cfg["fixed_memory_per_cls"])
self._herding_matrix = []
self._coreset_strategy = cfg["coreset_strategy"]
if self._cfg["save_ckpt"]:
save_path = os.path.join(os.getcwd(), f"{self._cfg.exp.saveckpt}")
if not os.path.exists(save_path):
os.mkdir(save_path)
if self._cfg["save_mem"]:
save_path = os.path.join(os.getcwd(), f"{self._cfg.exp.saveckpt}/mem")
if not os.path.exists(save_path):
os.mkdir(save_path)
def eval(self):
self._parallel_network.eval()
def train(self):
if self._dea:
self._parallel_network.train()
self._parallel_network.module.convnets[-1].train()
if self._task >= 1:
for i in range(self._task):
self._parallel_network.module.convnets[i].eval()
else:
self._parallel_network.train()
def _before_task(self, taski, inc_dataset,mask,min_dist,all_dist):
self._ex.logger.info(f"Begin step {taski}")
# Update Task info
self._task = taski
self._n_classes += self._task_size
self.classnum_list.append(self._task_size)
self.sample_list = [ int(2000/(self._n_classes-10)) for i in range(self._n_classes-10)] + [ 500 for i in range(10)]
# Memory
self._memory_size.update_n_classes(self._n_classes)
self._memory_size.update_memory_per_cls(self._network, self._n_classes, self._task_size)
self._ex.logger.info("Now {} examplars per class.".format(self._memory_per_class))
self._network.add_classes(self._task_size,min_dist)
self._network.task_size = self._task_size
mask.model=self._network.convnets[-1]
mask.init_length(taski,task_nn=self._network.task_nn)
self.set_optimizer()
def set_optimizer(self, lr=None):
if lr is None:
lr = self._lr
if self._cfg["dynamic_weight_decay"]:
# used in BiC official implementation
weight_decay = self._weight_decay * self._cfg["task_max"] / (self._task + 1)
else:
weight_decay = self._weight_decay
self._ex.logger.info("Step {} weight decay {:.5f}".format(self._task, weight_decay))
# if self._dea and self._task > 0 and not self._cfg.get("caculate_params", False):
# for i in range(self._task):
# for p in self._parallel_network.module.convnets[i].parameters():
# p.requires_grad = False
self._optimizer = factory.get_optimizer(self._network.convnets[-1].parameters(),
self._opt_name, lr, weight_decay)
if "cos" in self._cfg["scheduler"]:
self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer, self._n_epochs)
else:
self._scheduler = torch.optim.lr_scheduler.MultiStepLR(self._optimizer,
self._scheduling,
gamma=self._lr_decay)
if self._warmup:
print("warmup")
self._warmup_scheduler = GradualWarmupScheduler(self._optimizer,
multiplier=1,
total_epoch=self._cfg['warmup_epochs'],
after_scheduler=self._scheduler)
def _train_task(self, task_i,train_loader, val_loader,mask,min_dist,all_dist):
self._ex.logger.info(f"nb {len(train_loader.dataset)}")
topk = 5 if self._n_classes > 5 else self._task_size
accu = ClassErrorMeter(accuracy=True, topk=[1, topk])
train_new_accu = ClassErrorMeter(accuracy=True)
train_old_accu = ClassErrorMeter(accuracy=True)
self._optimizer.zero_grad()
self._optimizer.step()
for epoch in range(self._n_epochs):
# torch.cuda.empty_cache()
_loss, _loss_div, _loss_trip, _loss_dist, _loss_atmap = 0.0, 0.0, 0.0, 0.0, 0.0
accu.reset()
train_new_accu.reset()
train_old_accu.reset()
if self._warmup:
self._warmup_scheduler.step()
if epoch == self._cfg['warmup_epochs']:
if self.torc:
self._network.convnets[-1].classifer.reset_parameters()
if self._cfg['use_div_cls']:
self._network.convnets[-1].aux_classifier.reset_parameters()
else:
self._network.convnets[task_i].classifer.reset_parameters()
if self._cfg['use_div_cls']:
self._network.aux_classifier[task_i].reset_parameters()
for i, (inputs, targets) in enumerate(train_loader, start=1):
self.train()
self._optimizer.zero_grad()
old_classes = targets < (self._n_classes - self._task_size)
new_classes = targets >= (self._n_classes - self._task_size)
loss_ce, loss_div, loss_trip, loss_dist, loss_atmap = self._forward_loss(
task_i,
inputs,
targets,
old_classes,
new_classes,
epoch,
accu=accu,
new_accu=train_new_accu,
old_accu=train_old_accu,
mask=mask
)
loss = loss_ce
if self._cfg["distillation"] and self._task > 0:
# trade-off - the lambda from the paper if lamb=-1
if self.lamb == -1:
lamb = (self._n_classes - self._task_size) / self._n_classes
loss = (1-lamb) * loss + lamb * loss_dist
else:
loss = loss + self.lamb * loss_dist
if self._cfg["use_div_cls"] and self._task > 0:
loss += loss_div
loss.backward()
self._optimizer.step()
if self.torc:
if self._cfg["postprocessor"]["enable"]:
if self._cfg["postprocessor"]["type"].lower() == "cr" or self._cfg["postprocessor"]["type"].lower() == "aver":
for p in self._network.convnets[-1].classifer.parameters():
p.data.clamp_(0.0)
_loss += loss_ce
_loss_trip += loss_trip
_loss_div += loss_div
_loss_dist += loss_dist
_loss_atmap += loss_atmap
if task_i>0:
mask.init_mask(self._task,epoch,dim_cur=self._network.channel_dim, task_nn=self._network.task_nn,all_dist=all_dist,all_model=self._network.convnets)
mat=mask.do_mask(self._task)
_loss = _loss.item()
_loss_div = _loss_div.item()
_loss_trip = _loss_trip.item()
_loss_dist = _loss_dist.item()
_loss_atmap = _loss_atmap.item()
if not self._warmup:
self._scheduler.step()
self._ex.logger.info(
"Task {}/{}, Epoch {}/{} => Clf loss: {} Div loss: {}, Knowledge Distllation loss:{}, Train Accu: {}, Train@5 Acc: {}".
format(
self._task + 1,
self._n_tasks,
epoch + 1,
self._n_epochs,
round(_loss / i, 3),
round(_loss_div / i, 3),
round(_loss_dist / i, 3),
round(accu.value()[0], 3),
round(accu.value()[1], 3),
))
if self._val_per_n_epoch > 0 and epoch % self._val_per_n_epoch == 0:
self.validate(val_loader)
if self.torc:
# For the large-scale dataset, we manage the data in the shared memory.
self._inc_dataset.shared_data_inc = train_loader.dataset.share_memory
utils.display_weight_norm(self._ex.logger, self._parallel_network, self._increments, "After training")
utils.display_feature_norm(task_i,self._ex.logger, self._parallel_network, train_loader, self._n_classes,
self._increments, "Trainset",mask=mask)
self._run.info[f"trial{self._trial_i}"][f"task{self._task}_train_accu"] = round(accu.value()[0], 3)
def _forward_loss(self, task_i,inputs, targets, old_classes, new_classes, epoch, accu=None, new_accu=None, old_accu=None,mask=None):
inputs, targets = inputs.to(self._device, non_blocking=True), targets.to(self._device, non_blocking=True)
outputs = self._parallel_network(task_i,inputs,mask)
if accu is not None:
accu.add(outputs['logit'], targets)
return self._compute_loss(task_i, inputs, targets, outputs, old_classes, new_classes, epoch,mask=mask)
def cross_entropy(self, outputs, targets, exp=1.0, size_average=True, eps=1e-5):
"""Calculates cross-entropy with temperature scaling"""
out = torch.nn.functional.softmax(outputs, dim=1)
tar = torch.nn.functional.softmax(targets, dim=1)
if exp != 1:
out = out.pow(exp)
out = out / out.sum(1).view(-1, 1).expand_as(out)
tar = tar.pow(exp)
tar = tar / tar.sum(1).view(-1, 1).expand_as(tar)
out = out + eps / out.size(1)
out = out / out.sum(1).view(-1, 1).expand_as(out)
ce = -(tar * out.log()).sum(1)
if size_average:
ce = ce.mean()
return ce
def hcl(self, fstudent, fteacher, targets):
loss_all = 0.0
fs = fstudent
select_teacher = self._cfg.get("select_teacher",False)
if select_teacher:
for i in range(len(fteacher)):
ft = fteacher[i]
if i > 0:
old_classes = np.logical_and((targets < (self._n_classes - self._task_size * (len(fteacher)-i))).cpu(), (targets >= (self._n_classes - self._task_size * (len(fteacher)-i+1))).cpu())
else:
old_classes = (targets < (self._n_classes - self._task_size * len(fteacher))).cpu()
classes_indice = torch.from_numpy(np.where(old_classes==True)[0]).to(self._device)
# targets_old = torch.index_select(targets_old, 0, old_classes_indice)
# log_probs_new = torch.index_select(log_probs_new, 0, old_classes_indice)
fs = torch.index_select(fstudent, 0, classes_indice)
ft = torch.index_select(ft, 0, classes_indice)
n,c,h,w = fs.shape
if n == 0:
break
loss = F.mse_loss(fs, ft, reduction='mean')
cnt = 1.0
tot = 1.0
for l in [4,2,1]:
if l >=h:
continue
tmpfs = F.adaptive_avg_pool2d(fs, (l,l))
tmpft = F.adaptive_avg_pool2d(ft, (l,l))
cnt /= 2.0
loss += F.mse_loss(tmpfs, tmpft, reduction='mean') * cnt
tot += cnt
loss = loss / tot
loss_all = loss_all + loss
else:
for i in range(len(fteacher)):
ft = fteacher[i]
n,c,h,w = fs.shape
if n == 0:
break
loss = F.mse_loss(fs, ft, reduction='mean')
cnt = 1.0
tot = 1.0
for l in [4,2,1]:
if l >=h:
continue
tmpfs = F.adaptive_avg_pool2d(fs, (l,l))
tmpft = F.adaptive_avg_pool2d(ft, (l,l))
cnt /= 2.0
loss += F.mse_loss(tmpfs, tmpft, reduction='mean') * cnt
tot += cnt
loss = loss / tot
loss_all = loss_all + loss
return loss_all
def _compute_loss(self, task_i, inputs, targets, outputs, old_classes, new_classes, epoch,mask=None):
loss = F.cross_entropy(outputs['logit'], targets)
trip_loss = torch.zeros([1]).cuda()
atmap_loss = torch.zeros([1]).cuda()
if outputs['div_logit'] is not None:
div_targets = targets.clone()
if self._cfg["div_type"] == "n+1":
div_targets[old_classes] = 0
div_targets[new_classes] -= sum(self._inc_dataset.increments[:self._task]) - 1
elif self._cfg["div_type"] == "1+1":
div_targets[old_classes] = 0
div_targets[new_classes] = 1
elif self._cfg["div_type"] == "n+t":
div_targets[new_classes] -= sum(self._inc_dataset.increments[:self._task]) - self._task
for i in range(self._task):
if i > 0:
old_class = np.logical_and((targets < (self._n_classes - self._task_size * (self._task-i))).cpu(), (targets >= (self._n_classes - self._task_size * (self._task-i+1))).cpu())
else:
old_class = (targets < (self._n_classes - self._task_size * self._task)).cpu()
div_targets[old_class] = i
# import pdb
# pdb.set_trace()
div_loss = F.cross_entropy(outputs['div_logit'], div_targets)
else:
div_loss = torch.zeros([1]).cuda()
if self._cfg["distillation"] and self._old_model is not None:
outputs_old = self._old_model(task_i-1,inputs,mask=None)
targets_old = outputs_old['logit'].detach()
if self._cfg["disttype"] == "KL":
log_probs_new = (outputs['logit'][:, :-self._task_size] / self._temperature).log_softmax(dim=1)
if self._task > 1 and self._cfg["postprocessor"]["enable"]:
if self._cfg["postprocessor"]["type"].lower() == "aver":
targets_old = self._old_model.module.postprocessor.post_process(targets_old, self._task_size, self.classnum_list[:-1], self._task-1)
else:
targets_old = self._old_model.module.postprocessor.post_process(targets_old, self._task_size)
modify = self._cfg.get("modify_new",False)
if modify:
old_weight_norm = torch.norm(self._network.convnets[-1].classifer.weight[:-self._task_size], p=2, dim=1)
new_weight_norm = torch.norm(self._network.convnets[-1].classifer.weight[-self._task_size:], p=2, dim=1)
gamma = old_weight_norm.mean() / new_weight_norm.mean()
targets_old[new_classes,:] = targets_old[new_classes,:] * gamma
probs_old = (targets_old / self._temperature).softmax(dim=1)
dist_loss = F.kl_div(log_probs_new, probs_old, reduction="batchmean")
else:
dist_loss = self.cross_entropy(outputs['logit'][:, :-self._task_size], targets_old, exp=1.0 / self._temperature)
else:
dist_loss = torch.zeros([1]).cuda()
return loss, div_loss, trip_loss, dist_loss, atmap_loss
def _after_task(self, taski, inc_dataset,mask=None):
network = deepcopy(self._parallel_network)
network.eval()
if self._cfg["save_ckpt"] and taski >= self._cfg["start_task"] and not self.prune:
self._ex.logger.info("save model")
save_path = os.path.join(os.getcwd(), f"{self._cfg.exp.saveckpt}")
torch.save(network.cpu().state_dict(), "{}/step{}.ckpt".format(save_path, self._task))
if self.torc:
if self._cfg["postprocessor"]["enable"]:
self._update_postprocessor(taski,inc_dataset,mask=mask)
if self._cfg["infer_head"] == 'NCM':
self._ex.logger.info("compute prototype")
self.update_prototype()
if self._memory_size.memsize != 0:
self._ex.logger.info("build memory")
self.build_exemplars(taski,inc_dataset, self._coreset_strategy,mask=mask)
if self._cfg["save_mem"]:
save_path = os.path.join(os.getcwd(), f"{self._cfg.exp.saveckpt}/mem")
memory = {
'x': inc_dataset.data_memory,
'y': inc_dataset.targets_memory,
'herding': self._herding_matrix
}
if not os.path.exists(save_path):
os.makedirs(save_path)
if not (os.path.exists(f"{save_path}/mem_step{self._task}.ckpt") and self._cfg['load_mem']):
torch.save(memory, "{}/mem_step{}.ckpt".format(save_path, self._task))
self._ex.logger.info(f"Save step{self._task} memory!")
# utils.display_weight_norm(self._ex.logger, self._parallel_network, self._increments, "After training")
self._parallel_network.eval()
self._old_model = deepcopy(self._parallel_network)
if not self._cfg.get("caculate_params", False):
self._old_model.module.freeze()
del self._inc_dataset.shared_data_inc
self._inc_dataset.shared_data_inc = None
def _eval_task(self,task_i, data_loader,mask):
# if self._cfg.get("caculate_params", False):
# from thop import profile
# self._parallel_network.eval()
# with torch.no_grad():
# input = torch.randn(1, 3, 256, 256).to(self._device, non_blocking=True)
# flops, params = profile(self._parallel_network, inputs=(input,))
# ypred = flops/1000**3
# ytrue = params/1000**2
# from torchstat import stat
# stat(self._parallel_network, (3, 256, 256))
# ypred,ytrue = 0,0
# else:
if self._infer_head == "softmax":
ypred, ytrue = self._compute_accuracy_by_netout(task_i,data_loader,mask)
elif self._infer_head == "NCM":
ypred, ytrue = self._compute_accuracy_by_ncm(data_loader)
else:
raise ValueError()
return ypred, ytrue
def _compute_accuracy_by_netout(self, task_i,data_loader,mask):
preds, targets = [], []
self._parallel_network.eval()
if self._cfg.get("caculate_params", False):
with torch.no_grad():
from thop import profile
inputs = torch.randn(1, 3, 112, 112)
flops, params = profile(self._parallel_network, (inputs,))
preds = flops/1000**3
targets = params/1000**2
# print('flops: ', flops, 'params: ', params)
# for i, (inputs, lbls) in enumerate(data_loader):
# from thop import profile
# # inputs = inputs.to(self._device, non_blocking=True)
# flops, params = profile(self._parallel_network, inputs[0])
# preds = flops/1000**3
# targets = params/1000**2
# break
else:
with torch.no_grad():
for i, (inputs, lbls) in enumerate(data_loader):
inputs = inputs.to(self._device, non_blocking=True)
_preds = self._parallel_network(task_i,inputs,mask)['logit']
if self.torc:
if self._cfg["postprocessor"]["enable"] and self._task > 0:
if self._cfg["postprocessor"]["type"].lower() == "aver":
_preds = self._network.postprocessor.post_process(_preds, self._task_size, self.classnum_list, self._task)
else:
_preds = self._network.postprocessor.post_process(_preds, self._task_size)
preds.append(_preds.detach().cpu().numpy())
targets.append(lbls.long().cpu().numpy())
preds = np.concatenate(preds, axis=0)
targets = np.concatenate(targets, axis=0)
return preds, targets
def _compute_accuracy_by_ncm(self, loader):
features, targets_ = extract_features(self._parallel_network, loader)
targets = np.zeros((targets_.shape[0], self._n_classes), np.float32)
targets[range(len(targets_)), targets_.astype("int32")] = 1.0
class_means = (self._class_means.T / (np.linalg.norm(self._class_means.T, axis=0) + EPSILON)).T
features = (features.T / (np.linalg.norm(features.T, axis=0) + EPSILON)).T
# Compute score for iCaRL
sqd = cdist(class_means, features, "sqeuclidean")
score_icarl = (-sqd).T
return score_icarl[:, :self._n_classes], targets_
def _update_postprocessor(self, taski,inc_dataset,mask=None):
if self._cfg["postprocessor"]["type"].lower() == "bic":
if False:#self._cfg["postprocessor"]["disalign_resample"] is True:
bic_loader = inc_dataset._get_loader(inc_dataset.data_inc,
inc_dataset.targets_inc,
mode="train",
resample='disalign_resample')
else:
xdata, ydata = inc_dataset._select(taski,inc_dataset.data_train,
inc_dataset.targets_train,
low_range=0,
high_range=self._n_classes)
bic_loader = inc_dataset._get_loader(xdata, ydata, shuffle=True, mode='train')
bic_loss = None
self._network.postprocessor.reset(n_classes=self._n_classes)
self._network.postprocessor.update(self._ex.logger,
self._task_size,
self._parallel_network,
bic_loader,
loss_criterion=bic_loss,
taski=taski,
mask=mask)
elif self._cfg["postprocessor"]["type"].lower() == "cr":
self._ex.logger.info("Post processor cr update !")
self._network.postprocessor.update(self._network.convnets[-1].classifer, self._task_size)
elif self._cfg["postprocessor"]["type"].lower() == "aver":
self._ex.logger.info("Post processor aver update !")
self._network.postprocessor.update(self._network.convnets[-1].classifer, self._task_size, self.classnum_list, self._task)
def update_prototype(self):
if hasattr(self._inc_dataset, 'shared_data_inc'):
shared_data_inc = self._inc_dataset.shared_data_inc
else:
shared_data_inc = None
self._class_means = update_classes_mean(self._parallel_network,
self._inc_dataset,
self._n_classes,
self._task_size,
share_memory=self._inc_dataset.shared_data_inc,
metric='None')
def build_exemplars(self, task_i,inc_dataset, coreset_strategy,mask=None):
save_path = os.path.join(os.getcwd(), f"{self._cfg.exp.saveckpt}/mem/mem_step{self._task}.ckpt")
if self._cfg["load_mem"] and os.path.exists(save_path):
memory_states = torch.load(save_path)
self._inc_dataset.data_memory = memory_states['x']
self._inc_dataset.targets_memory = memory_states['y']
self._herding_matrix = memory_states['herding']
self._ex.logger.info(f"Load saved step{self._task} memory!")
return
if coreset_strategy == "random":
from inclearn.tools.memory import random_selection
self._inc_dataset.data_memory, self._inc_dataset.targets_memory = random_selection(
self._n_classes,
self._task_size,
self._parallel_network,
self._ex.logger,
inc_dataset,
self._memory_per_class,
)
elif coreset_strategy == "iCaRL":
from inclearn.tools.memory import herding
data_inc = self._inc_dataset.shared_data_inc if self._inc_dataset.shared_data_inc is not None else self._inc_dataset.data_inc
self._inc_dataset.data_memory, self._inc_dataset.targets_memory, self._herding_matrix = herding(
task_i,
self._n_classes,
self._task_size,
self._parallel_network,
self._herding_matrix,
inc_dataset,
data_inc,
self._memory_per_class,
self._ex.logger,
mask=mask
)
else:
raise ValueError()
def validate(self, data_loader):
if self._infer_head == 'NCM':
self.update_prototype()
ypred, ytrue = self._eval_task(data_loader)
test_acc_stats = utils.compute_accuracy(ypred, ytrue, increments=self._increments, n_classes=self._n_classes)
self._ex.logger.info(f"test top1acc:{test_acc_stats['top1']}")
return test_acc_stats['top1']['total']
def after_prune(self, taski, inc_dataset):
x = torch.randn(1, 3, 32, 32)
self._network = self._network.cpu()
dim1, dim2 = self._network.caculate_dim(x)
del self._network.classifier
self._network.classifier = self._network._gen_classifier(dim1, self._n_classes)
if self._network.se is not None:
del self._network.se
ft_type = self._cfg.get('feature_type', 'ce')
at_res = self._cfg.get('attention_use_residual', False)
self._network.se = factory.get_attention(dim1, ft_type, at_res)
if taski > 0:
del self._network.aux_classifier
self._network.aux_classifier = self._network._gen_classifier(dim2, self._task_size+1)
del self._parallel_network
self._parallel_network = DataParallel(self._network)
class DistillKL(nn.Module):
"""Distilling the Knowledge in a Neural Network"""
def __init__(self, T):
super(DistillKL, self).__init__()
self.T = T
def forward(self, y_s, y_t):
p_s = F.log_softmax(y_s/self.T, dim=1)
p_t = F.softmax(y_t/self.T, dim=1)
loss = F.kl_div(p_s, p_t, reduction="sum") * (self.T**2) / y_s.shape[0]
# loss = F.kl_div(p_s, p_t, reduction="batchmean") * (self.T**2) / y_s.shape[0]
return loss
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/tools/__init__.py
================================================
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/tools/autoaugment_extra.py
================================================
from PIL import Image, ImageEnhance, ImageOps, ImageDraw
import numpy as np
import random
class ImageNetPolicy(object):
""" Randomly choose one of the best 24 Sub-policies on ImageNet.
Example:
>>> policy = ImageNetPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> ImageNetPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), # set-1
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), # set-3
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), #set-11
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor),
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), # set-2
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor),
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor),
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 1, fillcolor),
SubPolicy(0.8, "translateY", 9, 0.9, "translateY", 9, fillcolor),
SubPolicy(0.8, "autocontrast", 0, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.2, "translateY", 7, 0.9, "color", 6, fillcolor),
SubPolicy(0.7, "equalize", 6, 0.4, "color", 9, fillcolor),
SubPolicy(0.3, "brightness", 7, 0.5, "autocontrast", 8, fillcolor),
SubPolicy(0.9, "autocontrast", 4, 0.5, "autocontrast", 6, fillcolor),
SubPolicy(0.3, "solarize", 5, 0.6, "equalize", 5, fillcolor),
SubPolicy(0.2, "translateY", 4, 0.3, "sharpness", 3, fillcolor),
SubPolicy(0.0, "brightness", 8, 0.8, "color", 8, fillcolor),
SubPolicy(0.2, "solarize", 6, 0.8, "color", 6, fillcolor),
SubPolicy(0.2, "solarize", 6, 0.8, "autocontrast", 1, fillcolor),
SubPolicy(0.4, "solarize", 1, 0.6, "equalize", 5, fillcolor),
SubPolicy(0.0, "brightness", 0, 0.5, "solarize", 2, fillcolor),
SubPolicy(0.9, "autocontrast", 5, 0.5, "brightness", 3, fillcolor),
SubPolicy(0.7, "contrast", 5, 0.0, "brightness", 2, fillcolor),
SubPolicy(0.2, "solarize", 8, 0.1, "solarize", 5, fillcolor),
SubPolicy(0.5, "contrast", 1, 0.2, "translateY", 9, fillcolor),
SubPolicy(0.6, "autocontrast", 5, 0.0, "translateY", 9, fillcolor),
SubPolicy(0.9, "autocontrast", 4, 0.8, "equalize", 4, fillcolor),
SubPolicy(0.0, "brightness", 7, 0.4, "equalize", 7, fillcolor),
SubPolicy(0.2, "solarize", 5, 0.7, "equalize", 5, fillcolor),
SubPolicy(0.6, "equalize", 8, 0.6, "color", 2, fillcolor),
SubPolicy(0.3, "color", 7, 0.2, "color", 4, fillcolor),
SubPolicy(0.5, "autocontrast", 2, 0.7, "solarize", 2, fillcolor),
SubPolicy(0.2, "autocontrast", 0, 0.1, "equalize", 0, fillcolor),
SubPolicy(0.6, "shearY", 5, 0.6, "equalize", 5, fillcolor),
SubPolicy(0.9, "brightness", 3, 0.4, "autocontrast", 1, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.7, "equalize", 7, fillcolor),
SubPolicy(0.7, "equalize", 7, 0.5, "solarize", 0, fillcolor),
SubPolicy(0.8, "equalize", 4, 0.8, "translateY", 9, fillcolor),
SubPolicy(0.8, "translateY", 9, 0.6, "translateY", 9, fillcolor),
SubPolicy(0.9, "translateY", 0, 0.5, "translateY", 9, fillcolor),
SubPolicy(0.5, "autocontrast", 3, 0.3, "solarize", 4, fillcolor),
SubPolicy(0.5, "solarize", 3, 0.4, "equalize", 4, fillcolor),
SubPolicy(0.1, "autocontrast", 5, 0.0, "brightness", 0, fillcolor),
SubPolicy(0.7, "equalize", 7, 0.6, "autocontrast", 4, fillcolor),
SubPolicy(0.1, "color", 8, 0.2, "shearY", 3, fillcolor),
SubPolicy(0.4, "shearY", 2, 0.7, "rotate", 0, fillcolor),
SubPolicy(0.1, "shearY", 3, 0.9, "autocontrast", 5, fillcolor),
SubPolicy(0.5, "equalize", 0, 0.6, "solarize", 6, fillcolor),
SubPolicy(0.3, "autocontrast", 5, 0.2, "rotate", 7, fillcolor),
SubPolicy(0.8, "equalize", 2, 0.4, "invert", 0, fillcolor),
SubPolicy(0.9, "equalize", 5, 0.7, "color", 0, fillcolor),
SubPolicy(0.1, "equalize", 1, 0.1, "shearY", 3, fillcolor),
SubPolicy(0.7, "autocontrast", 3, 0.7, "equalize", 0, fillcolor),
SubPolicy(0.5, "brightness", 1, 0.1, "contrast", 7, fillcolor),
SubPolicy(0.1, "contrast", 4, 0.6, "solarize", 5, fillcolor),
SubPolicy(0.2, "solarize", 3, 0.0, "shearX", 0, fillcolor),
SubPolicy(0.3, "translateX", 0, 0.6, "translateX", 0, fillcolor),
SubPolicy(0.5, "equalize", 9, 0.6, "translateY", 7, fillcolor),
SubPolicy(0.1, "shearX", 0, 0.5, "sharpness", 1, fillcolor),
SubPolicy(0.8, "equalize", 6, 0.3, "invert", 6, fillcolor),
SubPolicy(0.4, "shearX", 4, 0.9, "autocontrast", 2, fillcolor),
SubPolicy(0.0, "shearX", 3, 0.0, "posterize", 3, fillcolor),
SubPolicy(0.4, "solarize", 3, 0.2, "color", 4, fillcolor),
SubPolicy(0.1, "equalize", 4, 0.7, "equalize", 6, fillcolor),
SubPolicy(0.3, "equalize", 8, 0.4, "autocontrast", 3, fillcolor),
SubPolicy(0.6, "solarize", 4, 0.7, "autocontrast", 6, fillcolor),
SubPolicy(0.2, "autocontrast", 9, 0.4, "brightness", 8, fillcolor),
SubPolicy(0.1, "equalize", 0, 0.0, "equalize", 6, fillcolor),
SubPolicy(0.8, "equalize", 4, 0.0, "equalize", 4, fillcolor),
SubPolicy(0.5, "equalize", 5, 0.1, "autocontrast", 2, fillcolor),
SubPolicy(0.5, "solarize", 5, 0.9, "autocontrast", 5, fillcolor),
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment ImageNet Policy"
class SubPolicy(object):
def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
ranges = {
"shearX": np.linspace(0, 0.3, 10),
"shearY": np.linspace(0, 0.3, 10),
"translateX": np.linspace(0, 150 / 331, 10),
"translateY": np.linspace(0, 150 / 331, 10),
"rotate": np.linspace(0, 30, 10),
"color": np.linspace(0.0, 0.9, 10),
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int64),
"solarize": np.linspace(256, 0, 10),
"contrast": np.linspace(0.0, 0.9, 10),
"sharpness": np.linspace(0.0, 0.9, 10),
"brightness": np.linspace(0.0, 0.9, 10),
"autocontrast": [0] * 10,
"equalize": [0] * 10,
"invert": [0] * 10,
"cutout": np.linspace(0.0, 0.2, 10),
}
def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
#assert 0.0 <= v <= 0.2
if v <= 0.:
return img
v = v * img.size[0]
return CutoutAbs(img, v)
# x0 = np.random.uniform(w - v)
# y0 = np.random.uniform(h - v)
# xy = (x0, y0, x0 + v, y0 + v)
# color = (127, 127, 127)
# img = img.copy()
# PIL.ImageDraw.Draw(img).rectangle(xy, color)
# return img
def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
# assert 0 <= v <= 20
if v < 0:
return img
w, h = img.size
x0 = np.random.uniform(w)
y0 = np.random.uniform(h)
x0 = int(max(0, x0 - v / 2.))
y0 = int(max(0, y0 - v / 2.))
x1 = min(w, x0 + v)
y1 = min(h, y0 + v)
xy = (x0, y0, x1, y1)
color = (125, 123, 114)
# color = (0, 0, 0)
img = img.copy()
ImageDraw.Draw(img).rectangle(xy, color)
return img
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
func = {
"shearX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"shearY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"translateX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
fillcolor=fillcolor),
"translateY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
fillcolor=fillcolor),
"cutout": lambda img, magnitude: Cutout(img, magnitude),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
# "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img)
}
# self.name = "{}_{:.2f}_and_{}_{:.2f}".format(
# operation1, ranges[operation1][magnitude_idx1],
# operation2, ranges[operation2][magnitude_idx2])
self.p1 = p1
self.operation1 = func[operation1]
self.magnitude1 = ranges[operation1][magnitude_idx1]
self.p2 = p2
self.operation2 = func[operation2]
self.magnitude2 = ranges[operation2][magnitude_idx2]
def __call__(self, img):
if random.random() < self.p1: img = self.operation1(img, self.magnitude1)
if random.random() < self.p2: img = self.operation2(img, self.magnitude2)
return img
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/tools/cutout.py
================================================
import torch
import numpy as np
class Cutout(object):
"""Randomly mask out one or more patches from an image.
Args:
n_holes (int): Number of patches to cut out of each image.
length (int): The length (in pixels) of each square patch.
"""
def __init__(self, n_holes, length):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
"""
Args:
img (Tensor): Tensor image of size (C, H, W).
Returns:
Tensor: Image with n_holes of dimension length x length cut out of it.
"""
h = img.size(1)
w = img.size(2)
mask = np.ones((h, w), np.float32)
for n in range(self.n_holes):
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img = img * mask
return img
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/tools/data_utils.py
================================================
import numpy as np
def construct_balanced_subset(x, y):
xdata, ydata = [], []
minsize = np.inf
for cls_ in np.unique(y):
xdata.append(x[y == cls_])
ydata.append(y[y == cls_])
if ydata[-1].shape[0] < minsize:
minsize = ydata[-1].shape[0]
for i in range(len(xdata)):
# if xdata[i].shape[0] < minsize:
# import pdb
# pdb.set_trace()
idx = np.arange(xdata[i].shape[0])
np.random.shuffle(idx)
xdata[i] = xdata[i][idx][:minsize]
ydata[i] = ydata[i][idx][:minsize]
# !list
return np.concatenate(xdata, 0), np.concatenate(ydata, 0)
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/tools/factory.py
================================================
from matplotlib.transforms import Transform
import torch
from torch import nn
from torch import optim
from inclearn import models
from inclearn.convnet import sew_resnet
from inclearn.datasets import data
from inclearn.convnet.resnet import SEFeatureAt
def get_optimizer(params, optimizer, lr, weight_decay=0.0):
if optimizer == "adam":
return optim.Adam(params, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999))
elif optimizer == "sgd":
return optim.SGD(params, lr=lr, weight_decay=weight_decay, momentum=0.9)
else:
raise NotImplementedError
def get_attention(inplane, type, at_res):
return SEFeatureAt(inplane, type, at_res)
def get_convnet(convnet_type, c_dim=None,cdim_cur=None,**kwargs):
if convnet_type == "resnet18":
return sew_resnet.sew_resnet18(c_dim,cdim_cur,**kwargs)
else:
raise NotImplementedError("Unknwon convnet type {}.".format(convnet_type))
def get_model(cfg, trial_i, _run, ex, tensorboard, inc_dataset):
if cfg["model"] == "incmodel":
return models.IncModel(cfg, trial_i, _run, ex, tensorboard, inc_dataset)
else:
raise NotImplementedError(cfg["model"])
def get_data(cfg, trial_i):
return data.IncrementalDataset(
trial_i=trial_i,
dataset_name=cfg["dataset"],
random_order=cfg["random_classes"],
shuffle=True,
batch_size=cfg["batch_size"],
workers=cfg["workers"],
validation_split=cfg["validation"],
resampling=cfg["resampling"],
increment=cfg["increment"],
data_folder=cfg["data_folder"],
start_class=cfg["start_class"],
torc=cfg.get("distillation")
)
def set_device(cfg):
device_type = cfg["device"]
if device_type == -1:
device = torch.device("cpu")
else:
device = torch.device("cuda:{}".format(device_type))
cfg["device"] = device
return device
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/tools/memory.py
================================================
import numpy as np
from copy import deepcopy
import torch
from torch.nn import functional as F
from inclearn.tools.utils import get_class_loss
from inclearn.convnet.utils import extract_features
class MemorySize:
def __init__(self, mode, inc_dataset, total_memory=None, fixed_memory_per_cls=None):
self.mode = mode
assert mode.lower() in ["uniform_fixed_per_cls", "uniform_fixed_total_mem", "dynamic_fixed_per_cls"]
self.total_memory = total_memory
self.fixed_memory_per_cls = fixed_memory_per_cls
self._n_classes = 0
self.mem_per_cls = []
self._inc_dataset = inc_dataset
def update_n_classes(self, n_classes):
self._n_classes = n_classes
def update_memory_per_cls_uniform(self, n_classes):
if "fixed_per_cls" in self.mode:
self.mem_per_cls = [self.fixed_memory_per_cls for i in range(n_classes)]
elif "fixed_total_mem" in self.mode:
self.mem_per_cls = [self.total_memory // n_classes for i in range(n_classes)]
return self.mem_per_cls
def update_memory_per_cls(self, network, n_classes, task_size):
if "uniform" in self.mode:
self.update_memory_per_cls_uniform(n_classes)
else:
if n_classes == task_size:
self.update_memory_per_cls_uniform(n_classes)
@property
def memsize(self):
if self.mode == "fixed_total_mem":
return self.total_memory
elif self.mode == "fixed_per_cls":
return self.fixed_memory_per_cls * self._n_classes
def compute_examplar_mean(feat_norm, feat_flip, herding_mat, nb_max):
EPSILON = 1e-8
D = feat_norm.T
D = D / (np.linalg.norm(D, axis=0) + EPSILON)
D2 = feat_flip.T
D2 = D2 / (np.linalg.norm(D2, axis=0) + EPSILON)
alph = herding_mat
alph = (alph > 0) * (alph < nb_max + 1) * 1.0
alph_mean = alph / np.sum(alph)
mean = (np.dot(D, alph_mean) + np.dot(D2, alph_mean)) / 2
# mean = np.dot(D, alph_mean)
mean /= np.linalg.norm(mean) + EPSILON
return mean, alph
def select_examplars(features, nb_max):
EPSILON = 1e-8
D = features.T
D = D / (np.linalg.norm(D, axis=0) + EPSILON)
mu = np.mean(D, axis=1)
herding_matrix = np.zeros((features.shape[0], ))
idxes = []
w_t = mu
iter_herding, iter_herding_eff = 0, 0
while not (np.sum(herding_matrix != 0) == min(nb_max, features.shape[0])) and iter_herding_eff < 1000:
tmp_t = np.dot(w_t, D)
# tmp_t = -np.linalg.norm(w_t[:,np.newaxis]-D, axis=0)
# tmp_t = np.linalg.norm(w_t[:,np.newaxis]-D, axis=0)
ind_max = np.argmax(tmp_t)
iter_herding_eff += 1
if herding_matrix[ind_max] == 0:
herding_matrix[ind_max] = 1 + iter_herding
idxes.append(ind_max)
iter_herding += 1
w_t = w_t + mu - D[:, ind_max]
return herding_matrix, idxes
def random_selection(n_classes, task_size, network, logger, inc_dataset, memory_per_class: list):
# TODO: Move data_memroy,targets_memory into IncDataset
logger.info("Building & updating memory.(Random Selection)")
tmp_data_memory, tmp_targets_memory = [], []
assert len(memory_per_class) == n_classes
for class_idx in range(n_classes):
# 旧类数据从get_custom_loader_from_memory中读取,新类数据从get_custom_loader中读取
if class_idx < n_classes - task_size:
inputs, targets, loader = inc_dataset.get_custom_loader_from_memory([class_idx])
else:
inputs, targets, loader = inc_dataset.get_custom_loader(class_idx, mode="test")
memory_this_cls = min(memory_per_class[class_idx], inputs.shape[0])
idxs = np.random.choice(inputs.shape[0], memory_this_cls, replace=False)
tmp_data_memory.append(inputs[idxs])
tmp_targets_memory.append(targets[idxs])
tmp_data_memory = np.concatenate(tmp_data_memory)
tmp_targets_memory = np.concatenate(tmp_targets_memory)
return tmp_data_memory, tmp_targets_memory
def herding(task_i,n_classes, task_size, network, herding_matrix, inc_dataset, shared_data_inc, memory_per_class: list,
logger,mask=None):
"""Herding matrix: list
"""
logger.info("Building & updating memory.(iCaRL)")
tmp_data_memory, tmp_targets_memory = [], []
for class_idx in range(n_classes):
inputs = inc_dataset.data_train[inc_dataset.targets_train == class_idx]
targets = inc_dataset.targets_train[inc_dataset.targets_train == class_idx]
# zi = inc_dataset.zimages[inc_dataset.zlabels == class_idx]
# zt = inc_dataset.zlabels[inc_dataset.zlabels == class_idx]
# inputs = np.concatenate((inputs, zi))
# targets = np.concatenate((targets, zt))
if class_idx >= n_classes - task_size:
if len(shared_data_inc) > len(inc_dataset.targets_inc):
share_memory = [shared_data_inc[i] for i in np.where(inc_dataset.targets_inc == class_idx)[0].tolist()]
else:
share_memory = []
for i in np.where(inc_dataset.targets_inc == class_idx)[0].tolist():
if i < len(shared_data_inc):
share_memory.append(shared_data_inc[i])
# share_memory = [shared_data_inc[i] for i in np.where(inc_dataset.targets_inc == class_idx)[0].tolist()]
loader = inc_dataset._get_loader(inc_dataset.data_inc[inc_dataset.targets_inc == class_idx],
inc_dataset.targets_inc[inc_dataset.targets_inc == class_idx],
share_memory=share_memory,
batch_size=128,
shuffle=False,
mode="test")
features, _ = extract_features(task_i,network, loader,mask=mask)
# features_flipped, _ = extract_features(network, inc_dataset.get_custom_loader(class_idx, mode="flip")[-1])
herding_matrix.append(select_examplars(features, memory_per_class[class_idx])[0])
alph = herding_matrix[class_idx]
alph = (alph > 0) * (alph < memory_per_class[class_idx] + 1) * 1.0
# examplar_mean, alph = compute_examplar_mean(features, features_flipped, herding_matrix[class_idx],
# memory_per_class[class_idx])
tmp_data_memory.append(inputs[np.where(alph == 1)[0]])
tmp_targets_memory.append(targets[np.where(alph == 1)[0]])
tmp_data_memory = np.concatenate(tmp_data_memory)
tmp_targets_memory = np.concatenate(tmp_targets_memory)
return tmp_data_memory, tmp_targets_memory, herding_matrix
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/tools/metrics.py
================================================
import numpy as np
import torch
import numbers
import math
class IncConfusionMeter:
"""Maintains a confusion matrix for a given calssification problem.
The ConfusionMeter constructs a confusion matrix for a multi-class
classification problems. It does not support multi-label, multi-class problems:
for such problems, please use MultiLabelConfusionMeter.
Args:
k (int): number of classes in the classification problem
normalized (boolean): Determines whether or not the confusion matrix
is normalized or not
"""
def __init__(self, k, increments, normalized=False):
self.conf = np.ndarray((k, k), dtype=np.int32)
self.normalized = normalized
self.increments = increments
self.cum_increments = [0] + [sum(increments[:i + 1]) for i in range(len(increments))]
self.k = k
self.reset()
def reset(self):
self.conf.fill(0)
def add(self, predicted, target):
"""Computes the confusion matrix of K x K size where K is no of classes
Args:
predicted (tensor): Can be an N x K tensor of predicted scores obtained from
the model for N examples and K classes or an N-tensor of
integer values between 0 and K-1.
target (tensor): Can be a N-tensor of integer values assumed to be integer
values between 0 and K-1 or N x K tensor, where targets are
assumed to be provided as one-hot vectors
"""
if isinstance(predicted, torch.Tensor):
predicted = predicted.cpu().numpy()
if isinstance(target, torch.Tensor):
target = target.cpu().numpy()
assert predicted.shape[0] == target.shape[0], \
'number of targets and predicted outputs do not match'
if np.ndim(predicted) != 1:
assert predicted.shape[1] == self.k, \
'number of predictions does not match size of confusion matrix'
predicted = np.argmax(predicted, 1)
else:
assert (predicted.max() < self.k) and (predicted.min() >= 0), \
'predicted values are not between 1 and k'
onehot_target = np.ndim(target) != 1
if onehot_target:
assert target.shape[1] == self.k, \
'Onehot target does not match size of confusion matrix'
assert (target >= 0).all() and (target <= 1).all(), \
'in one-hot encoding, target values should be 0 or 1'
assert (target.sum(1) == 1).all(), \
'multi-label setting is not supported'
target = np.argmax(target, 1)
else:
assert (predicted.max() < self.k) and (predicted.min() >= 0), \
'predicted values are not between 0 and k-1'
# hack for bincounting 2 arrays together
x = predicted + self.k * target
bincount_2d = np.bincount(x.astype(np.int32), minlength=self.k**2)
assert bincount_2d.size == self.k**2
conf = bincount_2d.reshape((self.k, self.k))
self.conf += conf
def value(self):
"""
Returns:
Confustion matrix of K rows and K columns, where rows corresponds
to ground-truth targets and columns corresponds to predicted
targets.
"""
conf = self.conf.astype(np.float32)
new_conf = np.zeros([len(self.increments), len(self.increments) + 2])
for i in range(len(self.increments)):
idxs = range(self.cum_increments[i], self.cum_increments[i + 1])
new_conf[i, 0] = conf[idxs, idxs].sum()
new_conf[i, 1] = conf[self.cum_increments[i]:self.cum_increments[i + 1],
self.cum_increments[i]:self.cum_increments[i + 1]].sum() - new_conf[i, 0]
for j in range(len(self.increments)):
new_conf[i, j + 2] = conf[self.cum_increments[i]:self.cum_increments[i + 1],
self.cum_increments[j]:self.cum_increments[j + 1]].sum()
conf = new_conf
if self.normalized:
return conf / conf[:, 2:].sum(1).clip(min=1e-12)[:, None]
else:
return conf
class ClassErrorMeter:
def __init__(self, topk=[1], accuracy=False):
super(ClassErrorMeter, self).__init__()
self.topk = np.sort(topk)
self.accuracy = accuracy
self.reset()
def reset(self):
self.sum = {v: 0 for v in self.topk}
self.n = 0
def add(self, output, target):
if isinstance(output, np.ndarray):
output = torch.Tensor(output)
if isinstance(target, np.ndarray):
target = torch.Tensor(target)
# if torch.is_tensor(output):
# output = output.cpu().squeeze().numpy()
# if torch.is_tensor(target):
# target = target.cpu().squeeze().numpy()
# elif isinstance(target, numbers.Number):
# target = np.asarray([target])
# if np.ndim(output) == 1:
# output = output[np.newaxis]
# else:
# assert np.ndim(output) == 2, \
# 'wrong output size (1D or 2D expected)'
# assert np.ndim(target) == 1, \
# 'target and output do not match'
# assert target.shape[0] == output.shape[0], \
# 'target and output do not match'
topk = self.topk
maxk = int(topk[-1]) # seems like Python3 wants int and not np.int64
no = output.shape[0]
pred = output.topk(maxk, 1, True, True)[1]
correct = pred == target.unsqueeze(1).repeat(1, pred.shape[1])
# pred = torch.from_numpy(output).topk(maxk, 1, True, True)[1].numpy()
# correct = pred == target[:, np.newaxis].repeat(pred.shape[1], 1)
for k in topk:
self.sum[k] += no - correct[:, 0:k].sum()
self.n += no
def value(self, k=-1):
if k != -1:
assert k in self.sum.keys(), \
'invalid k (this k was not provided at construction time)'
if self.n == 0:
return float('nan')
if self.accuracy:
return (1. - float(self.sum[k]) / self.n) * 100.0
else:
return float(self.sum[k]) / self.n * 100.0
else:
return [self.value(k_) for k_ in self.topk]
class AverageValueMeter:
def __init__(self):
super(AverageValueMeter, self).__init__()
self.reset()
self.val = 0
def add(self, value, n=1):
self.val = value
self.sum += value
self.var += value * value
self.n += n
if self.n == 0:
self.mean, self.std = np.nan, np.nan
elif self.n == 1:
self.mean, self.std = self.sum, np.inf
self.mean_old = self.mean
self.m_s = 0.0
else:
self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n)
self.m_s += (value - self.mean_old) * (value - self.mean)
self.mean_old = self.mean
self.std = math.sqrt(self.m_s / (self.n - 1.0))
def value(self):
return self.mean, self.std
def reset(self):
self.n = 0
self.sum = 0.0
self.var = 0.0
self.val = 0.0
self.mean = np.nan
self.mean_old = 0.0
self.m_s = 0.0
self.std = np.nan
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/tools/results_utils.py
================================================
import glob
import json
import math
import os
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
from . import utils
def get_template_results(cfg):
return {"config": cfg, "results": []}
def save_results(results, label):
del results["config"]["device"]
folder_path = os.path.join("results", "{}_{}".format(utils.get_date(), label))
if not os.path.exists(folder_path):
os.makedirs(folder_path)
file_path = "{}_{}_.json".format(utils.get_date(), results["config"]["seed"])
with open(os.path.join(folder_path, file_path), "w+") as f:
json.dump(results, f, indent=2)
def compute_avg_inc_acc(results):
"""Computes the average incremental accuracy as defined in iCaRL.
The average incremental accuracies at task X are the average of accuracies
at task 0, 1, ..., and X.
:param accs: A list of dict for per-class accuracy at each step.
:return: A float.
"""
top1_tasks_accuracy = [r['top1']["total"] for r in results]
top1acc = sum(top1_tasks_accuracy) / len(top1_tasks_accuracy)
if "top5" in results[0].keys():
top5_tasks_accuracy = [r['top5']["total"] for r in results]
top5acc = sum(top5_tasks_accuracy) / len(top5_tasks_accuracy)
else:
top5acc = None
return top1acc, top5acc
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/tools/scheduler.py
================================================
import math
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
class ConstantTaskLR:
def __init__(self, lr):
self._lr = lr
def get_lr(self, task_i):
return self._lr
class CosineAnnealTaskLR:
def __init__(self, lr_max, lr_min, task_max):
self._lr_max = lr_max
self._lr_min = lr_min
self._task_max = task_max
def get_lr(self, task_i):
return self._lr_min + (self._lr_max - self._lr_min) * (1 + math.cos(math.pi * task_i / self._task_max)) / 2
class GradualWarmupScheduler(_LRScheduler):
""" Gradually warm-up(increasing) learning rate in optimizer.
https://github.com/ildoonet/pytorch-gradual-warmup-lr
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
Args:
optimizer (Optimizer): Wrapped optimizer.
multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
total_epoch: target learning rate is reached at total_epoch, gradually
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
"""
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
self.multiplier = multiplier
if self.multiplier < 1.:
raise ValueError('multiplier should be greater thant or equal to 1.')
self.total_epoch = total_epoch
self.after_scheduler = after_scheduler
self.finished = False
super(GradualWarmupScheduler, self).__init__(optimizer)
def get_lr(self):
if self.last_epoch > self.total_epoch:
if self.after_scheduler:
if not self.finished:
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
self.finished = True
return self.after_scheduler.get_last_lr()
return [base_lr * self.multiplier for base_lr in self.base_lrs]
if self.multiplier == 1.0:
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
else:
return [
base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.)
for base_lr in self.base_lrs
]
def step_ReduceLROnPlateau(self, metrics, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
if self.last_epoch <= self.total_epoch:
warmup_lr = [
base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.)
for base_lr in self.base_lrs
]
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
param_group['lr'] = lr
else:
if epoch is None:
self.after_scheduler.step(metrics, None)
else:
self.after_scheduler.step(metrics, epoch - self.total_epoch)
def step(self, epoch=None, metrics=None):
if type(self.after_scheduler) != ReduceLROnPlateau:
if self.finished and self.after_scheduler:
if epoch is None:
self.after_scheduler.step(None)
else:
self.after_scheduler.step(epoch - self.total_epoch)
self._last_lr = self.after_scheduler.get_last_lr()
else:
return super(GradualWarmupScheduler, self).step(epoch)
else:
self.step_ReduceLROnPlateau(metrics, epoch)
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/tools/similar.py
================================================
import torch.nn as nn
import torchvision.models as models
import numpy as np
import os
from sklearn.utils import shuffle
import torch
import torch.nn.functional as F
class Appr(object):
""" Class implementing the TALL """
def __init__(self, pretrained_feat_extractor, num_task, torc,device='cuda', args=None):
self.task2expert = []
self.expert2task = []
self.torc=torc
self.num_task=num_task
self.task2mean = {}
self.task2cov = {}
for i in range(num_task):
self.task2mean[i]=[]
self.task2cov[i]=[]
self.task_dist = torch.zeros(num_task, num_task).to(device='cuda')
self.task_dist2 = torch.zeros(num_task, num_task).to(device='cuda')
self.feat_extractor = pretrained_feat_extractor
self.task_relatedness_method = "mean"
self.reuse_threshold=0.3
self.reuse_cell_threshold=0.75
def get_mean_cov_feats(self, taski, data, device):
"""compute mean and cov for features of data extracted by the expert of task t
"""
# data = deepcopy(data) # copy for using different preprocess
# self.model.requires_grad_(False)
# self.model.eval()
# self.model.set_current_task(t)
if self.torc:
steps=int(100/self.num_task)
class_num = steps*taski
labels = torch.arange(class_num,class_num+steps).view(-1, 1).to(device)
else:
steps = int(100/self.num_task)
labels = torch.arange(steps).view(-1, 1).to(device)
all_task_feats={}
for t in range(taski):
all_task_feats[t]=[[] for _ in range(steps)]
self.feat_extractor.eval()
with torch.no_grad():
for batch_idx, (x, y) in enumerate(data):
x, y = x.to(device), y.to(device)
index = labels == y.view(1, -1) # CxC
for t_p in range(taski):
feat = self.feat_extractor(t_p,x,mask=None,classify=False)['features']
feat = feat.view(feat.size(0), -1)
for i in range(steps):
all_task_feats[t_p][i].append(feat[index[i]])
feat_means = {}
feat_covs = {}
all_feats_cat={}
for t_p in range(taski):
feat_means[t_p] = []
feat_covs [t_p]= []
all_feats_cat[t_p] = [torch.cat(feats, axis=0) for feats in all_task_feats[t_p]]
for feat in all_feats_cat[t_p]:
feat_mean, feat_cov = gaussian_mean_cov(feat)
feat_means[t_p].append(feat_mean)
# feat_covs[t_p].append(feat_cov)
# feat_means = [torch.mean(feat, dim=0) for feat in all_feats_cat]
return feat_means, feat_covs, all_feats_cat
def after_get_mean_cov_feats(self, taski, data, device):
"""compute mean and cov for features of data extracted by the expert of task t
"""
# data = deepcopy(data) # copy for using different preprocess
# self.model.requires_grad_(False)
# self.model.eval()
# self.model.set_current_task(t)
if self.torc:
steps=int(100/self.num_task)
class_num = steps*taski
labels = torch.arange(class_num,class_num+steps).view(-1, 1).to(device)
else:
steps = int(100/self.num_task)
labels = torch.arange(steps).view(-1, 1).to(device)
all_task_feats=[[] for _ in range(steps)]
self.feat_extractor.eval()
with torch.no_grad():
for batch_idx, (x, y) in enumerate(data):
x, y = x.to(device), y.to(device)
# forward
feat = self.feat_extractor(taski,x,mask=None,classify=False)['features']
feat = feat.view(feat.size(0), -1)
index = labels == y.view(1, -1) # CxC
for i in range(steps):
all_task_feats[i].append(feat[index[i]])
feat_means= []
feat_covs = []
all_feats_cat= [torch.cat(feats, axis=0) for feats in all_task_feats]
for feat in all_feats_cat:
feat_mean, feat_cov = gaussian_mean_cov(feat)
feat_means.append(feat_mean)
# feat_covs.append(feat_cov)
# feat_means = [torch.mean(feat, dim=0) for feat in all_feats_cat]
return feat_means, feat_covs, all_feats_cat
def add_mean_cov(self, taski,mean, cov=None):
self.task2mean[taski].append(mean)
# self.task2cov[taski].append(cov)
def task_relatedness_knnkl(self, task_id, p_task_id, all_feats):
"""
Params:
all_feat: shape C x N_c x D
"""
# means and features of current data from expert of p_task_id
# feat_means, all_feats = means_and_feats[p_task_id]
# means of data of p_task_id
p_feat_means = self.task2mean[p_task_id][p_task_id] # C x D
feat_means = self.task2mean[task_id][p_task_id] # C x D
# p_feat_cov = self.task2cov[p_task_id][p_task_id] # C x D
# feat_cov = self.task2cov[task_id][p_task_id] # C x D
p_feat_means = torch.stack(p_feat_means, dim=0)
task_dist = 0
task_dist2=0
d = p_feat_means.shape[-1]
n = 0
flag = False
for i in range(len(feat_means)):
# for each current class
n += all_feats[i].shape[0]
dist_in = all_feats[i] - feat_means[i] # N_c x D
dist_in = torch.sqrt(torch.sum(dist_in ** 2, dim=-1)) # N_c
# N_c x C x D
dist_out = torch.unsqueeze(all_feats[i], dim=1) - torch.unsqueeze(p_feat_means, dim=0)
dist_out, _ = torch.min(torch.sqrt(torch.sum(dist_out ** 2, dim=-1)), dim=-1) # N_c
dist = torch.mean(torch.log(dist_out / (0.9*dist_in)))
#dist = torch.mean(torch.log(dist_out))
# if dist <= 0:
# flag = True
task_dist += torch.maximum(dist, torch.zeros_like(dist))
task_dist2 += torch.maximum(dist, torch.zeros_like(dist))
# task_dist = task_dist / len(feat_means)
# task_dist = 1 - torch.exp(-2*task_dist)
task_dist = torch.minimum(1 - torch.exp(-2*task_dist), task_dist)
if task_dist == 0:
task_dist = torch.ones_like(task_dist)
return task_dist,task_dist2
def get_relatedness(self, task_id, feats):
"""Compute relatedness
"""
# the distance between task_id and task_id
self.task_dist[task_id][task_id] = 0
self.task_dist2[task_id][task_id] = 0
for p_task_id in range(task_id):
# for p_task_id in range(task_id + 1):
# task_dist = self.task_relatedness_cos(task_id, p_task_id)
task_dist,task_dist2 = self.task_relatedness_knnkl(task_id, p_task_id, feats[p_task_id])
# task_dist = self.task_relatedness_CKA(task_id, p_task_id, feats)
# task_dist = self.task_relatedness_gaussian_kl(task_id, p_task_id)
self.task_dist[task_id][p_task_id] = task_dist
self.task_dist[p_task_id][task_id] = task_dist
self.task_dist2[task_id][p_task_id] = task_dist2
self.task_dist2[p_task_id][task_id] = task_dist2
def strategy(self, task_id, num_train_samples):
""" Find the expert to be reused. If not found, return -1.
"""
expert = -1
min_dist = None
all_dist = []
all_dist2 = []
for expert_id, p_tasks in enumerate(self.expert2task):
d = self.task_dist[task_id, p_tasks]
dd= self.task_dist2[task_id, p_tasks]
if self.task_relatedness_method == "mean":
s = torch.mean(d).item()
ss = torch.mean(dd).item()
elif self.task_relatedness_method == "max":
s = torch.max(d).item()
elif self.task_relatedness_method == "min":
s = torch.min(d).item()
else:
raise Exception("Unknown reuse strategy !!!")
all_dist.append(s)
all_dist2.append(ss)
if min_dist is None:
min_dist = s
expert = expert_id
elif s < min_dist:
min_dist = s
expert = expert_id
# if num_train_samples <= 25: # for s_long
# all_dist = torch.tensor(all_dist)
# _, expert_idx =torch.sort(all_dist)
# for e in expert_idx:
# if self.model.expert2max_train_samples[e] >= 10 * num_train_samples:
# return "reuse", e
if min_dist <= self.reuse_threshold:
return "reuse", expert ,min_dist,all_dist,all_dist2
# elif min_dist <= self.reuse_cell_threshold:
# return "reuse cell", expert
else:
return "new", expert,min_dist,all_dist,all_dist2
def learn(self, task_id, valid_data, batch_size,device):
"""learn a task
"""
if task_id == 0:
# train
strategy='new'
expert_id=task_id
min_dist=0
all_dist=0
all_dist2=0
else:
feat_means, feat_covs, all_feats = self.get_mean_cov_feats(
task_id, valid_data, device=device)
for t in range(task_id):
self.add_mean_cov(task_id,feat_means[t],feat_covs[t])
self.get_relatedness(task_id, all_feats)
num_train_samples=len(valid_data)*batch_size
strategy, expert_id,min_dist,all_dist,all_dist2 = self.strategy(task_id, num_train_samples)
self.expert2task.append([task_id])
print(self.task_dist)
print(self.task_dist2)
return strategy, expert_id,min_dist,all_dist,all_dist2
def after_learn(self, task_id, valid_data, batch_size,device):
"""learn a task
"""
feat_means, feat_covs, all_feats = self.after_get_mean_cov_feats(
task_id, valid_data, device=device)
self.add_mean_cov(task_id,feat_means,feat_covs)
print(self.task_dist)
print(self.task_dist2)
class ResNet_FE(nn.Module):
"""
Create a feature extractor model from an Alexnet architecture, that is used to train the autoencoder model
and get the most related model whilst training a new task in a sequence
"""
def __init__(self, resnet_model):
super(ResNet_FE, self).__init__()
self.fe_model = nn.Sequential(*list(resnet_model.children())[:-1])
self.fe_model.eval()
self.fe_model.requires_grad_(False)
def forward(self, x):
return self.fe_model(x)
def get_pretrained_feat_extractor(name):
"""get the feature extractor pretrained on ImageNet
"""
if name == "resnet18":
feat_extractor = ResNet_FE(models.resnet18(weights=True))
# self.logger.info("Using relatedness feature extractor: ResNet18")
else:
raise Exception("Unknown relatedness feature extractor !!!")
return feat_extractor
def gaussian_mean_cov(X):
"""mean and covariance of Guassian distribution
Params:
X: N x D
"""
device = X.device
N, D = X.shape[0], X.shape[1]
u = torch.mean(X, dim=0)
u_row = torch.reshape(u, (1, -1)) # 1 x D
cov = torch.matmul(X.T, X) - N * torch.matmul(u_row.T, u_row) # D x D
cov = cov / (N - 1)
cov = cov * torch.diag(torch.ones(D)).to(X.device) + (torch.diag(torch.ones(D))).to(X.device)
return u, cov
# import torch.nn as nn
# import torchvision.models as models
# import numpy as np
# import os
# from sklearn.utils import shuffle
# import torch
# import torch.nn.functional as F
# class Appr(object):
# """ Class implementing the TALL """
# def __init__(self, pretrained_feat_extractor, num_task, device='cuda', args=None):
# self.task2expert = []
# self.expert2task = []
# self.task2mean = []
# self.task2cov = []
# self.task_dist = torch.zeros(num_task, num_task).to(device='cuda')
# self.feat_extractor = get_pretrained_feat_extractor(pretrained_feat_extractor).to(device='cuda')
# self.task_relatedness_method = "mean"
# self.reuse_threshold=0.3
# self.reuse_cell_threshold=0.75
# def get_mean_cov_feats(self, t, data, device):
# """compute mean and cov for features of data extracted by the expert of task t
# """
# # data = deepcopy(data) # copy for using different preprocess
# # self.model.requires_grad_(False)
# # self.model.eval()
# # self.model.set_current_task(t)
# class_num = 10
# labels = torch.arange(class_num).view(-1, 1).to(device)
# all_feats = [[] for _ in range(class_num)]
# self.feat_extractor.eval()
# with torch.no_grad():
# for batch_idx, (x, y) in enumerate(data):
# x, y = x.to(device), y.to(device)
# # forward
# feat = self.feat_extractor(x)
# feat = feat.view(feat.size(0), -1)
# index = labels == y.view(1, -1) # CxC
# for i in range(class_num):
# all_feats[i].append(feat[index[i]])
# all_feats_cat = [torch.cat(feats, axis=0) for feats in all_feats]
# feat_means = []
# feat_covs = []
# for feat in all_feats_cat:
# feat_mean, feat_cov = gaussian_mean_cov(feat)
# feat_means.append(feat_mean)
# feat_covs.append(feat_cov)
# # feat_means = [torch.mean(feat, dim=0) for feat in all_feats_cat]
# return feat_means, feat_covs, all_feats_cat
# def add_mean_cov(self, mean, cov=None):
# self.task2mean.append(mean)
# self.task2cov.append(cov)
# def task_relatedness_knnkl(self, task_id, p_task_id, all_feats):
# """
# Params:
# all_feat: shape C x N_c x D
# """
# # means and features of current data from expert of p_task_id
# # feat_means, all_feats = means_and_feats[p_task_id]
# # means of data of p_task_id
# p_feat_means = self.task2mean[p_task_id] # C x D
# feat_means = self.task2mean[task_id] # C x D
# p_feat_means = torch.stack(p_feat_means, dim=0)
# task_dist = 0
# d = p_feat_means.shape[-1]
# n = 0
# flag = False
# for i in range(len(feat_means)):
# # for each current class
# n += all_feats[i].shape[0]
# dist_in = all_feats[i] - feat_means[i] # N_c x D
# dist_in = torch.sqrt(torch.sum(dist_in ** 2, dim=-1)) # N_c
# # N_c x C x D
# dist_out = torch.unsqueeze(all_feats[i], dim=1) - torch.unsqueeze(p_feat_means, dim=0)
# dist_out, _ = torch.min(torch.sqrt(torch.sum(dist_out ** 2, dim=-1)), dim=-1) # N_c
# dist = torch.mean(torch.log(dist_out / dist_in))
# # if dist <= 0:
# # flag = True
# task_dist += torch.maximum(dist, torch.zeros_like(dist))
# # task_dist = task_dist / len(feat_means)
# # task_dist = 1 - torch.exp(-2*task_dist)
# task_dist = torch.minimum(1 - torch.exp(-2*task_dist), task_dist)
# if task_dist == 0:
# task_dist = torch.ones_like(task_dist)
# return task_dist
# def get_relatedness(self, task_id, feats):
# """Compute relatedness
# """
# # the distance between task_id and task_id
# self.task_dist[task_id][task_id] = 0
# for p_task_id in range(task_id):
# # for p_task_id in range(task_id + 1):
# # task_dist = self.task_relatedness_cos(task_id, p_task_id)
# task_dist = self.task_relatedness_knnkl(task_id, p_task_id, feats)
# # task_dist = self.task_relatedness_CKA(task_id, p_task_id, feats)
# # task_dist = self.task_relatedness_gaussian_kl(task_id, p_task_id)
# self.task_dist[task_id][p_task_id] = task_dist
# self.task_dist[p_task_id][task_id] = task_dist
# def strategy(self, task_id, num_train_samples):
# """ Find the expert to be reused. If not found, return -1.
# """
# expert = -1
# min_dist = None
# all_dist = []
# for expert_id, p_tasks in enumerate(self.expert2task):
# d = self.task_dist[task_id, p_tasks]
# if self.task_relatedness_method == "mean":
# s = torch.mean(d).item()
# elif self.task_relatedness_method == "max":
# s = torch.max(d).item()
# elif self.task_relatedness_method == "min":
# s = torch.min(d).item()
# else:
# raise Exception("Unknown reuse strategy !!!")
# all_dist.append(s)
# if min_dist is None:
# min_dist = s
# expert = expert_id
# elif s < min_dist:
# min_dist = s
# expert = expert_id
# # if num_train_samples <= 25: # for s_long
# # all_dist = torch.tensor(all_dist)
# # _, expert_idx =torch.sort(all_dist)
# # for e in expert_idx:
# # if self.model.expert2max_train_samples[e] >= 10 * num_train_samples:
# # return "reuse", e
# if min_dist <= self.reuse_threshold:
# return "reuse", expert ,min_dist,all_dist
# # elif min_dist <= self.reuse_cell_threshold:
# # return "reuse cell", expert
# else:
# return "new", expert,min_dist,all_dist
# def learn(self, task_id, valid_data, batch_size,device):
# """learn a task
# """
# feat_means, feat_covs, all_feats = self.get_mean_cov_feats(
# task_id, valid_data, device=device)
# self.add_mean_cov(feat_means)
# self.get_relatedness(task_id, all_feats)
# if task_id == 0:
# # train
# strategy='new'
# expert_id=task_id
# min_dist=0
# all_dist=0
# else:
# num_train_samples=len(valid_data)*batch_size
# strategy, expert_id,min_dist,all_dist = self.strategy(task_id, num_train_samples)
# self.expert2task.append([task_id])
# return strategy, expert_id,min_dist,all_dist
# class ResNet_FE(nn.Module):
# """
# Create a feature extractor model from an Alexnet architecture, that is used to train the autoencoder model
# and get the most related model whilst training a new task in a sequence
# """
# def __init__(self, resnet_model):
# super(ResNet_FE, self).__init__()
# self.fe_model = nn.Sequential(*list(resnet_model.children())[:-1])
# self.fe_model.eval()
# self.fe_model.requires_grad_(False)
# def forward(self, x):
# return self.fe_model(x)
# def get_pretrained_feat_extractor(name):
# """get the feature extractor pretrained on ImageNet
# """
# if name == "resnet18":
# feat_extractor = ResNet_FE(models.resnet18(weights=True))
# # self.logger.info("Using relatedness feature extractor: ResNet18")
# else:
# raise Exception("Unknown relatedness feature extractor !!!")
# return feat_extractor
# def gaussian_mean_cov(X):
# """mean and covariance of Guassian distribution
# Params:
# X: N x D
# """
# device = X.device
# N, D = X.shape[0], X.shape[1]
# u = torch.mean(X, dim=0)
# u_row = torch.reshape(u, (1, -1)) # 1 x D
# cov = torch.matmul(X.T, X) - N * torch.matmul(u_row.T, u_row) # D x D
# cov = cov / (N - 1)
# cov = cov * torch.diag(torch.ones(D)).to(X.device) + (torch.diag(torch.ones(D))).to(X.device)
# return u, cov
================================================
FILE: examples/Structural_Development/SCA-SNN/inclearn/tools/utils.py
================================================
import random
from copy import deepcopy
import numpy as np
import datetime
import torch
from inclearn.tools.metrics import ClassErrorMeter
from sklearn.metrics import classification_report
def get_date():
return datetime.datetime.now().strftime("%Y%m%d")
def to_onehot(targets, n_classes):
if not hasattr(targets, "device"):
targets = torch.from_numpy(targets)
onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device)
onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.0)
return onehot
def get_class_loss(network, cur_n_cls, loader):
class_loss = torch.zeros(cur_n_cls)
n_cls_data = torch.zeros(cur_n_cls) # the num of imgs for cls i.
EPS = 1e-10
task_size = 10
network.eval()
for x, y in loader:
x, y = x.cuda(), y.cuda()
preds = network(x)['logit'].softmax(dim=1)
# preds[:,-task_size:] = preds[:,-task_size:].softmax(dim=1)
for i, lbl in enumerate(y):
class_loss[lbl] = class_loss[lbl] - (preds[i, lbl] + EPS).detach().log().cpu()
n_cls_data[lbl] += 1
class_loss = class_loss / n_cls_data
return class_loss
def get_featnorm_grouped_by_class(task_i,network, cur_n_cls, loader,m=None):
"""
Ret: feat_norms: list of list
feat_norms[idx] is the list of feature norm of the images for class idx.
"""
feats = [[] for i in range(cur_n_cls)]
feat_norms = np.zeros(cur_n_cls)
network.eval()
with torch.no_grad():
for x, y in loader:
x = x.cuda()
feat = network(task_i,x,m)['feature'].cpu()
for i, lbl in enumerate(y):
if lbl >= cur_n_cls:
continue
feats[lbl].append(feat[y == lbl])
for i in range(len(feats)):
if len(feats[i]) != 0:
feat_cls = torch.cat((feats[i]))
feat_norms[i] = torch.norm(feat_cls, p=2, dim=1).mean().data.numpy()
return feat_norms
def set_seed(seed):
print("Set seed", seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True # This will slow down training.
torch.backends.cudnn.benchmark = False
def display_weight_norm(logger, network, increments, tag):
weight_norms = [[] for _ in range(len(increments))]
increments = np.cumsum(np.array(increments))
for idx in range(network.module.convnets[-1].classifer.weight.shape[0]):
norm = torch.norm(network.module.convnets[-1].classifer.weight[idx].data, p=2).item()
for i in range(len(weight_norms)):
if idx < increments[i]:
break
weight_norms[i].append(round(norm, 3))
avg_weight_norm = []
# all_weight_norms = []
for idx in range(len(weight_norms)):
# all_weight_norms += weight_norms[idx]
# logger.info("task %s: Weight norm per class %s" % (str(idx), str(weight_norms[idx])))
avg_weight_norm.append(round(np.array(weight_norms[idx]).mean(), 3))
logger.info("%s: Weight norm per task %s" % (tag, str(avg_weight_norm)))
def display_feature_norm(task_i,logger, network, loader, n_classes, increments, tag, return_norm=False,mask=None):
avg_feat_norm_per_cls = get_featnorm_grouped_by_class(task_i,network, n_classes, loader,m=mask)
feature_norms = [[] for _ in range(len(increments))]
increments = np.cumsum(np.array(increments))
for idx in range(len(avg_feat_norm_per_cls)):
for i in range(len(feature_norms)):
if idx < increments[i]: #Find the mapping from class idx to step i.
break
feature_norms[i].append(round(avg_feat_norm_per_cls[idx], 3))
avg_feature_norm = []
for idx in range(len(feature_norms)):
avg_feature_norm.append(round(np.array(feature_norms[idx]).mean(), 3))
logger.info("%s: Feature norm per class %s" % (tag, str(avg_feature_norm)))
if return_norm:
return avg_feature_norm
else:
return
def check_loss(loss):
return not bool(torch.isnan(loss).item()) and bool((loss >= 0.0).item())
def class2task(class_form, classnum):
target_form = deepcopy(class_form)
for i in range(classnum):
mask = (target_form==i)
target_form[mask] = -(i//10)-1
target_form = (target_form+1)*(-1)
return target_form
def maskclass(pred, target, classnum, type='new'):
# type 为new,遮盖new的class,old遮盖旧的,all遮盖新旧
target_form = deepcopy(target)
pred_form = deepcopy(pred)
if type == 'old':
mask = np.logical_or(pred_form<(classnum-10), target_form<(classnum-10))
pred_form[mask] = 0
target_form[mask] = 0
if type == 'new':
mask = np.logical_or(pred_form>=(classnum-10), target_form>=(classnum-10))
pred_form[mask] = 1000
target_form[mask] = 1000
if type == 'all':
mask = (target_form>=(classnum-10))
target_form[mask] = 1000
mask = (pred_form>=(classnum-10))
pred_form[mask] = 1000
mask = (target_form<(classnum-10))
target_form[mask] = 0
mask = (pred_form<(classnum-10))
pred_form[mask] = 0
all_err = np.sum(pred_form!=target_form)
pred_form1 = deepcopy(pred_form)
mask = (target_form<(classnum-10))
pred_form[mask] = 0
new_old_err = np.sum(pred_form!=target_form)
mask = (target_form>=(classnum-10))
pred_form1[mask] = 1000
old_new_err = np.sum(pred_form1!=target_form)
return all_err, new_old_err, old_new_err
return pred_form, target_form
def compute_old_new_mix(ypred, ytrue, increments, n_classes, task_order):
task_means = []
for i in range (n_classes//10):
taski_mask = np.logical_and(ytrue>=i*10, ytrue<(i+1)*10)
task_i_mean = ypred[np.arange(ytrue.shape[0]), ytrue][taski_mask].mean().item()
task_means.append(task_i_mean)
task_mean = ypred[np.arange(ytrue.shape[0]), ytrue].mean().item()
classnum = ypred.shape[1]
ypred = ypred.argmax(1)
all_err = np.sum(ypred!=ytrue)
ypred_task = class2task(ypred, classnum)
ytrue_task = class2task(ytrue, classnum)
err_among_task = np.sum(ypred_task!=ytrue_task)
err_inner_task = all_err - err_among_task
# print("all err : {}\n among task err: {}\n inner task err: {}\n".format(all_err, err_among_task, err_inner_task))
ypred_new, ytrue_new = maskclass(ypred, ytrue, n_classes, 'old')
new_err = np.sum(ypred_new!=ytrue_new)
ypred_old, ytrue_old = maskclass(ypred, ytrue, n_classes, 'new')
old_err = np.sum(ypred_old!=ytrue_old)
all_err, new_old_err, old_new_err = maskclass(ypred, ytrue, n_classes, 'all')
print("******all_err:****", all_err)
all_acc = {"task_mean": task_mean, "task_means":task_means, "new_err": new_err, "old_err":old_err, "new_old_err": new_old_err, "old_new_err": old_new_err, "err_among_task": err_among_task, "err_inner_task": err_inner_task}
return all_acc
def compute_task_accuracy(ypred, ytrue, increments, n_classes, task_order):
task_mean = ypred[np.arange(ytrue.shape[0]), ytrue].mean().item()
classnum = ypred.shape[1]
ypred = ypred.argmax(1)
ypred_task = class2task(ypred, classnum)
ytrue_task = class2task(ytrue, classnum)
all_acc = {"task_mean": task_mean, "class_info": classification_report(ytrue, ypred), "task_info": classification_report(ytrue_task, ypred_task)}
return all_acc
def compute_accuracy(ypred, ytrue, increments, n_classes):
all_acc = {"top1": {}, "top5": {}}
topk = 5 if n_classes >= 5 else n_classes
ncls = np.unique(ytrue).shape[0]
if topk > ncls:
topk = ncls
all_acc_meter = ClassErrorMeter(topk=[1, topk], accuracy=True)
all_acc_meter.add(ypred, ytrue)
all_acc["top1"]["total"] = round(all_acc_meter.value()[0], 3)
all_acc["top5"]["total"] = round(all_acc_meter.value()[1], 3)
# all_acc["total"] = round((ypred == ytrue).sum() / len(ytrue), 3)
# for class_id in range(0, np.max(ytrue), task_size):
start, end = 0, 0
for i in range(len(increments)):
if increments[i] <= 0:
pass
else:
start = end
end += increments[i]
idxes = np.where(np.logical_and(ytrue >= start, ytrue < end))[0]
topk_ = 5 if increments[i] >= 5 else increments[i]
ncls = np.unique(ytrue[idxes]).shape[0]
if topk_ > ncls:
topk_ = ncls
cur_acc_meter = ClassErrorMeter(topk=[1, topk_], accuracy=True)
cur_acc_meter.add(ypred[idxes], ytrue[idxes])
top1_acc = (ypred[idxes].argmax(1) == ytrue[idxes]).sum() / idxes.shape[0] * 100
if start < end:
label = "{}-{}".format(str(start).rjust(2, "0"), str(end - 1).rjust(2, "0"))
else:
label = "{}-{}".format(str(start).rjust(2, "0"), str(end).rjust(2, "0"))
all_acc["top1"][label] = round(top1_acc, 3)
all_acc["top5"][label] = round(cur_acc_meter.value()[1], 3)
# all_acc[label] = round((ypred[idxes] == ytrue[idxes]).sum() / len(idxes), 3)
return all_acc
def make_logger(log_name, savedir='.logs/'):
"""Set up the logger for saving log file on the disk
Args:
cfg: configuration dict
Return:
logger: a logger for record essential information
"""
import logging
import os
from logging.config import dictConfig
import time
logging_config = dict(
version=1,
formatters={'f_t': {
'format': '\n %(asctime)s | %(levelname)s | %(name)s \t %(message)s'
}},
handlers={
'stream_handler': {
'class': 'logging.StreamHandler',
'formatter': 'f_t',
'level': logging.INFO
},
'file_handler': {
'class': 'logging.FileHandler',
'formatter': 'f_t',
'level': logging.INFO,
'filename': None,
}
},
root={
'handlers': ['stream_handler', 'file_handler'],
'level': logging.DEBUG,
},
)
# set up logger
log_file = '{}.log'.format(log_name)
# if folder not exist,create it
if not os.path.exists(savedir):
os.makedirs(savedir)
log_file_path = os.path.join(savedir, log_file)
logging_config['handlers']['file_handler']['filename'] = log_file_path
open(log_file_path, 'w').close() # Clear the content of logfile
# get logger from dictConfig
dictConfig(logging_config)
logger = logging.getLogger()
return logger
================================================
FILE: examples/Structural_Development/SCA-SNN/main.py
================================================
import sys
import os
import os.path as osp
import copy
import time
import shutil
import cProfile
import logging
from pathlib import Path
import numpy as np
import random
from easydict import EasyDict as edict
from tensorboardX import SummaryWriter
import os
import inclearn.convnet.maskcl2 as Mask
os.environ['CUDA_VISIBLE_DEVICES']='0'
repo_name = 'TCIL'
base_dir = '/data1/hanbing/TCIL10/'
sys.path.insert(0, base_dir)
from sacred import Experiment
ex = Experiment(base_dir=base_dir, save_git_info=False)
import torch
from inclearn.tools import factory, results_utils, utils
from inclearn.tools.metrics import IncConfusionMeter
from inclearn.tools.similar import Appr
def initialization(config, seed, mode, exp_id):
torch.backends.cudnn.benchmark = True # This will result in non-deterministic results.
# ex.captured_out_filter = lambda text: 'Output capturing turned off.'
cfg = edict(config)
utils.set_seed(cfg['seed'])
if exp_id is None:
exp_id = -1
cfg.exp.savedir = "./logs_aphal"
logger = utils.make_logger(str(exp_id)+str(cfg.exp.name)+str(mode), savedir=cfg.exp.savedir)
# Tensorboard
exp_name = '{exp_id}_{cfg["exp"]["name"]}' if exp_id is not None else '../inbox/{cfg["exp"]["name"]}'
tensorboard_dir = cfg["exp"]["tensorboard_dir"] + "/{exp_name}"
# If not only save latest tensorboard log.
# if Path(tensorboard_dir).exists():
# shutil.move(tensorboard_dir, cfg["exp"]["tensorboard_dir"] + f"/../inbox/{time.time()}_{exp_name}")
tensorboard = SummaryWriter(tensorboard_dir)
return cfg, logger, tensorboard
@ex.command
def train(_run, _rnd, _seed):
cfg, ex.logger, tensorboard = initialization(_run.config, _seed, "train", _run._id)
ex.logger.info(cfg)
cfg.data_folder = osp.join(base_dir, "data")
start_time = time.time()
_train(cfg, _run, ex, tensorboard)
ex.logger.info("Training finished in {}s.".format(int(time.time() - start_time)))
def _train(cfg, _run, ex, tensorboard):
device = factory.set_device(cfg)
trial_i = cfg['trial']
torc=cfg['distillation']
inc_dataset = factory.get_data(cfg, trial_i)
ex.logger.info("classes_order")
ex.logger.info(inc_dataset.class_order)
model = factory.get_model(cfg, trial_i, _run, ex, tensorboard, inc_dataset)
mask=Mask.Mask(model._network.convnets[-1])
if _run.meta_info["options"]["--file_storage"] is not None:
_save_dir = osp.join(_run.meta_info["options"]["--file_storage"], str(_run._id))
else:
_save_dir = cfg["exp"]["ckptdir"]
results = results_utils.get_template_results(cfg)
appr=Appr(model._network,10,torc)
for task_i in range(inc_dataset.n_tasks):
task_info, train_loader, val_loader, test_loader = inc_dataset.new_task(task_i)
model.set_task_info(
task=task_info["task"],
total_n_classes=task_info["max_class"],
increment=task_info["increment"],
n_train_data=task_info["n_train_data"],
n_test_data=task_info["n_test_data"],
n_tasks=inc_dataset.n_tasks,
)
if torc:
strategy, expert_id ,min_dist,all_dist,all_dist2= appr.learn(task_i, val_loader, cfg['batch_size'],device)
else:
strategy, expert_id ,min_dist,all_dist,all_dist2= appr.learn(task_i, test_loader[task_i], cfg['batch_size'],device)
print("Task:",task_i,strategy, expert_id,min_dist,all_dist,all_dist2)
model.before_task(task_i, inc_dataset,mask,min_dist,all_dist)
# TODO: Move to incmodel.py
if 'min_class' in task_info:
ex.logger.info("Train on {}->{}.".format(task_info["min_class"], task_info["max_class"]))
if torc:
model.train_task(task_i,train_loader, test_loader,mask,min_dist,all_dist)
model.after_task(task_i, inc_dataset,mask)
appr.after_learn(task_i, val_loader, cfg['batch_size'],device)
else:
model.train_task(task_i,train_loader, val_loader[task_i],mask,min_dist,all_dist)
appr.after_learn(task_i, test_loader[task_i], cfg['batch_size'],device)
if torc:
ex.logger.info("Eval on {}->{}.".format(0, task_info["max_class"]))
ypred, ytrue = model.eval_task(task_i,test_loader,mask)
acc_stats = utils.compute_accuracy(ypred, ytrue, increments=model._increments, n_classes=model._n_classes)
#Logging
model._tensorboard.add_scalar("taskaccu/trial{trial_i}", acc_stats["top1"]["total"], task_i)
_run.log_scalar("trial{trial_i}_taskaccu", acc_stats["top1"]["total"], task_i)
_run.log_scalar("trial{trial_i}_task_top5_accu", acc_stats["top5"]["total"], task_i)
ex.logger.info("top1:"+str(acc_stats['top1']))
ex.logger.info("top5:"+str(acc_stats['top5']))
results["results"].append(acc_stats)
else:
for taski in range(task_i+1):
ypred, ytrue = model.eval_task(taski,test_loader[taski],mask)
acc_stats = utils.compute_accuracy(ypred, ytrue, increments=[1], n_classes=model._n_classes)
model._tensorboard.add_scalar(f"taskaccu/trial{trial_i}", acc_stats["top1"]["total"], taski)
_run.log_scalar(f"trial{trial_i}_taskaccu", acc_stats["top1"]["total"], taski)
_run.log_scalar(f"trial{trial_i}_task_top5_accu", acc_stats["top5"]["total"], taski)
ex.logger.info(f"top1:{acc_stats['top1']}")
ex.logger.info(f"top5:{acc_stats['top5']}")
results["results"].append(acc_stats)
top1_avg_acc, top5_avg_acc = results_utils.compute_avg_inc_acc(results["results"])
_run.info["trial{trial_i}"]["avg_incremental_accu_top1"] = top1_avg_acc
_run.info["trial{trial_i}"]["avg_incremental_accu_top5"] = top5_avg_acc
ex.logger.info("Average Incremental Accuracy Top 1: {} Top 5: {}.".format(
_run.info["trial{trial_i}"]["avg_incremental_accu_top1"],
_run.info["trial{trial_i}"]["avg_incremental_accu_top5"],
))
if cfg["exp"]["name"]:
results_utils.save_results(results, cfg["exp"]["name"])
if __name__ == "__main__":
ex.add_config("/data1/hanbing/SCA-SNN/configs/train.yaml")
ex.run_commandline()
================================================
FILE: examples/Structural_Development/SD-SNN/README.md
================================================
# Adaptive Sparse Structure Development with Pruning and Regeneration for Spiking Neural Networks #
## Requirments ##
* numpy
* timm
* pytorch >= 1.7.0
* collections
* argparse
## Run ##
```CUDA_VISIBLE_DEVICES=0 python main.py```
## Citation ##
If you find the code and dataset useful in your research, please consider citing:
```
@article{han2025adaptive,
title={Adaptive sparse structure development with pruning and regeneration for spiking neural networks},
author={Han, Bing and Zhao, Feifei and Pan, Wenxuan and Zeng, Yi},
journal={Information Sciences},
volume={689},
pages={121481},
year={2025},
publisher={Elsevier}
}
@article{zeng2023braincog,
title={Braincog: A spiking neural network based, brain-inspired cognitive intelligence engine for brain-inspired ai and brain simulation},
author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},
journal={Patterns},
volume={4},
number={8},
year={2023},
publisher={Elsevier},
}
```
Enjoy!
================================================
FILE: examples/Structural_Development/SD-SNN/main.py
================================================
import argparse
import time
import os
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
import sys
sys.path.append('..')
import logging
import torch
from timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
from braincog.base.node.node import *
from braincog.base.encoder.encoder import *
from braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule
from braincog.base.utils.criterions import *
from braincog.datasets.datasets import *
from prun_and_generation import *
from snn_model import *
from utils import *
_logger = logging.getLogger('train')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
exp_name = '-'.join([datetime.now().strftime("%Y%m%d-%H%M%S"),'c10'])
output_dir = get_outdir('./', 'train', exp_name)
setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))
_logger.info(exp_name)
config_parser = cfg = argparse.ArgumentParser(description='Training Config', add_help=False)
model='cifar_convnet'
dataset='cifar10'
num_classes=10
step=8
encode='direct'
node_type='PLIFNode'
thresh=0.5
tau=2.0
torch.backends.cudnn.benchmark = True
devicee=0
seed=36
channels = 2
lr=5e-3
batch_size=50
epochs=600
linear_scaled_lr = lr * batch_size/ 1024.0
cfg.opt='adamw'
cfg.lr=linear_scaled_lr
cfg.weight_decay=0.01
cfg.momentum=0.9
cfg.epochs=epochs
cfg.sched='cosine'
cfg.min_lr=1e-5
cfg.warmup_lr=1e-6
cfg.warmup_epochs=5
cfg.cooldown_epochs=10
cfg.decay_rate=0.1
epoch_prune = 1
eval_metric='top1'
best_test = 0
best_testepoch = 0
best_testprun = 0
best_testepochprun = 0
spines_num=18
torch.cuda.set_device('cuda:%d' % devicee)
torch.manual_seed(seed)
model = my_cifar_model(step=step,encode_type=encode,node_type=node_type,num_classes=num_classes)
model = model.cuda()
print(model)
optimizer = create_optimizer(cfg, model)
lr_scheduler, num_epochs = create_scheduler(cfg, optimizer)
loader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % dataset)(batch_size=batch_size, step=step)
train_loss_fn = UnilateralMse(1.)
validate_loss_fn = UnilateralMse(1.)
m = Mask(model,spines_num)
def train_epoch(
epoch, model, loader, optimizer, loss_fn,
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None):
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
top1_m = AverageMeter()
model.train()
end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
for batch_idx, (inputs, target) in enumerate(loader):
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()
output = model(inputs)
loss = loss_fn(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses_m.update(loss.item(), inputs.size(0))
top1_m.update(acc1.item(), inputs.size(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_time_m.update(time.time() - end)
if last_batch or batch_idx %100 == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
print("Train: epoch:",epoch,batch_idx,"/",len(loader),"loss:",losses_m.avg,"acc1:", top1_m.avg,"lr:",lr)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
end = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.avg)])
def validate(model, loader, loss_fn, amp_autocast=suppress, log_suffix=''):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.eval()
end = time.time()
last_idx = len(loader) - 1
with torch.no_grad():
for batch_idx, (inputs, target) in enumerate(loader):
#print(inputs.size())
# inputs = inputs.type(torch.float64)
last_batch = batch_idx == last_idx
inputs = inputs.type(torch.FloatTensor).cuda()
target = target.cuda()
output = model(inputs)
if isinstance(output, (tuple, list)):
output = output[0]
loss = loss_fn(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), inputs.size(0))
top1_m.update(acc1.item(), output.size(0))
if last_batch or batch_idx %100 == 0:
print("Test: loss:",losses_m.avg,"acc1:", top1_m.avg)
batch_time_m.update(time.time() - end)
end = time.time()
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
return metrics
for epoch in range(epochs):
train_metrics= train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn,
lr_scheduler=lr_scheduler)
if epoch==0:
m.init_length()
if epoch>0:
m.model = model
m.init_mask_dsd()
if epoch>spines_num:
matt=m.do_mask_dsd()
if epoch>2*spines_num:
m.do_growth_ww(epoch)
matt=m.do_pruning_dsd(epoch)
model = m.model
cc=m.if_zero()
eval_metrics = validate(model, loader_eval, validate_loss_fn)
top1=eval_metrics['top1']
if top1 > best_testprun:
best_testprun = top1
best_testepochprun =epoch
if epoch%40==0:
print('best acc:',best_testprun,'best epoch:',best_testepochprun)
if epoch>4:
_logger.info('*** epoch: {0} (pruning rate {1},acc:{2})'.format(epoch, cc,top1))
if lr_scheduler is not None:
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
================================================
FILE: examples/Structural_Development/SD-SNN/prun_and_generation.py
================================================
import numpy as np
import torch
import math
from utils import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Mask:
def __init__(self, model,count_thre):
self.model_size = {}
self.model_length = {}
self.compress_rate = {}
self.mat = {}
self.model = model
self.mask_index = []
self.distance_rate = {}
self.filter_small_index = {}
self.filter_large_index = {}
self.similar_matrix = {}
self.norm_matrix = {}
# dendritic dynamics
self.cur_range_pos = {} # current range foe every weight
self.cur_range_neg = {} # current range foe every weight
self.dsum_pos_out = {} # current range foe every weight
self.dsum_neg_out = {} # current range foe every weight
self.dsum_pos_in = {} # current range foe every weight
self.dsum_neg_in = {} # current range foe every weight
self.dendritic_previous_pos = {} # the index for beyond range weight
self.dendritic_previous_neg = {} # the index for within range weight
self.dendritic_previous_in = {} # the index for within range weight
self.dendritic_count_pos = {} # the count of beyond range for every weight
self.dendritic_count_neg = {} # the count of within range for every weight
self.dendritic_count_in = {} # the count of within range for every weight
self.out = 0
self.count_thre=count_thre
self.weight_previous = {}
self.mask={}
self.codebook = {}
self.codebookww={}
self.his_ind={}
self.his_groth={}
self.his_gro_count={}
self.prune={}
self.pruncc={}
self.prunn={}
self.his_prun={}
self.convlayer =model.convlayer
for index in self.convlayer:
if indexr+
weight_vec = weight_torch.view(-1) # one conv one weight vector
weight_vec_np = weight_vec.cpu().numpy()
weight_np_abs = abs(weight_vec_np)
dendritic_pos_tmp = np.where((weight_vec_np >= self.cur_range_pos[i])) # find the weight beyong range
pos_index = set(dendritic_pos_tmp[0]) & set(self.dendritic_previous_pos[i]) # calculate intersection consectively
pos_zero = set([i for i in range(length)]) - pos_index # non-intersection weight will be count from 0
pos_index = np.array(list(pos_index))
pos_zero = np.array(list(pos_zero))
if pos_zero.size > 0:
self.dendritic_count_pos[i][pos_zero] = 0
self.dsum_pos_out[i][pos_zero] = 0
if pos_index.size > 0:
self.dendritic_count_pos[i][pos_index] = self.dendritic_count_pos[i][pos_index] + 1 # intewrsection +1
self.dsum_pos_out[i][pos_index] += weight_vec_np[pos_index] - self.cur_range_pos[i][pos_index]
dendritic_index = np.where(self.dendritic_count_pos[i] >= count_thre) # count>threshold
self.out = self.out + len(dendritic_index[0])
self.cur_range_pos[i][dendritic_index]+=(self.dsum_pos_out[i][dendritic_index] / count_thre) # self.cur_range_pos[i][dendritic_index]+0.025 #
self.dendritic_count_pos[i][dendritic_index] = 0 # intrsection count set to 0
self.dsum_pos_out[i][dendritic_index] = 0
self.dendritic_previous_pos[i] = dendritic_pos_tmp[0] # update previous
# 0:
self.dendritic_count_neg[i][neg_zero] = 0
self.dsum_neg_out[i][neg_zero] = 0
if neg_index.size > 0:
self.dendritic_count_neg[i][neg_index] = self.dendritic_count_neg[i][neg_index] + 1 # intewrsection +1
self.dsum_neg_out[i][neg_index] += weight_vec_np[neg_index] - self.cur_range_neg[i][neg_index]
dendritic_index_neg = np.where(self.dendritic_count_neg[i] >= count_thre) # count>threshold
self.out = self.out + len(dendritic_index_neg[0])
self.cur_range_neg[i][dendritic_index_neg] += (self.dsum_neg_out[i][dendritic_index_neg] / count_thre) #self.cur_range_neg[i][dendritic_index_neg]-0.025 #
self.dendritic_count_neg[i][dendritic_index_neg] = 0 # intrsection count set to 0
self.dsum_neg_out[i][dendritic_index_neg] = 0
self.dendritic_previous_neg[i] = dendritic_neg_tmp[0] # update previous
# r-~r+
dendritic_in_tmp = np.where((weight_np_abs < self.weight_previous[i])) # find the weight beyong range
in_index = set(dendritic_in_tmp[0]) & set(self.dendritic_previous_in[i]) # calculate intersection consectively
in_zero = set([i for i in range(length)]) - in_index # non-intersection weight will be count from 0
in_index = np.array(list(in_index))
in_zero = np.array(list(in_zero))
if in_zero.size > 0:
self.dendritic_count_in[i][in_zero] = 0
self.dsum_neg_in[i][in_zero] = 0
if in_index.size > 0:
self.dendritic_count_in[i][in_index] = self.dendritic_count_in[i][in_index] + 1 # intewrsection +1
self.dsum_neg_in[i][in_index] += weight_np_abs[in_index] - self.weight_previous[i][in_index]
dendritic_index_in = np.where(self.dendritic_count_in[i] >= count_thre) # count>threshold
self.cur_range_pos[i][dendritic_index_in] = 0.75*self.cur_range_pos[i][dendritic_index_in] # self.cur_range_pos[i][dendritic_index_in]-0.025
self.cur_range_neg[i][dendritic_index_in] = 0.75*self.cur_range_neg[i][dendritic_index_in] # self.cur_range_neg[i][dendritic_index_in]+0.025
self.dendritic_count_in[i][dendritic_index_in] = 0 # intrsection count set to 0
self.dsum_neg_in[i][dendritic_index_in] = 0
self.dendritic_previous_in[i] = dendritic_in_tmp[0] # update previous
self.weight_previous[i] = weight_np_abs
print('dendritic dynamics done', np.mean(self.cur_range_pos[i]), np.mean(self.cur_range_neg[i][0]))
return self.cur_range_pos[i], self.cur_range_neg[i]
def do_mask_dsd(self):
for index in self.convlayer:
if index self.cur_range_pos[index]] = self.cur_range_pos[index][
a > self.cur_range_pos[index]] # weight beyond range set to range
a[a < self.cur_range_neg[index]] = self.cur_range_neg[index][a < self.cur_range_neg[index]]
a = torch.FloatTensor(a).cuda()
ww.data = a.view(self.model_size[index])
print("mask Done")
def do_pruning_dsd(self,epoch):
for index in self.convlayer:
if index60:
aphla=0.0005
sumbook=torch.sum(self.codebook[self.convlayer[-1]],dim=1)
prun=torch.where(sumbook<50)[0]
prunn=prun.size()[0]
self.prune[index]=self.prune[index]+aphla*(512-prunn)/10
print(prunn,self.prune[index])
def do_growth_ww(self,epoch):
for index in self.convlayer:
if index99:
rate=99
gg=np.percentile(ww, rate)
grow=np.where(ww>gg)[0]
growth_ind=set(grow) & set(p_index)
growth_index=growth_ind & set(self.his_groth[index])
zero_index=set([i for i in range(ww.size)]) - growth_index
growth_index=np.array(list(growth_index))
zero_index=np.array(list(zero_index))
if zero_index.size>0:
self.his_gro_count[index][zero_index]=0
if growth_index.size>0:
self.his_gro_count[index][growth_index]=self.his_gro_count[index][growth_index]+1
gr_index=np.where(self.his_gro_count[index]> self.count_thre)[0]
self.codebook[index]=self.codebook[index].view(-1)
for x in range(len(gr_index)):
self.codebook[index][gr_index[x]*9:(gr_index[x]+1)*9]=1
self.codebook[index]=self.codebook[index].view(self.model_size[index])
print(len(gr_index),len(growth_ind),len(p_index))
self.his_groth[index]=growth_ind
self.his_gro_count[index][gr_index]=0
if index==self.convlayer[-1]:
ww=self.fc1.fc.weight
ww=ww.data
ww=ww.view(-1).cpu().numpy()
p_index=np.where(self.codebook[index].view(-1).cpu().numpy()==0)[0]
rate=60+1.1**(epoch- 2*self.count_thre-1)
if rate>99:
rate=99
gg=np.percentile(ww, rate)
grow=np.where(ww>gg)[0]
growth_ind=set(grow) & set(p_index)
growth_index=growth_ind & set(self.his_groth[index])
zero_index=set([i for i in range(ww.size)]) - growth_index
growth_index=np.array(list(growth_index))
zero_index=np.array(list(zero_index))
if zero_index.size>0:
self.his_gro_count[index][zero_index]=0
if growth_index.size>0:
self.his_gro_count[index][growth_index]=self.his_gro_count[index][growth_index]+1
gr_index=np.where(self.his_gro_count[index]> self.count_thre)[0]
self.codebook[index]=self.codebook[index].view(-1)
self.codebook[index][gr_index]=1
self.codebook[index]=self.codebook[index].view(self.model_size[index][0],-1)
print(len(gr_index),len(growth_ind),len(p_index))
self.his_groth[index]=growth_ind
self.his_gro_count[index][gr_index]=0
def if_zero(self):
cc=[]
for index in self.convlayer:
if index 1:
a = ww.data.view(self.model_length[index])
b = a.cpu().numpy()
print(
"number of nonzero weight is %d, zero is %d" % (np.count_nonzero(b), len(b) - np.count_nonzero(b)))
cc.append(len(b) - np.count_nonzero(b))
return cc
================================================
FILE: examples/Structural_Development/SD-SNN/snn_model.py
================================================
import abc
from functools import partial
from torch.nn import functional as F
import torchvision
from timm.models import register_model
from braincog.base.node.node import *
from braincog.base.encoder.encoder import *
from braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule
from utils import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class my_cifar_model(BaseModule):
def __init__(self,
num_classes=10,
step=8,
node_type=LIFNode,
encode_type='direct',
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.num_classes = num_classes
self.feature = nn.Sequential(
BaseConvModule(3, 128, kernel_size=(3, 3), padding=(1, 1)),
BaseConvModule(128,128, kernel_size=(3, 3), padding=(1, 1)),
nn.MaxPool2d(2),
BaseConvModule(128,256, kernel_size=(3, 3), padding=(1, 1)),
BaseConvModule(256, 256, kernel_size=(3, 3), padding=(1, 1)),
nn.MaxPool2d(2),
BaseConvModule(256, 512, kernel_size=(3, 3), padding=(1, 1)),
BaseConvModule(512, 512, kernel_size=(3, 3), padding=(1, 1)),
)
self.convlayer = [0,1,3,4,6,7,8]
self.cfla=self._cflatten()
self.fc_prun = self._create_fc_prun()
self.fc = self._create_fc()
def _cflatten(self):
fc = nn.Sequential(
nn.Flatten(),
)
return fc
def _create_fc_prun(self):
fc = nn.Sequential(
BaseLinearModule(512*8*8, 512)
)
return fc
def _create_fc(self):
fc = nn.Sequential(
BaseLinearModule(512, self.num_classes)
)
return fc
def forward(self, inputs):
inputs = self.encoder(inputs)
self.reset()
if not self.training:
self.fire_rate.clear()
outputs = []
for t in range(self.step):
x = inputs[t]
if x.shape[-1] > 32:
x = F.interpolate(x, size=[64, 64])
for i in range(len(self.feature)):
x=self.feature[i](x)
x=self.cfla(x)
x=self.fc_prun(x)
x = self.fc(x)
outputs.append(x)
return sum(outputs) / len(outputs)
================================================
FILE: examples/Structural_Development/SD-SNN/utils.py
================================================
import torch
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def unit(x):
if len(x.shape)>0:
maxx=np.max(x)
minx=np.min(x)
marge=maxx-minx
if marge!=0:
xx=(x-minx)/marge
xx=np.clip(xx, 0,1)
else:
xx=0.5*np.ones_like(x)
return xx
else:
return x
def unit_tensor(x):
if x.size()[0]>0:
maxx=torch.max(x)
minx=torch.min(x)
marge=maxx-minx
if marge!=0:
xx=(x-minx)/marge
else:
xx=0.5*torch.ones_like(x)
return xx
else:
return x
================================================
FILE: examples/Structure_Evolution/Adaptive_lsm/BrainCog-Version/README.md
================================================
# Adaptive structure evolution and biologically plausible synaptic plasticity for recurrent spiking neural networks —— Based on BrainCog #
**This is the BrainBog-based version of the paper. To run the original code, please go to [raw](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structure_Evolution/Adaptive_lsm/raw)**
## Requirments ##
* numpy
* pytorch >= 1.12.0
* BrainCog
## Run ##
```python main.py```
## Citation ##
If you find the code and dataset useful in your research, please consider citing:
```
@article{pan2023adaptive,
title = {Adaptive structure evolution and biologically plausible synaptic plasticity for recurrent spiking neural networks},
author = {Pan, Wenxuan and Zhao, Feifei and Zeng, Yi and Han, Bing},
journal = {Scientific Reports},
volume = {13},
number = {1},
pages = {16924},
year = {2023},
url = {https://doi.org/10.1038/s41598-023-43488-x},
doi = {10.1038/s41598-023-43488-x},
}
@article{zeng2023braincog,
title={BrainCog: A spiking neural network based, brain-inspired cognitive intelligence engine for brain-inspired AI and brain simulation},
author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},
journal={Patterns},
volume={4},
number={8},
year={2023},
publisher={Elsevier}
}
```
================================================
FILE: examples/Structure_Evolution/Adaptive_lsm/BrainCog-Version/brid.py
================================================
import torch, os
import pygame
from pygame.locals import *
from collections import deque
from random import randint
import numpy as np
import nsganet as engine
from pymop.problem import Problem
from pymoo.optimize import minimize
from pymoo.operators.sampling.random_sampling import RandomSampling
from pymoo.operators.mutation.bitflip_mutation import BinaryBitflipMutation
from lsmmodel import SNN
import torch.nn.functional as F
from tools.update_weights import stdp,bcm
import matplotlib.pyplot as plt
os.environ["SDL_VIDEODRIVER"] = "dummy"
steps=30000
seeds=50
result=np.zeros([seeds,int(steps/2)])
t=[i for i in range(steps)]
def randbool(size, p):
return torch.rand(*size) < p
class Evolve(Problem):
# first define the NAS problem (inherit from pymop)
def __init__(self, n_var=20, n_obj=1, n_constr=0, lb=None, ub=None):
super().__init__(n_var=n_var, n_obj=n_obj, n_constr=n_constr, type_var=np.int64)
self.xl = lb
self.xu = ub
self._n_evaluated = 0 # keep track of how many architectures are sampled
def _evaluate(self, x, out, *args, **kwargs):
objs = np.full((x.shape[0], self.n_obj), np.nan)
for i in range(x.shape[0]):
arch_id = self._n_evaluated + 1
print('Network= {}'.format(arch_id))
objs[i, 0] = np.linalg.matrix_rank(x[i])
objs[i, 1] = 0
self._n_evaluated += 1
out["F"] = objs
# if your NAS problem has constraints, use the following line to set constraints
# out["G"] = np.column_stack([g1, g2, g3, g4, g5, g6]) in case 6 constraints
# ---------------------------------------------------------------------------------------------------------
# Define what statistics to print or save for each generation
# ---------------------------------------------------------------------------------------------------------
def do_every_generations(algorithm):
# this function will be call every generation
# it has access to the whole algorithm class
gen = algorithm.n_gen
pop_var = algorithm.pop.get("X")
pop_obj = algorithm.pop.get("F")
# report generation info to files
print("generation = {}".format(gen))
print("population error: best = {}, mean = {}, "
"median = {}, worst = {}".format(np.min(pop_obj[:, 0]), np.mean(pop_obj[:, 0]),
np.median(pop_obj[:, 0]), np.max(pop_obj[:, 0])))
print('Best Genome id= {}'.format(np.argmin(pop_obj[:, 0])))
def load_images():
"""
Flappy Bird中load图像
:return:load的图像
"""
def load_image(img_file_name):
file_name = os.path.join('/home/panwenxuan/raw', 'birdimages', img_file_name)
img = pygame.image.load(file_name)
# converting all images before use speeds up blitting
img.convert()
return img
return {'background': load_image('background.png'),
'pipe-end': load_image('pipe_end.png'),
'pipe-body': load_image('pipe_body.png'),
# images for animating the flapping bird -- animated GIFs are
# not supported in pygame
'bird-wingup': load_image('bird_wing_up.png'),
'bird-wingdown': load_image('bird_wing_down.png'), }
class Bird(pygame.sprite.Sprite):
"""
Flappy Bird类
"""
WIDTH = HEIGHT = 32
SINK_SPEED = 0.2
Fail_SINk_SPEED = 0.6
CLIMB_SPEED = 0.25
CLIMB_DURATION = 333.3
REGION = CLIMB_DURATION / 3
NEAR_COLLIDE = 30
NEAR_PIPE = 0
def __init__(self, x, y, msec_to_climb, images):
super(Bird, self).__init__()
self.x, self.y = x, y
self.msec_to_climb = msec_to_climb
self._img_wingup, self._img_wingdown = images
self._mask_wingup = pygame.mask.from_surface(self._img_wingup)
self._mask_wingdown = pygame.mask.from_surface(self._img_wingdown)
def update(self, action, state, delta_frames=1):
"""
更新小鸟的位置
:param action: 输入行为
:param state:输入状态
:param delta_frames:Fault
:return:None
"""
if self.msec_to_climb > 0 and action == 1:
if state == 4 or state == 5 or state == 2 or state == 3:
self.y -= (2 * Bird.CLIMB_SPEED * (1000.0 * delta_frames / 60))
else:
self.y -= (Bird.CLIMB_SPEED * (1000.0 * delta_frames / 60))
else:
if state == 4 or state == 5 or state == 2 or state == 3:
self.y += 2 * Bird.SINK_SPEED * (1000.0 * delta_frames / 60)
else:
self.y += Bird.SINK_SPEED * (1000.0 * delta_frames / 60)
def sink(self, delta_frames=1):
self.y += Bird.Fail_SINk_SPEED * (1000.0 * delta_frames / 60)
@property
def image(self):
if pygame.time.get_ticks() % 500 >= 250:
return self._img_wingup
else:
return self._img_wingdown
@property
def mask(self):
if pygame.time.get_ticks() % 500 >= 250:
return self._mask_wingup
else:
return self._mask_wingdown
@property
def rect(self):
return Rect(self.x, self.y, Bird.WIDTH, Bird.HEIGHT)
class PipePair(pygame.sprite.Sprite):
"""
Flappy Bird 中的管子类
"""
WIDTH = 80
PIECE_HEIGHT = 32
ADD_INTERVAL = 2000
ADD_EVENT = pygame.USEREVENT + 1
ROOM_HIGHT = 2 * Bird.HEIGHT + 2 * PIECE_HEIGHT
def __init__(self, pipe_end_img, pipe_body_img):
self.x = float(WIN_WIDTH - 1)
self.score_counted = False
self.isNewPipe = True
self.image = pygame.Surface((PipePair.WIDTH, WIN_HEIGHT), SRCALPHA)
self.image.convert() # speeds up blitting
self.image.fill((0, 0, 0, 0))
total_pipe_body_pieces = int(
(WIN_HEIGHT - # fill window from top to bottom
3 * Bird.HEIGHT - # make room for bird to fit through
3 * PipePair.PIECE_HEIGHT) / # 2 end pieces + 1 body piece
PipePair.PIECE_HEIGHT # to get number of pipe pieces
)
self.bottom_pieces = randint(1, total_pipe_body_pieces)
self.top_pieces = total_pipe_body_pieces - self.bottom_pieces
# bottom pipe
for i in range(1, self.bottom_pieces + 1):
piece_pos = (0, WIN_HEIGHT - i * PipePair.PIECE_HEIGHT)
self.image.blit(pipe_body_img, piece_pos)
bottom_pipe_end_y = WIN_HEIGHT - self.bottom_height_px
bottom_end_piece_pos = (0, bottom_pipe_end_y - PipePair.PIECE_HEIGHT)
self.image.blit(pipe_end_img, bottom_end_piece_pos)
# top pipe
for i in range(self.top_pieces):
self.image.blit(pipe_body_img, (0, i * PipePair.PIECE_HEIGHT))
top_pipe_end_y = self.top_height_px
self.image.blit(pipe_end_img, (0, top_pipe_end_y))
self.center = (top_pipe_end_y + bottom_pipe_end_y) / 2
# compensate for added end pieces
self.top_pieces += 1
self.bottom_pieces += 1
# for collision detection
self.mask = pygame.mask.from_surface(self.image)
self.top_y = top_pipe_end_y
self.bottom_y = bottom_pipe_end_y
@property
def top_height_px(self):
return self.top_pieces * PipePair.PIECE_HEIGHT
@property
def bottom_height_px(self):
return self.bottom_pieces * PipePair.PIECE_HEIGHT
@property
def visible(self):
return -PipePair.WIDTH < self.x < WIN_WIDTH
@property
def rect(self):
return Rect(self.x, 0, PipePair.WIDTH, PipePair.PIECE_HEIGHT)
def update(self, delta_frames=1):
self.x -= 0.18 * 1000.0 * delta_frames / 60
def collides_with(self, bird):
return pygame.sprite.collide_mask(self, bird)
def judgeState(bird, pipes, collide):
"""
根据小鸟和管子之间的位置关系判断当前状态
:param bird:传入小鸟的各项属性
:param pipes:传入管子的各项属性
:param collide:是否发生碰撞
:return:状态,距离,是否是新的管子
"""
# bird's x and y coordinate in the left top of the image
dist = bird.y + Bird.HEIGHT / 2 - WIN_HEIGHT / 2
isNew = False
index = -1
state = -1
if collide:
state = 8
return state
for p in pipes:
if p.x + PipePair.WIDTH - Bird.HEIGHT / 4 < bird.x and not p.score_counted:
continue
if p.x - Bird.NEAR_PIPE <= bird.x + Bird.HEIGHT and \
p.x + PipePair.WIDTH - Bird.HEIGHT / 4 >= bird.x:
p_top_y = p.top_y + PipePair.PIECE_HEIGHT
p_bottom_y = p.bottom_y - PipePair.PIECE_HEIGHT
if p.center - bird.y - Bird.HEIGHT / 2 >= 0 and bird.y >= p_top_y + Bird.NEAR_COLLIDE / 2:
state = 0
elif bird.y - p.center + Bird.HEIGHT / 2 > 0 and bird.y + Bird.HEIGHT <= p_bottom_y - Bird.NEAR_COLLIDE / 2:
state = 1
elif bird.y < p_top_y + Bird.NEAR_COLLIDE / 2 and bird.y > p_top_y - 10:
state = 6
elif bird.y + Bird.HEIGHT > p_bottom_y - Bird.NEAR_COLLIDE / 2 and bird.y + Bird.HEIGHT < p_bottom_y + 10:
state = 7
if state > -0.5:
index = 1
elif p.x > bird.x + Bird.HEIGHT + Bird.NEAR_PIPE:
state = blankState(bird, p.center)
if p.isNewPipe:
isNew = True
p.isNewPipe = False
index = 1
if index > 0: # only judge the nearest and not passed pipe
dist = bird.y + Bird.HEIGHT / 2 - p.center
break
if index < -0.5: # no pipe left, key the bird in the middle
pos = WIN_HEIGHT / 2
dist = bird.y + Bird.HEIGHT / 2 - pos
state = blankState(bird, pos)
return state, dist, isNew
def blankState(bird, center):
"""
judgeState中调用的判断状态的函数 根据鸟的位置和管子中心的距离来判断
:param bird: 传入小鸟的各项属性
:param center:中心
:return:状态
"""
realHeight = (PipePair.ROOM_HIGHT - Bird.HEIGHT) / 2
if center - bird.y - Bird.HEIGHT / 2 >= 0 and \
center - bird.y - Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2:
state = 0
elif bird.y - center + Bird.HEIGHT / 2 >= 0 and \
bird.y - center + Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2:
state = 1
elif center - bird.y - Bird.HEIGHT / 2 >= realHeight - Bird.NEAR_COLLIDE / 2 and \
center - bird.y - Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:
state = 2
elif bird.y - center + Bird.HEIGHT / 2 >= realHeight - Bird.NEAR_COLLIDE / 2 and \
bird.y - center + Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:
state = 3
elif bird.y + Bird.HEIGHT / 2 <= center - (realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION):
state = 4
elif bird.y + Bird.HEIGHT / 2 >= center + realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:
state = 5
return state
def getReward(state, lastState, smallerError, isNewPipe):
"""
根据状态和距离的变化获得奖励
:param state: 执行行为后的当前状态
:param lastState:执行行为之前的上一状态
:param smallerError:距离是否变小
:param isNewPipe:是否是新的管子
:return:奖励
"""
if state == 0 or state == 1:
reward = 6
elif state == 2 or state == 3:
if lastState == state and not isNewPipe:
if smallerError:
reward = 3
else:
reward = -5
else:
reward = -3
elif state == 4 or state == 5:
if lastState == state and not isNewPipe:
if smallerError:
reward = 3
else:
reward = -8
else:
reward = -5
elif state == 6 or state == 7:
if lastState == state and not isNewPipe:
if smallerError:
reward = 3
else:
reward = -3
else:
reward = -3
elif state == 8: # collide
reward = -100
return reward
if __name__ == "__main__":
n_agent=1
num = 8
p_amount = int(num * num / 10)
s_amount = 4
num_state = 9
num_action = 2
weight_exc = 1
weight_inh = -0.5
trace_decay = 0.8
gens=1000
for seed in range(seeds):
kkk = Evolve(n_var=num*num,
n_obj=2, n_constr=0)
method = engine.nsganet(pop_size=n_agent,
sampling=RandomSampling(var_type='custom'),
mutation=BinaryBitflipMutation(),
n_offsprings=10,
eliminate_duplicates=True)
kres=minimize(kkk,
method,
callback=do_every_generations,
termination=('n_gen', gens))
pop_var = kres.X
pop_obj = kres.F
lm=torch.from_numpy(pop_var[np.argmin(pop_obj[:, 0])].reshape(num,num))
model = SNN(ins=9,num_classes=2,n_agent=n_agent,device='cuda:0',liquid_size=num,lsm_tau=2,lsm_th=0.2,connectivity_matrix=lm.to('cuda:0').float())
model.to('cuda:0')
con_matrix1 = torch.zeros((num_state, num_state * num_action), dtype=torch.float)
for i in range(num_state):
for j in range(num_action):
con_matrix1[i, i * num_action + j] = weight_exc
weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)
weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)
pygame.init()
WIN_HEIGHT = 512
WIN_WIDTH = 284 * 2
heighest = 0
contTime = 0
display_frame = 0
display_surface = pygame.display.set_mode((WIN_WIDTH, WIN_HEIGHT))
pygame.display.set_caption('Flappy Bird')
images = load_images()
bird = Bird(250, int(WIN_HEIGHT / 2 - Bird.HEIGHT / 2), 2,
(images['bird-wingup'], images['bird-wingdown']))
clock = pygame.time.Clock()
score_font = pygame.font.SysFont(None, 25, bold=True)
info_font = pygame.font.SysFont(None, 50, bold=True)
collide = paused = False
frame_clock = 0
pipes = deque()
score = 0
lastDist = 0
lastState = 0 # init
state = lastState
i = 0
num_reward = []
num_score = []
reward=1
while not collide:
i = i + 1
if i > steps:
break
clock.tick(60)
if frame_clock % 2 == 0 or frame_clock == 1:
state, dist, isNewPipe = judgeState(bird, pipes, collide)
lastState = state
lastDist = dist
action= model(F.one_hot(torch.tensor([state]), num_classes=num_state).to('cuda:0').float()).cpu().detach().numpy()
action=int(np.argmax(action,axis=1))
print(i)
if not (paused or frame_clock % (60 * PipePair.ADD_INTERVAL / 1000.0)):
pygame.event.post(pygame.event.Event(PipePair.ADD_EVENT))
for e in pygame.event.get():
if e.type == QUIT or (e.type == KEYUP and e.key == K_ESCAPE):
collide = True
elif e.type == KEYUP and e.key in (K_PAUSE, K_p):
paused = not paused
elif e.type == PipePair.ADD_EVENT:
pp = PipePair(images['pipe-end'], images['pipe-body'])
pipes.append(pp)
if paused:
continue # don't draw anything
pipe_collision = any(p.collides_with(bird) for p in pipes)
if pipe_collision or 0 >= bird.y or bird.y >= WIN_HEIGHT - Bird.HEIGHT:
collide = True
for x in (0, WIN_WIDTH / 2):
display_surface.blit(images['background'], (x, 0))
while pipes and not pipes[0].visible:
pipes.popleft()
for p in pipes:
p.update()
display_surface.blit(p.image, p.rect)
bird.update(action, state)
display_surface.blit(bird.image, bird.rect)
if frame_clock % 2 == 0 or frame_clock == 1 or collide:
dist = 0
if collide:
nextState = 8
isNewPipe = False
else:
nextState, dist, isNewPipe = judgeState(bird, pipes, collide) # judge the bird's state
print("next state:", nextState)
print("lastdist, dist:", lastDist, dist)
isSmallerError = False
if state == nextState:
isSmallerError = False
if lastDist <= 0:
if lastDist < dist:
isSmallerError = True
else:
if lastDist > dist:
isSmallerError = True
if frame_clock > 0 and not collide:
reward = getReward(nextState, state, isSmallerError, isNewPipe)
print("reward:", reward)
num_reward.append(reward)
bcmreward=np.array([reward, reward])
# bcm(model,bcmreward, input=input)
state = nextState # going on the next state
weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)
weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)
model.reset()
display_frame += 1
for p in pipes:
if p.x + PipePair.WIDTH < bird.x and not p.score_counted:
score += 1
p.score_counted = True
num_score.append(score)
score_surface = score_font.render('Current score: ' + str(score), True, (0, 0, 0)) # current score
score_x = WIN_WIDTH / 2 - 3 * score_surface.get_width() / 4
display_surface.blit(score_surface, (score_x, PipePair.PIECE_HEIGHT))
if heighest < score:
heighest = score
score_surface_h = score_font.render('Highest score: ' + str(heighest), True,
(0, 0, 0)) # heighest score
score_x_h = 4 * WIN_WIDTH / 5 - 1.2 * score_surface.get_width() / 3
display_surface.blit(score_surface_h, (score_x_h, PipePair.PIECE_HEIGHT))
score_surface_i = score_font.render('Attempts: 0', True, (0, 0, 0)) # heighest score
score_x_i = 10
display_surface.blit(score_surface_i, (score_x_i, PipePair.PIECE_HEIGHT))
frame_clock += 1
pygame.display.flip()
# if collide, display the fail information, for 2 frames
cct = 0
while (bird.y < WIN_HEIGHT - Bird.HEIGHT - 3):
clock.tick(60)
for x in (0, WIN_WIDTH / 2):
display_surface.blit(images['background'], (x, 0))
while pipes and not pipes[0].visible:
pipes.popleft()
for p in pipes:
display_surface.blit(p.image, p.rect)
if cct >= 6:
bird.sink()
display_surface.blit(bird.image, bird.rect)
fail_infor = info_font.render('Game over !', True, (255, 60, 30)) # current score
pos_x = WIN_WIDTH / 2 - fail_infor.get_width() / 2
pos_y = WIN_HEIGHT / 2 - 100
display_surface.blit(fail_infor, (pos_x, pos_y))
# display the score
score_surface = score_font.render('Current score: ' + str(score), True, (0, 0, 0)) # current score
score_x = WIN_WIDTH / 2 - 3 * score_surface.get_width() / 4
display_surface.blit(score_surface, (score_x, PipePair.PIECE_HEIGHT))
if heighest < score:
heighest = score
score_surface_h = score_font.render('Highest score: ' + str(heighest), True,
(0, 0, 0)) # heighest score
score_x_h = 4 * WIN_WIDTH / 5 - 1.2 * score_surface.get_width() / 3
display_surface.blit(score_surface_h, (score_x_h, PipePair.PIECE_HEIGHT))
score_surface_i = score_font.render('Attempts: 0', True, (0, 0, 0)) # heighest score
score_x_i = 10
display_surface.blit(score_surface_i, (score_x_i, PipePair.PIECE_HEIGHT))
pygame.display.flip()
cct += 1
if heighest < score:
heighest = score
contTime += 1
num_reward_np = np.array(num_reward)
print(num_reward_np)
k=num_reward_np.shape[0]
result[seed,:k]=num_reward_np
================================================
FILE: examples/Structure_Evolution/Adaptive_lsm/BrainCog-Version/lsmmodel.py
================================================
from functools import partial
from torch.nn import functional as F
from torch import nn as nn
import torchvision, pprint
from copy import deepcopy
from timm.models import register_model
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.encoder.encoder import *
from braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule
from braincog.base.brainarea.BrainArea import BrainArea
from braincog.base.connection.CustomLinear import *
from braincog.base.learningrule.STDP import *
from braincog.base.learningrule.BCM import *
import matplotlib.pyplot as plt
@register_model
class SNN(BaseModule):
def __init__(self,
liquid_size,
n_agent,
device,
connectivity_matrix,
num_classes=3,
step=1,
node_type=LIFNode,
encode_type='direct',
lsm_th=0.3,
fc_th=0.3,
lsm_tau=3,
fc_tau=3,
tw=100,
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.batchsize=n_agent
self.node_lsm=partial(node_type, **kwargs, step=step,tau=lsm_tau,threshold=lsm_th)
self.node_fc = partial(node_type, **kwargs, step=step,tau=fc_tau,threshold=fc_th)
self.liquid_size=liquid_size
self.out = torch.zeros(self.batchsize, liquid_size).to(device)
self.device=device
self.con=[]
self.learning_rule=[]
self.connectivity_matrix=connectivity_matrix
w1tmp=nn.Linear(4,liquid_size,bias=False).to(device)
self.con.append(w1tmp)
w2tmp=nn.Linear(liquid_size,liquid_size,bias=False).to(device)
self.liquid_weight=w2tmp.weight.data
w2tmp.weight.data=w2tmp.weight.data*self.connectivity_matrix
self.con.append(w2tmp)
self.learning_rule.append(BCM(self.node_lsm(), [self.con[0], self.con[1]])) # pm
self.fc = nn.Linear(liquid_size,num_classes).to(device)
self.learning_rule.append(BCM(self.node_fc(), [self.fc])) # pm
def forward(self, x):
x = x.reshape(x.shape[0], -1)
sum_spike=0
time_window=20
self.tw=time_window
self.firing_tw=torch.zeros(time_window, self.batchsize, self.liquid_size).to(self.device)
self.out = torch.zeros(self.batchsize, self.liquid_size).to(self.device)
for t in range(time_window):
self.out, self.dw = self.learning_rule[0](x, self.out)
# self.con[0].weight+=self.dw[0]
self.con[1].weight.data+=self.dw[1]
out_liquid=self.out[:,0:self.liquid_size]
xout,dw = self.learning_rule[1](out_liquid)
self.fc.weight.data+=dw[0]
sum_spike=sum_spike+xout
self.firing_tw[t]=out_liquid
outputs = sum_spike / time_window
return outputs
================================================
FILE: examples/Structure_Evolution/Adaptive_lsm/BrainCog-Version/maze.py
================================================
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import math
from matplotlib import pyplot as plt
import matplotlib
import seaborn as sns
from lsmmodel import SNN
from tools.ExperimentEnvGlobalNetworkSurvival import ExperimentEnvGlobalNetworkSurvival
from tools.MazeTurnEnvVec import MazeTurnEnvVec
import torch
import brewer2mpl
from cycler import cycler
import nsganet as engine
from pymop.problem import Problem
from pymoo.optimize import minimize
from pymoo.operators.sampling.random_sampling import RandomSampling
from pymoo.operators.mutation.bitflip_mutation import BinaryBitflipMutation
def randbool(size, p):
return torch.rand(*size) < p
class Evolve(Problem):
# first define the NAS problem (inherit from pymop)
def __init__(self, n_var=20, n_obj=1, n_constr=0, lb=None, ub=None):
super().__init__(n_var=n_var, n_obj=n_obj, n_constr=n_constr, type_var=np.int64)
self.xl = lb
self.xu = ub
self._n_evaluated = 0 # keep track of how many architectures are sampled
def _evaluate(self, x, out, *args, **kwargs):
objs = np.full((x.shape[0], self.n_obj), np.nan)
for i in range(x.shape[0]):
arch_id = self._n_evaluated + 1
print('Network= {}'.format(arch_id))
objs[i, 0] = np.linalg.matrix_rank(x[i])
self._n_evaluated += 1
out["F"] = objs
# if your NAS problem has constraints, use the following line to set constraints
# out["G"] = np.column_stack([g1, g2, g3, g4, g5, g6]) in case 6 constraints
# ---------------------------------------------------------------------------------------------------------
# Define what statistics to print or save for each generation
# ---------------------------------------------------------------------------------------------------------
def do_every_generations(algorithm):
# this function will be call every generation
# it has access to the whole algorithm class
gen = algorithm.n_gen
pop_var = algorithm.pop.get("X")
pop_obj = algorithm.pop.get("F")
# report generation info to files
print("generation = {}".format(gen))
print("population error: best = {}, mean = {}, "
"median = {}, worst = {}".format(np.min(pop_obj[:, 0]), np.mean(pop_obj[:, 0]),
np.median(pop_obj[:, 0]), np.max(pop_obj[:, 0])))
print('Best Genome id= {}'.format(np.argmin(pop_obj[:, 0])))
if __name__ == '__main__':
device = 'cuda:8'
num = 8
n_agent = 20
steps = 500
liquid_size=80
env = MazeTurnEnvVec(n_agent, n_steps=steps)
newenv=MazeTurnEnvVec(n_agent, n_steps=steps)
data_env = ExperimentEnvGlobalNetworkSurvival(env)
newdata_env = ExperimentEnvGlobalNetworkSurvival(newenv)
gens=100
seed=0
sum_of_env = np.zeros([gens, n_agent])
env_r=np.zeros([steps,n_agent])
population = torch.zeros(n_agent,liquid_size,liquid_size)
for i in range(n_agent):
population[i] = randbool([liquid_size, liquid_size],p=0.01).to(device).float()
kkk = Evolve(n_var=liquid_size*liquid_size,
n_obj=1, n_constr=2)
method = engine.nsganet(pop_size=n_agent,
sampling=RandomSampling(var_type='custom'),
mutation=BinaryBitflipMutation(),
n_offsprings=10,
eliminate_duplicates=True)
kres=minimize(kkk,
method,
callback=do_every_generations,
termination=('n_gen', gens))
# lm=evolve(population, gens)
model = SNN(ins=4,n_agent=n_agent,device=device,liquid_size=liquid_size,lsm_tau=2,lsm_th=0.2,connectivity_matrix=randbool([liquid_size, liquid_size],p=0.01).to(device).float())
model.to(device)
old_dis = np.ones([n_agent,])*13
X = data_env.reset()
envreward = np.zeros([n_agent, ])
fit=np.zeros([n_agent])
for i in range(steps):
model.reset()
out = model(torch.from_numpy(X+1).float().to(device)).cpu().detach().numpy()
X_next, envreward, fitness, infos = data_env.step(np.argmax(out,axis=1))
food_pos = data_env.env.food_pos[:, 0, :2]
agent_pos = data_env.env.agents_pos
print(agent_pos)
dis = ((agent_pos - food_pos) ** 2).sum(1)
reward =np.array((np.sqrt(old_dis)-np.sqrt(dis))>0,dtype=int)
aa=np.ones_like(reward)*-1
bb = np.ones_like(reward)*3
cc = np.ones_like(reward)*-3
reward=np.where(reward == 0 , aa, reward)
reward=np.where(envreward == 1, bb, reward)
reward = np.where(envreward == -1, cc, reward)
old_dis= dis
env_r[i]=reward
================================================
FILE: examples/Structure_Evolution/Adaptive_lsm/BrainCog-Version/tools/EnuGlobalNetwork.py
================================================
import pickle
import time
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import gridspec
from AbstractLayerBMM import AbstractLayerBMM
from EvolvableNeuralUnitStacked import EvolvableNeuralUnitStacked
from Tools import get_data_path
sns.set_style("darkgrid")
class EnuGlobalNetwork(AbstractLayerBMM):
"""Network of ENUs implementation in PyTorch, where each synapse and neuron is modeled as an ENU. """
def __init__(self, n_offspring, n_pseudo_env, n_input_neurons, n_hidden_neurons, n_output_neurons, n_syn_per_neuron):
# offspring
self.n_offspring = n_offspring
self.n_pseudo_env = n_pseudo_env
# input channels
n_input_channels = 16
self.n_input_channels = n_input_channels
n_dynamic_param = 32
# total neurons
n_neurons = n_output_neurons + n_hidden_neurons
self.n_neurons = n_neurons
super().__init__(n_offspring, n_neurons, n_input_neurons, n_output_neurons)
torch.random.manual_seed(0)
#NOTE: batch dimension holds output of each neuron/synapse, allowing fast GPU MM
#NOTE neurons far less than synapses, so can be relatively bigger rnn for little cost
n_input_channels_neuron = 16
n_input_neuron, n_output_neuron = n_input_channels_neuron, n_input_channels
self.neurons = EvolvableNeuralUnitStacked(n_offspring, batch_size=self.n_neurons, n_input=n_input_neuron, n_dynamic_param=n_dynamic_param, n_output=n_output_neuron)
#self.n_syn = next_power_of_2(int(n_neurons * (rel_connectivity*n_neurons)))
self.n_syn_per_neuron = n_syn_per_neuron
self.n_syn = n_neurons * n_syn_per_neuron
n_input_syn, n_output_syn = n_input_channels * 2, n_input_channels_neuron # * 2 for neuron feedback (which same channel as n_channel input)
self.synapses = EvolvableNeuralUnitStacked(n_offspring, batch_size=self.n_syn, n_input=n_input_syn, n_dynamic_param=n_dynamic_param, n_output=n_output_syn)
# just randomly connect synapses to neurons
self.synapse_connections = torch.randint(n_input_neurons + n_neurons, size=(n_neurons, n_syn_per_neuron), device='cuda', dtype=torch.long)
# fixed predefined connection patterns
if n_input_neurons==2 and n_output_neurons==2 and n_hidden_neurons==2:
print("Fixed connection Network 2-2-2")
self.synapse_connections = torch.tensor([[0, 1],
[0, 1],
[2, 3],
[2, 3]], device='cuda', dtype=torch.long)
elif n_input_neurons == 4 and n_output_neurons == 3 and n_hidden_neurons == 3 and n_syn_per_neuron==3:
print("Fixed connection Network 4-3-3 (3syn)")
self.synapse_connections = torch.tensor([[0, 1, 3],# hidden connections #4
[0, 2, 3], #5
[1, 2, 3],# 6
[4, 5, 6], # output connections #7
[4, 5, 6],#8
[4, 5, 6]#9
], device='cuda', dtype=torch.long)
elif n_input_neurons==5 and n_hidden_neurons==0 and n_output_neurons==4:
print("Fixed connection Network 5-0-4 (5syn)")
# neuron i connected to neuron j and k, neuron 0..input_neurons is index
self.synapse_connections = torch.tensor([[0, 1, 2, 3, 4],# output connections
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]
], device='cuda', dtype=torch.long)
elif n_input_neurons==1 and n_hidden_neurons==0 and n_output_neurons==2:
print("Fixed connection Network 1-0-2 (1syn)")
# neuron i connected to neuron j and k, neuron 0..input_neurons is index
self.synapse_connections = torch.tensor([[0],# output connections
[0]
], device='cuda', dtype=torch.long)
elif n_input_neurons==4 and n_hidden_neurons==0 and n_output_neurons==3 and n_syn_per_neuron==4:
print("Sparse connection Network 4-0-3 (4syn)")
# neuron i connected to neuron j and k, neuron 0..input_neurons is index
self.synapse_connections = torch.tensor([[0, 1, 2, 3],# output connections #4
[0, 1, 2, 3], #5
[0, 1, 2, 3],# 6
], device='cuda', dtype=torch.long)
elif n_input_neurons == 4 and n_hidden_neurons == 3 and n_output_neurons == 3 and n_syn_per_neuron == 4:
print("Sparse connection Network 4-3-3 (3syn)")
# neuron i connected to neuron j and k, neuron 0..input_neurons is index
self.synapse_connections = torch.tensor([[0, 1, 3], # hidden connections #4
[0, 2, 3], # 5
[1, 2, 3], # 6
[4, 5, 3], # output connections #7
[4, 6, 3], # 8
[5, 6, 3] # 9
], device='cuda', dtype=torch.long)
elif n_input_neurons==4 and n_hidden_neurons==3 and n_output_neurons==3 and n_syn_per_neuron==8:
print("Sparse connection Network 4-3-3 (8syn)")
# neuron i connected to neuron j and k, neuron 0..input_neurons is index
self.synapse_connections = torch.tensor([[0, 1, 5, 6, 7, 8, 3, 4],# hidden connections #4
[0, 2, 4, 6, 7, 9, 3, 5], #5
[1, 2, 4, 5, 8, 9, 3, 6],# 6
[4, 5, 8, 9, 0, 1, 3, 7], # output connections #7
[4, 6, 7, 9, 0, 2, 3, 8],#8
[5, 6, 7, 8, 1, 2, 3, 9]#9
], device='cuda', dtype=torch.long)
elif n_input_neurons==4 and n_hidden_neurons==4 and n_output_neurons==4 and n_syn_per_neuron==8:
print("Fixed connection Network 4-4-4 (8syn)")
# neuron i connected to neuron j and k, neuron 0..input_neurons is index
self.synapse_connections = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7],# hidden connections
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[4, 5, 6, 7, 8, 9, 10, 11], # output connections
[4, 5, 6, 7, 8, 9, 10, 11],
[4, 5, 6, 7, 8, 9, 10, 11],
[4, 5, 6, 7, 8, 9, 10, 11],
], device='cuda', dtype=torch.long)
elif n_input_neurons==5 and n_hidden_neurons==5 and n_output_neurons==4:
print("Fixed connection Network 5-5-4 (5syn)")
# neuron i connected to neuron j and k, neuron 0..input_neurons is index
self.synapse_connections = torch.tensor([[0, 1, 2, 3, 4],# hidden connections
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9], # output connections
[5, 6, 7, 8, 9],
[5, 6, 7, 8, 9],
[5, 6, 7, 8, 9],
], device='cuda', dtype=torch.long)
elif n_input_neurons==1 and n_hidden_neurons==0 and n_output_neurons==1:
print("Fixed connection Single")
self.synapse_connections = torch.tensor([[0]], device='cuda', dtype=torch.long)
else:
print("Random connections")
# each synapse is connected also to its post-synaptic neuron, to allow STDP type learning to emerge
self.synapse_connections_post = torch.arange(n_neurons, device='cuda', dtype=torch.long).reshape(n_neurons, -1).repeat(1, n_syn_per_neuron)
# define compartments
self.compartments = [self.neurons, self.synapses]
self.trainable_layers = self.neurons.trainable_layers + self.synapses.trainable_layers
self.track_data = False
def dump_model(self, e, exp_name):
"""Dump model to restore"""
with open(get_data_path(e, exp_name, "Model"), 'wb') as f:
parameters = {}
parameters["neuron"] = [layer.base_parameters.cpu().numpy() for layer in self.neurons.trainable_layers]
parameters["synapse"] = [layer.base_parameters.cpu().numpy() for layer in self.synapses.trainable_layers]
pickle.dump(parameters, f)
def restore_model(self, e, exp_name):
"""Restore model"""
with open(get_data_path(e, exp_name, "Model"), 'rb') as f:
parameters = pickle.load(f)
#TODO: refactor to dump/restore at ENU level and just call those functions
assert len(self.neurons.trainable_layers) == len(parameters["neuron"])
for i in range(len(parameters["neuron"])):
self.neurons.trainable_layers[i].base_parameters = torch.from_numpy(parameters["neuron"][i].astype(np.float32)).cuda()
assert len(self.synapses.trainable_layers) == len(parameters["synapse"])
for i in range(len(parameters["synapse"])):
self.synapses.trainable_layers[i].base_parameters = torch.from_numpy(parameters["synapse"][i].astype(np.float32)).cuda()
@staticmethod
def plot_weights(e, exp_name):
"""Visualize weights of ENU gates"""
sns.set_style("dark")
def calc_average(start, stop):
weights_average = None
for e in range(start, stop, 1000):
with open(get_data_path(e, exp_name, "Model"), 'rb') as f:
parameters = pickle.load(f)
weights = []
for i in range(len(parameters["neuron"])):
weights += [parameters["neuron"][i].astype(np.float32)]
if weights_average is None:
weights_average = weights
else:
for i in range(len(weights_average)):
weights_average[i] += weights[i]
return weights_average
weights_mean1 = calc_average(20000, 30000)
fig, ax = plt.subplots(1, 2, sharex='col', sharey='row')
for i in range(len(weights_mean1)):
ax[i].imshow(weights_mean1[i], cmap="gray")
weights_mean2 = calc_average(30000, 40000)
fig, ax = plt.subplots(1, 2, sharex='col', sharey='row')
for i in range(len(weights_mean2)):
ax[i].imshow(weights_mean2[i], cmap="gray")
fig, ax = plt.subplots(1, 2, sharex='col', sharey='row')
for i in range(len(weights_mean2)):
ax[i].imshow((weights_mean2[i] - weights_mean1[i])**5, cmap="gray")
plt.show()
def dump_network_activity(self, e, exp_name):
"""Dump raw data for visualization"""
with open(get_data_path(e, exp_name, "GlobalNetwork"), 'wb') as f:
pickle.dump(self.vis_data, f)
def print(self):
print("--Neurons--")
self.neurons.print()
print("--Synapses--")
self.synapses.print()
def reset(self):
self.vis_data = []
if self.track_data:
print("Tracking network activity")
for compartment in self.compartments:
compartment.reset()
def forward(self, X):
"""Main computation forward pass"""
# transfer to GPU
X_raw_gpu = torch.from_numpy(X.astype(np.float32)).cuda()
X_gpu = torch.zeros((X.shape[0], X.shape[1], self.n_input_channels), device='cuda', dtype=torch.float32)
X_gpu[:, :, :X_raw_gpu.shape[2]] = X_raw_gpu
# first compute synapses, set input to previous output of connected neuron
# concat our input spiking pattern directly to input to our synapses (the neurons)
# NOTE: this concats in batch dimension, meaning it feeds into input neurons directly spiking pattern, while rest receive input from network
input_to_synapses = torch.cat([X_gpu, self.neurons.out_mem], dim=1)
# connect each synapse randomly to multiple inputs
input_to_synapses_connected = input_to_synapses[:, self.synapse_connections.flatten(), :]
# need feedback connection from neuron to synapse, to allow stdp type rules to emerge (else it has to do it through feedback connections, but less guarentee on connections and cannot distinguise type)
# one synapse has 1 pre-synaptic neuron and 1 post-synaptic neuron, connectection defined in synapse_connections, synapse_connections[i, :] gives all input synapses of that neuron
# so feedback to all it's input synapses through broadcasting backwards
post_neuron_backprop_connected = self.neurons.out_mem[:, self.synapse_connections_post.flatten(), :]
input_to_synapses_connected = torch.cat([input_to_synapses_connected, post_neuron_backprop_connected], dim=-1)
# compute synapse
self.synapses.forward(input_to_synapses_connected)
# then integrate(sum) all outputs of a neurons input synapses, can just reshape into valid shape, since we already randomly connected when computing synapses
# NOTE: each neuron then requires same number of synapses, then reshape by modifying batch dim (which contains syn outputs)
integration = torch.sum(self.synapses.out.reshape((self.n_offspring, self.n_neurons, -1, self.synapses.shape[-1])), dim=2)
# scale by number of synapses
integration /= self.n_syn_per_neuron
self.out_integration = integration
# finally set neuron input to summated connected synapses output
input_to_neurons = integration
out = self.neurons.forward(input_to_neurons)
# output is last neuron output, NOTE: just first channel is returned, since we reshape neurons to channels
self.out = out[:, -self.n_output:, 0].reshape(self.n_offspring, self.n_output)
if self.track_data:
self._track_vis_data(X, input_to_synapses_connected, input_to_neurons)
return self.out
def _track_vis_data(self, X, input_to_synapses_connected, input_to_neurons):
offspring_idx = 0
self.vis_data += [(X[offspring_idx], input_to_neurons[offspring_idx].cpu().numpy(), self.neurons.out[offspring_idx].cpu().numpy(),
input_to_synapses_connected[offspring_idx].cpu().numpy(), self.synapses.out[offspring_idx].cpu().numpy())]
@staticmethod
def plot_network_activity(e, exp_name):
with open(get_data_path(e, exp_name, "GlobalNetwork"), 'rb') as f:
vis_data = pickle.load(f)
X, input_to_neurons, neurons_out, input_to_synapses, synapses_out = map(np.array, zip(*vis_data))
def plot_enu_activity(input, output, title):
n_cells = output.shape[1]
n_cells = np.minimum(10, output.shape[1])
fig, grid = plt.subplots(2, n_cells, sharex='col', sharey='row')
if n_cells==1:
grid[0].plot(input[:, 0, :])
grid[1].plot(output[:, 0, :])
else:
for i in range(n_cells):
grid[0, i].plot(input[:, i, :])
grid[1, i].plot(output[:, i, :])
plt.xlabel("t")
plt.title(title)
#plt.ylabel("")
plt.legend()
plt.figure()
plt.plot(X[:, :, 0])
plot_enu_activity(input_to_neurons, neurons_out, "ENU neuron activity")
plot_enu_activity(input_to_synapses, synapses_out, "ENU synapse activity")
plt.figure()
spike_points = np.where(neurons_out[:, :, 0] > 0)
plt.scatter(spike_points[0], spike_points[1], marker='|')
plt.show()
================================================
FILE: examples/Structure_Evolution/Adaptive_lsm/BrainCog-Version/tools/ExperimentEnvGlobalNetworkSurvival.py
================================================
import pickle
import numpy as np
from tools.Tools import get_data_path
class ExperimentEnvGlobalNetworkSurvival:
"""Wrapper around a given RL environment for a Network of ENUs model,
turns reward into fitness and dumps relevant data"""
def __init__(self, env, exp_name='maze'):
self.env = env
self.exp_name = exp_name
self.n_output = self.env.n_actions
#NOTE: +1 reward neuron
self.n_input_neurons = self.env.n_obs + 1
self.n_agents = self.env.n_agents
def _convert_obs(self, obs, rewards):
n_input_channels_used = 3
X = np.zeros((self.n_agents, self.n_input_neurons, n_input_channels_used))
#X[:, :obs.shape[1], 0] = obs
# Shuffle only obs to avoid topology exploitation, reward neuron linked to EnuGlobal synapse connectivity
X[:, :obs.shape[1], 0] = np.take_along_axis(obs, self.obs_shuffle, axis=1)
# split pos and negative reward to different channels, And set to last input neuron
if rewards is not None:
X[rewards>0, -1, 1] = np.abs(rewards[rewards>0])
X[rewards<=0, -1, 2] = np.abs(rewards[rewards<=0])
return X
def _convert_reward(self, obs, actions, rewards, infos, dones):
fitness = np.copy(rewards)
# first poison is considered positive reward, since learning to learn
#NOTE: dead by env means less reward can be obtained so should implictely reduce overall fitness automatically
fitness[np.logical_and(self._prev_reward_count == 1, rewards != 0)] = 1
# include episode length as extra fitness, since not taking poison would allow survive longer, so should try avoid take poison
fitness[dones==0] += 0.1/4
return fitness
def step(self, y):
# if self.t % 3 != 0:
# actions = np.zeros((self.n_agents), dtype=np.int32) - 1
# else:
# winner take all, in given time window
actions = y
# if all same output, do nothing
# equal_actions = self.y_hist.shape[1] == np.sum(self.y_hist == np.take_along_axis(self.y_hist, actions.reshape(-1, 1), axis=1), axis=-1)
# actions[equal_actions] = -1
# self.y_hist[:] = 0
# take env step
allobs, obs, rewards, dones, infos = self.env.step(actions)
# X = self._convert_obs(obs, rewards)
X=allobs
self._prev_reward_count += rewards!=0
fitness = self._convert_reward(obs, actions, rewards, infos, dones)
self._prev_action = actions
self._prev_obs = obs
return X, rewards, fitness, None
def reset(self):
self.t = 0
self.y_hist = np.zeros((self.n_agents, self.n_output), dtype=np.float32)
self._prev_action = None
self._prev_obs = None
self._prev_reward_count = np.zeros((self.n_agents), dtype=np.float32)
# each time different input/output neurons should have different meaning, to have learning to learn
self.obs_shuffle = np.argsort(np.random.randn(self.n_agents, self.n_input_neurons - 1), axis=1, kind='mergesort')
self.action_shuffle = np.argsort(np.random.randn(self.n_agents, self.n_output), axis=1, kind='mergesort')
# reset env
self.allobs,self.obs = self.env.reset()
# return self._convert_obs(self.obs, None)
return self.allobs
def render(self):
if self.t%4==0:
self.env.render()
def track_vis_data(self, vis_data, model, X, y_est, t):
n_fetch = 128
# TODO: also get our gates from the model
vis_data+=[(X[:n_fetch, :], y_est[:n_fetch, :])]
def dump_vis_data(self, vis_data, fitness_per_offspring, e):
with open(get_data_path(e, self.exp_name, "output"), 'wb') as f:
pickle.dump((vis_data, fitness_per_offspring), f)
@staticmethod
def load_vis_data(e, exp_name):
with open(get_data_path(e, exp_name, "output"), 'rb') as f:
vis_data, fitness_per_offspring = pickle.load(f)
return vis_data, fitness_per_offspring
@staticmethod
def plot_vis_data(e, exp_name):
vis_data, fitness_per_offspring = ExperimentEnvGlobalNetworkSurvival.load_vis_data(e, exp_name)
================================================
FILE: examples/Structure_Evolution/Adaptive_lsm/BrainCog-Version/tools/MazeTurnEnvVec.py
================================================
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tools.Tools import save_fig, get_data_path
# np.random.seed(0)
class MazeTurnEnvVec:
"""Vectorized RL T-Maze environment written in pure Numpy. We require an efficient environment since we need to evaluate
and run up to thousands of offspring in parallel"""
def __init__(self, n_agents, n_steps):
# 4 important points, start point, decision point, food point, dead point.
# just generate a very large matrix that could fit any maze of any size, then can generate smaller maze as well
self.n_actions = 3
self.n_obs = 3
self.max_size = 7
self.n_agents = n_agents
self.n_steps = n_steps
self.window = plt.figure()
self.t_maze = True
self.turn_based = False
# steps can be longer if poison and need to turn around
self.steps_to_food = 2
if self.t_maze:
self.steps_to_food = 3
self.steps_to_food += self.steps_to_food*2
# give some extra leniency
self.steps_to_food *= 2
def step(self, actions):
# L R U D
# TODO: check legal action or not..
pos_copy = np.copy(self.agents_pos)
actions = np.copy(actions)
# GIVE TIME UPDATE WEIGHTS
# actions[self.agents_reset > 0] = -1
# actions[self.agent_energy<=0] = -1
self.agents_reset[self.agents_reset > 0] -= 1
# if turn based
if self.turn_based:
Forward = actions == 0
self.agents_pos[np.logical_and(Forward, self.agent_directions == 0), 1] += 1
self.agents_pos[np.logical_and(Forward, self.agent_directions == 2), 1] -= 1
self.agents_pos[np.logical_and(Forward, self.agent_directions == 1), 0] -= 1
self.agents_pos[np.logical_and(Forward, self.agent_directions == 3), 0] += 1
L = actions == 1
self.agent_directions[L] += 1
R = actions == 2
self.agent_directions[R] -= 1
self.agent_directions[self.agent_directions > 3] = 0
self.agent_directions[self.agent_directions < 0] = 3
else:
# or just direct movement
U = actions == 2
D = actions == 1
R = actions == 0
if self.agents_pos[U].size>0:
self.agents_pos[U, 0] += 1
self.agent_directions[U] = 3
if self.agents_pos[D].size>0:
self.agents_pos[D, 0] -= 1
self.agent_directions[D] = 1
if self.t_maze and self.agents_pos[R].size>0:
self.agents_pos[R, 1] += 1
self.agent_directions[R] = 0
# UNDO MOVES THAT GOT AGENT INTO WALL
self.current_cells = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1]]
self.agents_pos[self.current_cells==1] = pos_copy[self.current_cells==1]
movement_loss = np.prod(self.agents_pos==pos_copy, axis=-1)
# CHECK IF FOOD CONSUMED, is reward + pos reset
consumed_food = np.prod(self.agents_pos==self.food_pos[:, 0, :2], axis=-1).astype(np.bool)
consumed_pois = np.prod(self.agents_pos==self.food_pos[:, 1, :2], axis=-1).astype(np.bool)
self.consumed_count += consumed_food.astype(np.int32)
self.consumed_count_total += consumed_food.astype(np.int32)
self.consumed_count_pois += consumed_pois.astype(np.int32)
self._reset_pos(np.logical_or(consumed_food, consumed_pois))
# self._reset_pos_pois(consumed_pois)
# self._reset_food(self.consumed_count==self.swap_limit, prob=0.0)
# reset food for agents that ate food, and swap with some probability
self._reset_food(self.consumed_count==5, prob=0.5)
self.rewards = consumed_food.astype(np.float32) - consumed_pois.astype(np.float32) #* 0.5 #- movement_loss.astype(np.float32) * 0.01
# get observation from current position of each agent
self.agent_allobs,self.obs = self._get_obs_from_pos()
# instant dead on second poison
self.agent_energy += np.where(self.consumed_count_pois<=1, np.abs(self.rewards) * self.steps_to_food, -self.agent_energy * np.abs(self.rewards))
# energy decay to encourage exploration, agent dies if running out of energy
self.agent_energy = np.minimum(self.agent_energy, self.steps_to_food)
self.agent_energy -= 1.0/4
dones = self.agent_energy<=0
return self.agent_allobs,self.obs, self.rewards, dones, None
def _reset_pos(self, idxs):
self.agents_pos[idxs] = [self.start_point, 2] # set X pos
self.agent_directions[idxs] = 0
if self.max_size==5:
self.agent_directions[idxs] = 1
self.agents_reset[idxs] = 0
self.agents_reset_count[idxs] += 1
def _reset_pos_pois(self, idxs):
#NOTE: 2 since we call reset twice!
#NOTE: turning around already cost 8 steps, then 4x4 more is 16+8, 24, so reset should be much worse
self.agents_reset[np.logical_and(idxs, self.agents_reset_count>2)] = 64
def _reset_food(self, idxs, prob=0.5):
# swap food with some probability, avoids agent overfitting on environment
swap = np.take_along_axis(self.random_swap_matrix, self.consumed_count_total.reshape(-1, 1), axis=1).ravel()
swap_idxs = swap * idxs
food_loc = np.copy(self.food_pos[swap_idxs, 0, :])
pois_loc = np.copy(self.food_pos[swap_idxs, 1, :])
self.food_pos[swap_idxs, 0, :] = pois_loc
self.food_pos[swap_idxs, 1, :] = food_loc
# set maze value
self.mazes[np.arange(self.mazes.shape[0]), self.food_pos[:, 0, 0], self.food_pos[:, 0, 1]] = 2
self.mazes[np.arange(self.mazes.shape[0]), self.food_pos[:, 1, 0], self.food_pos[:, 1, 1]] = 3
self.consumed_count[idxs] = 0
def reset(self):
self.consumed_count = np.zeros((self.n_agents), dtype=np.int32)
self.consumed_count_total = np.zeros_like(self.consumed_count)
# consistent swapping such that if agent eat food once for all agents swapped with same seed, fair fitness comparison
max_eat = self.n_steps
self.random_swap_matrix = np.random.uniform(0, 1, size=(1, max_eat)) >= 0.5
self.random_swap_matrix = np.repeat(self.random_swap_matrix, int(self.n_agents), axis=0)
self.agent_energy = np.zeros((self.n_agents), dtype=np.float32) + self.steps_to_food
self.consumed_count_pois = np.zeros_like(self.consumed_count)
#self.swap_limit = np.random.randint(1, 5, size=1)
#self.swap_limit = np.random.randint(1, 4, size=self.n_agents)
self.mazes = np.ones((self.n_agents, self.max_size, self.max_size), dtype=np.int32)
#TODO: support variable maze length
self.start_point = int(self.max_size/2)
if self.t_maze:
self.mazes[:, self.start_point, 2:-1] = 0
self.mazes[:, 1:-1, -2] = 0
# FOOD either at -1,-1 or -1,1?
# two foods: x, y, value
self.food_pos = np.zeros((self.n_agents, 2, 2), dtype=np.int32)
self.food_pos[:, :, 1] = self.max_size - 2
self.food_pos[:, 1, 0] = 1
self.food_pos[:, 0, 0] = self.max_size - 2
self._reset_food(np.ones(self.food_pos.shape[0], dtype=np.bool), prob=0.5)
else:
self.mazes[:, 1:-1, 1] = 0
# two foods: x, y, value
self.food_pos = np.zeros((self.n_agents, 2, 2), dtype=np.int32)
self.food_pos[:, :, 1] = 1
self.food_pos[:, 0, 0] = 1
self.food_pos[:, 1, 0] = (self.max_size - 2)
self._reset_food(np.ones(self.food_pos.shape[0], dtype=np.bool), prob=0.5)
# AGENT
self.agents_pos = np.ones((self.n_agents, 2), dtype=np.int32)
self.agents_reset = np.zeros((self.n_agents), dtype=np.int32)
self.agents_reset_count = np.zeros_like(self.agents_reset)
self.agent_directions = np.zeros((self.n_agents), dtype=np.int32)
self._reset_pos(np.arange(self.agents_pos.shape[0]))
# OBS
self.agent_allobs,self.obs = self._get_obs_from_pos()
return self.agent_allobs,self.obs
def _get_obs_from_pos(self):
# obs is neighbouring cell states around agent
obs = np.zeros((self.n_agents, self.n_obs), dtype=np.float32)
raw_obs = np.zeros(self.n_agents, dtype=np.int32)
# get observation based on direction agent is facing
leftobs = np.zeros(self.n_agents)
rightobs = np.zeros(self.n_agents)
backobs = np.zeros(self.n_agents)
# get observation based on direction agent is facing
D = self.agent_directions == 0
raw_obs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] + 1][D]
leftobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] - 1, self.agents_pos[:, 1]][D]
rightobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] + 1, self.agents_pos[:, 1]][D]
backobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1]-1][D]
D = self.agent_directions == 2
raw_obs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] - 1][D]
leftobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] + 1, self.agents_pos[:, 1]][D]
rightobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] - 1, self.agents_pos[:, 1]][D]
backobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1]+1][D]
D = self.agent_directions == 1
raw_obs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] - 1, self.agents_pos[:, 1]][D]
leftobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] - 1][D]
rightobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] + 1][D]
backobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0]+1, self.agents_pos[:, 1]][D]
D = self.agent_directions == 3
raw_obs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] + 1, self.agents_pos[:, 1]][D]
leftobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] + 1][D]
rightobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] - 1][D]
backobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0]-1, self.agents_pos[:, 1]][D]
# mark what was observed at different index
obs[raw_obs == 1, 0] = 1
obs[raw_obs == 2, 1] = 1
obs[raw_obs == 3, 2] = 1
allobs=np.squeeze(np.dstack((leftobs,raw_obs,rightobs,backobs)))
return allobs, obs
def render(self):
plt.clf()
sns.set_style("white")
#TODO: support render all mazes? can reshape to square?
max_render = 1
flattened_render = np.dstack(np.split(self.mazes[18, :], max_render, axis=0)).reshape(self.mazes.shape[1], -1)
flattened_render[flattened_render>1] = 0
plt.axis('off')
plt.imshow(flattened_render,cmap='bone')
for j in range(1):
i=18
marker = ">"
if self.agent_directions[i] == 1:
marker = "^"
if self.agent_directions[i] == 2:
marker = "<"
if self.agent_directions[i] == 3:
marker = "v"
obs_color = "black"
if self.obs[i, 0] == 1:
obs_color = "gray"
if self.obs[i, 1] == 1:
obs_color = "green"
if self.obs[i, 2] == 1:
obs_color = "red"
alpha = 1
if self.agent_energy[i]<=0:
alpha = 1
plt.scatter(self.agents_pos[i, 1] + j * self.mazes.shape[1], self.agents_pos[i, 0], color="skyblue", alpha=alpha, marker=marker)
plt.scatter(self.agents_pos[i, 1] + j * self.mazes.shape[1], self.agents_pos[i, 0], color=obs_color,alpha=alpha, marker=marker, s=3)
plt.scatter(self.food_pos[i, 0, 1] + j * self.mazes.shape[1], self.food_pos[i, 0, 0], color="green", alpha=1, marker="o")
plt.scatter(self.food_pos[i, 1, 1] + j * self.mazes.shape[1], self.food_pos[i, 1, 0], color="red", alpha=1, marker="o")
plt.pause(0.001)
#plt.pause(2)
@staticmethod
def load_vis_data(e, exp_name):
with open(get_data_path(e, exp_name, "output"), 'rb') as f:
vis_data, fitness_per_offspring = pickle.load(f)
return vis_data, fitness_per_offspring
@staticmethod
def plot_vis_data(e, exp_name):
import matplotlib.pyplot as plt
from cycler import cycler
import seaborn as sns
sns.set_style("whitegrid")
vis_data, fitness_per_offspring = MazeTurnEnvVec.load_vis_data(e, exp_name)
offspring_idx = 0
#x, y_est, y = np.array(vis_data).transpose(1, 2, 0, 3)
X, Y_est = map(np.array, zip(*vis_data))
X_base, y_est_base = X[:, offspring_idx], Y_est[:, offspring_idx]
#X_base, y_est_base = X_base[:300], y_est_base[:300]
#X_base = np.max(X_base, axis=1)
# --- normal output single example---
plt.rc('axes', prop_cycle=(cycler('color', ['gray', '#ff7f0e', '#9467bd', '#8c564b', '#e377c2', '#17becf'])))
# OLD METHOD!
plt.figure()
#NOTE: last neuron is always reward neuron
plt.plot(-X_base[:, :-1, 0], label="N-ENUs input", alpha=0.7)
plt.plot(- np.max(X_base, axis=1)[:, 1], label="Positive reward", alpha=0.7, color='#2ca02c')
plt.plot(- np.max(X_base, axis=1)[:, 2], label="Negative reward", alpha=0.7, color='#d62728')
#plt.gca().set_color_cycle(['orange', 'purple', 'brown'])
plt.plot(y_est_base[:, :], label="N-ENUs output", alpha=0.8)
plt.legend(loc='upper right')
save_fig(e, exp_name, "single_episode")
plt.rc('axes', prop_cycle=(cycler('color', ['gray', '#ff7f0e', '#9467bd', '#8c564b', '#e377c2', '#17becf'])))
# new method!
fig, grid = plt.subplots(2, sharex=True)
# input
#grid[1].set_prop_cycle(cycler('color', ['gray', '#ff7f0e', '#1f77b4']))
grid[1].set_prop_cycle(cycler('color', ['gray', '#F5B041', '#2E86C1']))
grid[1].plot(X_base[:, :-1, 0], label="N-ENUs input", alpha=0.7, linewidth=2)
grid[1].plot(np.max(X_base, axis=1)[:, 1], label="Positive reward", alpha=0.7, color='#1ABC9C', linewidth=2)
grid[1].plot(np.max(X_base, axis=1)[:, 2], label="Negative reward", alpha=0.7, color='#CB4335', linewidth=2)
grid[1].legend(['Sensor (wall)', 'Sensor (red)', 'Sensor (green)', 'Positive reward', 'Negative reward'], loc='upper right')
grid[1].set_ylabel('Neuron output')
# output
grid[0].set_prop_cycle(cycler('color', ['#9467bd', '#e377c2', '#17becf','#8c564b']))
grid[0].plot(y_est_base[:, :], label="N-ENUs output", alpha=0.8, linewidth=2)
grid[0].legend(['ENU-NN (left)', 'ENU-NN (right)', 'ENU-NN (forward)'],loc='upper right')
plt.xlabel('t')
grid[0].set_ylabel('ENU neuron output')
plt.xlim(-5, X_base.shape[0]+10)
save_fig(e, exp_name, "single_episode_dual")
#plt.show()
@staticmethod
def plot_rollout_data(e, exp_name):
#TODO: dump rollout as array not the actual plots
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("white")
import os
rollout_path = "./" + get_data_path(e, exp_name, "rollout").split(".")[1][:-1]+"/"
rollout_files = sorted(os.listdir(rollout_path))
rollouts = []
for i in range(4, 200, 4):
file = "rollout_{}_.png".format(i)
print(file)
rollout = plt.imread(rollout_path + file)
rollout = rollout[80:450, 200:500]
rollout = np.where(rollout[:, :, [0]]> 0.98, 1, rollout)
rollouts.append(rollout)
# plt.imshow(rollout)
# plt.show()
# for rollout in rollouts:
vert_bar = np.zeros((rollout.shape[0], 5, 4))
vert_bar[::] += 0.5
# get red -> learned go other way
# food swapped -> sees red -> learned turn around
rollouts1 = [rollouts[1], rollouts[9], rollouts[10], rollouts[11], rollouts[17], rollouts[18]]#, rollouts[19]
#, rollouts[28]
#rollouts[24],
rollouts2 = [rollouts[25], rollouts[29], rollouts[30], rollouts[31], rollouts[32], rollouts[33]]
plt.axis('off')
plt.imshow(np.vstack([np.column_stack(rollouts1), np.column_stack(rollouts2)]), cmap='gray')
save_fig(e, exp_name, "rollout_combined")
#plt.show()
if __name__ == '__main__':
"""Test function"""
n_offspring = 1024
envs = MazeTurnEnvVec(n_offspring, n_steps=400)
envs.n_pseudo_env = 8
while True:
envs.reset()
opt_actions_up = [0,0,0,0,1,0]
opt_actions_down = [0,0,0,0,2,0]
opt_actions = [opt_actions_up, opt_actions_down]
opt_current = 0
total_reward = 0
rewards = np.zeros((n_offspring))
rewards_all = np.zeros_like(rewards)
k = 0
n_steps = 200
for i in range(n_steps):
actions = np.random.randint(0, envs.n_actions, size=n_offspring)
#actions[:] = opt_actions_up[i%len(opt_actions_up)]
#actions[0] = opt_actions[opt_current][k % len(opt_actions_up)]
obs, rewards,_,_ = envs.step(actions=actions)
total_reward += (rewards[0] * 100) + 1
rewards_all += (rewards * 100) + 1
#print(rewards[0])
#print(obs[0], rewards[0])
envs.render()
plt.pause(0.5)
k += 1
if rewards[0] < 0:
#print(rewards_last[0])
if opt_current==0:
opt_current = 1
else:
opt_current = 0
if rewards[0] != 0:
k = 0
total_reward/=n_steps
rewards_all/=n_steps
print(rewards_all[0], np.mean(rewards_all), np.std(rewards_all), np.std(rewards_all)/np.mean(rewards_all))
print(rewards_all[:10])
================================================
FILE: examples/Structure_Evolution/Adaptive_lsm/BrainCog-Version/tools/nsganet.py
================================================
import numpy as np
from pymoo.algorithms.genetic_algorithm import GeneticAlgorithm
from pymoo.docs import parse_doc_string
from pymoo.model.individual import Individual
from pymoo.model.survival import Survival
from pymoo.operators.crossover.point_crossover import PointCrossover
from pymoo.operators.mutation.polynomial_mutation import PolynomialMutation
from pymoo.operators.mutation.bitflip_mutation import BinaryBitflipMutation
from pymoo.operators.sampling.random_sampling import RandomSampling
from pymoo.operators.selection.tournament_selection import compare, TournamentSelection
from pymoo.util.display import disp_multi_objective
from pymoo.util.dominator import Dominator
from pymoo.util.non_dominated_sorting import NonDominatedSorting
from pymoo.util.randomized_argsort import randomized_argsort
# =========================================================================================================
# Implementation
# based on nsga2 from https://github.com/msu-coinlab/pymoo
# =========================================================================================================
class NSGANet(GeneticAlgorithm):
def __init__(self, **kwargs):
kwargs['individual'] = Individual(rank=np.inf, crowding=-1)
super().__init__(**kwargs)
self.tournament_type = 'comp_by_dom_and_crowding'
self.func_display_attrs = disp_multi_objective
# ---------------------------------------------------------------------------------------------------------
# Binary Tournament Selection Function
# ---------------------------------------------------------------------------------------------------------
def binary_tournament(pop, P, algorithm, **kwargs):
if P.shape[1] != 2:
raise ValueError("Only implemented for binary tournament!")
tournament_type = algorithm.tournament_type
S = np.full(P.shape[0], np.nan)
for i in range(P.shape[0]):
a, b = P[i, 0], P[i, 1]
# if at least one solution is infeasible
if pop[a].CV > 0.0 or pop[b].CV > 0.0:
S[i] = compare(a, pop[a].CV, b, pop[b].CV, method='smaller_is_better', return_random_if_equal=True)
# both solutions are feasible
else:
if tournament_type == 'comp_by_dom_and_crowding':
rel = Dominator.get_relation(pop[a].F, pop[b].F)
if rel == 1:
S[i] = a
elif rel == -1:
S[i] = b
elif tournament_type == 'comp_by_rank_and_crowding':
S[i] = compare(a, pop[a].rank, b, pop[b].rank,
method='smaller_is_better')
else:
raise Exception("Unknown tournament type.")
# if rank or domination relation didn't make a decision compare by crowding
if np.isnan(S[i]):
S[i] = compare(a, pop[a].get("crowding"), b, pop[b].get("crowding"),
method='larger_is_better', return_random_if_equal=True)
return S[:, None].astype(np.int)
# ---------------------------------------------------------------------------------------------------------
# Survival Selection
# ---------------------------------------------------------------------------------------------------------
class RankAndCrowdingSurvival(Survival):
def __init__(self) -> None:
super().__init__(True)
def _do(self, pop, n_survive, D=None, **kwargs):
# get the objective space values and objects
F = pop.get("F")
# the final indices of surviving individuals
survivors = []
# do the non-dominated sorting until splitting front
fronts = NonDominatedSorting().do(F, n_stop_if_ranked=n_survive)
for k, front in enumerate(fronts):
# calculate the crowding distance of the front
crowding_of_front = calc_crowding_distance(F[front, :])
# save rank and crowding in the individual class
for j, i in enumerate(front):
pop[i].set("rank", k)
pop[i].set("crowding", crowding_of_front[j])
# current front sorted by crowding distance if splitting
if len(survivors) + len(front) > n_survive:
I = randomized_argsort(crowding_of_front, order='descending', method='numpy')
I = I[:(n_survive - len(survivors))]
# otherwise take the whole front unsorted
else:
I = np.arange(len(front))
# extend the survivors by all or selected individuals
survivors.extend(front[I])
return pop[survivors]
def calc_crowding_distance(F):
infinity = 1e+14
n_points = F.shape[0]
n_obj = F.shape[1]
if n_points <= 2:
return np.full(n_points, infinity)
else:
# sort each column and get index
I = np.argsort(F, axis=0, kind='mergesort')
# now really sort the whole array
F = F[I, np.arange(n_obj)]
# get the distance to the last element in sorted list and replace zeros with actual values
dist = np.concatenate([F, np.full((1, n_obj), np.inf)]) \
- np.concatenate([np.full((1, n_obj), -np.inf), F])
index_dist_is_zero = np.where(dist == 0)
dist_to_last = np.copy(dist)
for i, j in zip(*index_dist_is_zero):
dist_to_last[i, j] = dist_to_last[i - 1, j]
dist_to_next = np.copy(dist)
for i, j in reversed(list(zip(*index_dist_is_zero))):
dist_to_next[i, j] = dist_to_next[i + 1, j]
# normalize all the distances
norm = np.max(F, axis=0) - np.min(F, axis=0)
norm[norm == 0] = np.nan
dist_to_last, dist_to_next = dist_to_last[:-1] / norm, dist_to_next[1:] / norm
# if we divided by zero because all values in one columns are equal replace by none
dist_to_last[np.isnan(dist_to_last)] = 0.0
dist_to_next[np.isnan(dist_to_next)] = 0.0
# sum up the distance to next and last and norm by objectives - also reorder from sorted list
J = np.argsort(I, axis=0)
crowding = np.sum(dist_to_last[J, np.arange(n_obj)] + dist_to_next[J, np.arange(n_obj)], axis=1) / n_obj
# replace infinity with a large number
crowding[np.isinf(crowding)] = infinity
return crowding
# =========================================================================================================
# Interface
# =========================================================================================================
def nsganet(
pop_size=100,
sampling=RandomSampling(var_type=np.int),
selection=TournamentSelection(func_comp=binary_tournament),
crossover=PointCrossover(n_points=2),
mutation=PolynomialMutation(eta=3, var_type=np.int),
eliminate_duplicates=True,
n_offsprings=None,
**kwargs):
"""
Parameters
----------
pop_size : {pop_size}
sampling : {sampling}
selection : {selection}
crossover : {crossover}
mutation : {mutation}
eliminate_duplicates : {eliminate_duplicates}
n_offsprings : {n_offsprings}
Returns
-------
nsganet : :class:`~pymoo.model.algorithm.Algorithm`
Returns an NSGANet algorithm object.
"""
return NSGANet(pop_size=pop_size,
sampling=sampling,
selection=selection,
crossover=crossover,
mutation=mutation,
survival=RankAndCrowdingSurvival(),
eliminate_duplicates=eliminate_duplicates,
n_offsprings=n_offsprings,
**kwargs)
parse_doc_string(nsganet)
================================================
FILE: examples/Structure_Evolution/Adaptive_lsm/raw/BCM.py
================================================
import argparse, math, os, sys
import numpy as np
import gym
from gym import wrappers
import matplotlib.pyplot as plt
import nsganet as engine
from pymop.problem import Problem
from pymoo.optimize import minimize
from pymoo.operators.sampling.random_sampling import RandomSampling
from pymoo.operators.mutation.bitflip_mutation import BinaryBitflipMutation
from tools.ExperimentEnvGlobalNetworkSurvival import ExperimentEnvGlobalNetworkSurvival
from tools.MazeTurnEnvVec import MazeTurnEnvVec
import torch
import torch.nn.utils as utils
from torch.distributions import Categorical
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from itertools import product
from functools import partial
import torchvision, pprint
from timm.models import register_model
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.model_zoo.base_module import BaseModule
from braincog.base.learningrule.BCM import *
from braincog.base.learningrule.STDP import *
parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
parser.add_argument('--gamma', type=float, default=0.98, metavar='G')
parser.add_argument('--seed', type=int, default=1, metavar='N')
parser.add_argument('--num_steps', type=int, default=500, metavar='N')
parser.add_argument('--num_episodes', type=int, default=100, metavar='N')
parser.add_argument('--render', action='store_true')
args = parser.parse_args()
n_agent = 1
steps = 500
hidden_size=64
env = MazeTurnEnvVec(n_agent, n_steps=steps)
data_env = ExperimentEnvGlobalNetworkSurvival(env)
s_dim = 4
a_dim = 3
def randbool(size, p):
return torch.rand(*size) < p
def fit(agent):
states = list(product([0, 1], repeat=4))
ls_list=[]
for state in states:
agent.model.reset()
state_tensor = torch.tensor(state).float().reshape(1, -1)
la, ls = agent.model(Variable(state_tensor.float()).reshape(-1,))
ls_np = ls.detach().numpy()
ls_list.append(ls_np)
ls_matrix = np.vstack(ls_list)
rank = np.linalg.matrix_rank(ls_matrix)
return rank
@register_model
class SNN(BaseModule):
def __init__(self,
hidden_size,
n_agent,
connectivity_matrix,
num_classes=3,
step=1,
node_type=LIFNode,
encode_type='direct',
ins=4,
lsm_th=0.3,
fc_th=0.3,
lsm_tau=3,
fc_tau=3,
tw=100,
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.linear1 = nn.Linear(s_dim, hidden_size)
self.node=partial(node_type, **kwargs, step=step,tau=lsm_tau,threshold=lsm_th)
self.linear2 = nn.Linear(hidden_size, a_dim)
self.node_lsm=partial(node_type, **kwargs, step=step,tau=lsm_tau,threshold=lsm_th)
self.node_fc = partial(node_type, **kwargs, step=step,tau=fc_tau,threshold=fc_th)
self.hidden_size=hidden_size
self.out = torch.zeros(hidden_size)
self.con=[]
self.learning_rule=[]
self.connectivity_matrix=connectivity_matrix
w1tmp=nn.Linear(ins,hidden_size,bias=False)
self.con.append(w1tmp)
w2tmp=nn.Linear(hidden_size,hidden_size,bias=False)
self.liquid_weight=w2tmp.weight.data
w2tmp.weight.data=w2tmp.weight.data*self.connectivity_matrix
self.con.append(w2tmp)
self.learning_rule.append(BCM(self.node_lsm(), [self.con[0], self.con[1]]))
self.fc = nn.Linear(hidden_size,num_classes)
self.learning_rule.append(BCM(self.node_fc(), [self.fc]))
def forward(self, x):
sum_spike=0
time_window=20
self.tw=time_window
self.firing_tw=torch.zeros(time_window, self.hidden_size)
self.out = torch.zeros(self.hidden_size)
for t in range(time_window):
self.out, self.dw = self.learning_rule[0](x, self.out)
self.con[1].weight.data+=self.dw[1]
out_liquid=self.out[0:self.hidden_size]
xout,dw = self.learning_rule[1](out_liquid)
self.fc.weight.data+=dw[0]
sum_spike=sum_spike+xout
self.firing_tw[t]=out_liquid
outputs = sum_spike+0.0001 / time_window
return outputs,out_liquid
class REINFORCE:
def __init__(self, lm):
self.model = SNN(ins=4,n_agent=n_agent,hidden_size=hidden_size,lsm_tau=2,lsm_th=0.2,connectivity_matrix=lm)
self.model.train()
def select_action(self, state):
# mu, sigma_sq = self.model(Variable(state).cuda())
prob,_= self.model(Variable(state).reshape(-1,))
dist = Categorical(probs=prob)
action = dist.sample()
log_prob = prob[action.item()].log()
entropy = dist.entropy()
return action, log_prob, entropy
class Evolve(Problem):
# first define the NAS problem (inherit from pymop)
def __init__(self, n_var=20, n_obj=1, n_constr=0, lb=None, ub=None):
super().__init__(n_var=n_var, n_obj=n_obj, n_constr=n_constr, type_var=np.int64)
self.xl = lb
self.xu = ub
self._n_evaluated = 0 # keep track of how many architectures are sampled
def _evaluate(self, x, out, *args, **kwargs):
objs = np.full((x.shape[0], self.n_obj), np.nan)
for i in range(x.shape[0]):
arch_id = self._n_evaluated + 1
print('Network= {}'.format(arch_id))
agent = REINFORCE(torch.from_numpy(x[i].reshape(hidden_size,hidden_size)).float())
log_reward = []
log_smooth = []
# gamma=np.linspace(0.9,1.0,100)
gam=0.9
# for gam in gamma:
for i_episode in range(100):
state = torch.tensor(data_env.reset()).unsqueeze(0)
entropies = []
log_probs = []
rewards = []
old_dis = np.ones([1,])*13
reawrd_perstep=[]
ss=0
allrewards=[]
for t in range(500):
action, log_prob, entropy = agent.select_action(state.float())
action=action.unsqueeze(0).numpy()
next_state, envreward, done, _ = data_env.step(action)
entropies.append(entropy)
log_probs.append(log_prob)
state = torch.Tensor([next_state])
rewards.append(envreward[0])
print("Episode: {}, reward: {}".format(i_episode, np.sum(rewards)))
log_reward.append(np.sum(rewards))
if i_episode == 0:
log_smooth.append(log_reward[-1])
else:
log_smooth.append(log_smooth[-1]*0.99+0.01*np.sum(rewards))
plt.plot(log_smooth)
plt.plot(log_reward)
plt.pause(1e-5)
objs[i, 0] = fit(agent)
self._n_evaluated += 1
out["F"] = objs
def do_every_generations(algorithm):
gen = algorithm.n_gen
pop_var = algorithm.pop.get("X")
pop_obj = algorithm.pop.get("F")
if __name__ == "__main__":
n_agent=1
kkk = Evolve(n_var=hidden_size*hidden_size,
n_obj=1, n_constr=0)
method = engine.nsganet(pop_size=n_agent,
sampling=RandomSampling(var_type='custom'),
mutation=BinaryBitflipMutation(),
n_offsprings=10,
eliminate_duplicates=True)
kres=minimize(kkk,
method,
callback=do_every_generations,
termination=('n_gen', 1000))
================================================
FILE: examples/Structure_Evolution/Adaptive_lsm/raw/README.md
================================================
# Adaptive structure evolution and biologically plausible synaptic plasticity for recurrent spiking neural networks #
## Requirments ##
* numpy
* pytorch >= 1.12.0
## Run ##
```python BCM.py```
## Citation ##
If you find the code and dataset useful in your research, please consider citing:
```
@article{pan2023adaptive,
title = {Adaptive structure evolution and biologically plausible synaptic plasticity for recurrent spiking neural networks},
author = {Pan, Wenxuan and Zhao, Feifei and Zeng, Yi and Han, Bing},
journal = {Scientific Reports},
volume = {13},
number = {1},
pages = {16924},
year = {2023},
url = {https://doi.org/10.1038/s41598-023-43488-x},
doi = {10.1038/s41598-023-43488-x},
}
@article{zeng2023braincog,
title={BrainCog: A spiking neural network based, brain-inspired cognitive intelligence engine for brain-inspired AI and brain simulation},
author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},
journal={Patterns},
volume={4},
number={8},
year={2023},
publisher={Elsevier}
}
```
================================================
FILE: examples/Structure_Evolution/Adaptive_lsm/raw/lstm.py
================================================
import argparse, math, os, sys
from re import S
from aiohttp import ServerDisconnectedError
import numpy as np
import gym
from gym import wrappers
import matplotlib.pyplot as plt
from tools.ExperimentEnvGlobalNetworkSurvival import ExperimentEnvGlobalNetworkSurvival
from tools.MazeTurnEnvVec import MazeTurnEnvVec
import torch
from torch.autograd import Variable
import torch.autograd as autograd
import torch.nn.utils as utils
from torch.distributions import Categorical
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
parser.add_argument('--gamma', type=float, default=0.98, metavar='G')
parser.add_argument('--seed', type=int, default=598, metavar='N')
parser.add_argument('--num_steps', type=int, default=500, metavar='N')
parser.add_argument('--num_episodes', type=int, default=1000, metavar='N')
parser.add_argument('--hidden_size', type=int, default=128, metavar='N')
parser.add_argument('--render', action='store_true')
args = parser.parse_args()
n_agent = 1
steps = 500
env = MazeTurnEnvVec(n_agent, n_steps=steps)
data_env = ExperimentEnvGlobalNetworkSurvival(env)
s_dim = 4
a_dim = 3
class Policy(nn.Module):
def __init__(self, hidden_size, s_dim, a_dim):
super(Policy, self).__init__()
self.lstm = nn.LSTM(s_dim, hidden_size, batch_first = True)
self.linear1 = nn.Linear(hidden_size, hidden_size)
self.linear2 = nn.Linear(hidden_size, a_dim)
def forward(self, x,hidden):
x, hidden = self.lstm(x, hidden)
x = F.relu(self.linear1(x))
p = F.softmax(self.linear2(x),-1)
return p,hidden
class REINFORCE:
def __init__(self, hidden_size, s_dim, a_dim):
self.model = Policy(hidden_size, s_dim, a_dim)
self.optimizer = optim.Adam(self.model.parameters(), lr=1e-2) #
self.model.train()
self.pi = Variable(torch.FloatTensor([math.pi])) #
def select_action(self, state,hx,cx):
# mu, sigma_sq = self.model(Variable(state).cuda())
prob,(hx,cx) = self.model(Variable(state),(hx,cx))
dist = Categorical(probs=prob)
action = dist.sample()
log_prob = prob[0][0,action.item()].log()
# log_prob = prob.log()
entropy = dist.entropy()
return action, log_prob, entropy
def update_parameters(self, rewards, log_probs, entropies, gamma):# 更新参数
R = torch.tensor(0)
loss = 0
for i in reversed(range(len(rewards))):
R = gamma * R + rewards[i]
loss = loss - (log_probs[i]*Variable(R)) - 0.005*entropies[i][0]
loss = loss / len(rewards)
self.optimizer.zero_grad()
loss.backward()
utils.clip_grad_norm_(self.model.parameters(), 2)
self.optimizer.step()
seeds=20
for seed in range(seeds):
log_reward = []
log_smooth = []
gamma=np.linspace(0.9,1.0,100)
for g in range(100):
agent = REINFORCE(args.hidden_size,s_dim,a_dim)
result=np.zeros([100,args.num_steps])
for i_episode in range(args.num_episodes):
state = torch.tensor(data_env.reset()).unsqueeze(0)
entropies = []
log_probs = []
rewards = []
old_dis = np.ones([1,])*13
reawrd_perstep=[]
allrewards=[]
hx = torch.zeros(args.hidden_size).unsqueeze(0).unsqueeze(0)
cx = torch.zeros(args.hidden_size).unsqueeze(0).unsqueeze(0)
for t in range(args.num_steps): # 1个episode最长num_steps
action, log_prob, entropy = agent.select_action(state.unsqueeze(0).float(),hx,cx)
action = action.cpu().numpy()
next_state, envreward, done, _ = data_env.step(action[0])
entropies.append(entropy)
log_probs.append(log_prob)
state = torch.Tensor([next_state])
rewards.append(envreward[0])
agent.update_parameters(rewards, log_probs, entropies, gamma[g])
print("Episode: {}, reward: {}".format(i_episode, np.sum(rewards)))
log_reward.append(np.sum(rewards))
if i_episode == 0:
log_smooth.append(log_reward[-1])
else:
log_smooth.append(log_smooth[-1]*0.99+0.01*np.sum(rewards))
plt.plot(log_smooth)
plt.plot(log_reward)
plt.pause(1e-5)
result[g]=np.array(allrewards).squeeze(1)
np.save('./lstm.npy',result)
================================================
FILE: examples/Structure_Evolution/Adaptive_lsm/raw/main.py
================================================
import argparse, math, os, sys
import numpy as np
import gym
from gym import wrappers
import matplotlib.pyplot as plt
from tools.ExperimentEnvGlobalNetworkSurvival import ExperimentEnvGlobalNetworkSurvival
from tools.MazeTurnEnvVec import MazeTurnEnvVec
import torch
from torch.autograd import Variable
import torch.autograd as autograd
import torch.nn.utils as utils
from torch.distributions import Categorical
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
parser.add_argument('--gamma', type=float, default=0.98, metavar='G')
parser.add_argument('--seed', type=int, default=1, metavar='N')
parser.add_argument('--num_steps', type=int, default=500, metavar='N')
parser.add_argument('--num_episodes', type=int, default=100, metavar='N')
parser.add_argument('--hidden_size', type=int, default=128, metavar='N')
parser.add_argument('--render', action='store_true')
args = parser.parse_args()
n_agent = 1
steps = 500
env = MazeTurnEnvVec(n_agent, n_steps=steps)
data_env = ExperimentEnvGlobalNetworkSurvival(env)
s_dim = 4
a_dim = 3
class Policy(nn.Module):
def __init__(self, hidden_size, s_dim, a_dim):
super(Policy, self).__init__()
self.linear1 = nn.Linear(s_dim, hidden_size)
self.linear2 = nn.Linear(hidden_size, a_dim)
def forward(self, x):
x = F.relu(self.linear1(x))
p = F.softmax(self.linear2(x),-1)
return p
class REINFORCE:
def __init__(self, hidden_size, s_dim, a_dim):
self.model = Policy(hidden_size, s_dim, a_dim)
# self.model = self.model.cuda()
self.optimizer = optim.Adam(self.model.parameters(), lr=1e-2)
self.model.train()
self.pi = Variable(torch.FloatTensor([math.pi]))
def select_action(self, state):
# mu, sigma_sq = self.model(Variable(state).cuda())
prob = self.model(Variable(state))
dist = Categorical(probs=prob)
action = dist.sample()
log_prob = prob[0,action.item()].log()
# log_prob = prob.log()
entropy = dist.entropy()
return action, log_prob, entropy
def update_parameters(self, rewards, log_probs, entropies, gamma):
R = torch.tensor(0)
loss = 0
for i in reversed(range(len(rewards))):
R = gamma * R + rewards[i]
loss = loss - (log_probs[i]*Variable(R)) - 0.005*entropies[i][0]
loss = loss / len(rewards)
self.optimizer.zero_grad()
loss.backward()
utils.clip_grad_norm_(self.model.parameters(), 2)
self.optimizer.step()
seeds=20
for seed in range(seeds):
# torch.manual_seed(args.seed)
# np.random.seed(args.seed)
agent = REINFORCE(args.hidden_size,s_dim,a_dim)
log_reward = []
log_smooth = []
gamma=np.linspace(0.9,1.0,100)
for gam in gamma:
for i_episode in range(args.num_episodes):
state = torch.tensor(data_env.reset()).unsqueeze(0)
entropies = []
log_probs = []
rewards = []
old_dis = np.ones([1,])*13
reawrd_perstep=[]
ss=0
allrewards=[]
for t in range(args.num_steps):
action, log_prob, entropy = agent.select_action(state.float())
action = action.cpu().numpy()
next_state, envreward, done, _ = data_env.step(action)
entropies.append(entropy)
log_probs.append(log_prob)
state = torch.Tensor([next_state])
rewards.append(envreward[0])
agent.update_parameters(rewards, log_probs, entropies, gam)
print("Episode: {}, reward: {}".format(i_episode, np.sum(rewards)))
log_reward.append(np.sum(rewards))
if i_episode == 0:
log_smooth.append(log_reward[-1])
else:
log_smooth.append(log_smooth[-1]*0.99+0.01*np.sum(rewards))
plt.plot(log_smooth)
plt.plot(log_reward)
plt.pause(1e-5)
================================================
FILE: examples/Structure_Evolution/Adaptive_lsm/raw/pltbcm.py
================================================
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import pyplot
import matplotlib as mpl
from scipy.ndimage import gaussian_filter1d
sigm=3
# mpl.rcParams['font.size']=x
plt.style.use('seaborn-whitegrid')
plt.figure( figsize=(8,8) )
ax = plt.subplot()
palette = pyplot.get_cmap('Set1')
font1 = {'family': 'Times New Roman',
'weight': 'normal',
'size': 14,
}
steps=500
t = [i for i in range(steps)]
########################################BCM+BCM
bcm=np.load('./10rewards.npy')
for e in range(bcm.shape[0]):
sum1=0
sum2=0
best_agent_id=np.argmax(np.sum(bcm[e,:,:],axis=0))
best_agent=bcm[e,:,best_agent_id]
best_agent=best_agent[:steps]
for i in range(steps): #累积
sum2=sum2+best_agent[i]
best_agent[i]=sum2
bcm[e]=best_agent
avg = np.mean(bcm, axis=0)
std = np.std(bcm, axis=0)
r1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))
r2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))
y_smoothed = gaussian_filter1d(avg, sigma=40)
r1 = gaussian_filter1d(r1, sigma=40)
r2 = gaussian_filter1d(r2, sigma=40)
color = palette(0)
ax.plot(t, y_smoothed, color=color, label="Evolved model with DA-BCM", linewidth=3.0)
ax.fill_between(t, r1, r2, color=color, alpha=0.2)
print("Evolved model with DA-BCM")
print(avg[-1],avg[-1]-r1[-1],r2[-1]-avg[-1])
########################################unbcm
unbcm=np.load('./unevolved_with_bcm.npy')
avg = np.mean(unbcm, axis=0)
std = np.std(unbcm, axis=0)
r1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))
r2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))
y_smoothed = gaussian_filter1d(avg, sigma=sigm)
r1 = gaussian_filter1d(r1, sigma=sigm)
r2 = gaussian_filter1d(r2, sigma=sigm)
color = palette(1)
ax.plot(t, y_smoothed, color=color, label="Unevolved model with DA-BCM", linewidth=3.0)
ax.fill_between(t, r1, r2, color=color, alpha=0.2)
print("Unevolved model with DA-BCM+DA-BCM")
print(avg[-1],avg[-1]-r1[-1],r2[-1]-avg[-1])
########################################none+bcm
nonbcm=np.load('./none_bcm.npy')
avg = np.mean(nonbcm, axis=0)
std = np.std(nonbcm, axis=0)
r1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))
r2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))
y_smoothed = gaussian_filter1d(avg, sigma=sigm)
r1 = gaussian_filter1d(r1, sigma=sigm)
r2 = gaussian_filter1d(r2, sigma=sigm)
color = palette(2)
ax.plot(t, y_smoothed, color=color, label="Evolved model with NONE+DA-BCM", linewidth=3.0)
ax.fill_between(t, r1, r2, color=color, alpha=0.2)
print("Evolved model with none+DA-BCM")
print(avg[-1],avg[-1]-r1[-1],r2[-1]-avg[-1])
########################################stdp+bcm
stdpbcm=np.load('./stdp_bcm.npy')
avg = np.mean(stdpbcm, axis=0)
std = np.std(stdpbcm, axis=0)
r1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))
r2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))
y_smoothed = gaussian_filter1d(avg, sigma=sigm)
r1 = gaussian_filter1d(r1, sigma=sigm)
r2 = gaussian_filter1d(r2, sigma=sigm)
color = palette(5)
ax.plot(t, y_smoothed, color=color, label="Evolved model with STDP+DA-BCM", linewidth=3.0)
ax.fill_between(t, r1, r2, color=color, alpha=0.2)
print("Evolved model with STDP+DA-BCM")
print(avg[-1],avg[-1]-r1[-1],r2[-1]-avg[-1])
########################################LSTM
lstm=np.load('./lstm.npy')
avg = np.mean(lstm, axis=0)
std = np.std(lstm, axis=0)
r1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))
r2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))
y_smoothed = gaussian_filter1d(avg, sigma=sigm)
r1 = gaussian_filter1d(r1, sigma=sigm)
r2 = gaussian_filter1d(r2, sigma=sigm)
color = palette(3)
ax.plot(t, y_smoothed, color=color, label="LSTM", linewidth=3.0)
ax.fill_between(t, r1, r2, color=color, alpha=0.2)
print("LSTM")
print(avg[-1],avg[-1]-r1[-1],r2[-1]-avg[-1])
########################################Q-learning
ql=np.load('./ql.npy')
avg = np.mean(ql, axis=0)
std = np.std(ql, axis=0)
r1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))
r2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))
y_smoothed = gaussian_filter1d(avg, sigma=sigm)
r1 = gaussian_filter1d(r1, sigma=sigm)
r2 = gaussian_filter1d(r2, sigma=sigm)
color = palette(6)
ax.plot(t, y_smoothed, color=color, label="Q-learning", linewidth=3.0)
ax.fill_between(t, r1, r2, color=color, alpha=0.2)
print("Q-learning")
print(avg[-1],avg[-1]-r1[-1],r2[-1]-avg[-1])
########################################STDP
stdp=np.load('./inac.npy')
avg = np.mean(stdp, axis=0)
std = np.std(stdp, axis=0)
r1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))
r2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))
y_smoothed = gaussian_filter1d(avg, sigma=sigm)
r1 = gaussian_filter1d(r1, sigma=sigm)
r2 = gaussian_filter1d(r2, sigma=sigm)
color = palette(4)
ax.plot(t, y_smoothed, color=color, label="Evolved STDP", linewidth=3.0)
ax.fill_between(t, r1, r2, color=color, alpha=0.2)
print("Evolved STDP")
print(avg[-1],avg[-1]-r1[-1],r2[-1]-avg[-1])
ax.tick_params(labelsize=16)
ax.spines['right'].set_color('black')
ax.spines['top'].set_color('black')
ax.spines['left'].set_color('black')
ax.spines['bottom'].set_color('black')
ax.legend(loc='upper left', prop=font1)
plt.xlabel('Steps', fontsize=18)
plt.ylabel('Average Reward', fontsize=18)
plt.savefig('./bcm.png')
plt.show()
================================================
FILE: examples/Structure_Evolution/Adaptive_lsm/raw/pltrank.py
================================================
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import pyplot
import matplotlib as mpl
from scipy.ndimage import gaussian_filter1d
plt.figure( figsize=(8,8) )
steps=1000
t = [i for i in range(steps)]
plt.style.use('seaborn-whitegrid')
palette = pyplot.get_cmap('Set1')
font1 = {'family': 'Times New Roman',
'weight': 'normal',
'size': 18,
}
kk=np.load('./rank.npy')
avg = np.mean(kk, axis=0)
std = np.std(kk, axis=0)
r1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))
r2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))
r1 = gaussian_filter1d(r1, sigma=20)
r2 = gaussian_filter1d(r2, sigma=20)
y_smoothed = gaussian_filter1d(avg, sigma=20)
color = palette(0)
ax = plt.subplot()
ax.plot(t, y_smoothed, color=color, label="Average Fitness", linewidth=3.0)
ax.fill_between(t, r1, r2, color=color, alpha=0.2)
ax.tick_params(labelsize=18)
ax.spines['right'].set_color('black')
ax.spines['top'].set_color('black')
ax.spines['left'].set_color('black')
ax.spines['bottom'].set_color('black')
ax.legend(loc='lower right', prop=font1)
plt.xlabel('generations', fontsize=18)
plt.ylabel('SP', fontsize=18)
plt.savefig('./rank.png')
plt.show()
================================================
FILE: examples/Structure_Evolution/Adaptive_lsm/raw/q_l.py
================================================
import random
import time
import tkinter as tk
import pandas as pd
from tools.ExperimentEnvGlobalNetworkSurvival import ExperimentEnvGlobalNetworkSurvival
from tools.MazeTurnEnvVec import MazeTurnEnvVec
import numpy as np
from matplotlib import pyplot as plt
steps=500
t=[i for i in range(steps)]
class Agent(object):
'''个体类'''
MAZE_R = 6
MAZE_C = 6
def __init__(self, env,alpha=0.1, gamma=0.9):
'''初始化'''
self.states = {}
self.actions = 3
self.alpha = alpha
self.gamma = gamma
self.q_table = np.zeros([32,3])
def choose_action(self,state,epsilon=0.8):
'''选择相应的动作。根据当前状态,随机或贪婪,按照参数epsilon'''
if random.uniform(0, 1) > epsilon:
action = random.choice([0,1,2])
else:
max_index=(self.q_table[state] == self.q_table[state].max()).nonzero()
if len(max_index)==1:
max_qvalue_actions=max_index[0]
else:
max_qvalue_actions=max_index[:][1]
action = random.choice(np.array(max_qvalue_actions))
return np.array([action])
def update_q_value(self, state, action, next_state_reward, next_state_q_values):
self.q_table[state, action] += self.alpha * (
next_state_reward + self.gamma * next_state_q_values.max() - self.q_table[state, action])
def add_state(self,X_next):
x_str = ','.join(str(i) for i in X_next.astype(int))
if (x_str in self.states) == False:
self.states[x_str] = max(self.states.values()) + 1
return self.states[x_str]
def learn(self, env, episode=100, epsilon=0.8):
'''q-learning算法'''
env.reset()
X=np.array([0,1,0,0])
sss = ','.join(str(i) for i in X.astype(int))
self.states[sss] = 0
for i in range(episode):
steps=0
current_state = np.array([0])
env.env.current_cell=np.array([0])
X_next, envreward, fitness, infos=env.step(current_state)
self.add_state(X_next)
next_state_reward=0
while next_state_reward==0 and steps<1000:
current_action = self.choose_action(current_state, epsilon)
X_next, next_state_reward, fitness, infos = env.step(current_action)
next_state_number=self.add_state(X_next)
next_state_q_values = self.q_table[next_state_number]
self.update_q_value(current_state, current_action, next_state_reward, next_state_q_values)
current_state = next_state_number
steps+=1
def play(self, env):
step=0
self.learn(env, epsilon=0.8)
current_state = np.array([0])
env.env.current_cell = np.array([0])
X_next, envreward, fitness, infos = env.step(current_state)
self.add_state(X_next)
env_r=[]
rsum=0
old_dis=13
while step0,dtype=int)[0]
if reward==0:
reward=-1
elif reward==1:
reward=1
if envreward==1:
reward=3
elif envreward==-1:
reward=-3
next_state_number = self.add_state(X_next)
rsum+=reward
current_state = next_state_number
env_r.append(rsum)
step+=1
return np.array(env_r)
def QQ():
steps=500
env = MazeTurnEnvVec(1, n_steps=steps)
data_env = ExperimentEnvGlobalNetworkSurvival(env)
agent = Agent(data_env)
r=agent.play(data_env)
return r
np.save('./ql.npy',QQ())
================================================
FILE: examples/Structure_Evolution/Adaptive_lsm/raw/tools/EnuGlobalNetwork.py
================================================
import pickle
import time
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import gridspec
from AbstractLayerBMM import AbstractLayerBMM
from EvolvableNeuralUnitStacked import EvolvableNeuralUnitStacked
from Tools import get_data_path
sns.set_style("darkgrid")
class EnuGlobalNetwork(AbstractLayerBMM):
"""Network of ENUs implementation in PyTorch, where each synapse and neuron is modeled as an ENU. """
def __init__(self, n_offspring, n_pseudo_env, n_input_neurons, n_hidden_neurons, n_output_neurons, n_syn_per_neuron):
# offspring
self.n_offspring = n_offspring
self.n_pseudo_env = n_pseudo_env
# input channels
n_input_channels = 16
self.n_input_channels = n_input_channels
n_dynamic_param = 32
# total neurons
n_neurons = n_output_neurons + n_hidden_neurons
self.n_neurons = n_neurons
super().__init__(n_offspring, n_neurons, n_input_neurons, n_output_neurons)
torch.random.manual_seed(0)
#NOTE: batch dimension holds output of each neuron/synapse, allowing fast GPU MM
#NOTE neurons far less than synapses, so can be relatively bigger rnn for little cost
n_input_channels_neuron = 16
n_input_neuron, n_output_neuron = n_input_channels_neuron, n_input_channels
self.neurons = EvolvableNeuralUnitStacked(n_offspring, batch_size=self.n_neurons, n_input=n_input_neuron, n_dynamic_param=n_dynamic_param, n_output=n_output_neuron)
#self.n_syn = next_power_of_2(int(n_neurons * (rel_connectivity*n_neurons)))
self.n_syn_per_neuron = n_syn_per_neuron
self.n_syn = n_neurons * n_syn_per_neuron
n_input_syn, n_output_syn = n_input_channels * 2, n_input_channels_neuron # * 2 for neuron feedback (which same channel as n_channel input)
self.synapses = EvolvableNeuralUnitStacked(n_offspring, batch_size=self.n_syn, n_input=n_input_syn, n_dynamic_param=n_dynamic_param, n_output=n_output_syn)
# just randomly connect synapses to neurons
self.synapse_connections = torch.randint(n_input_neurons + n_neurons, size=(n_neurons, n_syn_per_neuron), device='cuda', dtype=torch.long)
# fixed predefined connection patterns
if n_input_neurons==2 and n_output_neurons==2 and n_hidden_neurons==2:
print("Fixed connection Network 2-2-2")
self.synapse_connections = torch.tensor([[0, 1],
[0, 1],
[2, 3],
[2, 3]], device='cuda', dtype=torch.long)
elif n_input_neurons == 4 and n_output_neurons == 3 and n_hidden_neurons == 3 and n_syn_per_neuron==3:
print("Fixed connection Network 4-3-3 (3syn)")
self.synapse_connections = torch.tensor([[0, 1, 3],# hidden connections #4
[0, 2, 3], #5
[1, 2, 3],# 6
[4, 5, 6], # output connections #7
[4, 5, 6],#8
[4, 5, 6]#9
], device='cuda', dtype=torch.long)
elif n_input_neurons==5 and n_hidden_neurons==0 and n_output_neurons==4:
print("Fixed connection Network 5-0-4 (5syn)")
# neuron i connected to neuron j and k, neuron 0..input_neurons is index
self.synapse_connections = torch.tensor([[0, 1, 2, 3, 4],# output connections
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]
], device='cuda', dtype=torch.long)
elif n_input_neurons==1 and n_hidden_neurons==0 and n_output_neurons==2:
print("Fixed connection Network 1-0-2 (1syn)")
# neuron i connected to neuron j and k, neuron 0..input_neurons is index
self.synapse_connections = torch.tensor([[0],# output connections
[0]
], device='cuda', dtype=torch.long)
elif n_input_neurons==4 and n_hidden_neurons==0 and n_output_neurons==3 and n_syn_per_neuron==4:
print("Sparse connection Network 4-0-3 (4syn)")
# neuron i connected to neuron j and k, neuron 0..input_neurons is index
self.synapse_connections = torch.tensor([[0, 1, 2, 3],# output connections #4
[0, 1, 2, 3], #5
[0, 1, 2, 3],# 6
], device='cuda', dtype=torch.long)
elif n_input_neurons == 4 and n_hidden_neurons == 3 and n_output_neurons == 3 and n_syn_per_neuron == 4:
print("Sparse connection Network 4-3-3 (3syn)")
# neuron i connected to neuron j and k, neuron 0..input_neurons is index
self.synapse_connections = torch.tensor([[0, 1, 3], # hidden connections #4
[0, 2, 3], # 5
[1, 2, 3], # 6
[4, 5, 3], # output connections #7
[4, 6, 3], # 8
[5, 6, 3] # 9
], device='cuda', dtype=torch.long)
elif n_input_neurons==4 and n_hidden_neurons==3 and n_output_neurons==3 and n_syn_per_neuron==8:
print("Sparse connection Network 4-3-3 (8syn)")
# neuron i connected to neuron j and k, neuron 0..input_neurons is index
self.synapse_connections = torch.tensor([[0, 1, 5, 6, 7, 8, 3, 4],# hidden connections #4
[0, 2, 4, 6, 7, 9, 3, 5], #5
[1, 2, 4, 5, 8, 9, 3, 6],# 6
[4, 5, 8, 9, 0, 1, 3, 7], # output connections #7
[4, 6, 7, 9, 0, 2, 3, 8],#8
[5, 6, 7, 8, 1, 2, 3, 9]#9
], device='cuda', dtype=torch.long)
elif n_input_neurons==4 and n_hidden_neurons==4 and n_output_neurons==4 and n_syn_per_neuron==8:
print("Fixed connection Network 4-4-4 (8syn)")
# neuron i connected to neuron j and k, neuron 0..input_neurons is index
self.synapse_connections = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7],# hidden connections
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[4, 5, 6, 7, 8, 9, 10, 11], # output connections
[4, 5, 6, 7, 8, 9, 10, 11],
[4, 5, 6, 7, 8, 9, 10, 11],
[4, 5, 6, 7, 8, 9, 10, 11],
], device='cuda', dtype=torch.long)
elif n_input_neurons==5 and n_hidden_neurons==5 and n_output_neurons==4:
print("Fixed connection Network 5-5-4 (5syn)")
# neuron i connected to neuron j and k, neuron 0..input_neurons is index
self.synapse_connections = torch.tensor([[0, 1, 2, 3, 4],# hidden connections
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9], # output connections
[5, 6, 7, 8, 9],
[5, 6, 7, 8, 9],
[5, 6, 7, 8, 9],
], device='cuda', dtype=torch.long)
elif n_input_neurons==1 and n_hidden_neurons==0 and n_output_neurons==1:
print("Fixed connection Single")
self.synapse_connections = torch.tensor([[0]], device='cuda', dtype=torch.long)
else:
print("Random connections")
# each synapse is connected also to its post-synaptic neuron, to allow STDP type learning to emerge
self.synapse_connections_post = torch.arange(n_neurons, device='cuda', dtype=torch.long).reshape(n_neurons, -1).repeat(1, n_syn_per_neuron)
# define compartments
self.compartments = [self.neurons, self.synapses]
self.trainable_layers = self.neurons.trainable_layers + self.synapses.trainable_layers
self.track_data = False
def dump_model(self, e, exp_name):
"""Dump model to restore"""
with open(get_data_path(e, exp_name, "Model"), 'wb') as f:
parameters = {}
parameters["neuron"] = [layer.base_parameters.cpu().numpy() for layer in self.neurons.trainable_layers]
parameters["synapse"] = [layer.base_parameters.cpu().numpy() for layer in self.synapses.trainable_layers]
pickle.dump(parameters, f)
def restore_model(self, e, exp_name):
"""Restore model"""
with open(get_data_path(e, exp_name, "Model"), 'rb') as f:
parameters = pickle.load(f)
#TODO: refactor to dump/restore at ENU level and just call those functions
assert len(self.neurons.trainable_layers) == len(parameters["neuron"])
for i in range(len(parameters["neuron"])):
self.neurons.trainable_layers[i].base_parameters = torch.from_numpy(parameters["neuron"][i].astype(np.float32)).cuda()
assert len(self.synapses.trainable_layers) == len(parameters["synapse"])
for i in range(len(parameters["synapse"])):
self.synapses.trainable_layers[i].base_parameters = torch.from_numpy(parameters["synapse"][i].astype(np.float32)).cuda()
@staticmethod
def plot_weights(e, exp_name):
"""Visualize weights of ENU gates"""
sns.set_style("dark")
def calc_average(start, stop):
weights_average = None
for e in range(start, stop, 1000):
with open(get_data_path(e, exp_name, "Model"), 'rb') as f:
parameters = pickle.load(f)
weights = []
for i in range(len(parameters["neuron"])):
weights += [parameters["neuron"][i].astype(np.float32)]
if weights_average is None:
weights_average = weights
else:
for i in range(len(weights_average)):
weights_average[i] += weights[i]
return weights_average
weights_mean1 = calc_average(20000, 30000)
fig, ax = plt.subplots(1, 2, sharex='col', sharey='row')
for i in range(len(weights_mean1)):
ax[i].imshow(weights_mean1[i], cmap="gray")
weights_mean2 = calc_average(30000, 40000)
fig, ax = plt.subplots(1, 2, sharex='col', sharey='row')
for i in range(len(weights_mean2)):
ax[i].imshow(weights_mean2[i], cmap="gray")
fig, ax = plt.subplots(1, 2, sharex='col', sharey='row')
for i in range(len(weights_mean2)):
ax[i].imshow((weights_mean2[i] - weights_mean1[i])**5, cmap="gray")
plt.show()
def dump_network_activity(self, e, exp_name):
"""Dump raw data for visualization"""
with open(get_data_path(e, exp_name, "GlobalNetwork"), 'wb') as f:
pickle.dump(self.vis_data, f)
def print(self):
print("--Neurons--")
self.neurons.print()
print("--Synapses--")
self.synapses.print()
def reset(self):
self.vis_data = []
if self.track_data:
print("Tracking network activity")
for compartment in self.compartments:
compartment.reset()
def forward(self, X):
"""Main computation forward pass"""
# transfer to GPU
X_raw_gpu = torch.from_numpy(X.astype(np.float32)).cuda()
X_gpu = torch.zeros((X.shape[0], X.shape[1], self.n_input_channels), device='cuda', dtype=torch.float32)
X_gpu[:, :, :X_raw_gpu.shape[2]] = X_raw_gpu
# first compute synapses, set input to previous output of connected neuron
# concat our input spiking pattern directly to input to our synapses (the neurons)
# NOTE: this concats in batch dimension, meaning it feeds into input neurons directly spiking pattern, while rest receive input from network
input_to_synapses = torch.cat([X_gpu, self.neurons.out_mem], dim=1)
# connect each synapse randomly to multiple inputs
input_to_synapses_connected = input_to_synapses[:, self.synapse_connections.flatten(), :]
# need feedback connection from neuron to synapse, to allow stdp type rules to emerge (else it has to do it through feedback connections, but less guarentee on connections and cannot distinguise type)
# one synapse has 1 pre-synaptic neuron and 1 post-synaptic neuron, connectection defined in synapse_connections, synapse_connections[i, :] gives all input synapses of that neuron
# so feedback to all it's input synapses through broadcasting backwards
post_neuron_backprop_connected = self.neurons.out_mem[:, self.synapse_connections_post.flatten(), :]
input_to_synapses_connected = torch.cat([input_to_synapses_connected, post_neuron_backprop_connected], dim=-1)
# compute synapse
self.synapses.forward(input_to_synapses_connected)
# then integrate(sum) all outputs of a neurons input synapses, can just reshape into valid shape, since we already randomly connected when computing synapses
# NOTE: each neuron then requires same number of synapses, then reshape by modifying batch dim (which contains syn outputs)
integration = torch.sum(self.synapses.out.reshape((self.n_offspring, self.n_neurons, -1, self.synapses.shape[-1])), dim=2)
# scale by number of synapses
integration /= self.n_syn_per_neuron
self.out_integration = integration
# finally set neuron input to summated connected synapses output
input_to_neurons = integration
out = self.neurons.forward(input_to_neurons)
# output is last neuron output, NOTE: just first channel is returned, since we reshape neurons to channels
self.out = out[:, -self.n_output:, 0].reshape(self.n_offspring, self.n_output)
if self.track_data:
self._track_vis_data(X, input_to_synapses_connected, input_to_neurons)
return self.out
def _track_vis_data(self, X, input_to_synapses_connected, input_to_neurons):
offspring_idx = 0
self.vis_data += [(X[offspring_idx], input_to_neurons[offspring_idx].cpu().numpy(), self.neurons.out[offspring_idx].cpu().numpy(),
input_to_synapses_connected[offspring_idx].cpu().numpy(), self.synapses.out[offspring_idx].cpu().numpy())]
@staticmethod
def plot_network_activity(e, exp_name):
with open(get_data_path(e, exp_name, "GlobalNetwork"), 'rb') as f:
vis_data = pickle.load(f)
X, input_to_neurons, neurons_out, input_to_synapses, synapses_out = map(np.array, zip(*vis_data))
def plot_enu_activity(input, output, title):
n_cells = output.shape[1]
n_cells = np.minimum(10, output.shape[1])
fig, grid = plt.subplots(2, n_cells, sharex='col', sharey='row')
if n_cells==1:
grid[0].plot(input[:, 0, :])
grid[1].plot(output[:, 0, :])
else:
for i in range(n_cells):
grid[0, i].plot(input[:, i, :])
grid[1, i].plot(output[:, i, :])
plt.xlabel("t")
plt.title(title)
#plt.ylabel("")
plt.legend()
plt.figure()
plt.plot(X[:, :, 0])
plot_enu_activity(input_to_neurons, neurons_out, "ENU neuron activity")
plot_enu_activity(input_to_synapses, synapses_out, "ENU synapse activity")
plt.figure()
spike_points = np.where(neurons_out[:, :, 0] > 0)
plt.scatter(spike_points[0], spike_points[1], marker='|')
plt.show()
================================================
FILE: examples/Structure_Evolution/Adaptive_lsm/raw/tools/ExperimentEnvGlobalNetworkSurvival.py
================================================
import pickle
import numpy as np
from tools.Tools import get_data_path
class ExperimentEnvGlobalNetworkSurvival:
"""Wrapper around a given RL environment for a Network of ENUs model,
turns reward into fitness and dumps relevant data"""
def __init__(self, env, exp_name='maze'):
self.env = env
self.exp_name = exp_name
self.n_output = self.env.n_actions
#NOTE: +1 reward neuron
self.n_input_neurons = self.env.n_obs + 1
self.n_agents = self.env.n_agents
def _convert_obs(self, obs, rewards):
n_input_channels_used = 3
X = np.zeros((self.n_agents, self.n_input_neurons, n_input_channels_used))
#X[:, :obs.shape[1], 0] = obs
# Shuffle only obs to avoid topology exploitation, reward neuron linked to EnuGlobal synapse connectivity
X[:, :obs.shape[1], 0] = np.take_along_axis(obs, self.obs_shuffle, axis=1)
# split pos and negative reward to different channels, And set to last input neuron
if rewards is not None:
X[rewards>0, -1, 1] = np.abs(rewards[rewards>0])
X[rewards<=0, -1, 2] = np.abs(rewards[rewards<=0])
return X
def _convert_reward(self, obs, actions, rewards, infos, dones):
fitness = np.copy(rewards)
# first poison is considered positive reward, since learning to learn
#NOTE: dead by env means less reward can be obtained so should implictely reduce overall fitness automatically
fitness[np.logical_and(self._prev_reward_count == 1, rewards != 0)] = 1
# include episode length as extra fitness, since not taking poison would allow survive longer, so should try avoid take poison
fitness[dones==0] += 0.1/4
return fitness
def step(self, y):
# if self.t % 3 != 0:
# actions = np.zeros((self.n_agents), dtype=np.int32) - 1
# else:
# winner take all, in given time window
actions = y
# if all same output, do nothing
# equal_actions = self.y_hist.shape[1] == np.sum(self.y_hist == np.take_along_axis(self.y_hist, actions.reshape(-1, 1), axis=1), axis=-1)
# actions[equal_actions] = -1
# self.y_hist[:] = 0
# take env step
allobs, obs, rewards, dones, infos = self.env.step(actions)
# X = self._convert_obs(obs, rewards)
X=allobs
self._prev_reward_count += rewards!=0
fitness = self._convert_reward(obs, actions, rewards, infos, dones)
self._prev_action = actions
self._prev_obs = obs
return X, rewards, fitness, None
def reset(self):
self.t = 0
self.y_hist = np.zeros((self.n_agents, self.n_output), dtype=np.float32)
self._prev_action = None
self._prev_obs = None
self._prev_reward_count = np.zeros((self.n_agents), dtype=np.float32)
# each time different input/output neurons should have different meaning, to have learning to learn
self.obs_shuffle = np.argsort(np.random.randn(self.n_agents, self.n_input_neurons - 1), axis=1, kind='mergesort')
self.action_shuffle = np.argsort(np.random.randn(self.n_agents, self.n_output), axis=1, kind='mergesort')
# reset env
self.allobs,self.obs = self.env.reset()
# return self._convert_obs(self.obs, None)
return self.allobs
def render(self):
if self.t%4==0:
self.env.render()
def track_vis_data(self, vis_data, model, X, y_est, t):
n_fetch = 128
# TODO: also get our gates from the model
vis_data+=[(X[:n_fetch, :], y_est[:n_fetch, :])]
def dump_vis_data(self, vis_data, fitness_per_offspring, e):
with open(get_data_path(e, self.exp_name, "output"), 'wb') as f:
pickle.dump((vis_data, fitness_per_offspring), f)
@staticmethod
def load_vis_data(e, exp_name):
with open(get_data_path(e, exp_name, "output"), 'rb') as f:
vis_data, fitness_per_offspring = pickle.load(f)
return vis_data, fitness_per_offspring
@staticmethod
def plot_vis_data(e, exp_name):
vis_data, fitness_per_offspring = ExperimentEnvGlobalNetworkSurvival.load_vis_data(e, exp_name)
================================================
FILE: examples/Structure_Evolution/Adaptive_lsm/raw/tools/MazeTurnEnvVec.py
================================================
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tools.Tools import save_fig, get_data_path
# np.random.seed(0)
class MazeTurnEnvVec:
"""Vectorized RL T-Maze environment written in pure Numpy. We require an efficient environment since we need to evaluate
and run up to thousands of offspring in parallel"""
def __init__(self, n_agents, n_steps):
# 4 important points, start point, decision point, food point, dead point.
# just generate a very large matrix that could fit any maze of any size, then can generate smaller maze as well
self.n_actions = 3
self.n_obs = 3
self.max_size = 7
self.n_agents = n_agents
self.n_steps = n_steps
self.window = plt.figure()
self.t_maze = True
self.turn_based = False
# steps can be longer if poison and need to turn around
self.steps_to_food = 2
if self.t_maze:
self.steps_to_food = 3
self.steps_to_food += self.steps_to_food*2
# give some extra leniency
self.steps_to_food *= 2
def step(self, actions):
# L R U D
# TODO: check legal action or not..
pos_copy = np.copy(self.agents_pos)
actions = np.copy(actions)
# GIVE TIME UPDATE WEIGHTS
# actions[self.agents_reset > 0] = -1
# actions[self.agent_energy<=0] = -1
self.agents_reset[self.agents_reset > 0] -= 1
# if turn based
if self.turn_based:
Forward = actions == 0
self.agents_pos[np.logical_and(Forward, self.agent_directions == 0), 1] += 1
self.agents_pos[np.logical_and(Forward, self.agent_directions == 2), 1] -= 1
self.agents_pos[np.logical_and(Forward, self.agent_directions == 1), 0] -= 1
self.agents_pos[np.logical_and(Forward, self.agent_directions == 3), 0] += 1
L = actions == 1
self.agent_directions[L] += 1
R = actions == 2
self.agent_directions[R] -= 1
self.agent_directions[self.agent_directions > 3] = 0
self.agent_directions[self.agent_directions < 0] = 3
else:
# or just direct movement
U = actions == 2
D = actions == 1
R = actions == 0
if self.agents_pos[U].size>0:
self.agents_pos[U, 0] += 1
self.agent_directions[U] = 3
if self.agents_pos[D].size>0:
self.agents_pos[D, 0] -= 1
self.agent_directions[D] = 1
if self.t_maze and self.agents_pos[R].size>0:
self.agents_pos[R, 1] += 1
self.agent_directions[R] = 0
# UNDO MOVES THAT GOT AGENT INTO WALL
self.current_cells = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1]]
self.agents_pos[self.current_cells==1] = pos_copy[self.current_cells==1]
movement_loss = np.prod(self.agents_pos==pos_copy, axis=-1)
# CHECK IF FOOD CONSUMED, is reward + pos reset
consumed_food = np.prod(self.agents_pos==self.food_pos[:, 0, :2], axis=-1).astype(np.bool)
consumed_pois = np.prod(self.agents_pos==self.food_pos[:, 1, :2], axis=-1).astype(np.bool)
self.consumed_count += consumed_food.astype(np.int32)
self.consumed_count_total += consumed_food.astype(np.int32)
self.consumed_count_pois += consumed_pois.astype(np.int32)
self._reset_pos(np.logical_or(consumed_food, consumed_pois))
# self._reset_pos_pois(consumed_pois)
# self._reset_food(self.consumed_count==self.swap_limit, prob=0.0)
# reset food for agents that ate food, and swap with some probability
self._reset_food(self.consumed_count==5, prob=0.5)
self.rewards = consumed_food.astype(np.float32) - consumed_pois.astype(np.float32) #* 0.5 #- movement_loss.astype(np.float32) * 0.01
# get observation from current position of each agent
self.agent_allobs,self.obs = self._get_obs_from_pos()
# instant dead on second poison
self.agent_energy += np.where(self.consumed_count_pois<=1, np.abs(self.rewards) * self.steps_to_food, -self.agent_energy * np.abs(self.rewards))
# energy decay to encourage exploration, agent dies if running out of energy
self.agent_energy = np.minimum(self.agent_energy, self.steps_to_food)
self.agent_energy -= 1.0/4
dones = self.agent_energy<=0
return self.agent_allobs,self.obs, self.rewards, dones, None
def _reset_pos(self, idxs):
self.agents_pos[idxs] = [self.start_point, 2] # set X pos
self.agent_directions[idxs] = 0
if self.max_size==5:
self.agent_directions[idxs] = 1
self.agents_reset[idxs] = 0
self.agents_reset_count[idxs] += 1
def _reset_pos_pois(self, idxs):
#NOTE: 2 since we call reset twice!
#NOTE: turning around already cost 8 steps, then 4x4 more is 16+8, 24, so reset should be much worse
self.agents_reset[np.logical_and(idxs, self.agents_reset_count>2)] = 64
def _reset_food(self, idxs, prob=0.5):
# swap food with some probability, avoids agent overfitting on environment
swap = np.take_along_axis(self.random_swap_matrix, self.consumed_count_total.reshape(-1, 1), axis=1).ravel()
swap_idxs = swap * idxs
food_loc = np.copy(self.food_pos[swap_idxs, 0, :])
pois_loc = np.copy(self.food_pos[swap_idxs, 1, :])
self.food_pos[swap_idxs, 0, :] = pois_loc
self.food_pos[swap_idxs, 1, :] = food_loc
# set maze value
self.mazes[np.arange(self.mazes.shape[0]), self.food_pos[:, 0, 0], self.food_pos[:, 0, 1]] = 2
self.mazes[np.arange(self.mazes.shape[0]), self.food_pos[:, 1, 0], self.food_pos[:, 1, 1]] = 3
self.consumed_count[idxs] = 0
def reset(self):
self.consumed_count = np.zeros((self.n_agents), dtype=np.int32)
self.consumed_count_total = np.zeros_like(self.consumed_count)
# consistent swapping such that if agent eat food once for all agents swapped with same seed, fair fitness comparison
max_eat = self.n_steps
self.random_swap_matrix = np.random.uniform(0, 1, size=(1, max_eat)) >= 0.5
self.random_swap_matrix = np.repeat(self.random_swap_matrix, int(self.n_agents), axis=0)
self.agent_energy = np.zeros((self.n_agents), dtype=np.float32) + self.steps_to_food
self.consumed_count_pois = np.zeros_like(self.consumed_count)
#self.swap_limit = np.random.randint(1, 5, size=1)
#self.swap_limit = np.random.randint(1, 4, size=self.n_agents)
self.mazes = np.ones((self.n_agents, self.max_size, self.max_size), dtype=np.int32)
#TODO: support variable maze length
self.start_point = int(self.max_size/2)
if self.t_maze:
self.mazes[:, self.start_point, 2:-1] = 0
self.mazes[:, 1:-1, -2] = 0
# FOOD either at -1,-1 or -1,1?
# two foods: x, y, value
self.food_pos = np.zeros((self.n_agents, 2, 2), dtype=np.int32)
self.food_pos[:, :, 1] = self.max_size - 2
self.food_pos[:, 1, 0] = 1
self.food_pos[:, 0, 0] = self.max_size - 2
self._reset_food(np.ones(self.food_pos.shape[0], dtype=np.bool), prob=0.5)
else:
self.mazes[:, 1:-1, 1] = 0
# two foods: x, y, value
self.food_pos = np.zeros((self.n_agents, 2, 2), dtype=np.int32)
self.food_pos[:, :, 1] = 1
self.food_pos[:, 0, 0] = 1
self.food_pos[:, 1, 0] = (self.max_size - 2)
self._reset_food(np.ones(self.food_pos.shape[0], dtype=np.bool), prob=0.5)
# AGENT
self.agents_pos = np.ones((self.n_agents, 2), dtype=np.int32)
self.agents_reset = np.zeros((self.n_agents), dtype=np.int32)
self.agents_reset_count = np.zeros_like(self.agents_reset)
self.agent_directions = np.zeros((self.n_agents), dtype=np.int32)
self._reset_pos(np.arange(self.agents_pos.shape[0]))
# OBS
self.agent_allobs,self.obs = self._get_obs_from_pos()
return self.agent_allobs,self.obs
def _get_obs_from_pos(self):
# obs is neighbouring cell states around agent
obs = np.zeros((self.n_agents, self.n_obs), dtype=np.float32)
raw_obs = np.zeros(self.n_agents, dtype=np.int32)
# get observation based on direction agent is facing
leftobs = np.zeros(self.n_agents)
rightobs = np.zeros(self.n_agents)
backobs = np.zeros(self.n_agents)
# get observation based on direction agent is facing
D = self.agent_directions == 0
raw_obs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] + 1][D]
leftobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] - 1, self.agents_pos[:, 1]][D]
rightobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] + 1, self.agents_pos[:, 1]][D]
backobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1]-1][D]
D = self.agent_directions == 2
raw_obs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] - 1][D]
leftobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] + 1, self.agents_pos[:, 1]][D]
rightobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] - 1, self.agents_pos[:, 1]][D]
backobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1]+1][D]
D = self.agent_directions == 1
raw_obs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] - 1, self.agents_pos[:, 1]][D]
leftobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] - 1][D]
rightobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] + 1][D]
backobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0]+1, self.agents_pos[:, 1]][D]
D = self.agent_directions == 3
raw_obs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] + 1, self.agents_pos[:, 1]][D]
leftobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] + 1][D]
rightobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] - 1][D]
backobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0]-1, self.agents_pos[:, 1]][D]
# mark what was observed at different index
obs[raw_obs == 1, 0] = 1
obs[raw_obs == 2, 1] = 1
obs[raw_obs == 3, 2] = 1
allobs=np.squeeze(np.dstack((leftobs,raw_obs,rightobs,backobs)))
return allobs, obs
def render(self):
plt.clf()
sns.set_style("white")
#TODO: support render all mazes? can reshape to square?
max_render = 1
flattened_render = np.dstack(np.split(self.mazes[18, :], max_render, axis=0)).reshape(self.mazes.shape[1], -1)
flattened_render[flattened_render>1] = 0
plt.axis('off')
plt.imshow(flattened_render,cmap='bone')
for j in range(1):
i=18
marker = ">"
if self.agent_directions[i] == 1:
marker = "^"
if self.agent_directions[i] == 2:
marker = "<"
if self.agent_directions[i] == 3:
marker = "v"
obs_color = "black"
if self.obs[i, 0] == 1:
obs_color = "gray"
if self.obs[i, 1] == 1:
obs_color = "green"
if self.obs[i, 2] == 1:
obs_color = "red"
alpha = 1
if self.agent_energy[i]<=0:
alpha = 1
plt.scatter(self.agents_pos[i, 1] + j * self.mazes.shape[1], self.agents_pos[i, 0], color="skyblue", alpha=alpha, marker=marker)
plt.scatter(self.agents_pos[i, 1] + j * self.mazes.shape[1], self.agents_pos[i, 0], color=obs_color,alpha=alpha, marker=marker, s=3)
plt.scatter(self.food_pos[i, 0, 1] + j * self.mazes.shape[1], self.food_pos[i, 0, 0], color="green", alpha=1, marker="o")
plt.scatter(self.food_pos[i, 1, 1] + j * self.mazes.shape[1], self.food_pos[i, 1, 0], color="red", alpha=1, marker="o")
plt.pause(0.001)
#plt.pause(2)
@staticmethod
def load_vis_data(e, exp_name):
with open(get_data_path(e, exp_name, "output"), 'rb') as f:
vis_data, fitness_per_offspring = pickle.load(f)
return vis_data, fitness_per_offspring
@staticmethod
def plot_vis_data(e, exp_name):
import matplotlib.pyplot as plt
from cycler import cycler
import seaborn as sns
sns.set_style("whitegrid")
vis_data, fitness_per_offspring = MazeTurnEnvVec.load_vis_data(e, exp_name)
offspring_idx = 0
#x, y_est, y = np.array(vis_data).transpose(1, 2, 0, 3)
X, Y_est = map(np.array, zip(*vis_data))
X_base, y_est_base = X[:, offspring_idx], Y_est[:, offspring_idx]
#X_base, y_est_base = X_base[:300], y_est_base[:300]
#X_base = np.max(X_base, axis=1)
# --- normal output single example---
plt.rc('axes', prop_cycle=(cycler('color', ['gray', '#ff7f0e', '#9467bd', '#8c564b', '#e377c2', '#17becf'])))
# OLD METHOD!
plt.figure()
#NOTE: last neuron is always reward neuron
plt.plot(-X_base[:, :-1, 0], label="N-ENUs input", alpha=0.7)
plt.plot(- np.max(X_base, axis=1)[:, 1], label="Positive reward", alpha=0.7, color='#2ca02c')
plt.plot(- np.max(X_base, axis=1)[:, 2], label="Negative reward", alpha=0.7, color='#d62728')
#plt.gca().set_color_cycle(['orange', 'purple', 'brown'])
plt.plot(y_est_base[:, :], label="N-ENUs output", alpha=0.8)
plt.legend(loc='upper right')
save_fig(e, exp_name, "single_episode")
plt.rc('axes', prop_cycle=(cycler('color', ['gray', '#ff7f0e', '#9467bd', '#8c564b', '#e377c2', '#17becf'])))
# new method!
fig, grid = plt.subplots(2, sharex=True)
# input
#grid[1].set_prop_cycle(cycler('color', ['gray', '#ff7f0e', '#1f77b4']))
grid[1].set_prop_cycle(cycler('color', ['gray', '#F5B041', '#2E86C1']))
grid[1].plot(X_base[:, :-1, 0], label="N-ENUs input", alpha=0.7, linewidth=2)
grid[1].plot(np.max(X_base, axis=1)[:, 1], label="Positive reward", alpha=0.7, color='#1ABC9C', linewidth=2)
grid[1].plot(np.max(X_base, axis=1)[:, 2], label="Negative reward", alpha=0.7, color='#CB4335', linewidth=2)
grid[1].legend(['Sensor (wall)', 'Sensor (red)', 'Sensor (green)', 'Positive reward', 'Negative reward'], loc='upper right')
grid[1].set_ylabel('Neuron output')
# output
grid[0].set_prop_cycle(cycler('color', ['#9467bd', '#e377c2', '#17becf','#8c564b']))
grid[0].plot(y_est_base[:, :], label="N-ENUs output", alpha=0.8, linewidth=2)
grid[0].legend(['ENU-NN (left)', 'ENU-NN (right)', 'ENU-NN (forward)'],loc='upper right')
plt.xlabel('t')
grid[0].set_ylabel('ENU neuron output')
plt.xlim(-5, X_base.shape[0]+10)
save_fig(e, exp_name, "single_episode_dual")
#plt.show()
@staticmethod
def plot_rollout_data(e, exp_name):
#TODO: dump rollout as array not the actual plots
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("white")
import os
rollout_path = "./" + get_data_path(e, exp_name, "rollout").split(".")[1][:-1]+"/"
rollout_files = sorted(os.listdir(rollout_path))
rollouts = []
for i in range(4, 200, 4):
file = "rollout_{}_.png".format(i)
print(file)
rollout = plt.imread(rollout_path + file)
rollout = rollout[80:450, 200:500]
rollout = np.where(rollout[:, :, [0]]> 0.98, 1, rollout)
rollouts.append(rollout)
# plt.imshow(rollout)
# plt.show()
# for rollout in rollouts:
vert_bar = np.zeros((rollout.shape[0], 5, 4))
vert_bar[::] += 0.5
# get red -> learned go other way
# food swapped -> sees red -> learned turn around
rollouts1 = [rollouts[1], rollouts[9], rollouts[10], rollouts[11], rollouts[17], rollouts[18]]#, rollouts[19]
#, rollouts[28]
#rollouts[24],
rollouts2 = [rollouts[25], rollouts[29], rollouts[30], rollouts[31], rollouts[32], rollouts[33]]
plt.axis('off')
plt.imshow(np.vstack([np.column_stack(rollouts1), np.column_stack(rollouts2)]), cmap='gray')
save_fig(e, exp_name, "rollout_combined")
#plt.show()
if __name__ == '__main__':
"""Test function"""
n_offspring = 1024
envs = MazeTurnEnvVec(n_offspring, n_steps=400)
envs.n_pseudo_env = 8
while True:
envs.reset()
opt_actions_up = [0,0,0,0,1,0]
opt_actions_down = [0,0,0,0,2,0]
opt_actions = [opt_actions_up, opt_actions_down]
opt_current = 0
total_reward = 0
rewards = np.zeros((n_offspring))
rewards_all = np.zeros_like(rewards)
k = 0
n_steps = 200
for i in range(n_steps):
actions = np.random.randint(0, envs.n_actions, size=n_offspring)
#actions[:] = opt_actions_up[i%len(opt_actions_up)]
#actions[0] = opt_actions[opt_current][k % len(opt_actions_up)]
obs, rewards,_,_ = envs.step(actions=actions)
total_reward += (rewards[0] * 100) + 1
rewards_all += (rewards * 100) + 1
#print(rewards[0])
#print(obs[0], rewards[0])
envs.render()
plt.pause(0.5)
k += 1
if rewards[0] < 0:
#print(rewards_last[0])
if opt_current==0:
opt_current = 1
else:
opt_current = 0
if rewards[0] != 0:
k = 0
total_reward/=n_steps
rewards_all/=n_steps
print(rewards_all[0], np.mean(rewards_all), np.std(rewards_all), np.std(rewards_all)/np.mean(rewards_all))
print(rewards_all[:10])
================================================
FILE: examples/Structure_Evolution/EB-NAS/acc_predictor/adaptive_switching.py
================================================
import utils
import numpy as np
from acc_predictor.factory import get_acc_predictor
class AdaptiveSwitching:
""" ensemble surrogate model """
""" try all available models, pick one based on 10-fold crx vld """
def __init__(self, n_fold=10):
# self.model_pool = ['rbf', 'gp', 'mlp', 'carts']
self.model_pool = ['rbf', 'gp', 'carts']
self.n_fold = n_fold
self.name = 'adaptive switching'
self.model = None
# self.predictor_pool = []
def fit(self, train_data, train_target):
self._n_fold_validation(train_data, train_target, n=self.n_fold)
# for p in self.predictor_pool:
# p.fit(train_data,train_target)
def _n_fold_validation(self, train_data, train_target, n=10):
n_samples = len(train_data)
perm = np.random.permutation(n_samples)
kendall_tau = np.full((n, len(self.model_pool)), np.nan)
all_predict_result=[]
for i, tst_split in enumerate(np.array_split(perm, n)):
trn_split = np.setdiff1d(perm, tst_split, assume_unique=True)
rl=[]
# loop over all considered surrogate model in pool
for j, model in enumerate(self.model_pool):
acc_predictor = get_acc_predictor(model, train_data[trn_split], train_target[trn_split])
result = acc_predictor.predict(train_data[tst_split])
rl.append(result)
rmse, rho, tau = utils.get_correlation(result, train_target[tst_split])
kendall_tau[i, j] = tau
all_predict_result.append(rl)
winner = int(np.argmax(np.mean(kendall_tau, axis=0) - np.std(kendall_tau, axis=0)))
print("winner model = {}, tau = {}".format(self.model_pool[winner],
np.mean(kendall_tau, axis=0)[winner]))
self.winner = self.model_pool[winner]
# re-fit the winner model with entire data
# acc_predictor = get_acc_predictor(self.model_pool[winner], train_data, train_target)
# self.model = acc_predictor
def predict(self, test_data):
return self.model.predict(test_data)
================================================
FILE: examples/Structure_Evolution/EB-NAS/acc_predictor/carts.py
================================================
# implementation based on
# https://github.com/yn-sun/e2epp/blob/master/build_predict_model.py
# and https://github.com/HandingWang/RF-CMOCO
import numpy as np
from sklearn.tree import DecisionTreeRegressor
class CART:
""" Classification and Regression Tree """
def __init__(self, n_tree=1000):
self.n_tree = n_tree
self.name = 'carts'
self.model = None
@staticmethod
def _make_decision_trees(train_data, train_label, n_tree):
feature_record = []
tree_record = []
for i in range(n_tree):
sample_idx = np.arange(train_data.shape[0])
np.random.shuffle(sample_idx)
train_data = train_data[sample_idx, :]
train_label = train_label[sample_idx]
feature_idx = np.arange(train_data.shape[1])
np.random.shuffle(feature_idx)
n_feature = np.random.randint(1, train_data.shape[1] + 1)
selected_feature_ids = feature_idx[0:n_feature]
feature_record.append(selected_feature_ids)
dt = DecisionTreeRegressor()
dt.fit(train_data[:, selected_feature_ids], train_label)
tree_record.append(dt)
return tree_record, feature_record
def fit(self, train_data, train_label):
self.model = self._make_decision_trees(train_data, train_label, self.n_tree)
def predict(self, test_data):
assert self.model is not None, "carts does not exist, call fit to obtain cart first"
# redundant variable device
trees, features = self.model[0], self.model[1]
test_num, n_tree = len(test_data), len(trees)
predict_labels = np.zeros((test_num, 1))
for i in range(test_num):
this_test_data = test_data[i, :]
predict_this_list = np.zeros(n_tree)
for j, (tree, feature) in enumerate(zip(trees, features)):
predict_this_list[j] = tree.predict([this_test_data[feature]])[0]
# find the top 100 prediction
predict_this_list = np.sort(predict_this_list)
predict_this_list = predict_this_list[::-1]
this_predict = np.mean(predict_this_list)
predict_labels[i, 0] = this_predict
return predict_labels
================================================
FILE: examples/Structure_Evolution/EB-NAS/acc_predictor/factory.py
================================================
def get_acc_predictor(model, inputs, targets):
if model == 'rbf':
from acc_predictor.rbf import RBF
acc_predictor = RBF()
acc_predictor.fit(inputs, targets)
elif model == 'carts':
from acc_predictor.carts import CART
acc_predictor = CART(n_tree=5000)
acc_predictor.fit(inputs, targets)
elif model == 'gp':
from acc_predictor.gp import GP
acc_predictor = GP()
acc_predictor.fit(inputs, targets)
elif model == 'mlp':
from acc_predictor.mlp import MLP
acc_predictor = MLP(n_feature=inputs.shape[1])
acc_predictor.fit(x=inputs, y=targets)
elif model == 'as':
from acc_predictor.adaptive_switching import AdaptiveSwitching
acc_predictor = AdaptiveSwitching()
acc_predictor.fit(inputs, targets)
else:
raise NotImplementedError
return acc_predictor
================================================
FILE: examples/Structure_Evolution/EB-NAS/acc_predictor/gp.py
================================================
from pydacefit.regr import regr_constant
from pydacefit.dace import DACE, regr_linear, regr_quadratic
from pydacefit.corr import corr_gauss, corr_cubic, corr_exp, corr_expg, corr_spline, corr_spherical
class GP:
""" Gaussian Process (Kriging) """
def __init__(self, regr='linear', corr='gauss'):
self.regr = regr
self.corr = corr
self.name = 'gp'
self.model = None
def fit(self, train_data, train_label):
if self.regr == 'linear':
regr = regr_linear
elif self.regr == 'constant':
regr = regr_constant
elif self.regr == 'quadratic':
regr = regr_quadratic
else:
raise NotImplementedError("unknown GP regression")
if self.corr == 'gauss':
corr = corr_gauss
elif self.corr == 'cubic':
corr = corr_cubic
elif self.corr == 'exp':
corr = corr_exp
elif self.corr == 'expg':
corr = corr_expg
elif self.corr == 'spline':
corr = corr_spline
elif self.corr == 'spherical':
corr = corr_spherical
else:
raise NotImplementedError("unknown GP correlation")
self.model = DACE(
regr=regr, corr=corr, theta=1.0, thetaL=0.00001, thetaU=100)
self.model.fit(train_data, train_label)
def predict(self, test_data):
assert self.model is not None, "GP does not exist, call fit to obtain GP first"
return self.model.predict(test_data)
================================================
FILE: examples/Structure_Evolution/EB-NAS/acc_predictor/mlp.py
================================================
import copy
import torch
import numpy as np
import torch.nn as nn
from utils import get_correlation
class Net(nn.Module):
# N-layer MLP
def __init__(self, n_feature, n_layers=2, n_hidden=300, n_output=1, drop=0.2):
super(Net, self).__init__()
self.stem = nn.Sequential(nn.Linear(n_feature, n_hidden), nn.ReLU())
hidden_layers = []
for _ in range(n_layers):
hidden_layers.append(nn.Linear(n_hidden, n_hidden))
hidden_layers.append(nn.ReLU())
self.hidden = nn.Sequential(*hidden_layers)
self.regressor = nn.Linear(n_hidden, n_output) # output layer
self.drop = nn.Dropout(p=drop)
def forward(self, x):
x = self.stem(x)
x = self.hidden(x)
x = self.drop(x)
x = self.regressor(x) # linear output
return x
@staticmethod
def init_weights(m):
if type(m) == nn.Linear:
n = m.in_features
y = 1.0 / np.sqrt(n)
m.weight.data.uniform_(-y, y)
m.bias.data.fill_(0)
class MLP:
""" Multi Layer Perceptron """
def __init__(self, **kwargs):
self.model = Net(**kwargs)
self.name = 'mlp'
def fit(self, **kwargs):
self.model = train(self.model, **kwargs)
def predict(self, test_data, device='cpu'):
return predict(self.model, test_data, device=device)
def train(net, x, y, trn_split=0.8, pretrained=None, device='cpu',
lr=8e-4, epochs=2000, verbose=False):
n_samples = x.shape[0]
target = torch.zeros(n_samples, 1)
perm = torch.randperm(target.size(0))
trn_idx = perm[:int(n_samples * trn_split)]
vld_idx = perm[int(n_samples * trn_split):]
inputs = torch.from_numpy(x).float()
target[:, 0] = torch.from_numpy(y).float()
# back-propagation training of a NN
if pretrained is not None:
print("Constructing MLP surrogate model with pre-trained weights")
init = torch.load(pretrained, map_location='cpu')
net.load_state_dict(init)
best_net = copy.deepcopy(net)
else:
# print("Constructing MLP surrogate model with "
# "sample size = {}, epochs = {}".format(x.shape[0], epochs))
# initialize the weights
# net.apply(Net.init_weights)
net = net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
criterion = nn.SmoothL1Loss()
# criterion = nn.MSELoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(epochs), eta_min=0)
best_loss = 1e33
for epoch in range(epochs):
trn_inputs = inputs[trn_idx]
trn_labels = target[trn_idx]
loss_trn = train_one_epoch(net, trn_inputs, trn_labels, criterion, optimizer, device)
loss_vld = infer(net, inputs[vld_idx], target[vld_idx], criterion, device)
scheduler.step()
# if epoch % 500 == 0 and verbose:
# print("Epoch {:4d}: trn loss = {:.4E}, vld loss = {:.4E}".format(epoch, loss_trn, loss_vld))
if loss_vld < best_loss:
best_loss = loss_vld
best_net = copy.deepcopy(net)
validate(best_net, inputs, target, device=device)
return best_net.to('cpu')
def train_one_epoch(net, data, target, criterion, optimizer, device):
net.train()
optimizer.zero_grad()
data, target = data.to(device), target.to(device)
pred = net(data)
loss = criterion(pred, target)
loss.backward()
optimizer.step()
return loss.item()
def infer(net, data, target, criterion, device):
net.eval()
with torch.no_grad():
data, target = data.to(device), target.to(device)
pred = net(data)
loss = criterion(pred, target)
return loss.item()
def validate(net, data, target, device):
net.eval()
with torch.no_grad():
data, target = data.to(device), target.to(device)
pred = net(data)
pred, target = pred.cpu().detach().numpy(), target.cpu().detach().numpy()
rmse, rho, tau = get_correlation(pred, target)
# print("Validation RMSE = {:.4f}, Spearman's Rho = {:.4f}, Kendall’s Tau = {:.4f}".format(rmse, rho, tau))
return rmse, rho, tau, pred, target
def predict(net, query, device):
if query.ndim < 2:
data = torch.zeros(1, query.shape[0])
data[0, :] = torch.from_numpy(query).float()
else:
data = torch.from_numpy(query).float()
net = net.to(device)
net.eval()
with torch.no_grad():
data = data.to(device)
pred = net(data)
return pred.cpu().detach().numpy()
================================================
FILE: examples/Structure_Evolution/EB-NAS/acc_predictor/rbf.py
================================================
from pySOT.surrogate import RBFInterpolant, CubicKernel, TPSKernel, LinearTail, ConstantTail
class RBF:
""" Radial Basis Function """
def __init__(self, kernel='cubic', tail='linear'):
self.kernel = kernel
self.tail = tail
self.name = 'rbf'
self.model = None
def fit(self, train_data, train_label):
if self.kernel == 'cubic':
kernel = CubicKernel
elif self.kernel == 'tps':
kernel = TPSKernel
else:
raise NotImplementedError("unknown RBF kernel")
if self.tail == 'linear':
tail = LinearTail
elif self.tail == 'constant':
tail = ConstantTail
else:
raise NotImplementedError("unknown RBF tail")
self.model = RBFInterpolant(dim=train_data.shape[1], kernel=kernel(), tail=tail(train_data.shape[1]))
for i in range(len(train_data)):
self.model.add_points(train_data[i, :], train_label[i])
def predict(self, test_data):
assert self.model is not None, "RBF model does not exist, call fit to obtain rbf model first"
return self.model.predict(test_data)
================================================
FILE: examples/Structure_Evolution/EB-NAS/cellmodel.py
================================================
import os
from functools import partial
from typing import List, Type
from operations import *
from motifs import *
from utils import drop_path
from timm.models import register_model
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.model_zoo.base_module import BaseModule
from torchvision import transforms
EVO=True
class EvoCell2(nn.Module):
def __init__(self,motif, C_prev_prev, C_prev, C, reduction, reduction_prev, act_fun):
# print(C_prev_prev, C_prev, C, reduction)
super(EvoCell2, self).__init__()
self.act_fun = act_fun
self.reduction = reduction
self.motif=motif
self.back_connection=False
if reduction:
self.fun = FactorizedReduce(
C_prev, C * 3, act_fun=act_fun
)
self.multiplier = 3
else:
if reduction_prev:
self.preprocess0 = FactorizedReduce(
C_prev_prev, C, act_fun=act_fun)
else:
self.preprocess0 = ReLUConvBN(
C_prev_prev, C, 1, 1, 0, act_fun=act_fun)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, act_fun=act_fun)
op_names, indices = zip(*motif.normal)
concat = motif.normal_concat
self._compile(C, op_names, indices, concat, reduction)
def _compile(self, C, op_names, indices, concat, reduction):
assert len(op_names) == len(indices)
# self._steps = len(op_names) // 2
self._concat = concat
self.multiplier = len(concat)
self._ops = nn.ModuleList()
self._ops_back = nn.ModuleList()
back_begin_index = 0
for i, (name, index) in enumerate(zip(op_names, indices)):
# print(name, index)
if '_back' in name:
self.back_connection=True
back_begin_index = i
break
stride = 2 if reduction and index < 2 else 1
op = OPS[name](C, stride, True, act_fun=self.act_fun)
self._ops += [op]
if self.back_connection:
for name, index in zip(op_names[back_begin_index:], indices[back_begin_index:]):
op = OPS[name.replace('_back', '')](
C, 1, True, act_fun=self.act_fun)
self._ops_back += [op]
if self.back_connection:
self._indices_forward = indices[:back_begin_index]
self._indices_backward = indices[back_begin_index:]
else:
self._indices_backward = []
self._indices_forward = indices
self._steps = len(self._indices_forward) // 2
def forward(self, s0, s1, drop_prob):
if self.reduction:
return self.fun(s1)
# print('s0',s0.shape)
s0 = self.preprocess0(s0)
# print(s0.shape)
# print('s1',s1.shape)
s1 = self.preprocess1(s1)
# print(s1.shape)
states = [s0, s1]
for i in range(self._steps):
i1=self._indices_forward[2 * i]
i2=self._indices_forward[2 * i + 1]
h1 = states[i1]
h2 = states[i2]
op1 = self._ops[2 * i]
op2 = self._ops[2 * i + 1]
h1 = op1(h1)
h2 = op2(h2)
if self.training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
s = h1 + h2
if self.back_connection:
if i != 0:
s_back = self._ops_back[i - 1](s)
states[self._indices_backward[i - 1]] = states[self._indices_backward[i - 1]] + s_back
states += [s]
outputs = torch.cat([states[i]
for i in self._concat], dim=1) # N,C,H, W
return outputs
# return self.node(outputs)
class EvoCell3(nn.Module):
def __init__(self,motif, C_prev_prev_prev, C_prev_prev, C_prev, C, reduction, reduction_prev, reduction_prev_prev, act_fun):
# print(C_prev_prev_prev,C_prev_prev, C_prev, C, reduction,reduction_prev, reduction_prev_prev)
super(EvoCell3, self).__init__()
self.act_fun = act_fun
self.reduction = reduction
self.motif=motif
self.back_connection=False
if reduction:
self.fun = FactorizedReduce(C_prev, C * 3, act_fun=act_fun)
self.multiplier = 3
else:
if reduction_prev:
self.preprocess1 = FactorizedReduce(C_prev_prev, C, act_fun=act_fun)
else:
self.preprocess1 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, act_fun=act_fun)
if int(reduction_prev_prev)+int(reduction_prev)==1:
self.preprocess0 = FactorizedReduce(C_prev_prev_prev, C, act_fun=act_fun)
elif int(reduction_prev_prev)+int(reduction_prev)==2:
self.preprocess0 = F0(C_prev_prev_prev, C, act_fun=act_fun)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev_prev, C, 1, 1, 0, act_fun=act_fun)
self.preprocess2 = ReLUConvBN(C_prev, C, 1, 1, 0, act_fun=act_fun)
op_names, indices = zip(*motif.normal)
concat = motif.normal_concat
self._compile(C, op_names, indices, concat, reduction)
def _compile(self, C, op_names, indices, concat, reduction):
assert len(op_names) == len(indices)
# self._steps = len(op_names) // 2
self._concat = concat
self.multiplier = len(concat)
self._ops = nn.ModuleList()
self._ops_back = nn.ModuleList()
back_begin_index = 0
for i, (name, index) in enumerate(zip(op_names, indices)):
# print(name, index)
if '_back' in name:
self.back_connection=True
back_begin_index = i
break
stride = 2 if reduction and index < 2 else 1
op = OPS[name](C, stride, True, act_fun=self.act_fun)
self._ops += [op]
if self.back_connection:
for name, index in zip(op_names[back_begin_index:], indices[back_begin_index:]):
op = OPS[name.replace('_back', '')](
C, 1, True, act_fun=self.act_fun)
self._ops_back += [op]
if self.back_connection:
self._indices_forward = indices[:back_begin_index]
self._indices_backward = indices[back_begin_index:]
else:
self._indices_backward = []
self._indices_forward = indices
self._steps = len(self._indices_forward) // 3
def forward(self, s0, s1, s2, drop_prob):
if self.reduction:
return self.fun(s2)
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
s2 = self.preprocess2(s2)
states = [s0, s1, s2]
for i in range(self._steps):
i1=self._indices_forward[3 * i]
i2=self._indices_forward[3 * i + 1]
i3=self._indices_forward[3 * i + 2]
h1 = states[i1]
h2 = states[i2]
h3 = states[i3]
op1 = self._ops[3 * i]
op2 = self._ops[3 * i + 1]
op3 = self._ops[3 * i + 2]
h1 = op1(h1)
h2 = op2(h2)
h3 = op3(h3)
if self.training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
if not isinstance(op3, Identity):
h3 = drop_path(h3, drop_prob)
s = h1 + h2 + h3
if self.back_connection:
if i != 0:
s_back = self._ops_back[i - 1](s)
states[self._indices_backward[i - 1]] = states[self._indices_backward[i - 1]] + s_back
states += [s]
outputs = torch.cat([states[i] for i in self._concat], dim=1) # N,C,H, W
return outputs
# return self.node(outputs)
class EvoCell4(nn.Module):
def __init__(self,motif, C_prev_prev_prev_prev,C_prev_prev_prev, C_prev_prev, C_prev, C, reduction, reduction_prev, reduction_prev_prev,reduction_prev_prev_prev, act_fun):
# print(C_prev_prev_prev_prev,C_prev_prev_prev,C_prev_prev, C_prev, C, reduction,reduction_prev, reduction_prev_prev,reduction_prev_prev_prev)
super(EvoCell4, self).__init__()
self.act_fun = act_fun
self.reduction = reduction
self.motif=motif
self.back_connection=False
if reduction:
self.fun = FactorizedReduce(C_prev, C * 3, act_fun=act_fun)
self.multiplier = 3
else:
if reduction_prev:
self.preprocess2 = FactorizedReduce(C_prev_prev, C, act_fun=act_fun)
else:
self.preprocess2 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, act_fun=act_fun)
if int(reduction_prev_prev)+int(reduction_prev)==1:
self.preprocess1 = FactorizedReduce(C_prev_prev_prev, C, act_fun=act_fun)
elif int(reduction_prev_prev)+int(reduction_prev)==2:
self.preprocess1 = F0(C_prev_prev_prev, C, act_fun=act_fun)
else:
self.preprocess1 = ReLUConvBN(C_prev_prev_prev, C, 1, 1, 0, act_fun=act_fun)
if int(reduction_prev_prev_prev)+int(reduction_prev_prev)+int(reduction_prev)==1:
self.preprocess0 = FactorizedReduce(C_prev_prev_prev_prev, C, act_fun=act_fun)
elif int(reduction_prev_prev_prev)+int(reduction_prev_prev)+int(reduction_prev)==2:
self.preprocess0 = F0(C_prev_prev_prev_prev, C, act_fun=act_fun)
elif int(reduction_prev_prev_prev)+int(reduction_prev_prev)+int(reduction_prev)==3:
self.preprocess0 = F1(C_prev_prev_prev_prev, C, act_fun=act_fun)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev_prev_prev, C, 1, 1, 0, act_fun=act_fun)
self.preprocess3 = ReLUConvBN(C_prev, C, 1, 1, 0, act_fun=act_fun)
op_names, indices = zip(*motif.normal)
# print(self.preprocess0)
# print(self.preprocess1)
# print(self.preprocess2)
# print(self.preprocess3)
concat = motif.normal_concat
self._compile(C, op_names, indices, concat, reduction)
def _compile(self, C, op_names, indices, concat, reduction):
assert len(op_names) == len(indices)
# self._steps = len(op_names) // 2
self._concat = concat
self.multiplier = len(concat)
self._ops = nn.ModuleList()
self._ops_back = nn.ModuleList()
back_begin_index = 0
for i, (name, index) in enumerate(zip(op_names, indices)):
# print(name, index)
if '_back' in name:
self.back_connection=True
back_begin_index = i
break
stride = 2 if reduction and index < 2 else 1
op = OPS[name](C, stride, True, act_fun=self.act_fun)
self._ops += [op]
if self.back_connection:
for name, index in zip(op_names[back_begin_index:], indices[back_begin_index:]):
op = OPS[name.replace('_back', '')](
C, 1, True, act_fun=self.act_fun)
self._ops_back += [op]
if self.back_connection:
self._indices_forward = indices[:back_begin_index]
self._indices_backward = indices[back_begin_index:]
else:
self._indices_backward = []
self._indices_forward = indices
self._steps = len(self._indices_forward) // 4
def forward(self, s0, s1, s2, s3, drop_prob):
if self.reduction:
return self.fun(s3)
s0 = self.preprocess0(s0)
s3 = self.preprocess3(s3)
s1 = self.preprocess1(s1)
s2 = self.preprocess2(s2)
# if s1.shape[1]!=s3.shape[1]:
# s1 = nn.Conv2d(s1.shape[1], s3.shape[1], 3, stride=2, padding=1, bias=False)
states = [s0, s1, s2,s3]
for i in range(self._steps):
i1=self._indices_forward[4 * i]
i2=self._indices_forward[4 * i + 1]
i3=self._indices_forward[4 * i + 2]
i4=self._indices_forward[4 * i + 3]
h1 = states[i1]
h2 = states[i2]
h3 = states[i3]
h4 = states[i4]
op1 = self._ops[4 * i]
op2 = self._ops[4 * i + 1]
op3 = self._ops[4 * i + 2]
op4 = self._ops[4 * i + 3]
h1 = op1(h1)
h2 = op2(h2)
h3 = op3(h3)
h4 = op4(h4)
if self.training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
if not isinstance(op3, Identity):
h3 = drop_path(h3, drop_prob)
if not isinstance(op4, Identity):
h4= drop_path(h4, drop_prob)
s = h1 + h2 + h3 + h4
if self.back_connection:
if i != 0:
s_back = self._ops_back[i - 1](s)
states[self._indices_backward[i - 1]] = states[self._indices_backward[i - 1]] + s_back
states += [s]
outputs = torch.cat([states[i] for i in self._concat], dim=1) # N,C,H, W
return outputs
# return self.node(outputs)
@register_model
class NetworkCIFAR(BaseModule):
def __init__(self,
C,
num_classes,
layers,
auxiliary,
motif,
cell_type,
parse_method='darts',
step=5,
node_type='ReLUNode',
**kwargs):
super(NetworkCIFAR, self).__init__(
step=step,
num_classes=num_classes,
**kwargs
)
self.node_type=node_type
if isinstance(node_type, str):
self.act_fun = eval(node_type)
else:
self.act_fun = node_type
self.act_fun = partial(self.act_fun, **kwargs)
self.spike_output = kwargs['spike_output'] if 'spike_output' in kwargs else True
self.dataset = kwargs['dataset']
if self.layer_by_layer:
self.flatten = nn.Flatten(start_dim=1)
else:
self.flatten = nn.Flatten()
self._layers = layers
self.cell_type = cell_type
self._auxiliary = auxiliary
self.drop_path_prob = 0
stem_multiplier = 3
C_curr = stem_multiplier * C
if self.dataset == 'dvsg' or self.dataset == 'dvsc10' or self.dataset == 'NCALTECH101':
self.stem = nn.Sequential(
nn.Conv2d(2 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr),
)
# self.reduce_idx = [
# layers // 4,
# layers // 2,
# 3 * layers // 4
# ]
self.reduce_idx = [1, 3, 5, 7]
else:
self.stem = nn.Sequential(
nn.Conv2d(3 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr),
)
self.reduce_idx = [layers // 4,
layers // 2,
3 * layers // 4]
C_prev_prev_prev = C_curr
C_prev_prev_prev_prev = C_curr
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
reduction_prev_prev = False
reduction_prev_prev_prev = False
for i in range(layers):
if i in self.reduce_idx:
C_curr *= 2
reduction = True
else:
reduction = False
if cell_type==2:
# print(C_prev_prev, C_prev, C_curr)
cell = EvoCell2(motif[i], C_prev_prev, C_prev, C_curr,reduction, reduction_prev,act_fun=self.act_fun)
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
if cell_type==3:
cell = EvoCell3(motif[i], C_prev_prev_prev, C_prev_prev, C_prev, C_curr,reduction, reduction_prev,reduction_prev_prev,act_fun=self.act_fun)
self.cells += [cell]
C_prev_prev_prev = C_prev_prev
reduction_prev_prev = reduction_prev
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
if cell_type==4:
cell = EvoCell4(motif[i], C_prev_prev_prev_prev,C_prev_prev_prev, C_prev_prev, C_prev, C_curr,reduction, reduction_prev,reduction_prev_prev,reduction_prev_prev_prev,act_fun=self.act_fun)
self.cells += [cell]
C_prev_prev_prev_prev = C_prev_prev_prev
C_prev_prev_prev = C_prev_prev
reduction_prev_prev_prev = reduction_prev_prev
reduction_prev_prev = reduction_prev
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
reduction_prev = reduction
self.global_pooling = nn.Sequential(
self.act_fun(), nn.AdaptiveAvgPool2d(1))
if self.spike_output:
self.classifier = nn.Sequential(
nn.Linear(C_prev, 10 * num_classes),
self.act_fun())
self.vote = VotingLayer(10)
else:
self.classifier = nn.Linear(C_prev, num_classes)
self.vote = nn.Identity()
# self.classifier = nn.Linear(C_prev, num_classes)
# self.vote = nn.Identity()
def forward(self, inputs):
logits_aux = None
inputs = self.encoder(inputs)
if not self.layer_by_layer:
outputs = []
output_aux = []
self.reset()
if self.cell_type==2:
for t in range(self.step):
x = inputs[t]
s0 = s1 = self.stem(x)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
out = self.global_pooling(s1)
out = self.classifier(self.flatten(out))
logits = self.vote(out)
outputs.append(logits)
output_aux.append(logits_aux)
return sum(outputs) / len(outputs)
if self.cell_type==3:
for t in range(self.step):
x = inputs[t]
s0 = s1 = s2= self.stem(x)
for i, cell in enumerate(self.cells):
s0, s1, s2 = s1, s2, cell(s0, s1, s2, self.drop_path_prob)
out = self.global_pooling(s2)
out = self.classifier(self.flatten(out))
logits = self.vote(out)
outputs.append(logits)
output_aux.append(logits_aux)
return sum(outputs) / len(outputs)
if self.cell_type==4:
for t in range(self.step):
x = inputs[t]
s0 = s1 = s2= s3=self.stem(x)
for i, cell in enumerate(self.cells):
s0, s1, s2,s3= s1, s2, s3,cell(s0, s1, s2,s3 ,self.drop_path_prob)
out = self.global_pooling(s3)
out = self.classifier(self.flatten(out))
logits = self.vote(out)
outputs.append(logits)
output_aux.append(logits_aux)
return sum(outputs) / len(outputs)
# logits_aux if logits_aux is None else (sum(output_aux) / len(output_aux))
else:
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
if i == 2 * self._layers // 3:
if self._auxiliary and self.training:
logits_aux = self.auxiliary_head(s1)
out = self.global_pooling(s1)
out = self.classifier(self.flatten(out))
out = rearrange(out, '(t b) c -> t b c', t=self.step).mean(0)
logits = self.vote(out)
return logits
@register_model
class NetworkCIFAR_(BaseModule):
def __init__(self,
C,
num_classes,
layers,
glob,
auxiliary,
motif,
parse_method='darts',
step=5,
node_type='ReLUNode',
**kwargs):
super(NetworkCIFAR_, self).__init__(
step=step,
num_classes=num_classes,
**kwargs
)
self.node_type=node_type
if isinstance(node_type, str):
self.act_fun = eval(node_type)
else:
self.act_fun = node_type
self.act_fun = partial(self.act_fun, **kwargs)
self.spike_output = kwargs['spike_output'] if 'spike_output' in kwargs else True
self.dataset = kwargs['dataset']
if self.layer_by_layer:
self.flatten = nn.Flatten(start_dim=1)
else:
self.flatten = nn.Flatten()
self.glob = glob
self._layers = layers
self._auxiliary = auxiliary
self.drop_path_prob = 0
stem_multiplier = 3
C_curr = stem_multiplier * C
if self.dataset == 'dvsg' or self.dataset == 'dvsc10' or self.dataset == 'NCALTECH101':
self.stem = nn.Sequential(
nn.Conv2d(2 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr),
)
# self.reduce_idx = [
# layers // 4,
# layers // 2,
# 3 * layers // 4
# ]
self.reduce_idx = [1, 3, 5, 7]
else:
self.stem = nn.Sequential(
nn.Conv2d(3 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr),
)
self.reduce_idx = [layers // 4,
layers // 2,
3 * layers // 4]
C_prev_prev_prev = C_curr
C_prev_prev_prev_prev = C_curr
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
reduction_prev_prev = False
reduction_prev_prev_prev = False
for i in range(layers):
if i in self.reduce_idx:
C_curr *= 2
reduction = True
else:
reduction = False
cell = EvoCell2(motif[i], C_prev_prev, C_prev, C_curr,reduction, reduction_prev,act_fun=self.act_fun)
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
reduction_prev = reduction
self.global_pooling = nn.Sequential(
self.act_fun(), nn.AdaptiveAvgPool2d(1))
if self.spike_output:
self.classifier = nn.Sequential(
nn.Linear(C_prev, 10 * num_classes),
self.act_fun())
self.vote = VotingLayer(10)
else:
self.classifier = nn.Linear(C_prev, num_classes)
self.vote = nn.Identity()
# self.classifier = nn.Linear(C_prev, num_classes)
# self.vote = nn.Identity()
def forward(self, inputs):
logits_aux = None
inputs = self.encoder(inputs)
if not self.layer_by_layer:
outputs = []
output_aux = []
self.reset()
zzz=[]
kkk=[]
for t in range(self.step):
x = inputs[t]
s0 = s1 = self.stem(x)
# print(s1.shape)
for i, cell in enumerate(self.cells):
if t>0 and i%5==4:
qw = np.where(self.glob[:,int(i//5)]==1)
if qw[0].shape[0]!=0:
for m in qw:
if zzz[m[0]].shape[-1]>s1.shape[-1]:
ks=zzz[m[0]].shape[-1] - (s1.shape[-1]-1)*2+2
if ks<0:
ks=zzz[m[0]].shape[-1] - (s1.shape[-1]-1)+2
bb=nn.Conv2d(zzz[m[0]].shape[1], s1.shape[1], kernel_size=ks, stride=1,padding=1, bias=False).to(zzz[m[0]].device)
else:
bb=nn.Conv2d(zzz[m[0]].shape[1], s1.shape[1], kernel_size=ks, stride=2,padding=1, bias=False).to(zzz[m[0]].device)
aa=bb(zzz[m[0]])
s1=aa+s1
elif zzz[m[0]].shape[-1] t b c', t=self.step).mean(0)
logits = self.vote(out)
return logits
def occumpy_mem(cuda_device):
total, used = os.popen('"/usr/bin/nvidia-smi" --query-gpu=memory.total,memory.used --format=csv,nounits,noheader').read().strip().split("\n")[int(cuda_device)].split(',')
# total, used = check_mem(cuda_device)
total = int(total)
used = int(used)
max_mem = int(total * 1)
block_mem = int((max_mem - used)*0.85)
x = torch.cuda.FloatTensor(256,1024,block_mem)
del x
if __name__ == '__main__':
torch.cuda.set_device('cuda:3')
# occumpy_mem(str(3))
x = torch.rand(128, 3, 32, 32)
glob = np.array([[0,1,0,0],[1,0,1,0],[1,0,0,1],[0,1,0,0]])
# glob = np.array([[0,1],[1,0]])
glob = np.array([[0,1,0,0],[0,0,0,0],[0,0,0,1],[0,0,0,0]])
glob = np.array([[0,1,1,1],[1,0,1,1],[1,1,0,1],[1,1,1,0]])
glob = np.array([[0,1,0],[1,0,0],[1,1,1]])
motifs=[mm2,mm3,mm4,mm5,mm1,mm5,mm3,mm4,mm2,mm1,mm2,mm3,mm4,mm5,mm1]
# motifs=[m1,m2,m3,m1,m5,m4,m1,m2,m3,m1,m5,m4,m5,m4,m1]##3
# motifs=[t2,t3,t4,t5,t1,t5,t3,t4,t2,t1,t2,t3,t4,t5,t1,t5,t3,t4,t2,t1]
# motifs=[t2,t3,t4,t5,t1]
# motifs=[subnet,subnet,subnet]
net=NetworkCIFAR_(C=12,num_classes=10,motif=motifs,layers=len(motifs),auxiliary=True,dataset='cifar10',glob=glob)
# net=NetworkCIFAR(C=12,num_classes=10,motif=motifs,layers=len(motifs),auxiliary=True,dataset='cifar10',cell_type=2)
net=net.cuda()
layers=int(len(motifs)/5)
out=net(x.to('cuda:3'))
print(out.shape)
================================================
FILE: examples/Structure_Evolution/EB-NAS/ebnas.py
================================================
import sys
import numpy as np
import argparse
import time
import timm.models
import yaml
import os
import logging
from random import choice
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from micro_encoding import ops
from braincog.base.node.node import *
from braincog.utils import *
from braincog.base.utils.criterions import *
from braincog.datasets.datasets import *
from braincog.model_zoo.resnet import *
from braincog.model_zoo.convnet import *
# from braincog.model_zoo.reactnet import *
# from braincog.model_zoo.convxnet import *
from braincog.utils import save_feature_map, setup_seed
from braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix
import micro_encoding
import nsganet as engine
from pymop.problem import Problem
import torch
import torch.nn as nn
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from pymoo.optimize import minimize
from tm import train_motifs
from timm.data import create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
from cellmodel import NetworkCIFAR
bits=20
# from ptflops import get_model_complexity_info
# from thop import profile, clever_format
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('')
# The first arg parser parses out only thei --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
devices=[4]
max_gen = 100
parser = argparse.ArgumentParser(description='SNN Training and Evaluating')
# Model parameters
parser.add_argument('--seed', type=int, default=99, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--eval_epochs', type=int, default=1)
parser.add_argument('--bns', action='store_true', default=True)
parser.add_argument('--mid', type=int, default=5)
parser.add_argument('--trainning_epochs', type=int, default=600, metavar='N',help='number of epochs to train (default: 2)')
parser.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--init-channels', type=int, default=36)
parser.add_argument('--layers', type=int, default=2)
parser.add_argument('--output', default='', type=str, metavar='PATH')
parser.add_argument('--spike-rate', action='store_true', default=False)
parser.add_argument('--n_gens', type=int, default=max_gen, help='population size')
parser.add_argument('--bs', type=int, default=100)
parser.add_argument('--n_offspring', type=int, default=60, help='number of offspring created per generation')
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--num-classes', type=int, default=10, metavar='N')
parser.add_argument('--model', default='NetworkCIFAR', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',
help='path to eval checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
# Dataset parameters for static datasets
parser.add_argument('--img-size', type=int, default=224, metavar='N',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='inputs image center crop percent (for validation only)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--strgenome', default='4,0,0,1,1,1,0,1,0,1,0,0,0,1,0,1,0,1,0,0,4,1,1,1,1,1,0,1,1,0,1,1,0,1,1,1,0,1,0,0,4,1,1,0,1,0,0,1,0,1,1,0,1,1,0,0,1,0,1,2,4,1,0,0,1,1,1,0,0,1,0,0,1,0,0,1,0,0,1,1,4,0,1,1,0,1,0,0,1,0,1,0,1,1,1,0,1,0,1,3,4,1,0,1,0,0,1,1,1,0,0,1,0,1,0,0,0,1,0,0,3', type=str)
# Dataloader parameters
parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='inputs batch size for training (default: 128)')
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
help='ratio of validation batch size to training batch size (default: 1)')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.01,
help='weight decay (default: 0.01 for adamw)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=5e-3, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--power', type=int, default=1, help='power')
# Augmentation & regularization parameters ONLY FOR IMAGE NET
parser.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
parser.add_argument('--jsd', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "const")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.8,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix', type=float, default=1.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.0)')
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
parser.add_argument('--newton-maxiter', default=20, type=int,
help='max iterration in newton method')
parser.add_argument('--reset-drop', action='store_true', default=False,
help='whether to reset drop')
parser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],
help='The implementation way of gaussian kernel method, choose from "cuda" and "torch"')
# Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
parser.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
parser.add_argument('--dist-bn', type=str, default='',
help='Distribute BatchNorm stats between node after each epoch ("broadcast", "reduce", or "")')
parser.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.99996,
help='decay factor for model weights moving average (default: 0.9998)')
# Misc
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=8, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=len(devices),
help='Number of GPUS to use')
parser.add_argument('--save-images', action='store_true', default=False,
help='save images of inputs bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--device', type=int, default=devices[0])
# Spike parameters
parser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')
parser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')
parser.add_argument('--temporal-flatten', action='store_true',
help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')
parser.add_argument('--adaptive-node', action='store_true')
parser.add_argument('--critical-loss', action='store_true')
# neuron type
parser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')
parser.add_argument('--act-fun', type=str, default='QGateGrad',
help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')
parser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')
parser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')
parser.add_argument('--requires-thres-grad', action='store_true')
parser.add_argument('--sigmoid-thres', action='store_true')
parser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')
parser.add_argument('--noisy-grad', type=float, default=0.,
help='Add noise to backward, sometime will make higher accuracy (default: 0.)')
parser.add_argument('--spike-output', action='store_true', default=False,
help='Using mem output or spike output (default: False)')
parser.add_argument('--n_groups', type=int, default=1)
# EventData Augmentation
parser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')
parser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')
parser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')
parser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')
parser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')
parser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')
parser.add_argument('--cutmix_noise', type=float, default=0.,
help='Add Pepper noise after mix, sometimes work (default: 0.)')
parser.add_argument('--rand-aug', action='store_true',
help='Rand Augment for Event data (default: False)')
parser.add_argument('--randaug_n', type=int, default=3,
help='Rand Augment times n (default: 3)')
parser.add_argument('--randaug_m', type=int, default=15,
help='Rand Augment times n (default: 15) (0-30)')
parser.add_argument('--train-portion', type=float, default=0.9,
help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')
parser.add_argument('--event-size', default=48, type=int,
help='Event size. Resize event data before process (default: 48)')
parser.add_argument('--layer-by-layer', action='store_true',
help='forward step-by-step or layer-by-layer. '
'Larger Model with layer-by-layer will be faster (default: False)')
parser.add_argument('--node-resume', type=str, default='',
help='resume weights in node for adaptive node. (default: False)')
parser.add_argument('--node-trainable', action='store_true')
# visualize
parser.add_argument('--visualize', action='store_true',
help='Visualize spiking map for each layer, only for validate (default: False)')
parser.add_argument('--tsne', action='store_true')
parser.add_argument('--conf-mat', action='store_true')
# DARTS parameters
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
# parser.add_argument('--arch', default='dvsc10_new_skip19', type=str)
# parser.add_argument('--motif', default='m1', type=str)
parser.add_argument('--parse_method', default='darts', type=str)
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')
# parser.add_argument('--back-connection', action='store_true',default=True)
parser.add_argument('--suffix', type=str, default='',
help='Add an additional suffix to the save path (default: \'\')')
try:
from apex import amp
from apex.parallel import DistributedDataParallel as ApexDDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
class NAS(Problem):
# first define the NAS problem (inherit from pymop)
def __init__(self, args,n_var=20, n_obj=1, n_constr=0, lb=None, ub=None,
init_channels=24, layers=8):
super().__init__(n_var=n_var, n_obj=n_obj, n_constr=n_constr, type_var=np.int64)
self.xl = lb
self.xu = ub
self._lr =args.lr
self._n_evaluated = 0 # keep track of how many architectures are sampled
self.args=args
def _evaluate(self, x, out, *args, **kwargs):
objs = np.full((x.shape[0], self.n_obj), np.nan)
for i in range(x.shape[0]):
arch_id = self._n_evaluated + 1
print('\n')
_logger.info('Network= {}'.format(arch_id))
genome = x[i, :]
arch_dir=os.path.join(self.args.output_dir,i)
if os.path.exists(arch_dir) is False:
os.makedirs(arch_dir,exist_ok = True)
self.args.lr=self._lr
performance,acc=train_motifs(args=self.args,gen=0,arch_dir=arch_dir,genome=genome,_logger=_logger,args_text=args_text,devices=devices,bits=bits)
objs[i, 0] = 1000 - performance
_logger.info('performance= {}'.format(objs[i, 0]))
self._n_evaluated += 1
out["F"] = objs
# if your NAS problem has constraints, use the following line to set constraints
# out["G"] = np.column_stack([g1, g2, g3, g4, g5, g6]) in case 6 constraints
def do_every_generations(algorithm):
# this function will be call every generation
# it has access to the whole algorithm class
gen = algorithm.n_gen
pop_var = algorithm.pop.get("X")
pop_obj = algorithm.pop.get("F")
# report generation info to files
_logger.info("generation = {}".format(gen))
_logger.info("population error: best = {}, mean = {}, "
"median = {}, worst = {}".format(np.min(pop_obj[:, 0]), np.mean(pop_obj[:, 0]),
np.median(pop_obj[:, 0]), np.max(pop_obj[:, 0])))
_logger.info('Best Genome= {}'.format(pop_var[np.argmin(pop_obj[:, 0])]))
def _parse_args():
args_config, remaining = config_parser.parse_known_args()
args = parser.parse_args(remaining)
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
if __name__ == '__main__':
args, args_text = _parse_args()
# args.no_spike_output = args.no_spike_output | args.cut_mix
args.no_spike_output = True
output_dir = ''
if args.local_rank == 0:
output_base = args.output if args.output else './output'
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
# args.model,
# args.dataset,
str(args.layers)+'layers',
str(args.init_channels)+'channels',
'motif'+str(args.mid),
str(args.step)+'steps',
# args.suffix
# str(args.img_size)
])
output_dir = get_outdir(output_base,str(args.dataset),exp_name)
args.output_dir = output_dir
setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))
else:
setup_default_logging()
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1:
_logger.warning(
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
args.num_gpu = 1
# args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
if args.distributed:
args.num_gpu = 1
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
else:
torch.cuda.set_device('cuda:%d' % args.device)
assert args.rank >= 0
if args.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
_logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
# torch.manual_seed(args.seed + args.rank)
setup_seed(args.seed + args.rank)
defalut_lr = args.lr
sn = np.arange(1,6)
np.random.shuffle(sn)
args.subnet=sn
args.layers*=5
len_motifs=args.layers*bits
low = np.zeros(len_motifs)
up=[]
for i in range(0,args.layers*bits,bits):
t=[args.mid]
low[i]=args.mid
t=t+[(ops-1) for j in range(bits-1)]
t[-1]=2*(ops-1)
up.extend(t)
up=np.array(up).reshape(-1,)
kkk = NAS(args,n_var=len_motifs,
n_obj=2, n_constr=0, lb=low, ub=up,
init_channels=args.init_channels, layers=args.layers)
method = engine.nsganet(pop_size=args.pop_size,
n_offsprings=args.n_offspring,
eliminate_duplicates=True)
kres=minimize(kkk,
method,
callback=do_every_generations,
termination=('n_gen', args.n_gens))
================================================
FILE: examples/Structure_Evolution/EB-NAS/micro_encoding.py
================================================
# NASNet Search Space https://arxiv.org/pdf/1707.07012.pdf
# code modified from DARTS https://github.com/quark0/darts
import numpy as np
from collections import namedtuple
from random import choice
from numpy.linalg import matrix_rank
import itertools
import torch
# from models.micro_models import NetworkCIFAR as Network
import motifs
# Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
# Genotype_norm = namedtuple('Genotype', 'normal normal_concat')
# Genotype_redu = namedtuple('Genotype', 'reduce reduce_concat')
Genotype = namedtuple('Genotype', 'normal normal_concat')
# what you want to search should be defined here and in micro_operations
PRIMITIVES = [
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5',
'sep_conv_7x7',
'conv_7x1_1x7',
]
OPERATIONS_back = [
# 'max_pool_3x3_p_back',
# 'avg_pool_3x3_p_back',
'conv_3x3_p_back',
'conv_5x5_p_back',
# 'avg_pool_3x3_n_back',
'conv_3x3_n_back',
'conv_5x5_n_back',
# 'sep_conv_3x3_p_back',
# 'sep_conv_5x5_p_back',
# 'dil_conv_3x3_p_back',
# 'dil_conv_5x5_p_back',
# 'def_conv_3x3_p_back',
# 'def_conv_5x5_p_back',
]
OPERATIONS_p = [
# 'max_pool_3x3_p',
# 'avg_pool_3x3_p',
'conv_3x3_p',
'conv_5x5_p',
# 'sep_conv_3x3_p',
# 'sep_conv_5x5_p',
# 'dil_conv_3x3_p',
# 'dil_conv_5x5_p',
# 'def_conv_3x3_p',
# 'def_conv_5x5_p',
]
ops=len(OPERATIONS_p)
OPERATIONS_n = [
# 'max_pool_3x3_n',
# 'avg_pool_3x3_n',
'conv_3x3_n',
'conv_5x5_n',
# 'sep_conv_3x3_n',
# 'sep_conv_5x5_n',
# 'dil_conv_3x3_n',
# 'dil_conv_5x5_n',
# 'def_conv_3x3_n',
# 'def_conv_5x5_n',
# 'transformer',
]
mids=(3,3,2,2,1)
mids=(1,2,3,4,5)
ms=len(mids)
permutations = list({}.fromkeys(list(itertools.permutations(mids))).keys())
motifdict_c = dict(enumerate(permutations, 1))
motifdict = dict(zip(motifdict_c.values(),motifdict_c.keys()))
def convert_cell(cell_bit_string):
# convert cell bit-string to genome
tmp = [cell_bit_string[i:i + 2] for i in range(0, len(cell_bit_string), 2)]
return [tmp[i:i + 2] for i in range(0, len(tmp), 2)]
def filt(sn):
for i in range(sn.shape[0]):
if sn[i]<1:
sn[i]=1
if sn[i]>120:
sn[i]=120
return sn
# def convert(bit_string):
# # convert network bit-string (norm_cell + redu_cell) to genome
# norm_gene = convert_cell(bit_string[:len(bit_string)//2])
# redu_gene = convert_cell(bit_string[len(bit_string)//2:])
# return [norm_gene, redu_gene]
def shuffle_along_axis(a, axis):
idx = np.random.rand(*a.shape).argsort(axis=axis)
return np.take_along_axis(a,idx,axis=axis)
def sample(pops, layers, bits , ):
sn = np.tile(mids,pops*layers).reshape(-1,ms)
sn=shuffle_along_axis(sn,axis=1)
sn=sn.reshape(-1,1)
# bigmotifs=(np.random.rand(pops*layers*ms,bits-1)<0.5).astype(int).reshape(-1,bits-1)
bigmotifs=(np.random.rand(pops*layers,bits-1)<0.5).astype(int).reshape(-1,bits-1)
# bigmotifs=np.repeat(bigmotifs,ms,axis=0)
sn=sn.reshape(pops*layers,-1)
# glob=np.array([2 for i in range(pops)])[:,np.newaxis]
genome=np.concatenate((sn, bigmotifs),axis=1).reshape(pops,-1)
glob = (np.random.rand(pops,layers*layers)<0.5).astype(int).reshape(pops,layers,layers)
# for i in range(layers):
# glob[:,i,i]=0
genome=np.concatenate((genome, glob.reshape(pops,layers*layers)),axis=1).reshape(pops,-1)
return genome,sn,bigmotifs,ms,glob.reshape(pops,layers*layers)
def reencode(sn,bigmotifs,pops):
sn=sn.reshape(-1,ms)
mnumber=np.array([motifdict[tuple(x)] for x in sn]).reshape(pops,-1)
mgenome=np.concatenate((bigmotifs.reshape(pops,-1), mnumber),axis=1).reshape(pops,-1)
return mgenome
def convert(mgenome, layers, bits):
bigmotifs = mgenome[:,0:-layers].reshape(-1,bits-1)
sn=mgenome[:,-layers:].reshape(-1)
sn=filt(sn)
ssn=np.array([list(motifdict_c[sn[i]]) for i in range(sn.shape[0])]).reshape(-1,1)
result=np.concatenate((ssn,np.repeat(bigmotifs,ms,axis=0)),axis=1).reshape(-1,layers*bits*ms)
return np.c_[result,np.ones((result.shape[0],1))*2]
def convert_single(mgenome, layers, bits):
bigmotifs = mgenome[0:-layers].reshape(-1,bits-1)
sn=mgenome[-layers:].reshape(-1)
sn=filt(sn)
ssn=np.array([list(motifdict_c[sn[i]]) for i in range(sn.shape[0])]).reshape(-1,1)
result=list(np.concatenate((ssn,np.repeat(bigmotifs,ms,axis=0)),axis=1).reshape(layers*bits*ms))
return np.array(result)
def c_single(mgenome, layers, bits):
glob = mgenome[-layers*layers:]
mgenome = mgenome[:-layers*layers]
big=mgenome.reshape(layers,-1)
bigmotifs = big[:,ms:]
sn=big[:,0:ms]
result = np.concatenate((sn.reshape(-1,1),np.repeat(bigmotifs,ms,axis=0)),axis=1).reshape(-1,).tolist()
result.append(glob.reshape(layers,layers))
return result
# def decode_cell(genome, norm=True):
# cell, cell_concat = [], list(range(2, len(genome)+2))
# for block in genome:
# for unit in block:
# cell.append((PRIMITIVES[unit[0]], unit[1]))
# if unit[1] in cell_concat:
# cell_concat.remove(unit[1])
# if norm:
# return Genotype_norm(normal=cell, normal_concat=cell_concat)
# else:
# return Genotype_redu(reduce=cell, reduce_concat=cell_concat)
def decode(genome):
# decodes genome to architecture
normal_cell = genome[0]
reduce_cell = genome[1]
normal, normal_concat = [], list(range(2, len(normal_cell)+2))
reduce, reduce_concat = [], list(range(2, len(reduce_cell)+2))
for block in normal_cell:
for unit in block:
normal.append((PRIMITIVES[unit[0]], unit[1]))
if unit[1] in normal_concat:
normal_concat.remove(unit[1])
for block in reduce_cell:
for unit in block:
reduce.append((PRIMITIVES[unit[0]], unit[1]))
if unit[1] in reduce_concat:
reduce_concat.remove(unit[1])
return Genotype(
normal=normal, normal_concat=normal_concat,
reduce=reduce, reduce_concat=reduce_concat
)
def decode_motif(layers,bits,genome):
# decodes genome to architecture
motif_list=[]
motif_ids=[]
for b in range(0,layers*bits,bits):
motif_id='mm'+str(genome[b])
motif_ids.append(genome[b])
normalcell=eval('motifs.%s' % motif_id)
newnormal=[]
for i in range(0,len(normalcell.normal)):
op=normalcell.normal[i]
if 'skip' in op[0]:
newnormal.append(op)
continue
elif 'back' in op[0]:
newnormal.append((OPERATIONS_back[genome[b+1+len(normalcell.normal)-1]],op[1]))
continue
elif '_n' in op[0]:
newnormal.append((OPERATIONS_n[genome[b+1+i]],op[1]))
continue
elif '_p' in op[0]:
newnormal.append((OPERATIONS_p[genome[b+1+i]],op[1]))
continue
m=Genotype(normal=newnormal, normal_concat=normalcell.normal_concat,)
motif_list.append(m)
return motif_list,motif_ids
def compare_cell(cell_string1, cell_string2):
cell_genome1 = convert_cell(cell_string1)
cell_genome2 = convert_cell(cell_string2)
cell1, cell2 = cell_genome1[:], cell_genome2[:]
for block1 in cell1:
for block2 in cell2:
if block1 == block2 or block1 == block2[::-1]:
cell2.remove(block2)
break
if len(cell2) > 0:
return False
else:
return True
def compare(string1, string2):
if compare_cell(string1[:len(string1)//2],
string2[:len(string2)//2]):
if compare_cell(string1[len(string1)//2:],
string2[len(string2)//2:]):
return True
return False
# def debug():
# # design to debug the encoding scheme
# seed = 0
# np.random.seed(seed)
# budget = 2000
# B, n_ops, n_cell = 5, 7, 2
# networks = []
# design_id = 1
# while len(networks) < budget:
# bit_string = []
# for c in range(n_cell):
# for b in range(B):
# bit_string += [np.random.randint(n_ops),
# np.random.randint(b + 2),
# np.random.randint(n_ops),
# np.random.randint(b + 2)
# ]
# genome = convert(bit_string)
# # check against evaluated networks in case of duplicates
# doTrain = True
# for network in networks:
# if compare(genome, network):
# doTrain = False
# break
# if doTrain:
# genotype = decode(genome)
# model = Network(16, 10, 8, False, genotype)
# model.drop_path_prob = 0.0
# data = torch.randn(1, 3, 32, 32)
# output, output_aux = model(torch.autograd.Variable(data))
# networks.append(genome)
# design_id += 1
# print(design_id)
if __name__ == "__main__":
# debug()
# genome1 = [[[[3, 0], [3, 1]], [[3, 0], [3, 1]],
# [[3, 1], [2, 0]], [[2, 0], [5, 2]]],
# [[[0, 0], [0, 1]], [[2, 2], [0, 1]],
# [[0, 0], [2, 2]], [[2, 2], [0, 1]]]]
# genome2 = [[[[3, 1], [3, 0]], [[3, 1], [3, 0]],
# [[3, 1], [2, 0]], [[2, 0], [5, 2]]],
# [[[0, 1], [0, 0]], [[2, 2], [0, 1]],
# [[0, 0], [2, 2]], [[2, 2], [0, 0]]]]
#
# print(compare(genome1, genome2))
# print(genome1)
# print(genome2)
# bit_string1 = [3,1,3,0,3,1,3,0,3,1,2,0,2,0,5,2,0,0,0,1,2,2,0,1,0,0,2,2,2,2,0,1]
# bit_string2 = [3, 0, 3, 1, 3, 0, 3, 1, 3, 1, 2, 0, 2, 0, 5, 2,
# 0, 0, 0, 1, 2, 2, 0, 1, 0, 0, 2, 2, 2, 2, 0, 1]
# # print(convert(bit_string1))
# print(compare(bit_string1, bit_string2))
# print(decode(convert(bit_string)))
cell_bit_string = [3, 0, 3, 1, 3, 0, 3, 1, 3, 1, 2, 0, 2, 0, 5, 2]
# print(decode_cell(convert_cell(cell_bit_string), norm=False))
genome,sn,bigmotifs,ms=sample(195,2,20)
mgeno=reencode(sn,bigmotifs,195)
convert(mgeno,2,20)
================================================
FILE: examples/Structure_Evolution/EB-NAS/motifs.py
================================================
from collections import namedtuple
import torch
Genotype = namedtuple('Genotype', 'normal normal_concat')
"""
Operation sets
"""
PRIMITIVES = [
'skip_connect',
# 'max_pool_3x3',
# 'avg_pool_3x3',
# 'def_conv_3x3',
# 'def_conv_5x5',
# 'sep_conv_3x3',
# 'sep_conv_5x5',
# 'dil_conv_3x3',
# 'dil_conv_5x5',
# 'max_pool_3x3_p',
# 'avg_pool_3x3_p',
'conv_3x3_p',
'conv_5x5_p',
# 'skip_connect_p',
# 'sep_conv_3x3_p',
# 'sep_conv_5x5_p',
# 'dil_conv_3x3_p',
# 'dil_conv_5x5_p',
# 'def_conv_3x3_p',
# 'def_conv_5x5_p',n
# 'max_pool_3x3_n',
# 'avg_pool_3x3_n',
'conv_3x3_n',
'conv_5x5_n',
# 'skip_connect_n',
# 'sep_conv_3x3_n',
# 'sep_conv_5x5_n',
# 'dil_conv_3x3_n',
# 'dil_conv_5x5_n',
# 'def_conv_3x3_n',
# 'def_conv_5x5_n',
# 'transformer',
]
m0=Genotype(
normal=[
('skip', 0), ('skip', 1),('skip', 2),
],
normal_concat=range(3, 4)
)
mm0=Genotype(
normal=[
('skip', 0), ('skip', 1),('skip', 2),
],
normal_concat=range(2, 3)
)
mm1=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1),
('skip_connect', 0), ('conv_5x5_p', 2),
],
normal_concat=range(2, 4)
)
mm2=Genotype(
normal=[
('conv_5x5_p', 0), ('conv_5x5_p', 1),
('skip_connect', 0), ('conv_5x5_n', 2),
('conv_5x5_p', 2), ('conv_3x3_n', 3),
],
normal_concat=range(2, 5)
)
mm4=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1),#2
('conv_3x3_p', 0), ('conv_3x3_p', 1),#3
('conv_5x5_p', 2), ('conv_5x5_p', 3),#4
('skip_connect', 0), ('conv_3x3_p', 4),#5
('skip_connect', 0), ('conv_3x3_p', 4),#6
],
normal_concat=range(2, 7)
)
mm3=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1),#2
('skip_connect', 0), ('conv_5x5_n', 2),#3
('skip_connect', 0), ('conv_5x5_p', 3),#4
('skip_connect_back', 2),#3
('conv_3x3_p_back', 3),#4
],
normal_concat=range(2, 5)
)
mm5=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1),#2
('skip_connect', 0), ('conv_5x5_p', 2),#3
('skip_connect_back', 2),#3
],
normal_concat=range(2, 4)
)
m1=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_5x5_p', 2), #B3
('skip', 0), ('conv_5x5_p', 3), ('skip', 1), #C4
],
normal_concat=range(3, 5)
)
m2=Genotype(
normal=[
('conv_5x5_p', 0), ('conv_5x5_p', 1),('conv_5x5_p', 2), #B3
('skip', 0), ('conv_5x5_n', 3), ('skip', 1),#C4
('conv_5x5_p', 3), ('conv_3x3_n', 4), ('skip', 1), #D5
],
normal_concat=range(3, 6)
)
m4=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1),('conv_5x5_p', 2), #3
('conv_3x3_p', 0), ('conv_3x3_p', 1),('conv_5x5_p', 2), #4
('skip', 0), ('conv_5x5_p', 3), ('conv_5x5_p', 4), #5
('skip', 0), ('conv_3x3_p', 3),('conv_3x3_n', 5),#6
('skip', 0), ('conv_3x3_p', 4),('conv_3x3_n', 5),#7
],
normal_concat=range(3, 8)
)
m3=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_3x3_p', 2), #3
('skip', 0), ('conv_5x5_p', 3),('skip', 1), #4
('skip', 0), ('conv_5x5_p', 3), ('skip', 1), #5
('conv_3x3_n_back', 3),#4
('skip_back', 2),#5
],
normal_concat=range(3, 6)
)
m5=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_5x5_p', 2),#3
('skip', 0),('skip', 1), ('conv_5x5_n', 3), #4
('skip_connect_back', 3),#4
],
normal_concat=range(3, 5)
)
t1=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_5x5_p', 2), ('conv_5x5_p', 3), #4
('skip', 0), ('conv_5x5_p', 4), ('skip', 1), ('skip', 2), #5
('skip', 0), ('conv_5x5_p', 5), ('skip', 1), ('skip', 2), #6
('skip', 0), ('conv_5x5_p', 5), ('skip', 1), ('skip', 2), #7
],
normal_concat=range(4, 8)
)
t2=Genotype(
normal=[
('conv_5x5_p', 0), ('conv_5x5_p', 1),('conv_5x5_p', 2), ('conv_5x5_p', 3), #4
('skip', 0), ('conv_5x5_n', 4), ('skip', 1),('skip', 2),#5
('conv_5x5_p', 4), ('conv_3x3_n', 5), ('skip', 1),('skip', 2), #6
],
normal_concat=range(4, 7)
)
t4=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1),('conv_5x5_p', 2), ('conv_5x5_p', 3), #4
('conv_5x5_p', 0), ('skip', 1),('conv_5x5_n', 4), ('skip', 3), #5
('skip', 0), ('conv_5x5_p', 3), ('conv_5x5_n', 4), ('skip', 2),#6
],
normal_concat=range(4, 7)
)
t3=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_3x3_p', 2), ('conv_3x3_p', 3),#4
('skip', 0), ('skip', 2),('skip', 1), ('skip', 3),('conv_3x3_p', 4),#5
('skip', 0), ('conv_5x5_p', 4), ('skip', 1),('skip', 2), #6
('conv_3x3_n_back', 4),#5
('skip_back', 4),#6
],
normal_concat=range(4, 7)
)
t5=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_5x5_p', 2),('conv_5x5_p', 3),#4
('skip', 0),('skip', 1), ('skip', 2), ('conv_5x5_n', 4), #5
('skip', 0),('skip', 1), ('conv_5x5_n', 4),('conv_5x5_n', 5), #6
('skip', 0),('skip', 1), ('conv_5x5_n', 4),('conv_5x5_n', 5), #7
('conv_3x3_n_back', 4),#5
('skip_back', 4),#6
('skip_back', 5),#7
],
normal_concat=range(4, 8)
)
================================================
FILE: examples/Structure_Evolution/EB-NAS/nsganet.py
================================================
import numpy as np
from pymoo.algorithms.genetic_algorithm import GeneticAlgorithm
from pymoo.docs import parse_doc_string
from pymoo.model.individual import Individual
from pymoo.model.survival import Survival
from pymoo.operators.crossover.point_crossover import PointCrossover
from pymoo.operators.mutation.polynomial_mutation import PolynomialMutation
from pymoo.operators.sampling.random_sampling import RandomSampling
from pymoo.operators.selection.tournament_selection import compare, TournamentSelection
from pymoo.util.display import disp_multi_objective
from pymoo.util.dominator import Dominator
from pymoo.util.non_dominated_sorting import NonDominatedSorting
from pymoo.util.randomized_argsort import randomized_argsort
# =========================================================================================================
# Implementation
# based on nsga2 from https://github.com/msu-coinlab/pymoo
# =========================================================================================================
class NSGANet(GeneticAlgorithm):
def __init__(self, **kwargs):
kwargs['individual'] = Individual(rank=np.inf, crowding=-1)
super().__init__(**kwargs)
self.tournament_type = 'comp_by_dom_and_crowding'
self.func_display_attrs = disp_multi_objective
# ---------------------------------------------------------------------------------------------------------
# Binary Tournament Selection Function
# ---------------------------------------------------------------------------------------------------------
def binary_tournament(pop, P, algorithm, **kwargs):
if P.shape[1] != 2:
raise ValueError("Only implemented for binary tournament!")
tournament_type = algorithm.tournament_type
S = np.full(P.shape[0], np.nan)
for i in range(P.shape[0]):
a, b = P[i, 0], P[i, 1]
# if at least one solution is infeasible
if pop[a].CV > 0.0 or pop[b].CV > 0.0:
S[i] = compare(a, pop[a].CV, b, pop[b].CV, method='smaller_is_better', return_random_if_equal=True)
# both solutions are feasible
else:
if tournament_type == 'comp_by_dom_and_crowding':
rel = Dominator.get_relation(pop[a].F, pop[b].F)
if rel == 1:
S[i] = a
elif rel == -1:
S[i] = b
elif tournament_type == 'comp_by_rank_and_crowding':
S[i] = compare(a, pop[a].rank, b, pop[b].rank,
method='smaller_is_better')
else:
raise Exception("Unknown tournament type.")
# if rank or domination relation didn't make a decision compare by crowding
if np.isnan(S[i]):
S[i] = compare(a, pop[a].get("crowding"), b, pop[b].get("crowding"),
method='larger_is_better', return_random_if_equal=True)
return S[:, None].astype(np.int)
# ---------------------------------------------------------------------------------------------------------
# Survival Selection
# ---------------------------------------------------------------------------------------------------------
class RankAndCrowdingSurvival(Survival):
def __init__(self) -> None:
super().__init__(True)
def _do(self, pop, n_survive, D=None, **kwargs):
# get the objective space values and objects
F = pop.get("F")
# the final indices of surviving individuals
survivors = []
# do the non-dominated sorting until splitting front
fronts = NonDominatedSorting().do(F, n_stop_if_ranked=n_survive)
for k, front in enumerate(fronts):
# calculate the crowding distance of the front
crowding_of_front = calc_crowding_distance(F[front, :])
# save rank and crowding in the individual class
for j, i in enumerate(front):
pop[i].set("rank", k)
pop[i].set("crowding", crowding_of_front[j])
# current front sorted by crowding distance if splitting
if len(survivors) + len(front) > n_survive:
I = randomized_argsort(crowding_of_front, order='descending', method='numpy')
I = I[:(n_survive - len(survivors))]
# otherwise take the whole front unsorted
else:
I = np.arange(len(front))
# extend the survivors by all or selected individuals
survivors.extend(front[I])
return pop[survivors]
def calc_crowding_distance(F):
infinity = 1e+14
n_points = F.shape[0]
n_obj = F.shape[1]
if n_points <= 2:
return np.full(n_points, infinity)
else:
# sort each column and get index
I = np.argsort(F, axis=0, kind='mergesort')
# now really sort the whole array
F = F[I, np.arange(n_obj)]
# get the distance to the last element in sorted list and replace zeros with actual values
dist = np.concatenate([F, np.full((1, n_obj), np.inf)]) \
- np.concatenate([np.full((1, n_obj), -np.inf), F])
index_dist_is_zero = np.where(dist == 0)
dist_to_last = np.copy(dist)
for i, j in zip(*index_dist_is_zero):
dist_to_last[i, j] = dist_to_last[i - 1, j]
dist_to_next = np.copy(dist)
for i, j in reversed(list(zip(*index_dist_is_zero))):
dist_to_next[i, j] = dist_to_next[i + 1, j]
# normalize all the distances
norm = np.max(F, axis=0) - np.min(F, axis=0)
norm[norm == 0] = np.nan
dist_to_last, dist_to_next = dist_to_last[:-1] / norm, dist_to_next[1:] / norm
# if we divided by zero because all values in one columns are equal replace by none
dist_to_last[np.isnan(dist_to_last)] = 0.0
dist_to_next[np.isnan(dist_to_next)] = 0.0
# sum up the distance to next and last and norm by objectives - also reorder from sorted list
J = np.argsort(I, axis=0)
crowding = np.sum(dist_to_last[J, np.arange(n_obj)] + dist_to_next[J, np.arange(n_obj)], axis=1) / n_obj
# replace infinity with a large number
crowding[np.isinf(crowding)] = infinity
return crowding
# =========================================================================================================
# Interface
# =========================================================================================================
def nsganet(
pop_size=100,
sampling=RandomSampling(var_type=np.int),
selection=TournamentSelection(func_comp=binary_tournament),
crossover=PointCrossover(n_points=2),
mutation=PolynomialMutation(eta=3, var_type=np.int),
eliminate_duplicates=True,
n_offsprings=None,
**kwargs):
"""
Parameters
----------
pop_size : {pop_size}
sampling : {sampling}
selection : {selection}
crossover : {crossover}
mutation : {mutation}
eliminate_duplicates : {eliminate_duplicates}
n_offsprings : {n_offsprings}
Returns
-------
nsganet : :class:`~pymoo.model.algorithm.Algorithm`
Returns an NSGANet algorithm object.
"""
return NSGANet(pop_size=pop_size,
sampling=sampling,
selection=selection,
crossover=crossover,
mutation=mutation,
survival=RankAndCrowdingSurvival(),
eliminate_duplicates=eliminate_duplicates,
n_offsprings=n_offsprings,
**kwargs)
parse_doc_string(nsganet)
================================================
FILE: examples/Structure_Evolution/EB-NAS/operations.py
================================================
import numpy as np
import torch
import torch.nn as nn
from torch.nn import *
import torch.nn.functional as F
from torch import einsum
from einops import rearrange
# from braincog.model_zoo.base_module import DeformConvPack
from braincog.model_zoo.base_module import BaseLinearModule
# from mmcv.ops import ModulatedDeformConv2dPack
def si_relu(x, positive):
if positive == 1:
return torch.where(x > 0., x, torch.zeros_like(x))
elif positive == 0:
return x
elif positive == -1:
return torch.where(x < 0., x, torch.zeros_like(x))
else:
raise ValueError
class SiReLU(nn.Module):
def __init__(self, positive=0):
super().__init__()
self.positive = positive
def forward(self, x):
return si_relu(x, self.positive)
def weight_init(m):
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_normal(m.weight.data, gain=0.1)
torch.nn.init.constant(m.bias.data, 0.)
OPS_Mlp = {
'mlp': lambda C, act_fun:
SiMLP(C, C, act_fun=act_fun, positive=0),
'mlp_p': lambda C, act_fun:
SiMLP(C, C, act_fun=act_fun, positive=1),
'mlp_n': lambda C, act_fun:
SiMLP(C, C, act_fun=act_fun, positive=-1),
'skip_connect': lambda C, act_fun:
Identity(positive=0),
'skip_connect_p': lambda C, act_fun:
Identity(positive=1),
'skip_connect_n': lambda C, act_fun:
Identity(positive=-1),
}
OPS = {
'avg_pool_3x3': lambda C, stride, affine, act_fun: nn.AvgPool2d(3, stride=stride, padding=1,
count_include_pad=False),
'conv_3x3': lambda C, stride, affine, act_fun:
ReLUConvBN(C_in=C, C_out=C, kernel_size=3, padding=1, stride=stride, affine=affine, act_fun=act_fun, positive=0),
'conv_5x5': lambda C, stride, affine, act_fun:
ReLUConvBN(C_in=C, C_out=C, kernel_size=5, padding=2, stride=stride, affine=affine, act_fun=act_fun, positive=0),
'max_pool_3x3': lambda C, stride, affine, act_fun: nn.MaxPool2d(3, stride=stride, padding=1),
'skip_connect': lambda C, stride, affine, act_fun:
Identity(positive=0) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun),
'sep_conv_3x3': lambda C, stride, affine, act_fun:
SepConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=0),
'sep_conv_5x5': lambda C, stride, affine, act_fun:
SepConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=0),
'sep_conv_7x7': lambda C, stride, affine, act_fun:
SepConv(C, C, 7, stride, 3, affine=affine, act_fun=act_fun, positive=0),
'dil_conv_3x3': lambda C, stride, affine, act_fun:
DilConv(C, C, 3, stride, 2, 2, affine=affine, act_fun=act_fun, positive=0),
'dil_conv_5x5': lambda C, stride, affine, act_fun:
DilConv(C, C, 5, stride, 4, 2, affine=affine, act_fun=act_fun, positive=0),
'def_conv_3x3': lambda C, stride, affine, act_fun:
DeformConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=0),
'def_conv_5x5': lambda C, stride, affine, act_fun:
DeformConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=0),
'avg_pool_3x3_p': lambda C, stride, affine, act_fun: nn.Sequential(
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
SiReLU(positive=1)
),
'max_pool_3x3_p': lambda C, stride, affine, act_fun: nn.Sequential(
nn.MaxPool2d(3, stride=stride, padding=1),
SiReLU(positive=1)
),
'conv_3x3_p': lambda C, stride, affine, act_fun:
ReLUConvBN(C_in=C, C_out=C, kernel_size=3, padding=1, stride=stride, affine=affine, act_fun=act_fun, positive=1),
'conv_5x5_p': lambda C, stride, affine, act_fun:
ReLUConvBN(C_in=C, C_out=C, kernel_size=5, padding=2, stride=stride, affine=affine, act_fun=act_fun, positive=1),
'skip_connect_p': lambda C, stride, affine, act_fun:
Identity(positive=1) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun, positive=1),
'sep_conv_3x3_p': lambda C, stride, affine, act_fun:
SepConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=1),
'sep_conv_5x5_p': lambda C, stride, affine, act_fun:
SepConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=1),
'sep_conv_7x7_p': lambda C, stride, affine, act_fun:
SepConv(C, C, 7, stride, 3, affine=affine, act_fun=act_fun, positive=1),
'dil_conv_3x3_p': lambda C, stride, affine, act_fun:
DilConv(C, C, 3, stride, 2, 2, affine=affine, act_fun=act_fun, positive=1),
'dil_conv_5x5_p': lambda C, stride, affine, act_fun:
DilConv(C, C, 5, stride, 4, 2, affine=affine, act_fun=act_fun, positive=1),
'def_conv_3x3_p': lambda C, stride, affine, act_fun:
DeformConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=1),
'def_conv_5x5_p': lambda C, stride, affine, act_fun:
DeformConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=1),
'avg_pool_3x3_n': lambda C, stride, affine, act_fun: nn.Sequential(
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
SiReLU(positive=-1)
),
'max_pool_3x3_n': lambda C, stride, affine, act_fun: nn.Sequential(
nn.MaxPool2d(3, stride=stride, padding=1),
SiReLU(positive=-1)
),
'conv_3x3_n': lambda C, stride, affine, act_fun:
ReLUConvBN(C_in=C, C_out=C, kernel_size=3, padding=1, stride=stride, affine=affine, act_fun=act_fun, positive=-1),
'conv_5x5_n': lambda C, stride, affine, act_fun:
ReLUConvBN(C_in=C, C_out=C, kernel_size=5, padding=2, stride=stride, affine=affine, act_fun=act_fun, positive=-1),
'skip_connect_n': lambda C, stride, affine, act_fun:
Identity(positive=-1) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun, positive=-1),
'sep_conv_3x3_n': lambda C, stride, affine, act_fun:
SepConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=-1),
'sep_conv_5x5_n': lambda C, stride, affine, act_fun:
SepConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=-1),
'sep_conv_7x7_n': lambda C, stride, affine, act_fun:
SepConv(C, C, 7, stride, 3, affine=affine, act_fun=act_fun, positive=-1),
'dil_conv_3x3_n': lambda C, stride, affine, act_fun:
DilConv(C, C, 3, stride, 2, 2, affine=affine, act_fun=act_fun, positive=-1),
'dil_conv_5x5_n': lambda C, stride, affine, act_fun:
DilConv(C, C, 5, stride, 4, 2, affine=affine, act_fun=act_fun, positive=-1),
'def_conv_3x3_n': lambda C, stride, affine, act_fun:
DeformConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=-1),
'def_conv_5x5_n': lambda C, stride, affine, act_fun:
DeformConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=-1),
'conv_7x1_1x7': lambda C, stride, affine, act_fun: nn.Sequential(
# nn.ReLU(inplace=False),
act_fun(),
nn.Conv2d(C, C, (1, 7), stride=(1, stride),
padding=(0, 3), bias=False),
nn.Conv2d(C, C, (7, 1), stride=(stride, 1),
padding=(3, 0), bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'skip': lambda C, stride, affine, act_fun:
Zero(stride) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun, positive=1),
'transformer': lambda C, stride, affine, act_fun:
FactorizedReduce(
C, C, affine=affine, act_fun=act_fun) if stride != 1 else TransformerEncoderLayer(C),
}
class SiMLP(nn.Module):
def __init__(self, c_in, c_out, act_fun=nn.ReLU, positive=0, *args, **kwargs):
super(SiMLP, self).__init__()
self.op = nn.Sequential(
nn.Linear(c_in, c_out, bias=True),
act_fun()
)
self.positive = positive
def forward(self, x):
out = self.op(si_relu(x, self.positive))
return out
class DilConv(nn.Module):
"""
Dilation Convolution : ReLU -> DilConv -> Conv2d -> BatchNorm2d
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True, act_fun=nn.ReLU, positive=0):
super(DilConv, self).__init__()
self.op = nn.Sequential(
# nn.ReLU(inplace=False),
act_fun(),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
self.positive = positive
# if positive == -1:
# weight_init(self.op)
def forward(self, x):
out = self.op(x)
return si_relu(out, self.positive)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, act_fun=nn.ReLU, positive=0):
super(SepConv, self).__init__()
self.op = nn.Sequential(
# nn.ReLU(inplace=False),
act_fun(),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size,
stride=stride, padding=padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size,
stride=1, padding=padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
self.positive = positive
# if positive == -1:
# weight_init(self.op)
def forward(self, x):
out = self.op(x)
return si_relu(out, self.positive)
class Identity(nn.Module):
def __init__(self, positive=0):
super(Identity, self).__init__()
self.positive = positive
def forward(self, x):
return si_relu(x, self.positive)
class Zero(nn.Module):
def __init__(self, stride):
super(Zero, self).__init__()
self.stride = stride
def forward(self, x):
if self.stride == 1:
return x.mul(0.)
return x[:, :, ::self.stride, ::self.stride].mul(0.) # N * C * W * H
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=True, act_fun=nn.ReLU, positive=0):
super(FactorizedReduce, self).__init__()
assert C_out % 2 == 0
# self.relu = nn.ReLU(inplace=False)
self.activation = act_fun()
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
self.positive = positive
# if positive == -1:
# weight_init(self.op)
def forward(self, x):
# x = self.relu(x)
x = self.activation(x)
out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
out = si_relu(out, self.positive)
return out
class F0(nn.Module):
def __init__(self, C_in, C_out, affine=True, act_fun=nn.ReLU, positive=0):
super(F0, self).__init__()
assert C_out % 2 == 0
# self.relu = nn.ReLU(inplace=False)
self.activation = act_fun()
self.op=nn.Conv2d(C_out, C_out, 3, stride=2, padding=1, bias=False)
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
self.positive = positive
# if positive == -1:
# weight_init(self.op)
def forward(self, x):
# x = self.relu(x)
x = self.activation(x)
out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
out = si_relu(out, self.positive)
out=self.op(out)
return out
class F1(nn.Module):
def __init__(self, C_in, C_out, affine=True, act_fun=nn.ReLU, positive=0):
super(F1, self).__init__()
assert C_out % 2 == 0
# self.relu = nn.ReLU(inplace=False)
self.activation = act_fun()
self.op=nn.Conv2d(C_out, C_out, 3, stride=2, padding=1, bias=False)
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
self.positive = positive
# if positive == -1:
# weight_init(self.op)
def forward(self, x):
# x = self.relu(x)
x = self.activation(x)
out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
out = si_relu(out, self.positive)
out=self.op(out)
return out
class ReLUConvBN(nn.Module):
"""
ReLu -> Conv2d -> BatchNorm2d
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, act_fun=nn.ReLU, positive=0):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
# nn.ReLU(inplace=False),
act_fun(),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride,
padding=padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
self.positive = positive
# if positive == -1:
# weight_init(self.op)
def forward(self, x):
out = self.op(x)
return si_relu(out, self.positive)
# class DeformConv(nn.Module):
# def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, act_fun=nn.ReLU, positive=0):
# super(DeformConv, self).__init__()
# self.op = nn.Sequential(
# # nn.ReLU(inplace=False),
# act_fun(),
# DeformConvPack(C_in, C_out, kernel_size=kernel_size,
# stride=stride, padding=padding, bias=True),
# nn.BatchNorm2d(C_out, affine=affine)
# )
# self.positive = positive
# # if positive == -1:
# # weight_init(self.op)
# def forward(self, x):
# out = self.op(x)
# return si_relu(out, self.positive)
class Attention(Module):
"""
Obtained from: github.com:rwightman/pytorch-image-models
"""
def __init__(self, dim, num_heads=4, attention_dropout=0.1, projection_dropout=0.1):
super().__init__()
self.num_heads = num_heads
head_dim = dim // self.num_heads
self.scale = head_dim ** -0.5
self.qkv = Linear(dim, dim * 3, bias=False)
self.attn_drop = Dropout(attention_dropout)
self.proj = Linear(dim, dim)
self.proj_drop = Dropout(projection_dropout)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TransformerEncoderLayer(Module):
"""
Inspired by torch.nn.TransformerEncoderLayer and
rwightman's timm package.
"""
def __init__(self, d_model, nhead=4, dim_feedforward=256, dropout=0.1,
attention_dropout=0.1, drop_path_rate=0.1):
super(TransformerEncoderLayer, self).__init__()
self.pre_norm = LayerNorm(d_model)
self.self_attn = Attention(dim=d_model, num_heads=nhead,
attention_dropout=attention_dropout, projection_dropout=dropout)
dim_feedforward = d_model
self.linear1 = Linear(d_model, dim_feedforward)
self.dropout1 = Dropout(dropout)
self.norm1 = LayerNorm(d_model)
self.linear2 = Linear(dim_feedforward, d_model)
self.dropout2 = Dropout(dropout)
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0 else Identity()
self.activation = F.gelu
def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
# print(src.shape)
c = src.shape[-1]
src = rearrange(src, 'b d r c -> b (r c) d')
# print(src.shape)
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
src = self.norm1(src)
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
src = src + self.drop_path(self.dropout2(src2))
src = rearrange(src, 'b (r c) d -> b d r c', c=c)
return src
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""
Obtained from: github.com:rwightman/pytorch-image-models
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(Module):
"""
Obtained from: github.com:rwightman/pytorch-image-models
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
================================================
FILE: examples/Structure_Evolution/EB-NAS/readme.md
================================================
# Brain-Inspired Evolutionary Architectures for Spiking Neural Networks —— Based on BrainCog #
## Requirments ##
* numpy
* pytorch >= 1.12.0
* pymoo = 0.4.0
* BrainCog
## Run ##
```python ebnas.py```
## Citation ##
If you find the code and dataset useful in your research, please consider citing:
```
@article{pan2024brain,
title={Brain-inspired Evolutionary Architectures for Spiking Neural Networks},
author={Pan, Wenxuan and Zhao, Feifei and Zhao, Zhuoya and Zeng, Yi},
journal={IEEE Transactions on Artificial Intelligence},
year={2024},
publisher={IEEE}
}
@article{zeng2023braincog,
title={BrainCog: A spiking neural network based, brain-inspired cognitive intelligence engine for brain-inspired AI and brain simulation},
author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},
journal={Patterns},
volume={4},
number={8},
year={2023},
publisher={Elsevier}
}
```
================================================
FILE: examples/Structure_Evolution/EB-NAS/single_genome.py
================================================
import os
import json
import shutil
import argparse
import subprocess
import numpy as np
import torch
from tm import get_net_info
import micro_encoding
import logging
import yaml
from braincog.utils import save_feature_map, setup_seed
from tm import train_motifs
from datetime import datetime
from timm.utils import *
from pymoo.optimize import minimize
from pymoo.model.problem import Problem
from pymoo.factory import get_performance_indicator
from pymoo.algorithms.so_genetic_algorithm import GA
from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting
from pymoo.factory import get_algorithm, get_crossover, get_mutation
from search_space.ofa import OFASearchSpace
from acc_predictor.factory import get_acc_predictor
_DEBUG = True
if _DEBUG: from pymoo.visualization.scatter import Scatter
devices=[0]
os.environ['CUDA_VISIBLE_DEVICES']='3'
_logger = logging.getLogger('')
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser = argparse.ArgumentParser(description='SNN Training and Evaluating')
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--ocrate', type=float, default=0.0)
parser.add_argument('--dataset', type=str, default='dvsg',
help='imagenet, cifar10, cifar100, dvsg, dvsc10')
parser.add_argument('--step', type=int, default=8, help='Simulation time step (default: 10)')
parser.add_argument('--num-classes', type=int, default=11, metavar='N')
parser.add_argument('--layers', type=int, default=2)
parser.add_argument('--bits', type=int, default=20)
parser.add_argument('--eval_epochs', type=int, default=6000)
parser.add_argument('--trainning_epochs', type=int, default=600)
parser.add_argument('--iterations', type=int, default=20,
help='number of search iterations')
parser.add_argument('--n_doe', type=int, default=100,
help='initial sample size for DOE')
parser.add_argument('--bns', action='store_true', default=True)
parser.add_argument('--init-channels', type=int, default=16)
parser.add_argument('--spike-rate', action='store_true', default=False)
parser.add_argument('--save', type=str, default='',
help='location of dir to save')
parser.add_argument('--sec_obj', type=str, default='flops',
help='second objective to optimize simultaneously')
parser.add_argument('--n_iter', type=int, default=8,
help='number of architectures to high-fidelity eval (low level) in each iteration')
parser.add_argument('--predictor', type=str, default='rbf',
help='which accuracy predictor model to fit (rbf/gp/carts/mlp/as)')
parser.add_argument('--n_gpus', type=int, default=len(devices),
help='total number of available gpus')
parser.add_argument('--gpu', type=int, default=1,
help='number of gpus per evaluation job')
parser.add_argument('--supernet_path', type=str, default='./ofa_mbv3_d234_e346_k357_w1.0',
help='file path to supernet weights')
parser.add_argument('--n_workers', type=int, default=1,
help='number of workers for dataloader per evaluation job')
parser.add_argument('--vld_size', type=int, default=5000,
help='validation set size, randomly sampled from training set')
parser.add_argument('--trn_batch_size', type=int, default=128,
help='train batch size for training')
parser.add_argument('--vld_batch_size', type=int, default=200,
help='test batch size for inference')
parser.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--test', action='store_true', default=False,
help='evaluation performance on testing set')
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser.add_argument('--model', default='NetworkCIFAR_', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',
help='path to eval checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
# Dataset parameters for static datasets
parser.add_argument('--img-size', type=int, default=224, metavar='N',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='inputs image center crop percent (for validation only)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
# Dataloader parameters
parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='inputs batch size for training (default: 128)')
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
help='ratio of validation batch size to training batch size (default: 1)')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.01,
help='weight decay (default: 0.01 for adamw)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=5e-3, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--power', type=int, default=1, help='power')
# Augmentation & regularization parameters ONLY FOR IMAGE NET
parser.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
parser.add_argument('--jsd', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "const")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.8,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix', type=float, default=1.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.0)')
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
parser.add_argument('--newton-maxiter', default=20, type=int,
help='max iterration in newton method')
parser.add_argument('--reset-drop', action='store_true', default=False,
help='whether to reset drop')
parser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],
help='The implementation way of gaussian kernel method, choose from "cuda" and "torch"')
# Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
parser.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
parser.add_argument('--dist-bn', type=str, default='',
help='Distribute BatchNorm stats between node after each epoch ("broadcast", "reduce", or "")')
parser.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.99996,
help='decay factor for model weights moving average (default: 0.9998)')
# Misc
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=8, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=len(devices),
help='Number of GPUS to use')
parser.add_argument('--save-images', action='store_true', default=False,
help='save images of inputs bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--device', type=int, default=devices[0])
# Spike parameters
parser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')
parser.add_argument('--temporal-flatten', action='store_true',
help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')
parser.add_argument('--adaptive-node', action='store_true')
parser.add_argument('--critical-loss', action='store_true')
# neuron type
parser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')
parser.add_argument('--act-fun', type=str, default='QGateGrad',
help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')
parser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')
parser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')
parser.add_argument('--requires-thres-grad', action='store_true')
parser.add_argument('--sigmoid-thres', action='store_true')
parser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')
parser.add_argument('--noisy-grad', type=float, default=0.,
help='Add noise to backward, sometime will make higher accuracy (default: 0.)')
parser.add_argument('--spike-output', action='store_true', default=False,
help='Using mem output or spike output (default: False)')
parser.add_argument('--n_groups', type=int, default=1)
# EventData Augmentation
parser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')
parser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')
parser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')
parser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')
parser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')
parser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')
parser.add_argument('--cutmix_noise', type=float, default=0.,
help='Add Pepper noise after mix, sometimes work (default: 0.)')
parser.add_argument('--rand-aug', action='store_true',
help='Rand Augment for Event data (default: False)')
parser.add_argument('--randaug_n', type=int, default=3,
help='Rand Augment times n (default: 3)')
parser.add_argument('--randaug_m', type=int, default=15,
help='Rand Augment times n (default: 15) (0-30)')
parser.add_argument('--train-portion', type=float, default=0.9,
help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')
parser.add_argument('--event-size', default=48, type=int,
help='Event size. Resize event data before process (default: 48)')
parser.add_argument('--layer-by-layer', action='store_true',
help='forward step-by-step or layer-by-layer. '
'Larger Model with layer-by-layer will be faster (default: False)')
parser.add_argument('--node-resume', type=str, default='',
help='resume weights in node for adaptive node. (default: False)')
parser.add_argument('--node-trainable', action='store_true')
# visualize
parser.add_argument('--visualize', action='store_true',
help='Visualize spiking map for each layer, only for validate (default: False)')
parser.add_argument('--tsne', action='store_true')
parser.add_argument('--conf-mat', action='store_true')
# DARTS parameters
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
parser.add_argument('--parse_method', default='darts', type=str)
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')
parser.add_argument('--suffix', type=str, default='',
help='Add an additional suffix to the save path (default: \'\')')
def _parse_args():
args_config, remaining = config_parser.parse_known_args()
args = parser.parse_args(remaining)
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
if __name__ == '__main__':
args, args_text = _parse_args()
args.no_spike_output = True
output_dir = ''
if args.bns:
from cellmodel import NetworkCIFAR_
else:
from cell123model import NetworkCIFAR_
# if 'dvs' in args.dataset:
# args.step=10
# else:
# args.step=4
if args.local_rank == 0:
output_base = args.save
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
# args.model,
args.dataset,
str(args.layers)+'layers',
str(args.init_channels)+'channels',
str(args.step)+'steps',
# args.suffix
# str(args.img_size)
])
output_dir = get_outdir(output_base,str(args.dataset),exp_name)
args.output_dir = output_dir
setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))
else:
setup_default_logging()
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1:
_logger.warning(
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
args.num_gpu = 1
# args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
if args.distributed:
args.num_gpu = 1
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
else:
torch.cuda.set_device('cuda:0')
assert args.rank >= 0
if args.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
_logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
setup_seed(args.seed + args.rank)
defalut_lr = args.lr
arch_doe,sn,bigmotifs,ms ,glob_con= micro_encoding.sample(pops=10, layers=args.layers, bits=20)
i=0
for geno in arch_doe:
arch_dir=os.path.join('train',args.output_dir,str(i))
if os.path.exists(arch_dir) is False:
os.makedirs(arch_dir,exist_ok = True)
args.lr=defalut_lr
geno=micro_encoding.c_single(geno,layers=args.layers,bits=args.bits)
acc,info=train_motifs(args=args,gen=0,arch_dir=arch_dir,genome=np.array(geno[:-1]),_logger=_logger,args_text=args_text,devices=devices,ms=ms,glob=geno[-1])
i+=1
================================================
FILE: examples/Structure_Evolution/EB-NAS/tm.py
================================================
import sys
import numpy as np
import argparse
import time
import timm.models
import yaml
import os
import logging
from random import choice
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from micro_encoding import ops
from braincog.base.node.node import *
from braincog.utils import *
from braincog.base.utils.criterions import *
from braincog.datasets.datasets import *
from braincog.model_zoo.resnet import *
from braincog.model_zoo.convnet import *
from braincog.utils import save_feature_map, setup_seed
from braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix
import micro_encoding
from pymop.problem import Problem
import torch
from thop import profile
import torch.nn as nn
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from pymoo.optimize import minimize
from timm.data import create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
from torchprofile import profile_macs
import copy
import torch.backends.cudnn as cudnn
import warnings
warnings.simplefilter("ignore")
def train_motifs(args,gen,arch_dir,genome,_logger,args_text,devices,ms,glob):
if args.bns:
from cellmodel import NetworkCIFAR_
else:
from cell123model import NetworkCIFAR_
# qw=np.where(args.glob_con[0]==1)
# ccc=np.array([1,0,0,0])
# for i in qw:
# ccc[i[0]]=1
# ddd=np.where(args.glob_con[i[0]]==1)
# if len(ddd[0])!=0:
# for j in ddd:
# ccc[j[0]]=1
# www=np.where(args.glob_con[j[0]]==1)
# if len(www[0])!=0:
# for k in www:
# ccc[k[0]]=1
test_motifs,ids = micro_encoding.decode_motif(args.layers*ms,args.bits,genome.astype(int))
# if gen==-1:
args.epochs=args.eval_epochs
# else:
# args.epochs=args.eval_epochs
try:
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
dataset=args.dataset,
step=args.step,
encode_type=args.encode,
node_type=eval(args.node_type),
threshold=args.threshold,
tau=args.tau,
sigmoid_thres=args.sigmoid_thres,
requires_thres_grad=args.requires_thres_grad,
spike_output=not args.no_spike_output,
C=args.init_channels,
layers=args.layers*ms,
auxiliary=args.auxiliary,
motif=test_motifs,
parse_method=args.parse_method,
act_fun=args.act_fun,
temporal_flatten=args.temporal_flatten,
layer_by_layer=args.layer_by_layer,
n_groups=args.n_groups,
glob=glob,
)
if 'dvs' in args.dataset:
args.channels = 2
elif 'mnist' in args.dataset:
args.channels = 1
else:
args.channels = 3
# flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)
# _logger.info('flops = %fM', flops / 1e6)
# _logger.info('param size = %fM', params / 1e6)
flops=0
params=0
linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0
args.lr = linear_scaled_lr
_logger.info("learning rate is %f" % linear_scaled_lr)
if args.local_rank == 0:
sumpram=sum([m.numel() for m in model.parameters()])
_logger.info('Model %s created, param count: %d' %
(args.model, sumpram))
num_aug_splits = 0
if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
use_amp = None
if args.amp:
# for backwards compat, `--amp` arg tries apex before native amp
if has_apex:
args.apex_amp = True
elif has_native_amp:
args.native_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
if args.num_gpu > 1:
if use_amp == 'apex':
_logger.warning(
'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
use_amp = None
model = nn.DataParallel(model, device_ids=devices).cuda()
assert not args.channels_last, "Channels last not supported with DP, use DDP."
else:
model = model.cuda()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
optimizer = create_optimizer(args, model)
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
loss_scaler = ApexScaler()
if args.local_rank == 0:
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
amp_autocast = torch.cuda.amp.autocast
loss_scaler = NativeScaler()
if args.local_rank == 0:
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if args.local_rank == 0:
_logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint
resume_epoch = None
if args.resume and args.eval_checkpoint == '':
args.eval_checkpoint = args.resume
if args.resume:
args.eval = True
# checkpoint = torch.load(args.resume, map_location='cpu')
# model.load_state_dict(checkpoint['state_dict'], False)
resume_epoch = resume_checkpoint(
model, args.resume,
optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0)
# print(model.get_attr('mu'))
# print(model.get_attr('sigma'))
if args.critical_loss or args.spike_rate:
if args.num_gpu>1:
model.module.set_requires_fp(True)
else:
model.set_requires_fp(True)
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume=args.resume)
if args.node_resume:
ckpt = torch.load(args.node_resume, map_location='cpu')
model.load_node_weight(ckpt, args.node_trainable)
model_without_ddp = model
if args.distributed:
if args.sync_bn:
assert not args.split_bn
try:
if has_apex and use_amp != 'native':
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
except Exception as e:
_logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
if has_apex and use_amp != 'native':
# Apex DDP preferred unless native amp is activated
if args.local_rank == 0:
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[args.local_rank],
find_unused_parameters=True) # can use device str in Torch >= 1.1
model_without_ddp = model.module
# NOTE: EMA model does not need to be wrapped by DDP
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
_logger.info('Scheduled epochs: {}'.format(num_epochs))
# now config only for imnet
data_config = resolve_data_config(vars(args), model=model, verbose=False)
loader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.dataset)(
batch_size=args.batch_size,
step=args.step,
args=args,
_logge=_logger,
data_config=data_config,
num_aug_splits=num_aug_splits,
size=args.event_size,
mix_up=args.mix_up,
cut_mix=args.cut_mix,
event_mix=args.event_mix,
beta=args.cutmix_beta,
prob=args.cutmix_prob,
num=args.cutmix_num,
noise=args.cutmix_noise,
num_classes=args.num_classes,
rand_aug=args.rand_aug,
randaug_n=args.randaug_n,
randaug_m=args.randaug_m,
temporal_flatten=args.temporal_flatten,
portion=args.train_portion,
_logger=_logger,
)
if args.loss_fn == 'mse':
train_loss_fn = UnilateralMse(1.)
validate_loss_fn = UnilateralMse(1.)
else:
if args.jsd:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
elif mixup_active:
# smoothing is handled with mixup target transform
train_loss_fn = SoftTargetCrossEntropy().cuda()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
if args.loss_fn == 'mix':
train_loss_fn = MixLoss(train_loss_fn)
validate_loss_fn = MixLoss(validate_loss_fn)
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
if args.eval: # evaluate the model
if args.distributed:
state_dict = torch.load(args.eval_checkpoint)['state_dict_ema']
new_state_dict = OrderedDict()
# add module prefix for DDP
for k, v in state_dict.items():
k = 'module.' + k
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
# else:
# load_checkpoint(model, args.eval_checkpoint, args.model_ema)
for i in range(1):
val_metrics,_ = validate(start_epoch, model, loader_eval, validate_loss_fn, args,arch_dir,
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat)
print(f"Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%")
# return
saver = None
if args.local_rank == 0:
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
checkpoint_dir=arch_dir, recovery_dir=arch_dir, decreasing=decreasing)
with open(os.path.join(arch_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
f=open(os.path.join(arch_dir, 'direct_genome.txt'), 'a')
f.write(",".join(str(k) for k in genome))
f.write('\n')
f.close()
try: # train the model
if args.reset_drop:
model_without_ddp.reset_drop_path(0.0)
for epoch in range(start_epoch, args.epochs):
if epoch == 0 and args.reset_drop:
model_without_ddp.reset_drop_path(args.drop_path)
if args.distributed:
loader_train.sampler.set_epoch(epoch)
train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,_logger=_logger,
lr_scheduler=lr_scheduler, saver=saver, output_dir=arch_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
_logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics,_ = validate(epoch, model, loader_eval, validate_loss_fn, args, arch_dir,amp_autocast=amp_autocast,_logger=_logger,
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics,_ = validate(
epoch, model_ema.ema, loader_eval, validate_loss_fn, args, arch_dir,amp_autocast=amp_autocast, log_suffix=' (EMA)',_logger=_logger,
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat)
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
update_summary(
epoch, train_metrics, eval_metrics, os.path.join(arch_dir, 'summary.csv'),
write_header=best_metric is None)
# if saver is not None and epoch >= args.n_warm_up:
if saver is not None:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
best_metric, best_epoch = eval_metrics[eval_metric],epoch
_logger.info('Train: {0} '.format(best_metric))
f=open(os.path.join(arch_dir, 'direct.txt'), 'a')
f.write(str(best_metric))
f.write('\n')
f.close()
except KeyboardInterrupt:
pass
except MemoryError:
return -10000, 0
except RuntimeError:
# return -10000, {'flops': flops / 1e6, 'param': params / 1e6}
return -10000, 0
# if best_metric is not None:
# _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
# info=get_net_info(model)
val_metrics,spikes = validate(start_epoch, model, loader_eval, validate_loss_fn, args,arch_dir,
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat,_logger=_logger,)
return best_metric,spikes
def train_epoch(
epoch, model, loader, optimizer, loss_fn, args,_logger,
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
closses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.train()
# t, k = adjust_surrogate_coeff(100, args.epochs)
# model.set_attr('t', t)
# model.set_attr('k', k)
end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
for batch_idx, (inputs, target) in enumerate(loader):
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
if not args.prefetcher or args.dataset != 'imnet':
inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()
if mixup_fn is not None:
inputs, target = mixup_fn(inputs, target)
if args.channels_last:
inputs = inputs.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(inputs)
loss = loss_fn(output, target)
if not (args.cut_mix | args.mix_up | args.event_mix) and args.dataset != 'imnet':
# print(output.shape, target.shape)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
# acc1, = accuracy(output, target)
else:
acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])
closs = torch.tensor([0.], device=loss.device)
if args.critical_loss:
closs = calc_critical_loss(model)
loss = loss + .1 * closs
spike_rate_avg_layer_str = ''
threshold_str = ''
if not args.distributed:
losses_m.update(loss.item(), inputs.size(0))
top1_m.update(acc1.item(), inputs.size(0))
top5_m.update(acc5.item(), inputs.size(0))
closses_m.update(closs.item(), inputs.size(0))
if args.num_gpu>1:
spike_rate_avg_layer = model.module.get_fire_rate().tolist()
spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]
threshold = model.module.get_threshold()
else:
spike_rate_avg_layer = model.get_fire_rate().tolist()
spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]
threshold = model.get_threshold()
threshold_str = ['{:.3f}'.format(i) for i in threshold]
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)
else:
loss.backward(create_graph=second_order)
if args.noisy_grad != 0.:
random_gradient(model, args.noisy_grad)
if args.clip_grad is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
if args.opt == 'lamb':
optimizer.step(epoch=epoch)
else:
optimizer.step()
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
num_updates += 1
batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
mu_str = ''
sigma_str = ''
if not args.distributed:
if 'Noise' in args.node_type:
mu, sigma = model.get_noise_param()
mu_str = ['{:.3f}'.format(i.detach()) for i in mu]
sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item(), inputs.size(0))
closses_m.update(reduced_loss.item(), inputs.size(0))
if args.local_rank == 0:
if args.distributed:
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'cLoss: {closs.val:>9.6f} ({closs.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
closs=closses_m,
top1=top1_m,
top5=top5_m,
batch_time=batch_time_m,
rate=inputs.size(0) * args.world_size / batch_time_m.val,
rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m
))
else:
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'cLoss: {closs.val:>9.6f} ({closs.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})\n'
'Fire_rate: {spike_rate}\n'
# 'Thres: {threshold}\n'
# 'Mu: {mu_str}\n'
# 'Sigma: {sigma_str}\n'
.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
closs=closses_m,
top1=top1_m,
top5=top5_m,
batch_time=batch_time_m,
rate=inputs.size(0) * args.world_size / batch_time_m.val,
rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m,
spike_rate=spike_rate_avg_layer_str,
# threshold=threshold_str,
# mu_str=mu_str,
# sigma_str=sigma_str
))
if args.save_images and output_dir:
torchvision.utils.save_image(
inputs,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True)
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
end = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.avg)])
def validate(epoch, model, loader, loss_fn, args, arch_dir,_logger,amp_autocast=suppress,
log_suffix='', visualize=False, spike_rate=False, tsne=False, conf_mat=False):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
closses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.eval()
feature_vec = []
feature_cls = []
logits_vec = []
labels_vec = []
end = time.time()
last_idx = len(loader) - 1
with torch.no_grad():
for batch_idx, (inputs, target) in enumerate(loader):
# inputs = inputs.type(torch.float64)
last_batch = batch_idx == last_idx
if not args.prefetcher or args.dataset != 'imnet':
inputs = inputs.type(torch.FloatTensor).cuda()
target = target.cuda()
if args.channels_last:
inputs = inputs.contiguous(memory_format=torch.channels_last)
if not args.distributed:
if (visualize or spike_rate or tsne or conf_mat) and not args.critical_loss:
if args.num_gpu>1:
model.module.set_requires_fp(True)
else:
model.set_requires_fp(True)
# if not args.critical_loss:
# model.set_requires_fp(False)
with amp_autocast():
output = model(inputs)
if isinstance(output, (tuple, list)):
output = output[0]
if not args.distributed:
if visualize:
x = model.get_fp()
feature_path = os.path.join(arch_dir, 'feature_map')
if os.path.exists(feature_path) is False:
os.mkdir(feature_path)
save_feature_map(x, feature_path)
# if not args.critical_loss:
# model_config.set_requires_fp(False)
if tsne:
x = model.get_fp(temporal_info=False)[-1]
x = torch.nn.AdaptiveAvgPool2d((1, 1))(x)
x = x.reshape(x.shape[0], -1)
feature_vec.append(x)
feature_cls.append(target)
if conf_mat:
logits_vec.append(output)
labels_vec.append(target)
if spike_rate:
if args.num_gpu>1:
avg, var, spike, avg_per_step = model.module.get_spike_info()
else:
avg, var, spike, avg_per_step = model.get_spike_info()
save_spike_info(
os.path.join(arch_dir, 'spike_info.csv'),
epoch, batch_idx,
args.step, avg, var,
spike, avg_per_step)
# augmentation reduction
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor]
loss = loss_fn(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
# acc1, = accuracy(output, target)
closs = torch.tensor([0.], device=loss.device)
if not args.distributed:
if args.num_gpu>1:
spike_rate_avg_layer = model.module.get_fire_rate().tolist()
threshold = model.module.get_threshold()
threshold_str = ['{:.3f}'.format(i) for i in threshold]
spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]
tot_spike = model.module.get_tot_spike()
else:
spike_rate_avg_layer = model.get_fire_rate().tolist()
threshold = model.get_threshold()
threshold_str = ['{:.3f}'.format(i) for i in threshold]
spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]
tot_spike = model.get_tot_spike()
if args.critical_loss:
closs = calc_critical_loss(model)
loss = loss + .1 * closs
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
acc1 = reduce_tensor(acc1, args.world_size)
acc5 = reduce_tensor(acc5, args.world_size)
else:
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), inputs.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
closses_m.update(closs.item(), inputs.size(0))
batch_time_m.update(time.time() - end)
end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
mu_str = ''
sigma_str = ''
if not args.distributed:
if 'Noise' in args.node_type:
mu, sigma = model.get_noise_param()
mu_str = ['{:.3f}'.format(i.detach()) for i in mu]
sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]
if args.distributed:
_logger.info(
'{0}: [{1:>4d}/{2}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'cLoss: {closs.val:>7.4f} ({closs.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
log_name,
batch_idx,
last_idx,
batch_time=batch_time_m,
loss=losses_m,
closs=closses_m,
top1=top1_m,
top5=top5_m,
))
else:
_logger.info(
'{0}: [{1:>4d}/{2}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'cLoss: {closs.val:>7.4f} ({closs.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})\n'
'Fire_rate: {spike_rate}\n'
'Tot_spike: {tot_spike}\n'
'Thres: {threshold}\n'
'Mu: {mu_str}\n'
'Sigma: {sigma_str}\n'.format(
log_name,
batch_idx,
last_idx,
batch_time=batch_time_m,
loss=losses_m,
closs=closses_m,
top1=top1_m,
top5=top5_m,
spike_rate=spike_rate_avg_layer_str,
tot_spike=tot_spike,
threshold=threshold_str,
mu_str=mu_str,
sigma_str=sigma_str
))
# metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg)])
if not args.distributed:
if tsne:
feature_vec = torch.cat(feature_vec)
feature_cls = torch.cat(feature_cls)
plot_tsne(feature_vec, feature_cls, os.path.join(arch_dir, 't-sne-2d.eps'))
plot_tsne_3d(feature_vec, feature_cls, os.path.join(arch_dir, 't-sne-3d.eps'))
if conf_mat:
logits_vec = torch.cat(logits_vec)
labels_vec = torch.cat(labels_vec)
plot_confusion_matrix(logits_vec, labels_vec, os.path.join(arch_dir, 'confusion_matrix.eps'))
return metrics,tot_spike
def get_net_info(args, gen,genome,ms):
"""
Modified from https://github.com/mit-han-lab/once-for-all/blob/
35ddcb9ca30905829480770a6a282d49685aa282/ofa/imagenet_codebase/utils/pytorch_utils.py#L139
"""
from ofa.imagenet_codebase.utils.pytorch_utils import count_parameters, measure_net_latency
# artificial input data
if args.bns:
from cellmodel import NetworkCIFAR
else:
from cell123model import NetworkCIFAR
test_motifs,ids = micro_encoding.decode_motif(args.layers*ms,args.bits,genome.astype(int))
net = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
dataset=args.dataset,
step=args.step,
encode_type=args.encode,
node_type=eval(args.node_type),
threshold=args.threshold,
tau=args.tau,
sigmoid_thres=args.sigmoid_thres,
requires_thres_grad=args.requires_thres_grad,
spike_output=not args.no_spike_output,
C=args.init_channels,
layers=args.layers*ms,
auxiliary=args.auxiliary,
motif=test_motifs,
parse_method=args.parse_method,
act_fun=args.act_fun,
temporal_flatten=args.temporal_flatten,
layer_by_layer=args.layer_by_layer,
n_groups=args.n_groups,
cell_type=genome[-1],
)
if 'dvs' in args.dataset:
args.channels = 2
elif 'mnist' in args.dataset:
args.channels = 1
else:
args.channels = 3
inputs = torch.randn(1, args.channels, 224, 224)
# move network to GPU if available
if torch.cuda.is_available():
device = torch.device('cuda:0')
net = net.to(device)
cudnn.benchmark = True
inputs = inputs.to(device)
net_info = {}
if isinstance(net, nn.DataParallel):
net = net.module
# parameters
net_info['params'] = count_parameters(net)
# flops
net_info['flops'] = int(profile_macs(copy.deepcopy(net), inputs))
return net_info
================================================
FILE: examples/Structure_Evolution/ELSM/README.md
================================================
# Emergence of Brain-inspired Small-world Spiking Neural Network through Neuroevolution —— Based on BrainCog #
## Requirments ##
* numpy
* pytorch >= 1.12.0
* BrainCog
## Run ##
```python evolve.py```
## Citation ##
If you find the code and dataset useful in your research, please consider citing:
```
@article{pan2024emergence,
title={Emergence of Brain-inspired Small-world Spiking Neural Network through Neuroevolution},
author={Pan, Wenxuan and Zhao, Feifei and Han, Bing and Dong, Yiting and Zeng, Yi},
journal={iScience},
year={2024},
publisher={Elsevier}
}
@article{zeng2023braincog,
title={BrainCog: A spiking neural network based, brain-inspired cognitive intelligence engine for brain-inspired AI and brain simulation},
author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},
journal={Patterns},
volume={4},
number={8},
year={2023},
publisher={Elsevier}
}
```
================================================
FILE: examples/Structure_Evolution/ELSM/evolve.py
================================================
import time
import threading
from threading import Thread
import os
import networkx as nx
import numpy as np
from population import *
import nsganet as engine
from pymop.problem import Problem
from pymoo.optimize import minimize
from pymoo.operators.sampling.random_sampling import RandomSampling
from pymoo.operators.mutation.bitflip_mutation import BinaryBitflipMutation
import logging
from model import *
from spikes import calc_f2
from multiprocessing import Process,Pool
from datetime import datetime
import time
_logger = logging.getLogger('')
config_parser = parser = argparse.ArgumentParser(description='Evolution Config', add_help=False)
parser = argparse.ArgumentParser(description='SNN Evoving')
parser.add_argument('--device', type=int, default=2)
parser.add_argument('--seed', type=int, default=68, metavar='S')
parser.add_argument('--datapath', default='/data/', type=str, metavar='PATH')
parser.add_argument('--output', default='/data/LSM/Eresult/new', type=str, metavar='PATH')
parser.add_argument('--liquid-size', type=int, default=8000)
parser.add_argument('--pop-size', type=int, default=20)
parser.add_argument('--up', type=int, default=32000000)
parser.add_argument('--low', type=int, default=320000)
parser.add_argument('--n_offspring', type=int, default=200)
parser.add_argument('--n_gens', type=int, default=2000)
parser.add_argument('--arand', type=float, default=285)
parser.add_argument('--brand', type=float, default=1.8)
def _parse_args():
args_config, remaining = config_parser.parse_known_args()
args = parser.parse_args(remaining)
return args
def calc_f1(dirs):
ci=[]
G=nx.read_gpickle(dirs)
largest_component = max(nx.connected_components(G), key=len)
G = G.subgraph(largest_component)
for u in G.nodes:
ci.append(nx.clustering(G,u))
a=sum(ci)
print("start")
print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
path=nx.average_shortest_path_length(G)
print("end")
print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
return a,path
def mul_f1(pop,steps,rootdir):
result=[]
for i in range(0,pop,steps):
p = Pool(steps)
dirs=[os.path.join(rootdir,str(i)+'.pkl') for i in range(i,i+steps)]
ret = p.map(calc_f1,dirs)
result.extend(ret)
print(ret)
p.close()
p.join()
return result
class Evolve(Problem):
# first define the NAS problem (inherit from pymop)
def __init__(self, args,n_var=20, n_obj=1, n_constr=0, lb=None, ub=None):
super().__init__(n_var=n_var, n_obj=n_obj, n_constr=n_constr, type_var=np.int64)
self.xl = lb
self.xu = ub
self._n_evaluated = 0 # keep track of how many architectures are sampled
self.args=args
def _evaluate(self, x, out, *args, **kwargs):
objs = np.full((x.shape[0], self.n_obj), np.nan)
g1 = np.full((x.shape[0]), np.nan)
g2 = np.full((x.shape[0]), np.nan)
gen_dir=os.path.join(self.args.output,'generaion'+str(kwargs['algorithm'].n_gen))
os.makedirs(gen_dir,exist_ok = True)
# np.save(os.path.join(gen_dir,"x.npy"),x)
lsms = x.reshape(x.shape[0],self.args.liquid_size,self.args.liquid_size)
for i in range(x.shape[0]):
temp_G = nx.Graph(lsms[i])
nx.write_gpickle(temp_G, os.path.join(gen_dir,str(i)+".pkl"))
self.ob1=mul_f1(pop=x.shape[0],steps=10,rootdir=gen_dir)
for i in range(x.shape[0]):
arch_id = self._n_evaluated + 1
print('\n')
_logger.info('Network= {}'.format(arch_id))
genome = x[i, :]
g1[i]= genome.sum()-self.args.up
g2[i]= self.args.low-genome.sum()
lsmm = genome.reshape(self.args.liquid_size,self.args.liquid_size)
small_coe_a,small_coe_b=self.ob1[i]
lsmm=torch.tensor(lsmm,device='cuda:%d' % self.args.device).float()
crit = calc_f2(lsmm,'cuda:%d' % self.args.device)
objs[i, 1] = abs(crit-1)
# all objectives assume to be MINIMIZED !!!!!
objs[i, 0] = -(small_coe_a/self.args.arand)/(small_coe_b/self.args.brand)
_logger.info('small word= {}'.format(objs[i, 0]))
_logger.info('criticality= {}'.format(objs[i, 1]))
self._n_evaluated += 1
out["F"] = objs
out["G"] = np.column_stack([g1,g2])
# if your NAS problem has constraints, use the following line to set constraints
# out["G"] = np.column_stack([g1, g2, g3, g4, g5, g6]) in case 6 constraints
# ---------------------------------------------------------------------------------------------------------
# Define what statistics to print or save for each generation
# ---------------------------------------------------------------------------------------------------------
def do_every_generations(algorithm):
# this function will be call every generation
# it has access to the whole algorithm class
gen = algorithm.n_gen
pop_var = algorithm.pop.get("X")
pop_obj = algorithm.pop.get("F")
# report generation info to files
_logger.info("generation = {}".format(gen))
_logger.info("population error1: best = {}, mean = {}, "
"median1 = {}, worst1 = {}".format(np.min(pop_obj[:, 0]), np.mean(pop_obj[:, 0]),
np.median(pop_obj[:, 0]), np.max(pop_obj[:, 0])))
_logger.info('Best1 Genome id= {}'.format(np.argmin(pop_obj[:, 0])))
_logger.info("population error2: best = {}, mean = {}, "
"median2 = {}, worst2 = {}".format(np.min(pop_obj[:, 1]), np.mean(pop_obj[:, 1]),
np.median(pop_obj[:, 1]), np.max(pop_obj[:, 1])))
_logger.info('Best2 Genome id= {}'.format(np.argmin(pop_obj[:, 1])))
if gen%20==0:
best_sid=np.argmin(pop_obj[:, 0])
best_sname='-'.join([
'gen'+str(gen),
's'+str(float('%.4f' % pop_obj[best_sid, 0])),
'c'+str(float('%.4f' % pop_obj[best_sid, 1])),
])
best_cid=np.argmin(pop_obj[:, 1])
best_cname='-'.join([
'gen'+str(gen),
's'+str(float('%.4f' % pop_obj[best_cid, 0])),
'c'+str(float('%.4f' % pop_obj[best_cid, 1])),
])
np.save(os.path.join('/data/save/genome',best_sname+datetime.now().strftime("%Y%m%d-%H%M%S")),pop_var[np.argmin(pop_obj[:, 0])])
np.save(os.path.join('/data/save/genome',best_cname+datetime.now().strftime("%Y%m%d-%H%M%S")),pop_var[np.argmin(pop_obj[:, 1])])
if __name__ == '__main__':
args = _parse_args()
out_base_dir= os.path.join(args.output, datetime.now().strftime("%Y%m%d-%H%M%S"))
os.makedirs(out_base_dir,exist_ok = True)
args.output=out_base_dir
setup_default_logging(log_path=os.path.join(out_base_dir, 'log.txt'))
kkk = Evolve(args,n_var=args.liquid_size*args.liquid_size,
n_obj=2, n_constr=2)
method = engine.nsganet(pop_size=args.pop_size,
sampling=RandomSampling(var_type='custom'),
mutation=BinaryBitflipMutation(),
n_offsprings=args.n_offspring,
eliminate_duplicates=True)
kres=minimize(kkk,
method,
callback=do_every_generations,
termination=('n_gen', args.n_gens))
================================================
FILE: examples/Structure_Evolution/ELSM/lsm.py
================================================
from __future__ import print_function
import torchvision
import torchvision.transforms as transforms
import os
import time
import numpy as np
import torch
from torch import nn as nn
from mnistmodel import *
from tqdm import tqdm
import argparse
from datetime import datetime
import logging
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy
from braincog.base.utils import UnilateralMse, MixLoss
from braincog.base.learningrule.STDP import *
device='cuda:7'
def lr_scheduler(optimizer, epoch, init_lr=0.1, lr_decay_epoch=50):
"""Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
if epoch % lr_decay_epoch == 0 and epoch > 1:
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] * 0.1
return optimizer
batch_size=100
liquid_size=8000
learning_rate = 1e-3
num_epochs = 100 # max epoch
data_path = '/data'
load_path=''
train_dataset = torchvision.datasets.MNIST(root=data_path, train=True, download=False, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
test_set = torchvision.datasets.MNIST(root=data_path, train=False, download=False, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)
snn = SNN(ins=784,
batchsize=batch_size,
device=device,
liquid_size=liquid_size,
lsm_tau=lsm_tau,
lsm_th=lsm_th)
snn.load_state_dict(torch.load(load_path)['fc'])
snn.learning_rule=[]
snn.con[0].load_state_dict(torch.load(load_path)['lsm0'])
w2tmp=nn.Linear(liquid_size,liquid_size,bias=False,device=device)
snn.connectivity_matrix=torch.load(load_path)['connectivity_matrix'].to(device)
w2tmp.weight.data=(torch.load(load_path)['liquid_weight'].to(device))*snn.connectivity_matrix
snn.learning_rule.append(MutliInputSTDP(snn.node_lsm(), [snn.con[0], w2tmp])) # pm
snn.eval()
snn.to(device)
class LabelSmoothingBCEWithLogitsLoss(nn.Module):
def __init__(self, smoothing=0.1):
"""
Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
super(LabelSmoothingBCEWithLogitsLoss, self).__init__()
assert smoothing < 1.0
self.smoothing = smoothing
self.confidence = 1. - smoothing
self.BCELoss = nn.BCEWithLogitsLoss()
def forward(self, x, target):
target = torch.eye(x.shape[-1], device=x.device)[target]
nll = torch.ones_like(x) / x.shape[-1]
return self.BCELoss(x, target) * self.confidence + self.BCELoss(x, nll) * self.smoothing
ls = 'mse'
if ls == 'ce':
criterion = nn.CrossEntropyLoss()
elif ls == 'bce':
criterion = nn.BCEWithLogitsLoss()
elif ls == 'mse':
criterion = UnilateralMse(1.)
elif ls == 'sce':
criterion = LabelSmoothingCrossEntropy()
elif ls == 'sbce':
criterion = LabelSmoothingBCEWithLogitsLoss()
elif ls == 'umse':
criterion = UnilateralMse(.5)
optimizer = torch.optim.AdamW(snn.fc.parameters(),lr=0.001, weight_decay=1e-4)
l=[]
best_acc=0
for epoch in range(num_epochs):
running_loss = 0
start_time = time.time()
for i, (images, labels) in enumerate(tqdm(train_loader)):
snn.zero_grad()
optimizer.zero_grad()
images = images.float().to(device)
outputs = snn(images)
labels=labels.to(device)
loss = criterion(outputs, labels)
running_loss += loss.item()
loss.backward()
optimizer.step()
snn.reset()
if (i + 1) % 100 == 0:
running_loss = 0
correct = 0
total = 0
optimizer = lr_scheduler(optimizer, epoch, learning_rate, 40)
for batch_idx, (inputs, targets) in enumerate(test_loader):
inputs = inputs.float().to(device)
snn.zero_grad()
optimizer.zero_grad()
outputs = snn(inputs)
targets=targets.to(device)
loss = criterion(outputs, targets)
_, predicted = outputs.max(1)
total += float(targets.size(0))
correct += float(predicted.eq(targets).sum().item())
snn.reset()
if batch_idx % 100 == 0:
acc = 100. * float(correct) / float(total)
print(batch_idx, len(test_loader), ' Acc: %.5f' % acc)
print('Test Accuracy: %.3f' % (100 * correct / total))
acc = 100. * float(correct) / float(total)
if best_acc < acc:
best_acc = acc
print(best_acc)
l.append(best_acc)
================================================
FILE: examples/Structure_Evolution/ELSM/model.py
================================================
from functools import partial
from torch.nn import functional as F
from torch import nn as nn
import torchvision, pprint
from copy import deepcopy
from timm.models import register_model
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.encoder.encoder import *
from braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule
from braincog.base.brainarea.BrainArea import BrainArea
from braincog.base.connection.CustomLinear import *
from braincog.base.learningrule.STDP import *
import matplotlib.pyplot as plt
@register_model
class nSNN(BaseModule):
def __init__(self,
batchsize,
liquid_size,
device,
connectivity_matrix,
num_classes=10,
step=1,
node_type=LIFNode,
encode_type='direct',
lsm_th=0.3,
fc_th=0.3,
lsm_tau=3,
fc_tau=3,
ins=1156,
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.batchsize=batchsize
self.ins=ins
self.node_lsm=partial(node_type, **kwargs, step=step,tau=lsm_tau,threshold=lsm_th)
self.node_fc = partial(node_type, **kwargs, step=step,tau=fc_tau,threshold=fc_th)
self.liquid_size=liquid_size
self.device=device
self.con=[]
self.learning_rule=[]
self.connectivity_matrix=connectivity_matrix
w1tmp=nn.Linear(ins,liquid_size,bias=False).to(device)
self.con.append(w1tmp)
w2tmp=nn.Linear(liquid_size,liquid_size,bias=False).to(device)
self.liquid_weight=w2tmp.weight.data
w2tmp.weight.data=w2tmp.weight.data*self.connectivity_matrix
self.con.append(w2tmp)
self.learning_rule.append(MutliInputSTDP(self.node_lsm(), [self.con[0], self.con[1]])) # pm
self.fc = nn.Sequential(
nn.Linear(liquid_size,num_classes),
self.node_fc()
)
def forward(self, x):
sum_spike=0
self.out = torch.zeros(x.shape[0], self.liquid_size).to(self.device)
tw=x.shape[1]
self.tw=tw
self.firing_tw=torch.zeros(tw, self.batchsize, self.liquid_size).to(self.device)
for t in range(tw):
self.out, self.dw = self.learning_rule[0](x[:,t,:], self.out)
out_liquid=self.out[:,0:self.liquid_size]
xout = self.fc(out_liquid)
sum_spike=sum_spike+xout
self.firing_tw[t]=out_liquid
outputs = sum_spike / tw
return outputs
@register_model
class mSNN(BaseModule):
def __init__(self,
batchsize,
liquid_size,
device,
connectivity_matrix,
num_classes=10,
step=1,
node_type=LIFNode,
encode_type='direct',
lsm_th=0.3,
fc_th=0.3,
lsm_tau=3,
fc_tau=3,
tw=100,
*args,
**kwargs):
super().__init__(step, encode_type, *args, **kwargs)
self.batchsize=batchsize
self.node_lsm=partial(node_type, **kwargs, step=step,tau=lsm_tau,threshold=lsm_th)
self.node_fc = partial(node_type, **kwargs, step=step,tau=fc_tau,threshold=fc_th)
self.liquid_size=liquid_size
self.out = torch.zeros(self.batchsize, liquid_size).to(device)
self.device=device
self.con=[]
self.learning_rule=[]
self.connectivity_matrix=connectivity_matrix
w1tmp=nn.Linear(784,liquid_size,bias=False).to(device)
self.con.append(w1tmp)
w2tmp=nn.Linear(liquid_size,liquid_size,bias=False).to(device)
self.liquid_weight=w2tmp.weight.data
w2tmp.weight.data=w2tmp.weight.data*self.connectivity_matrix
self.con.append(w2tmp)
self.learning_rule.append(MutliInputSTDP(self.node_lsm(), [self.con[0], self.con[1]])) # pm
self.fc = nn.Sequential(
nn.Linear(liquid_size,num_classes),
self.node_fc()
)
def forward(self, x):
x = x.reshape(x.shape[0], -1)
sum_spike=0
time_window=20
self.tw=time_window
self.firing_tw=torch.zeros(time_window, self.batchsize, self.liquid_size).to(self.device)
self.out = torch.zeros(self.batchsize, self.liquid_size).to(self.device)
for t in range(time_window):
self.out, self.dw = self.learning_rule[0](x, self.out)
out_liquid=self.out[:,0:self.liquid_size]
xout = self.fc(out_liquid)
sum_spike=sum_spike+xout
self.firing_tw[t]=out_liquid
# print(out_liquid.sum())
# print(xout.sum())
outputs = sum_spike / time_window
return outputs
================================================
FILE: examples/Structure_Evolution/ELSM/nsganet.py
================================================
import numpy as np
from pymoo.algorithms.genetic_algorithm import GeneticAlgorithm
from pymoo.docs import parse_doc_string
from pymoo.model.individual import Individual
from pymoo.model.survival import Survival
from pymoo.operators.crossover.point_crossover import PointCrossover
from pymoo.operators.mutation.polynomial_mutation import PolynomialMutation
from pymoo.operators.mutation.bitflip_mutation import BinaryBitflipMutation
from pymoo.operators.sampling.random_sampling import RandomSampling
from pymoo.operators.selection.tournament_selection import compare, TournamentSelection
from pymoo.util.display import disp_multi_objective
from pymoo.util.dominator import Dominator
from pymoo.util.non_dominated_sorting import NonDominatedSorting
from pymoo.util.randomized_argsort import randomized_argsort
# =========================================================================================================
# Implementation
# based on nsga2 from https://github.com/msu-coinlab/pymoo
# =========================================================================================================
class NSGANet(GeneticAlgorithm):
def __init__(self, **kwargs):
kwargs['individual'] = Individual(rank=np.inf, crowding=-1)
super().__init__(**kwargs)
self.tournament_type = 'comp_by_dom_and_crowding'
self.func_display_attrs = disp_multi_objective
# ---------------------------------------------------------------------------------------------------------
# Binary Tournament Selection Function
# ---------------------------------------------------------------------------------------------------------
def binary_tournament(pop, P, algorithm, **kwargs):
if P.shape[1] != 2:
raise ValueError("Only implemented for binary tournament!")
tournament_type = algorithm.tournament_type
S = np.full(P.shape[0], np.nan)
for i in range(P.shape[0]):
a, b = P[i, 0], P[i, 1]
# if at least one solution is infeasible
if pop[a].CV > 0.0 or pop[b].CV > 0.0:
S[i] = compare(a, pop[a].CV, b, pop[b].CV, method='smaller_is_better', return_random_if_equal=True)
# both solutions are feasible
else:
if tournament_type == 'comp_by_dom_and_crowding':
rel = Dominator.get_relation(pop[a].F, pop[b].F)
if rel == 1:
S[i] = a
elif rel == -1:
S[i] = b
elif tournament_type == 'comp_by_rank_and_crowding':
S[i] = compare(a, pop[a].rank, b, pop[b].rank,
method='smaller_is_better')
else:
raise Exception("Unknown tournament type.")
# if rank or domination relation didn't make a decision compare by crowding
if np.isnan(S[i]):
S[i] = compare(a, pop[a].get("crowding"), b, pop[b].get("crowding"),
method='larger_is_better', return_random_if_equal=True)
return S[:, None].astype(np.int)
# ---------------------------------------------------------------------------------------------------------
# Survival Selection
# ---------------------------------------------------------------------------------------------------------
class RankAndCrowdingSurvival(Survival):
def __init__(self) -> None:
super().__init__(True)
def _do(self, pop, n_survive, D=None, **kwargs):
# get the objective space values and objects
F = pop.get("F")
# the final indices of surviving individuals
survivors = []
# do the non-dominated sorting until splitting front
fronts = NonDominatedSorting().do(F, n_stop_if_ranked=n_survive)
for k, front in enumerate(fronts):
# calculate the crowding distance of the front
crowding_of_front = calc_crowding_distance(F[front, :])
# save rank and crowding in the individual class
for j, i in enumerate(front):
pop[i].set("rank", k)
pop[i].set("crowding", crowding_of_front[j])
# current front sorted by crowding distance if splitting
if len(survivors) + len(front) > n_survive:
I = randomized_argsort(crowding_of_front, order='descending', method='numpy')
I = I[:(n_survive - len(survivors))]
# otherwise take the whole front unsorted
else:
I = np.arange(len(front))
# extend the survivors by all or selected individuals
survivors.extend(front[I])
return pop[survivors]
def calc_crowding_distance(F):
infinity = 1e+14
n_points = F.shape[0]
n_obj = F.shape[1]
if n_points <= 2:
return np.full(n_points, infinity)
else:
# sort each column and get index
I = np.argsort(F, axis=0, kind='mergesort')
# now really sort the whole array
F = F[I, np.arange(n_obj)]
# get the distance to the last element in sorted list and replace zeros with actual values
dist = np.concatenate([F, np.full((1, n_obj), np.inf)]) \
- np.concatenate([np.full((1, n_obj), -np.inf), F])
index_dist_is_zero = np.where(dist == 0)
dist_to_last = np.copy(dist)
for i, j in zip(*index_dist_is_zero):
dist_to_last[i, j] = dist_to_last[i - 1, j]
dist_to_next = np.copy(dist)
for i, j in reversed(list(zip(*index_dist_is_zero))):
dist_to_next[i, j] = dist_to_next[i + 1, j]
# normalize all the distances
norm = np.max(F, axis=0) - np.min(F, axis=0)
norm[norm == 0] = np.nan
dist_to_last, dist_to_next = dist_to_last[:-1] / norm, dist_to_next[1:] / norm
# if we divided by zero because all values in one columns are equal replace by none
dist_to_last[np.isnan(dist_to_last)] = 0.0
dist_to_next[np.isnan(dist_to_next)] = 0.0
# sum up the distance to next and last and norm by objectives - also reorder from sorted list
J = np.argsort(I, axis=0)
crowding = np.sum(dist_to_last[J, np.arange(n_obj)] + dist_to_next[J, np.arange(n_obj)], axis=1) / n_obj
# replace infinity with a large number
crowding[np.isinf(crowding)] = infinity
return crowding
# =========================================================================================================
# Interface
# =========================================================================================================
def nsganet(
pop_size=100,
sampling=RandomSampling(var_type=np.int),
selection=TournamentSelection(func_comp=binary_tournament),
crossover=PointCrossover(n_points=2),
mutation=PolynomialMutation(eta=3, var_type=np.int),
eliminate_duplicates=True,
n_offsprings=None,
**kwargs):
"""
Parameters
----------
pop_size : {pop_size}
sampling : {sampling}
selection : {selection}
crossover : {crossover}
mutation : {mutation}
eliminate_duplicates : {eliminate_duplicates}
n_offsprings : {n_offsprings}
Returns
-------
nsganet : :class:`~pymoo.model.algorithm.Algorithm`
Returns an NSGANet algorithm object.
"""
return NSGANet(pop_size=pop_size,
sampling=sampling,
selection=selection,
crossover=crossover,
mutation=mutation,
survival=RankAndCrowdingSurvival(),
eliminate_duplicates=eliminate_duplicates,
n_offsprings=n_offsprings,
**kwargs)
parse_doc_string(nsganet)
================================================
FILE: examples/Structure_Evolution/ELSM/spikes.py
================================================
from __future__ import print_function
import torchvision
import torchvision.transforms as transforms
import os
import numpy as np
import torch
from torch import nn as nn
from model import *
from tqdm import tqdm
import argparse
from datetime import datetime
import logging
from timm.utils import *
from spikingjelly.datasets.n_mnist import NMNIST
from timm.loss import LabelSmoothingCrossEntropy
from braincog.base.utils.criterions import *
import networkx as nx
import time
from braincog.base.learningrule.STDP import *
def randbool(size, p=0.5):
return torch.rand(*size) < p
def calc_f2(con,device):
batch_size=1
liquid_size=8000
images=torch.load('/1000images.pt')
labels=torch.load('/1000labels.pt')
load_path='970.t7'
snn = nSNN(ins=2312,
batchsize=batch_size,
device=device,
liquid_size=liquid_size,
lsm_tau=2.0,
lsm_th=0.20,
connectivity_matrix=randbool([liquid_size, liquid_size],p=0.01).to(device).int())
snn.load_state_dict(torch.load(load_path,map_location={'cuda:2':device})['fc'])
snn.con[0].load_state_dict(torch.load(load_path,map_location={'cuda:2':device})['lsm0'])
snn.to(device)
criterion = UnilateralMse(1.)
optimizer = torch.optim.AdamW(snn.fc.parameters(),lr=0.001, weight_decay=1e-4)
k=0
sbr=0
snn.connectivity_matrix=con
snn.learning_rule=[]
w2tmp=nn.Linear(liquid_size,liquid_size,bias=False,device=device)
w2tmp.weight.data=(torch.load(load_path,map_location={'cuda:2':device})['liquid_weight'])*snn.connectivity_matrix
snn.learning_rule.append(MutliInputSTDP(snn.node_lsm(), [snn.con[0], w2tmp]))
snn.eval()
for label,data in zip(labels,images):
running_loss = 0
snn.zero_grad()
optimizer.zero_grad()
data = data.to(device)
label = label.to(device)
data=data.reshape(batch_size,data.shape[0],-1)
output = snn(data)
# print(torch.argmax(output)==label)
out_liquid=snn.firing_tw.squeeze(-2)
mupost=torch.matmul(con,out_liquid.unsqueeze(-1))
mupre=torch.matmul(con.t(),out_liquid.unsqueeze(-1))
for t in range(snn.tw):
if t>5 and t 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
s = h1 + h2
if self.back_connection:
if i != 0:
s_back = self._ops_back[i - 1](s)
states[self._indices_backward[i - 1]] = states[self._indices_backward[i - 1]] + s_back
states += [s]
outputs = torch.cat([states[i]
for i in self._concat], dim=1) # N,C,H, W
return outputs
# return self.node(outputs)
class EvoCell3(nn.Module):
def __init__(self,motif, C_prev_prev_prev, C_prev_prev, C_prev, C, reduction, reduction_prev, reduction_prev_prev, act_fun):
# print(C_prev_prev_prev,C_prev_prev, C_prev, C, reduction,reduction_prev, reduction_prev_prev)
super(EvoCell3, self).__init__()
self.act_fun = act_fun
self.reduction = reduction
self.motif=motif
self.back_connection=False
if reduction:
self.fun = FactorizedReduce(C_prev, C * 3, act_fun=act_fun)
self.multiplier = 3
else:
if reduction_prev:
self.preprocess1 = FactorizedReduce(C_prev_prev, C, act_fun=act_fun)
else:
self.preprocess1 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, act_fun=act_fun)
if int(reduction_prev_prev)+int(reduction_prev)==1:
self.preprocess0 = FactorizedReduce(C_prev_prev_prev, C, act_fun=act_fun)
elif int(reduction_prev_prev)+int(reduction_prev)==2:
self.preprocess0 = F0(C_prev_prev_prev, C, act_fun=act_fun)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev_prev, C, 1, 1, 0, act_fun=act_fun)
self.preprocess2 = ReLUConvBN(C_prev, C, 1, 1, 0, act_fun=act_fun)
op_names, indices = zip(*motif.normal)
concat = motif.normal_concat
self._compile(C, op_names, indices, concat, reduction)
def _compile(self, C, op_names, indices, concat, reduction):
assert len(op_names) == len(indices)
# self._steps = len(op_names) // 2
self._concat = concat
self.multiplier = len(concat)
self._ops = nn.ModuleList()
self._ops_back = nn.ModuleList()
back_begin_index = 0
for i, (name, index) in enumerate(zip(op_names, indices)):
# print(name, index)
if '_back' in name:
self.back_connection=True
back_begin_index = i
break
stride = 2 if reduction and index < 2 else 1
op = OPS[name](C, stride, True, act_fun=self.act_fun)
self._ops += [op]
if self.back_connection:
for name, index in zip(op_names[back_begin_index:], indices[back_begin_index:]):
op = OPS[name.replace('_back', '')](
C, 1, True, act_fun=self.act_fun)
self._ops_back += [op]
if self.back_connection:
self._indices_forward = indices[:back_begin_index]
self._indices_backward = indices[back_begin_index:]
else:
self._indices_backward = []
self._indices_forward = indices
self._steps = len(self._indices_forward) // 3
def forward(self, s0, s1, s2, drop_prob):
if self.reduction:
return self.fun(s2)
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
s2 = self.preprocess2(s2)
states = [s0, s1, s2]
for i in range(self._steps):
i1=self._indices_forward[3 * i]
i2=self._indices_forward[3 * i + 1]
i3=self._indices_forward[3 * i + 2]
h1 = states[i1]
h2 = states[i2]
h3 = states[i3]
op1 = self._ops[3 * i]
op2 = self._ops[3 * i + 1]
op3 = self._ops[3 * i + 2]
h1 = op1(h1)
h2 = op2(h2)
h3 = op3(h3)
if self.training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
if not isinstance(op3, Identity):
h3 = drop_path(h3, drop_prob)
s = h1 + h2 + h3
if self.back_connection:
if i != 0:
s_back = self._ops_back[i - 1](s)
states[self._indices_backward[i - 1]] = states[self._indices_backward[i - 1]] + s_back
states += [s]
outputs = torch.cat([states[i] for i in self._concat], dim=1) # N,C,H, W
return outputs
# return self.node(outputs)
class EvoCell4(nn.Module):
def __init__(self,motif, C_prev_prev_prev_prev,C_prev_prev_prev, C_prev_prev, C_prev, C, reduction, reduction_prev, reduction_prev_prev,reduction_prev_prev_prev, act_fun):
# print(C_prev_prev_prev_prev,C_prev_prev_prev,C_prev_prev, C_prev, C, reduction,reduction_prev, reduction_prev_prev,reduction_prev_prev_prev)
super(EvoCell4, self).__init__()
self.act_fun = act_fun
self.reduction = reduction
self.motif=motif
self.back_connection=False
if reduction:
self.fun = FactorizedReduce(C_prev, C * 3, act_fun=act_fun)
self.multiplier = 3
else:
if reduction_prev:
self.preprocess2 = FactorizedReduce(C_prev_prev, C, act_fun=act_fun)
else:
self.preprocess2 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, act_fun=act_fun)
if int(reduction_prev_prev)+int(reduction_prev)==1:
self.preprocess1 = FactorizedReduce(C_prev_prev_prev, C, act_fun=act_fun)
elif int(reduction_prev_prev)+int(reduction_prev)==2:
self.preprocess1 = F0(C_prev_prev_prev, C, act_fun=act_fun)
else:
self.preprocess1 = ReLUConvBN(C_prev_prev_prev, C, 1, 1, 0, act_fun=act_fun)
if int(reduction_prev_prev_prev)+int(reduction_prev_prev)+int(reduction_prev)==1:
self.preprocess0 = FactorizedReduce(C_prev_prev_prev_prev, C, act_fun=act_fun)
elif int(reduction_prev_prev_prev)+int(reduction_prev_prev)+int(reduction_prev)==2:
self.preprocess0 = F0(C_prev_prev_prev_prev, C, act_fun=act_fun)
elif int(reduction_prev_prev_prev)+int(reduction_prev_prev)+int(reduction_prev)==3:
self.preprocess0 = F1(C_prev_prev_prev_prev, C, act_fun=act_fun)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev_prev_prev, C, 1, 1, 0, act_fun=act_fun)
self.preprocess3 = ReLUConvBN(C_prev, C, 1, 1, 0, act_fun=act_fun)
op_names, indices = zip(*motif.normal)
# print(self.preprocess0)
# print(self.preprocess1)
# print(self.preprocess2)
# print(self.preprocess3)
concat = motif.normal_concat
self._compile(C, op_names, indices, concat, reduction)
def _compile(self, C, op_names, indices, concat, reduction):
assert len(op_names) == len(indices)
# self._steps = len(op_names) // 2
self._concat = concat
self.multiplier = len(concat)
self._ops = nn.ModuleList()
self._ops_back = nn.ModuleList()
back_begin_index = 0
for i, (name, index) in enumerate(zip(op_names, indices)):
# print(name, index)
if '_back' in name:
self.back_connection=True
back_begin_index = i
break
stride = 2 if reduction and index < 2 else 1
op = OPS[name](C, stride, True, act_fun=self.act_fun)
self._ops += [op]
if self.back_connection:
for name, index in zip(op_names[back_begin_index:], indices[back_begin_index:]):
op = OPS[name.replace('_back', '')](
C, 1, True, act_fun=self.act_fun)
self._ops_back += [op]
if self.back_connection:
self._indices_forward = indices[:back_begin_index]
self._indices_backward = indices[back_begin_index:]
else:
self._indices_backward = []
self._indices_forward = indices
self._steps = len(self._indices_forward) // 4
def forward(self, s0, s1, s2, s3, drop_prob):
if self.reduction:
return self.fun(s3)
s0 = self.preprocess0(s0)
s3 = self.preprocess3(s3)
s1 = self.preprocess1(s1)
s2 = self.preprocess2(s2)
# if s1.shape[1]!=s3.shape[1]:
# s1 = nn.Conv2d(s1.shape[1], s3.shape[1], 3, stride=2, padding=1, bias=False)
states = [s0, s1, s2,s3]
for i in range(self._steps):
i1=self._indices_forward[4 * i]
i2=self._indices_forward[4 * i + 1]
i3=self._indices_forward[4 * i + 2]
i4=self._indices_forward[4 * i + 3]
h1 = states[i1]
h2 = states[i2]
h3 = states[i3]
h4 = states[i4]
op1 = self._ops[4 * i]
op2 = self._ops[4 * i + 1]
op3 = self._ops[4 * i + 2]
op4 = self._ops[4 * i + 3]
h1 = op1(h1)
h2 = op2(h2)
h3 = op3(h3)
h4 = op4(h4)
if self.training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
if not isinstance(op3, Identity):
h3 = drop_path(h3, drop_prob)
if not isinstance(op4, Identity):
h4= drop_path(h4, drop_prob)
s = h1 + h2 + h3 + h4
if self.back_connection:
if i != 0:
s_back = self._ops_back[i - 1](s)
states[self._indices_backward[i - 1]] = states[self._indices_backward[i - 1]] + s_back
states += [s]
outputs = torch.cat([states[i] for i in self._concat], dim=1) # N,C,H, W
return outputs
# return self.node(outputs)
@register_model
class NetworkCIFAR(BaseModule):
def __init__(self,
C,
num_classes,
layers,
auxiliary,
motif,
cell_type,
parse_method='darts',
step=5,
node_type='ReLUNode',
**kwargs):
super(NetworkCIFAR, self).__init__(
step=step,
num_classes=num_classes,
**kwargs
)
self.node_type=node_type
if isinstance(node_type, str):
self.act_fun = eval(node_type)
else:
self.act_fun = node_type
self.act_fun = partial(self.act_fun, **kwargs)
self.spike_output = kwargs['spike_output'] if 'spike_output' in kwargs else True
self.dataset = kwargs['dataset']
if self.layer_by_layer:
self.flatten = nn.Flatten(start_dim=1)
else:
self.flatten = nn.Flatten()
self._layers = layers
self.cell_type = cell_type
self._auxiliary = auxiliary
self.drop_path_prob = 0
stem_multiplier = 3
C_curr = stem_multiplier * C
if self.dataset == 'dvsg' or self.dataset == 'dvsc10' or self.dataset == 'NCALTECH101':
self.stem = nn.Sequential(
nn.Conv2d(2 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr),
)
# self.reduce_idx = [
# layers // 4,
# layers // 2,
# 3 * layers // 4
# ]
self.reduce_idx = [1, 3, 5, 7]
else:
self.stem = nn.Sequential(
nn.Conv2d(1, 3 * self.init_channel_mul, 3, padding=1, bias=False),
nn.Conv2d(3 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr),
)
self.reduce_idx = [layers // 4,
layers // 2,
3 * layers // 4]
C_prev_prev_prev = C_curr
C_prev_prev_prev_prev = C_curr
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
reduction_prev_prev = False
reduction_prev_prev_prev = False
for i in range(layers):
if i in self.reduce_idx:
C_curr *= 2
reduction = True
else:
reduction = False
if cell_type==2:
# print(C_prev_prev, C_prev, C_curr)
cell = EvoCell2(motif[i], C_prev_prev, C_prev, C_curr,reduction, reduction_prev,act_fun=self.act_fun)
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
if cell_type==3:
cell = EvoCell3(motif[i], C_prev_prev_prev, C_prev_prev, C_prev, C_curr,reduction, reduction_prev,reduction_prev_prev,act_fun=self.act_fun)
self.cells += [cell]
C_prev_prev_prev = C_prev_prev
reduction_prev_prev = reduction_prev
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
if cell_type==4:
cell = EvoCell4(motif[i], C_prev_prev_prev_prev,C_prev_prev_prev, C_prev_prev, C_prev, C_curr,reduction, reduction_prev,reduction_prev_prev,reduction_prev_prev_prev,act_fun=self.act_fun)
self.cells += [cell]
C_prev_prev_prev_prev = C_prev_prev_prev
C_prev_prev_prev = C_prev_prev
reduction_prev_prev_prev = reduction_prev_prev
reduction_prev_prev = reduction_prev
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
reduction_prev = reduction
self.global_pooling = nn.Sequential(
self.act_fun(), nn.AdaptiveAvgPool2d(1))
if self.spike_output:
self.classifier = nn.Sequential(
nn.Linear(C_prev, 10 * num_classes),
self.act_fun())
self.vote = VotingLayer(10)
else:
self.classifier = nn.Linear(C_prev, num_classes)
self.vote = nn.Identity()
# self.classifier = nn.Linear(C_prev, num_classes)
# self.vote = nn.Identity()
def forward(self, inputs):
logits_aux = None
inputs = self.encoder(inputs)
if not self.layer_by_layer:
outputs = []
output_aux = []
self.reset()
if self.cell_type==2:
for t in range(self.step):
x = inputs[t]
s0 = s1 = self.stem(x)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
out = self.global_pooling(s1)
out = self.classifier(self.flatten(out))
logits = self.vote(out)
outputs.append(logits)
output_aux.append(logits_aux)
return sum(outputs) / len(outputs)
if self.cell_type==3:
for t in range(self.step):
x = inputs[t]
s0 = s1 = s2= self.stem(x)
for i, cell in enumerate(self.cells):
s0, s1, s2 = s1, s2, cell(s0, s1, s2, self.drop_path_prob)
out = self.global_pooling(s2)
out = self.classifier(self.flatten(out))
logits = self.vote(out)
outputs.append(logits)
output_aux.append(logits_aux)
return sum(outputs) / len(outputs)
if self.cell_type==4:
for t in range(self.step):
x = inputs[t]
s0 = s1 = s2= s3=self.stem(x)
for i, cell in enumerate(self.cells):
s0, s1, s2,s3= s1, s2, s3,cell(s0, s1, s2,s3 ,self.drop_path_prob)
out = self.global_pooling(s3)
out = self.classifier(self.flatten(out))
logits = self.vote(out)
outputs.append(logits)
output_aux.append(logits_aux)
return sum(outputs) / len(outputs)
# logits_aux if logits_aux is None else (sum(output_aux) / len(output_aux))
else:
self.reset()
if self.cell_type==2:
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
if i == 2 * self._layers // 3:
if self._auxiliary and self.training:
logits_aux = self.auxiliary_head(s1)
out = self.global_pooling(s1)
out = self.classifier(self.flatten(out))
out = rearrange(out, '(t b) c -> t b c', t=self.step).mean(0)
logits = self.vote(out)
return logits
if self.cell_type==3:
s0 = s1 = s2= self.stem(inputs)
for i, cell in enumerate(self.cells):
s0, s1, s2 = s1, s2, cell(s0, s1, s2, self.drop_path_prob)
if i == 2 * self._layers // 3:
if self._auxiliary and self.training:
logits_aux = self.auxiliary_head(s1)
out = self.global_pooling(s2)
out = self.classifier(self.flatten(out))
out = rearrange(out, '(t b) c -> t b c', t=self.step).mean(0)
logits = self.vote(out)
return logits
if self.cell_type==4:
s0 = s1 = s2=s3= self.stem(inputs)
for i, cell in enumerate(self.cells):
s0, s1, s2,s3= s1, s2, s3,cell(s0, s1, s2,s3 ,self.drop_path_prob)
if i == 2 * self._layers // 3:
if self._auxiliary and self.training:
logits_aux = self.auxiliary_head(s1)
out = self.global_pooling(s3)
out = self.classifier(self.flatten(out))
out = rearrange(out, '(t b) c -> t b c', t=self.step).mean(0)
logits = self.vote(out)
return logits
@register_model
class NetworkImageNet(BaseModule):
def __init__(self,
C,
num_classes,
layers,
auxiliary,
motif,
step=1,
node_type='ReLUNode',
**kwargs):
super(NetworkImageNet, self).__init__(
step=step,
num_classes=num_classes,
**kwargs)
if isinstance(node_type, str):
self.act_fun = eval(node_type)
else:
self.act_fun = node_type
self.act_fun = partial(self.act_fun, **kwargs)
if 'back_connection' in kwargs.keys():
self.back_connection = kwargs['back_connection']
else:
self.back_connection = False
self.spike_output = kwargs['spike_output'] if 'spike_output' in kwargs else True
if self.layer_by_layer:
self.flatten = nn.Flatten(start_dim=1)
else:
self.flatten = nn.Flatten()
self._layers = layers
self._auxiliary = auxiliary
self.drop_path_prob = 0
self.stem0 = nn.Sequential(
nn.Conv2d(3, C // 2, kernel_size=3,
stride=2, padding=1, bias=False),
nn.BatchNorm2d(C // 2),
# nn.ReLU(inplace=True),
self.act_fun(),
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
self.stem1 = nn.Sequential(
# nn.ReLU(inplace=True),
self.act_fun(),
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
C_prev_prev, C_prev, C_curr = C, C, C
self.cells = nn.ModuleList()
reduction_prev = True
for i in range(layers):
if i in [layers // 3, 2 * layers // 3]:
C_curr *= 2
reduction = True
else:
reduction = False
cell = EvoCell2(motif[i], C_prev_prev, C_prev,C_curr, reduction, reduction_prev,act_fun=self.act_fun)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
self.global_pooling = nn.AvgPool2d(7)
self.classifier = nn.Linear(C_prev, num_classes)
def forward(self, inputs):
outputs = []
self.reset()
for t in range(self.step):
s0 = self.stem0(inputs)
s1 = self.stem1(s0)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
out = self.global_pooling(s1)
logits = self.classifier(self.flatten(out))
outputs.append(logits)
return sum(outputs) / len(outputs)
if __name__ == '__main__':
x = torch.rand(128, 36, 32, 32)
extra_edge=np.array([[3,5],[4,1]])
# sort based on head
extra_edge = extra_edge[extra_edge[:,0].argsort()]
# motifs=[mm1,mm2,mm3,mm1,mm5,mm4]
# motifs=[m1,m2,m3,m1,m5,m4,m1,m2,m3,m1,m5,m4,m5,m4,m1]
motifs=[t2,t3,t4,t5,t5,t4,t3,t4]
net=NetworkCIFAR(C=12,num_classes=10,motif=motifs,layers=len(motifs),auxiliary=True,dataset='cifar10',cell_type=4)
out=net(torch.rand(128, 3, 32, 32))
print(out.shape)
================================================
FILE: examples/Structure_Evolution/MSE-NAS/evolution.py
================================================
import sys
import numpy as np
import argparse
import time
import obj
import timm.models
import yaml
import os
import logging
from random import choice
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from micro_encoding import ops
from braincog.base.node.node import *
from braincog.utils import *
from braincog.base.utils.criterions import *
from braincog.datasets.datasets import *
from braincog.model_zoo.resnet import *
from braincog.model_zoo.convnet import *
from braincog.utils import save_feature_map, setup_seed
from braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix
import micro_encoding
import nsganet as engine
from pymop.problem import Problem
import torch
import torch.nn as nn
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from pymoo.optimize import minimize
from utils import data_transforms
from tm import train_motifs
from timm.data import create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
# os.environ['CUDA_VISIBLE_DEVICES']='3'
bits=20
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('')
# The first arg parser parses out only thei --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
devices=[1]
max_gen = 100
parser = argparse.ArgumentParser(description='SNN Training and Evaluating')
# Model parameters
parser.add_argument('--seed', type=int, default=99, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--eval_epochs', type=int, default=10)
parser.add_argument('--bns', action='store_true', default=True)
parser.add_argument('--mid', type=int, default=3)
parser.add_argument('--trainning_epochs', type=int, default=600, metavar='N',help='number of epochs to train (default: 2)')
parser.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--init-channels', type=int, default=48)
parser.add_argument('--layers', type=int, default=6)
parser.add_argument('--pop_size', type=int, default=50, help='population size of networks')
parser.add_argument('--output', default='', type=str, metavar='PATH')
parser.add_argument('--spike-rate', action='store_true', default=False)
parser.add_argument('--n_gens', type=int, default=max_gen, help='population size')
parser.add_argument('--bs', type=int, default=100)
parser.add_argument('--n_offspring', type=int, default=50, help='number of offspring created per generation')
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser.add_argument('--dataset', default='dvsg', type=str)
parser.add_argument('--num-classes', type=int, default=11, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('--model', default='NetworkCIFAR', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',
help='path to eval checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
# Dataset parameters for static datasets
parser.add_argument('--img-size', type=int, default=224, metavar='N',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='inputs image center crop percent (for validation only)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--strgenome', default='4,0,0,1,1,1,0,1,0,1,0,0,0,1,0,1,0,1,0,0,4,1,1,1,1,1,0,1,1,0,1,1,0,1,1,1,0,1,0,0,4,1,1,0,1,0,0,1,0,1,1,0,1,1,0,0,1,0,1,2,4,1,0,0,1,1,1,0,0,1,0,0,1,0,0,1,0,0,1,1,4,0,1,1,0,1,0,0,1,0,1,0,1,1,1,0,1,0,1,3,4,1,0,1,0,0,1,1,1,0,0,1,0,1,0,0,0,1,0,0,3', type=str)
# Dataloader parameters
parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='inputs batch size for training (default: 128)')
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
help='ratio of validation batch size to training batch size (default: 1)')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.01,
help='weight decay (default: 0.01 for adamw)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=5e-3, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--power', type=int, default=1, help='power')
# Augmentation & regularization parameters ONLY FOR IMAGE NET
parser.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
parser.add_argument('--jsd', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "const")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.8,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix', type=float, default=1.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.0)')
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
parser.add_argument('--newton-maxiter', default=20, type=int,
help='max iterration in newton method')
parser.add_argument('--reset-drop', action='store_true', default=False,
help='whether to reset drop')
parser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],
help='The implementation way of gaussian kernel method, choose from "cuda" and "torch"')
# Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
parser.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
parser.add_argument('--dist-bn', type=str, default='',
help='Distribute BatchNorm stats between node after each epoch ("broadcast", "reduce", or "")')
parser.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.99996,
help='decay factor for model weights moving average (default: 0.9998)')
# Misc
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=8, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=len(devices),
help='Number of GPUS to use')
parser.add_argument('--save-images', action='store_true', default=False,
help='save images of inputs bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--device', type=int, default=devices[0])
# Spike parameters
parser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')
parser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')
parser.add_argument('--temporal-flatten', action='store_true',
help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')
parser.add_argument('--adaptive-node', action='store_true')
parser.add_argument('--critical-loss', action='store_true')
# neuron type
parser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')
parser.add_argument('--act-fun', type=str, default='QGateGrad',
help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')
parser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')
parser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')
parser.add_argument('--requires-thres-grad', action='store_true')
parser.add_argument('--sigmoid-thres', action='store_true')
parser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')
parser.add_argument('--noisy-grad', type=float, default=0.,
help='Add noise to backward, sometime will make higher accuracy (default: 0.)')
parser.add_argument('--spike-output', action='store_true', default=False,
help='Using mem output or spike output (default: False)')
parser.add_argument('--n_groups', type=int, default=1)
# EventData Augmentation
parser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')
parser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')
parser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')
parser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')
parser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')
parser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')
parser.add_argument('--cutmix_noise', type=float, default=0.,
help='Add Pepper noise after mix, sometimes work (default: 0.)')
parser.add_argument('--rand-aug', action='store_true',
help='Rand Augment for Event data (default: False)')
parser.add_argument('--randaug_n', type=int, default=3,
help='Rand Augment times n (default: 3)')
parser.add_argument('--randaug_m', type=int, default=15,
help='Rand Augment times n (default: 15) (0-30)')
parser.add_argument('--train-portion', type=float, default=0.9,
help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')
parser.add_argument('--event-size', default=48, type=int,
help='Event size. Resize event data before process (default: 48)')
parser.add_argument('--layer-by-layer', action='store_true',
help='forward step-by-step or layer-by-layer. '
'Larger Model with layer-by-layer will be faster (default: False)')
parser.add_argument('--node-resume', type=str, default='',
help='resume weights in node for adaptive node. (default: False)')
parser.add_argument('--node-trainable', action='store_true')
# visualize
parser.add_argument('--visualize', action='store_true',
help='Visualize spiking map for each layer, only for validate (default: False)')
parser.add_argument('--tsne', action='store_true')
parser.add_argument('--conf-mat', action='store_true')
# DARTS parameters
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
# parser.add_argument('--arch', default='dvsc10_new_skip19', type=str)
# parser.add_argument('--motif', default='m1', type=str)
parser.add_argument('--parse_method', default='darts', type=str)
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')
# parser.add_argument('--back-connection', action='store_true',default=True)
parser.add_argument('--suffix', type=str, default='',
help='Add an additional suffix to the save path (default: \'\')')
try:
from apex import amp
from apex.parallel import DistributedDataParallel as ApexDDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
def check_mem(cuda_device):
devices_info = os.popen('"/usr/bin/nvidia-smi" --query-gpu=memory.total,memory.used --format=csv,nounits,noheader').read().strip().split("\n")
total, used = devices_info[int(cuda_device)].split(',')
return total,used
def occumpy_mem(cuda_device):
total, used = check_mem(cuda_device)
total = int(total)
used = int(used)
max_mem = int(total * 1)
block_mem = int((max_mem - used)*0.3)
x = torch.cuda.FloatTensor(256,1024,block_mem)
del x
class NAS(Problem):
# first define the NAS problem (inherit from pymop)
def __init__(self, args,n_var=20, n_obj=1, n_constr=0, lb=None, ub=None,
init_channels=24, layers=8):
super().__init__(n_var=n_var, n_obj=n_obj, n_constr=n_constr, type_var=np.int64)
self.xl = lb
self.xu = ub
self._lr =args.lr
self._n_evaluated = 0 # keep track of how many architectures are sampled
self.args=args
def _evaluate(self, x, out, *args, **kwargs):
objs = np.full((x.shape[0], self.n_obj), np.nan)
train_data, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % self.args.dataset)(
batch_size=self.args.batch_size,
step=self.args.step,
args=self.args,
_logge=_logger,
size=self.args.event_size,
mix_up=self.args.mix_up,
cut_mix=self.args.cut_mix,
event_mix=self.args.event_mix,
beta=self.args.cutmix_beta,
prob=self.args.cutmix_prob,
num=self.args.cutmix_num,
noise=self.args.cutmix_noise,
num_classes=self.args.num_classes,
rand_aug=self.args.rand_aug,
randaug_n=self.args.randaug_n,
randaug_m=self.args.randaug_m,
temporal_flatten=self.args.temporal_flatten,
portion=self.args.train_portion,
_logger=_logger)
for i in range(x.shape[0]):
arch_id = self._n_evaluated + 1
print('\n')
_logger.info('Network= {}'.format(arch_id))
genome = x[i, :]
arch_dir=os.path.join(self.args.output_dir)
if os.path.exists(arch_dir) is False:
os.makedirs(arch_dir,exist_ok = True)
self.args.lr=self._lr
performance,acc=train_motifs(args=self.args,gen=0,arch_dir=arch_dir,genome=genome,_logger=_logger,args_text=args_text,devices=devices,bits=bits)
objs[i, 0] = 1000 - performance
# sJ,J,Cm,Ccosine,Cpe,K = obj.LSP(self.args,genome,train_data)
# objs[i, 0] = 1000 - Cm
_logger.info('performance= {}'.format(objs[i, 0]))
self._n_evaluated += 1
out["F"] = objs
# if your NAS problem has constraints, use the following line to set constraints
# out["G"] = np.column_stack([g1, g2, g3, g4, g5, g6]) in case 6 constraints
def do_every_generations(algorithm):
# this function will be call every generation
# it has access to the whole algorithm class
gen = algorithm.n_gen
pop_var = algorithm.pop.get("X")
pop_obj = algorithm.pop.get("F")
# report generation info to files
_logger.info("generation = {}".format(gen))
_logger.info("population error: best = {}, mean = {}, "
"median = {}, worst = {}".format(np.min(pop_obj[:, 0]), np.mean(pop_obj[:, 0]),
np.median(pop_obj[:, 0]), np.max(pop_obj[:, 0])))
_logger.info('Best Genome= {}'.format(pop_var[np.argmin(pop_obj[:, 0])]))
def _parse_args():
args_config, remaining = config_parser.parse_known_args()
args = parser.parse_args(remaining)
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
if __name__ == '__main__':
args, args_text = _parse_args()
args.no_spike_output = True
output_dir = ''
if args.bns:
from cellmodel import NetworkCIFAR
else:
from cell123model import NetworkCIFAR
if args.local_rank == 0:
output_base = args.output if args.output else './output'
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
# args.model,
# args.dataset,
str(args.layers)+'layers',
str(args.init_channels)+'channels',
'motif'+str(args.mid),
str(args.step)+'steps',
# args.suffix
# str(args.img_size)
])
output_dir = get_outdir(output_base,str(args.dataset),exp_name)
args.output_dir = output_dir
setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))
else:
setup_default_logging()
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1:
_logger.warning(
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
args.num_gpu = 1
# args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
if args.distributed:
args.num_gpu = 1
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
else:
torch.cuda.set_device('cuda:%d' % args.device)
assert args.rank >= 0
if args.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
_logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
# torch.manual_seed(args.seed + args.rank)
setup_seed(args.seed + args.rank)
defalut_lr = args.lr
occumpy_mem(str(args.device))
train_data, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.dataset)(
batch_size=args.batch_size,
step=args.step,
args=args,
_logge=_logger,
size=args.event_size,
mix_up=args.mix_up,
cut_mix=args.cut_mix,
event_mix=args.event_mix,
beta=args.cutmix_beta,
prob=args.cutmix_prob,
num=args.cutmix_num,
noise=args.cutmix_noise,
num_classes=args.num_classes,
rand_aug=args.rand_aug,
randaug_n=args.randaug_n,
randaug_m=args.randaug_m,
temporal_flatten=args.temporal_flatten,
portion=args.train_portion,
_logger=_logger,
)
len_motifs=args.layers*bits+1
low = np.zeros(len_motifs)
low[-1] = 2
up=[]
for i in range(0,args.layers*bits,bits):
t=[args.mid]
t=t+[(ops-1) for j in range(bits-1)]
t[-1]=2*(ops-1)
up.extend(t)
up.append(3)
up=np.array(up).reshape(-1,)
kkk = NAS(args,n_var=len_motifs,
n_obj=2, n_constr=0, lb=low, ub=up,
init_channels=args.init_channels, layers=args.layers)
method = engine.nsganet(pop_size=args.pop_size,
n_offsprings=args.n_offspring,
eliminate_duplicates=True)
kres=minimize(kkk,
method,
callback=do_every_generations,
termination=('n_gen', args.n_gens))
================================================
FILE: examples/Structure_Evolution/MSE-NAS/loss_f.py
================================================
import torch
import torch.nn.functional as f
def psp(inputs, n_steps,tau_s):
shape = inputs.shape
n_steps = n_steps
tau_s = tau_s
syn = torch.zeros(shape[0], shape[1], shape[2], shape[3]).cuda()
syns = torch.zeros(shape[0], shape[1], shape[2], shape[3], n_steps).cuda()
for t in range(n_steps):
syn = syn - syn / tau_s + inputs[..., t]
syns[..., t] = syn / tau_s
return syns
class SpikeLoss(torch.nn.Module):
"""
This class defines different spike based loss modules that can be used to optimize the SNN.
"""
def __init__(self, desired_count,undesired_count):
super(SpikeLoss, self).__init__()
self.desired_count = desired_count
self.desired_count = undesired_count
self.criterion = torch.nn.CrossEntropyLoss()
def spike_count(self, outputs, target, desired_count,undesired_count):
delta = loss_count.apply(outputs, target, desired_count,undesired_count)
return 1 / 2 * torch.sum(delta ** 2)
def spike_kernel(self, outputs, target, desired_count,undesired_count):
delta = loss_kernel.apply(outputs, target, desired_count,undesired_count)
return 1 / 2 * torch.sum(delta ** 2)
def spike_soft_max(self, outputs, target):
delta = f.log_softmax(outputs.sum(dim=4).squeeze(-1).squeeze(-1), dim = 1)
return self.criterion(delta, target)
class loss_count(torch.autograd.Function): # a and u is the incremnet of each time steps
@staticmethod
def forward(ctx, outputs, target, desired_count,undesired_count):
desired_count = desired_count
undesired_count = undesired_count
shape = outputs.shape
n_steps = shape[4]
out_count = torch.sum(outputs, dim=4)
delta = (out_count - target) / n_steps
mask = torch.ones_like(out_count)
mask[target == undesired_count] = 0
mask[delta < 0] = 0
delta[mask == 1] = 0
mask = torch.ones_like(out_count)
mask[target == desired_count] = 0
mask[delta > 0] = 0
delta[mask == 1] = 0
delta = delta.unsqueeze_(-1).repeat(1, 1, 1, 1, n_steps)
return delta
@staticmethod
def backward(ctx, grad):
return grad, None, None, None
class loss_kernel(torch.autograd.Function): # a and u is the incremnet of each time steps
@staticmethod
def forward(ctx, outputs, target, n_steps,tau_s):
# out_psp = psp(outputs, network_config)
target_psp = psp(target, n_steps,tau_s)
delta = outputs - target_psp
return delta
@staticmethod
def backward(ctx, grad):
return grad, None, None
================================================
FILE: examples/Structure_Evolution/MSE-NAS/micro_encoding.py
================================================
# NASNet Search Space https://arxiv.org/pdf/1707.07012.pdf
# code modified from DARTS https://github.com/quark0/darts
import numpy as np
from collections import namedtuple
import torch
# from models.micro_models import NetworkCIFAR as Network
import motifs
# Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
# Genotype_norm = namedtuple('Genotype', 'normal normal_concat')
# Genotype_redu = namedtuple('Genotype', 'reduce reduce_concat')
Genotype = namedtuple('Genotype', 'normal normal_concat')
# what you want to search should be defined here and in micro_operations
PRIMITIVES = [
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5',
'sep_conv_7x7',
'conv_7x1_1x7',
]
OPERATIONS_back = [
# 'max_pool_3x3_p_back',
# 'avg_pool_3x3_p_back',
'conv_3x3_p_back',
'conv_5x5_p_back',
# 'avg_pool_3x3_n_back',
'conv_3x3_n_back',
'conv_5x5_n_back',
# 'sep_conv_3x3_p_back',
# 'sep_conv_5x5_p_back',
# 'dil_conv_3x3_p_back',
# 'dil_conv_5x5_p_back',
# 'def_conv_3x3_p_back',
# 'def_conv_5x5_p_back',
]
OPERATIONS_p = [
# 'max_pool_3x3_p',
# 'avg_pool_3x3_p',
'conv_3x3_p',
'conv_5x5_p',
# 'sep_conv_3x3_p',
# 'sep_conv_5x5_p',
# 'dil_conv_3x3_p',
# 'dil_conv_5x5_p',
# 'def_conv_3x3_p',
# 'def_conv_5x5_p',
]
ops=len(OPERATIONS_p)
OPERATIONS_n = [
# 'max_pool_3x3_n',
# 'avg_pool_3x3_n',
'conv_3x3_n',
'conv_5x5_n',
# 'sep_conv_3x3_n',
# 'sep_conv_5x5_n',
# 'dil_conv_3x3_n',
# 'dil_conv_5x5_n',
# 'def_conv_3x3_n',
# 'def_conv_5x5_n',
# 'transformer',
]
def convert_cell(cell_bit_string):
# convert cell bit-string to genome
tmp = [cell_bit_string[i:i + 2] for i in range(0, len(cell_bit_string), 2)]
return [tmp[i:i + 2] for i in range(0, len(tmp), 2)]
def convert(bit_string):
# convert network bit-string (norm_cell + redu_cell) to genome
norm_gene = convert_cell(bit_string[:len(bit_string)//2])
redu_gene = convert_cell(bit_string[len(bit_string)//2:])
return [norm_gene, redu_gene]
# def decode_cell(genome, norm=True):
# cell, cell_concat = [], list(range(2, len(genome)+2))
# for block in genome:
# for unit in block:
# cell.append((PRIMITIVES[unit[0]], unit[1]))
# if unit[1] in cell_concat:
# cell_concat.remove(unit[1])
# if norm:
# return Genotype_norm(normal=cell, normal_concat=cell_concat)
# else:
# return Genotype_redu(reduce=cell, reduce_concat=cell_concat)
def decode(genome):
# decodes genome to architecture
normal_cell = genome[0]
reduce_cell = genome[1]
normal, normal_concat = [], list(range(2, len(normal_cell)+2))
reduce, reduce_concat = [], list(range(2, len(reduce_cell)+2))
for block in normal_cell:
for unit in block:
normal.append((PRIMITIVES[unit[0]], unit[1]))
if unit[1] in normal_concat:
normal_concat.remove(unit[1])
for block in reduce_cell:
for unit in block:
reduce.append((PRIMITIVES[unit[0]], unit[1]))
if unit[1] in reduce_concat:
reduce_concat.remove(unit[1])
return Genotype(
normal=normal, normal_concat=normal_concat,
reduce=reduce, reduce_concat=reduce_concat
)
def decode_motif(layers,bits,genome):
# decodes genome to architecture
motif_list=[]
motif_ids=[]
for b in range(0,layers*bits,bits):
if genome[-1]==2:
motif_id='mm'+str(genome[b])
elif genome[-1]==3:
motif_id='m'+str(genome[b])
else:
motif_id='t'+str(genome[b])
motif_ids.append(genome[b])
normalcell=eval('motifs.%s' % motif_id)
newnormal=[]
for i in range(0,len(normalcell.normal)):
op=normalcell.normal[i]
if 'skip' in op[0]:
newnormal.append(op)
continue
elif 'back' in op[0]:
newnormal.append((OPERATIONS_back[genome[b+1+len(normalcell.normal)-1]],op[1]))
continue
elif '_n' in op[0]:
newnormal.append((OPERATIONS_n[genome[b+1+i]],op[1]))
continue
elif '_p' in op[0]:
newnormal.append((OPERATIONS_p[genome[b+1+i]],op[1]))
continue
m=Genotype(normal=newnormal, normal_concat=normalcell.normal_concat,)
motif_list.append(m)
return motif_list,motif_ids
def compare_cell(cell_string1, cell_string2):
cell_genome1 = convert_cell(cell_string1)
cell_genome2 = convert_cell(cell_string2)
cell1, cell2 = cell_genome1[:], cell_genome2[:]
for block1 in cell1:
for block2 in cell2:
if block1 == block2 or block1 == block2[::-1]:
cell2.remove(block2)
break
if len(cell2) > 0:
return False
else:
return True
def compare(string1, string2):
if compare_cell(string1[:len(string1)//2],
string2[:len(string2)//2]):
if compare_cell(string1[len(string1)//2:],
string2[len(string2)//2:]):
return True
return False
# def debug():
# # design to debug the encoding scheme
# seed = 0
# np.random.seed(seed)
# budget = 2000
# B, n_ops, n_cell = 5, 7, 2
# networks = []
# design_id = 1
# while len(networks) < budget:
# bit_string = []
# for c in range(n_cell):
# for b in range(B):
# bit_string += [np.random.randint(n_ops),
# np.random.randint(b + 2),
# np.random.randint(n_ops),
# np.random.randint(b + 2)
# ]
# genome = convert(bit_string)
# # check against evaluated networks in case of duplicates
# doTrain = True
# for network in networks:
# if compare(genome, network):
# doTrain = False
# break
# if doTrain:
# genotype = decode(genome)
# model = Network(16, 10, 8, False, genotype)
# model.drop_path_prob = 0.0
# data = torch.randn(1, 3, 32, 32)
# output, output_aux = model(torch.autograd.Variable(data))
# networks.append(genome)
# design_id += 1
# print(design_id)
if __name__ == "__main__":
# debug()
# genome1 = [[[[3, 0], [3, 1]], [[3, 0], [3, 1]],
# [[3, 1], [2, 0]], [[2, 0], [5, 2]]],
# [[[0, 0], [0, 1]], [[2, 2], [0, 1]],
# [[0, 0], [2, 2]], [[2, 2], [0, 1]]]]
# genome2 = [[[[3, 1], [3, 0]], [[3, 1], [3, 0]],
# [[3, 1], [2, 0]], [[2, 0], [5, 2]]],
# [[[0, 1], [0, 0]], [[2, 2], [0, 1]],
# [[0, 0], [2, 2]], [[2, 2], [0, 0]]]]
#
# print(compare(genome1, genome2))
# print(genome1)
# print(genome2)
# bit_string1 = [3,1,3,0,3,1,3,0,3,1,2,0,2,0,5,2,0,0,0,1,2,2,0,1,0,0,2,2,2,2,0,1]
# bit_string2 = [3, 0, 3, 1, 3, 0, 3, 1, 3, 1, 2, 0, 2, 0, 5, 2,
# 0, 0, 0, 1, 2, 2, 0, 1, 0, 0, 2, 2, 2, 2, 0, 1]
# # print(convert(bit_string1))
# print(compare(bit_string1, bit_string2))
# print(decode(convert(bit_string)))
cell_bit_string = [3, 0, 3, 1, 3, 0, 3, 1, 3, 1, 2, 0, 2, 0, 5, 2]
# print(decode_cell(convert_cell(cell_bit_string), norm=False))
================================================
FILE: examples/Structure_Evolution/MSE-NAS/motifs.py
================================================
from collections import namedtuple
import torch
Genotype = namedtuple('Genotype', 'normal normal_concat')
m0=Genotype(
normal=[
('skip', 0), ('skip', 1),('skip', 2),
],
normal_concat=range(3, 4)
)
mm0=Genotype(
normal=[
('skip', 0), ('skip', 1),('skip', 2),
],
normal_concat=range(2, 3)
)
mm1=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1),
('skip_connect', 0), ('conv_5x5_p', 2),
],
normal_concat=range(2, 4)
)
mm2=Genotype(
normal=[
('conv_5x5_p', 0), ('conv_5x5_p', 1),
('skip_connect', 0), ('conv_5x5_n', 2),
('conv_5x5_p', 2), ('conv_3x3_n', 3),
],
normal_concat=range(2, 5)
)
mm4=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1),#2
('conv_3x3_p', 0), ('conv_3x3_p', 1),#3
('conv_5x5_p', 2), ('conv_5x5_p', 3),#4
('skip_connect', 0), ('conv_3x3_p', 4),#5
('skip_connect', 0), ('conv_3x3_p', 4),#6
],
normal_concat=range(2, 7)
)
mm3=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1),#2
('skip_connect', 0), ('conv_5x5_n', 2),#3
('skip_connect', 0), ('conv_5x5_p', 3),#4
('skip_connect_back', 2),#3
('conv_3x3_p_back', 3),#4
],
normal_concat=range(2, 5)
)
mm5=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1),#2
('skip_connect', 0), ('conv_5x5_p', 2),#3
('skip_connect_back', 2),#3
],
normal_concat=range(2, 4)
)
m1=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_5x5_p', 2), #B3
('skip', 0), ('conv_5x5_p', 3), ('skip', 1), #C4
],
normal_concat=range(3, 5)
)
m2=Genotype(
normal=[
('conv_5x5_p', 0), ('conv_5x5_p', 1),('conv_5x5_p', 2), #B3
('skip', 0), ('conv_5x5_n', 3), ('skip', 1),#C4
('conv_5x5_p', 3), ('conv_3x3_n', 4), ('skip', 1), #D5
],
normal_concat=range(3, 6)
)
m4=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1),('conv_5x5_p', 2), #3
('conv_3x3_p', 0), ('conv_3x3_p', 1),('conv_5x5_p', 2), #4
('skip', 0), ('conv_5x5_p', 3), ('conv_5x5_p', 4), #5
('skip', 0), ('conv_3x3_p', 3),('conv_3x3_n', 5),#6
('skip', 0), ('conv_3x3_p', 4),('conv_3x3_n', 5),#7
],
normal_concat=range(3, 8)
)
m3=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_3x3_p', 2), #3
('skip', 0), ('conv_5x5_p', 3),('skip', 1), #4
('skip', 0), ('conv_5x5_p', 3), ('skip', 1), #5
('conv_3x3_n_back', 3),#4
('skip_back', 2),#5
],
normal_concat=range(3, 6)
)
m5=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_5x5_p', 2),#3
('skip', 0),('skip', 1), ('conv_5x5_n', 3), #4
('skip_connect_back', 3),#4
],
normal_concat=range(3, 5)
)
t1=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_5x5_p', 2), ('conv_5x5_p', 3), #4
('skip', 0), ('conv_5x5_p', 4), ('skip', 1), ('skip', 2), #5
('skip', 0), ('conv_5x5_p', 5), ('skip', 1), ('skip', 2), #6
('skip', 0), ('conv_5x5_p', 5), ('skip', 1), ('skip', 2), #7
],
normal_concat=range(4, 8)
)
t2=Genotype(
normal=[
('conv_5x5_p', 0), ('conv_5x5_p', 1),('conv_5x5_p', 2), ('conv_5x5_p', 3), #4
('skip', 0), ('conv_5x5_n', 4), ('skip', 1),('skip', 2),#5
('conv_5x5_p', 4), ('conv_3x3_n', 5), ('skip', 1),('skip', 2), #6
],
normal_concat=range(4, 7)
)
t4=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1),('conv_5x5_p', 2), ('conv_5x5_p', 3), #4
('conv_5x5_p', 0), ('skip', 1),('conv_5x5_n', 4), ('skip', 3), #5
('skip', 0), ('conv_5x5_p', 3), ('conv_5x5_n', 4), ('skip', 2),#6
],
normal_concat=range(4, 7)
)
t3=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_3x3_p', 2), ('conv_3x3_p', 3),#4
('skip', 0), ('skip', 2),('skip', 1), ('skip', 3),('conv_3x3_p', 4),#5
('skip', 0), ('conv_5x5_p', 4), ('skip', 1),('skip', 2), #6
('conv_3x3_n_back', 4),#5
('skip_back', 4),#6
],
normal_concat=range(4, 7)
)
t5=Genotype(
normal=[
('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_5x5_p', 2),('conv_5x5_p', 3),#4
('skip', 0),('skip', 1), ('skip', 2), ('conv_5x5_n', 4), #5
('skip', 0),('skip', 1), ('conv_5x5_n', 4),('conv_5x5_n', 5), #6
('skip', 0),('skip', 1), ('conv_5x5_n', 4),('conv_5x5_n', 5), #7
('conv_3x3_n_back', 4),#5
('skip_back', 4),#6
('skip_back', 5),#7
],
normal_concat=range(4, 8)
)
================================================
FILE: examples/Structure_Evolution/MSE-NAS/nsganet.py
================================================
import numpy as np
from pymoo.algorithms.genetic_algorithm import GeneticAlgorithm
from pymoo.docs import parse_doc_string
from pymoo.model.individual import Individual
from pymoo.model.survival import Survival
from pymoo.operators.crossover.point_crossover import PointCrossover
from pymoo.operators.mutation.polynomial_mutation import PolynomialMutation
from pymoo.operators.sampling.random_sampling import RandomSampling
from pymoo.operators.selection.tournament_selection import compare, TournamentSelection
from pymoo.util.display import disp_multi_objective
from pymoo.util.dominator import Dominator
from pymoo.util.non_dominated_sorting import NonDominatedSorting
from pymoo.util.randomized_argsort import randomized_argsort
# =========================================================================================================
# Implementation
# based on nsga2 from https://github.com/msu-coinlab/pymoo
# =========================================================================================================
class NSGANet(GeneticAlgorithm):
def __init__(self, **kwargs):
kwargs['individual'] = Individual(rank=np.inf, crowding=-1)
super().__init__(**kwargs)
self.tournament_type = 'comp_by_dom_and_crowding'
self.func_display_attrs = disp_multi_objective
# ---------------------------------------------------------------------------------------------------------
# Binary Tournament Selection Function
# ---------------------------------------------------------------------------------------------------------
def binary_tournament(pop, P, algorithm, **kwargs):
if P.shape[1] != 2:
raise ValueError("Only implemented for binary tournament!")
tournament_type = algorithm.tournament_type
S = np.full(P.shape[0], np.nan)
for i in range(P.shape[0]):
a, b = P[i, 0], P[i, 1]
# if at least one solution is infeasible
if pop[a].CV > 0.0 or pop[b].CV > 0.0:
S[i] = compare(a, pop[a].CV, b, pop[b].CV, method='smaller_is_better', return_random_if_equal=True)
# both solutions are feasible
else:
if tournament_type == 'comp_by_dom_and_crowding':
rel = Dominator.get_relation(pop[a].F, pop[b].F)
if rel == 1:
S[i] = a
elif rel == -1:
S[i] = b
elif tournament_type == 'comp_by_rank_and_crowding':
S[i] = compare(a, pop[a].rank, b, pop[b].rank,
method='smaller_is_better')
else:
raise Exception("Unknown tournament type.")
# if rank or domination relation didn't make a decision compare by crowding
if np.isnan(S[i]):
S[i] = compare(a, pop[a].get("crowding"), b, pop[b].get("crowding"),
method='larger_is_better', return_random_if_equal=True)
return S[:, None].astype(np.int)
# ---------------------------------------------------------------------------------------------------------
# Survival Selection
# ---------------------------------------------------------------------------------------------------------
class RankAndCrowdingSurvival(Survival):
def __init__(self) -> None:
super().__init__(True)
def _do(self, pop, n_survive, D=None, **kwargs):
# get the objective space values and objects
F = pop.get("F")
# the final indices of surviving individuals
survivors = []
# do the non-dominated sorting until splitting front
fronts = NonDominatedSorting().do(F, n_stop_if_ranked=n_survive)
for k, front in enumerate(fronts):
# calculate the crowding distance of the front
crowding_of_front = calc_crowding_distance(F[front, :])
# save rank and crowding in the individual class
for j, i in enumerate(front):
pop[i].set("rank", k)
pop[i].set("crowding", crowding_of_front[j])
# current front sorted by crowding distance if splitting
if len(survivors) + len(front) > n_survive:
I = randomized_argsort(crowding_of_front, order='descending', method='numpy')
I = I[:(n_survive - len(survivors))]
# otherwise take the whole front unsorted
else:
I = np.arange(len(front))
# extend the survivors by all or selected individuals
survivors.extend(front[I])
return pop[survivors]
def calc_crowding_distance(F):
infinity = 1e+14
n_points = F.shape[0]
n_obj = F.shape[1]
if n_points <= 2:
return np.full(n_points, infinity)
else:
# sort each column and get index
I = np.argsort(F, axis=0, kind='mergesort')
# now really sort the whole array
F = F[I, np.arange(n_obj)]
# get the distance to the last element in sorted list and replace zeros with actual values
dist = np.concatenate([F, np.full((1, n_obj), np.inf)]) \
- np.concatenate([np.full((1, n_obj), -np.inf), F])
index_dist_is_zero = np.where(dist == 0)
dist_to_last = np.copy(dist)
for i, j in zip(*index_dist_is_zero):
dist_to_last[i, j] = dist_to_last[i - 1, j]
dist_to_next = np.copy(dist)
for i, j in reversed(list(zip(*index_dist_is_zero))):
dist_to_next[i, j] = dist_to_next[i + 1, j]
# normalize all the distances
norm = np.max(F, axis=0) - np.min(F, axis=0)
norm[norm == 0] = np.nan
dist_to_last, dist_to_next = dist_to_last[:-1] / norm, dist_to_next[1:] / norm
# if we divided by zero because all values in one columns are equal replace by none
dist_to_last[np.isnan(dist_to_last)] = 0.0
dist_to_next[np.isnan(dist_to_next)] = 0.0
# sum up the distance to next and last and norm by objectives - also reorder from sorted list
J = np.argsort(I, axis=0)
crowding = np.sum(dist_to_last[J, np.arange(n_obj)] + dist_to_next[J, np.arange(n_obj)], axis=1) / n_obj
# replace infinity with a large number
crowding[np.isinf(crowding)] = infinity
return crowding
# =========================================================================================================
# Interface
# =========================================================================================================
def nsganet(
pop_size=100,
sampling=RandomSampling(var_type=np.int),
selection=TournamentSelection(func_comp=binary_tournament),
crossover=PointCrossover(n_points=2),
mutation=PolynomialMutation(eta=3, var_type=np.int),
eliminate_duplicates=True,
n_offsprings=None,
**kwargs):
"""
Parameters
----------
pop_size : {pop_size}
sampling : {sampling}
selection : {selection}
crossover : {crossover}
mutation : {mutation}
eliminate_duplicates : {eliminate_duplicates}
n_offsprings : {n_offsprings}
Returns
-------
nsganet : :class:`~pymoo.model.algorithm.Algorithm`
Returns an NSGANet algorithm object.
"""
return NSGANet(pop_size=pop_size,
sampling=sampling,
selection=selection,
crossover=crossover,
mutation=mutation,
survival=RankAndCrowdingSurvival(),
eliminate_duplicates=eliminate_duplicates,
n_offsprings=n_offsprings,
**kwargs)
parse_doc_string(nsganet)
================================================
FILE: examples/Structure_Evolution/MSE-NAS/obj.py
================================================
import sys
import os
import numpy as np
import torch
import logging
import argparse
import torch.nn as nn
import torch.utils
# import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from timm.models import create_model
from cell123model import NetworkCIFAR
from braincog.base.node.node import *
from braincog.utils import *
from braincog.base.utils.criterions import *
from braincog.datasets.datasets import *
from braincog.model_zoo.resnet import *
from braincog.model_zoo.convnet import *
from braincog.model_zoo.reactnet import *
from braincog.model_zoo.convxnet import *
from scipy.stats import kendalltau
from misc import utils
import micro_encoding
from misc.flops_counter import add_flops_counting_methods
from utils import data_transforms
from datetime import datetime
bits=20
def logdet(K):
s, ld = torch.linalg.slogdet(K)
return ld
def LSP(args,genome,train_data):
with torch.no_grad():
test_motifs,ids = micro_encoding.decode_motif(layers=args.layers,bits=bits,genome=genome)
pmodel = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
dataset=args.dataset,
step=args.step,
encode_type=args.encode,
node_type=eval(args.node_type),
threshold=args.threshold,
tau=args.tau,
sigmoid_thres=args.sigmoid_thres,
requires_thres_grad=args.requires_thres_grad,
spike_output=not args.no_spike_output,
C=args.init_channels,
layers=args.layers,
auxiliary=args.auxiliary,
motif=test_motifs,
parse_method=args.parse_method,
act_fun=args.act_fun,
temporal_flatten=args.temporal_flatten,
layer_by_layer=args.layer_by_layer,
n_groups=args.n_groups,
cell_type=genome[-1]
)
pmodel.to(args.device)
pmodel.K = torch.zeros(args.batch_size, args.batch_size,device=args.device)
pmodel.J = torch.zeros(args.batch_size, args.batch_size,device=args.device)
# pmodel.Cou = torch.zeros(args.batch_size, args.batch_size,device=args.device)
pmodel.Ccosine = torch.zeros(args.batch_size, args.batch_size,device=args.device)
pmodel.Cm = torch.zeros(args.batch_size, args.batch_size,device=args.device)
pmodel.Cpe = torch.zeros(args.batch_size, args.batch_size,device=args.device)
# pmodel.Cou = torch.zeros(args.batch_size,device=args.device)
# pmodel.Ccosine = torch.zeros(args.batch_size, device=args.device)
# pmodel.Cm = torch.zeros(args.batch_size,device=args.device)
pmodel.num_actfun_C = 0
pmodel.num_actfun_K = 0
def computing_LSP(module, inp, out):
if isinstance(out, tuple):
out = out[0]
#
out = out.view(out.size(0), -1)
batch_num , neuron_num = out.size()
x = (out > 0).float()
full_matrix = torch.ones((args.batch_size, args.batch_size)).cuda() * neuron_num
sparsity = (x.sum(1)/neuron_num).unsqueeze(1)
norm_K = ((sparsity @ (1-sparsity.t())) + ((1-sparsity) @ sparsity.t())) * neuron_num
rescale_factor = torch.div(0.5* torch.ones((args.batch_size, args.batch_size)).cuda(), norm_K+1e-3)
K1_0 = (x @ (1 - x.t()))
K0_1 = ((1-x) @ x.t())
K0_0 = (1-x) @ (1-x).t()
K1_1 = (1-x) @ (1-x).t()
K_total = (full_matrix - rescale_factor * (K0_1 + K1_0))
J_total = (K1_1+K0_0)/(K0_1+K1_0+K1_1)
pmodel.K = pmodel.K + K_total
pmodel.J = pmodel.J + J_total
pmodel.num_actfun_K += 1
# x = x / torch.norm(x, dim=-1, keepdim=True)
# similarity = torch.mm(x, x.T)
# dis_ou=torch.zeros_like(pmodel.Cou)
dis_man=torch.zeros_like(pmodel.Cm)
dis_cosine=torch.zeros_like(pmodel.Ccosine)
ou_dist = nn.PairwiseDistance(p=2)
m_dist = nn.PairwiseDistance(p=1)
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
# cos = nn.CosineSimilarity(dim=0, eps=1e-6)
# for i in range(args.batch_size):
# for j in range(i,args.batch_size):
# input1 = x[i]
# input2 = x[j]
# dis_ou[i][j] = ou_dist(input1,input2)
# dis_man[i][j] = m_dist(input1,input2)
# dis_cosine[i][j] = cos(input1,input2)
# pmodel.Cou = pmodel.Cou + dis_ou
for i in range(args.batch_size):
temp = x[i].repeat(args.batch_size,1)
dis_cosine[i] = cos(x,temp)
dis_man[i] = m_dist(x,temp)
# pmodel.Cou = pmodel.Cou + ou_dist(x,x.flip(dims=[0]))
# pmodel.Cm = pmodel.Cou + m_dist(x,x.flip(dims=[0]))
pmodel.Ccosine = pmodel.Ccosine + dis_cosine
pmodel.Cm = pmodel.Cm + dis_man
pmodel.Cpe = pmodel.Cpe + torch.corrcoef(x / torch.norm(x, dim=-1, keepdim=True))
pmodel.num_actfun_C += 1
pmodel.num_actfun_K += 1
s_ou = []
s_m = []
s_pe = []
s_cos = []
s_k = []
s_jac=[]
s_sum_j=[]
repeat=2
for name,module in pmodel.named_modules():
if args.node_type in str(type(module)):
handle = module.register_forward_hook(computing_LSP)
for j in range(repeat):
pmodel.K = torch.zeros(args.batch_size, args.batch_size,device=args.device)
pmodel.J = torch.zeros(args.batch_size, args.batch_size,device=args.device)
pmodel.Ccosine = torch.zeros(args.batch_size, args.batch_size,device=args.device)
pmodel.Cm = torch.zeros(args.batch_size, args.batch_size,device=args.device)
pmodel.Cpe = torch.zeros(args.batch_size, args.batch_size,device=args.device)
pmodel.num_actfun_C = 0
pmodel.num_actfun_K = 0
data_iterator = iter(train_data)
inputs, targets = next(data_iterator)
inputs, targets = inputs.cuda(), targets.cuda()
outputs = pmodel(inputs)
tc=pmodel.Ccosine/pmodel.num_actfun_C
tp=pmodel.Cpe/pmodel.num_actfun_C
tm=pmodel.Cm/pmodel.num_actfun_C
tj=pmodel.J/ (pmodel.num_actfun_K)
Ccos = torch.where(torch.isnan(tc), torch.full_like(tc, 0), tc)
Cpe = torch.where(torch.isnan(tp), torch.full_like(tp, 0), tp)
Cm = torch.where(torch.isnan(tm), torch.full_like(tm, 0), tm)
s_k.append(float(logdet(pmodel.K/ (pmodel.num_actfun_K))))
s_jac.append(float(logdet(tj)))
s_sum_j.append(float(tj.sum()))
s_m.append(float(Cm.sum()))
s_cos.append(float(Ccos.sum()))
s_pe.append(float(Cpe.sum()))
return np.mean(np.array(s_sum_j)),np.mean(np.array(s_jac)), np.mean(np.array(s_m)),np.mean(np.array(s_cos)),np.mean(np.array(s_pe)),np.mean(np.array(s_k))
================================================
FILE: examples/Structure_Evolution/MSE-NAS/operations.py
================================================
import numpy as np
import torch
import torch.nn as nn
from torch.nn import *
import torch.nn.functional as F
from torch import einsum
from einops import rearrange
# from braincog.model_zoo.base_module import DeformConvPack
from braincog.model_zoo.base_module import BaseLinearModule
# from mmcv.ops import ModulatedDeformConv2dPack
def si_relu(x, positive):
if positive == 1:
return torch.where(x > 0., x, torch.zeros_like(x))
elif positive == 0:
return x
elif positive == -1:
return torch.where(x < 0., x, torch.zeros_like(x))
else:
raise ValueError
class SiReLU(nn.Module):
def __init__(self, positive=0):
super().__init__()
self.positive = positive
def forward(self, x):
return si_relu(x, self.positive)
def weight_init(m):
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_normal(m.weight.data, gain=0.1)
torch.nn.init.constant(m.bias.data, 0.)
OPS_Mlp = {
'mlp': lambda C, act_fun:
SiMLP(C, C, act_fun=act_fun, positive=0),
'mlp_p': lambda C, act_fun:
SiMLP(C, C, act_fun=act_fun, positive=1),
'mlp_n': lambda C, act_fun:
SiMLP(C, C, act_fun=act_fun, positive=-1),
'skip_connect': lambda C, act_fun:
Identity(positive=0),
'skip_connect_p': lambda C, act_fun:
Identity(positive=1),
'skip_connect_n': lambda C, act_fun:
Identity(positive=-1),
}
OPS = {
'avg_pool_3x3': lambda C, stride, affine, act_fun: nn.AvgPool2d(3, stride=stride, padding=1,
count_include_pad=False),
'conv_3x3': lambda C, stride, affine, act_fun:
ReLUConvBN(C_in=C, C_out=C, kernel_size=3, padding=1, stride=stride, affine=affine, act_fun=act_fun, positive=0),
'conv_5x5': lambda C, stride, affine, act_fun:
ReLUConvBN(C_in=C, C_out=C, kernel_size=5, padding=2, stride=stride, affine=affine, act_fun=act_fun, positive=0),
'max_pool_3x3': lambda C, stride, affine, act_fun: nn.MaxPool2d(3, stride=stride, padding=1),
'skip_connect': lambda C, stride, affine, act_fun:
Identity(positive=0) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun),
'sep_conv_3x3': lambda C, stride, affine, act_fun:
SepConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=0),
'sep_conv_5x5': lambda C, stride, affine, act_fun:
SepConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=0),
'sep_conv_7x7': lambda C, stride, affine, act_fun:
SepConv(C, C, 7, stride, 3, affine=affine, act_fun=act_fun, positive=0),
'dil_conv_3x3': lambda C, stride, affine, act_fun:
DilConv(C, C, 3, stride, 2, 2, affine=affine, act_fun=act_fun, positive=0),
'dil_conv_5x5': lambda C, stride, affine, act_fun:
DilConv(C, C, 5, stride, 4, 2, affine=affine, act_fun=act_fun, positive=0),
'def_conv_3x3': lambda C, stride, affine, act_fun:
DeformConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=0),
'def_conv_5x5': lambda C, stride, affine, act_fun:
DeformConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=0),
'avg_pool_3x3_p': lambda C, stride, affine, act_fun: nn.Sequential(
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
SiReLU(positive=1)
),
'max_pool_3x3_p': lambda C, stride, affine, act_fun: nn.Sequential(
nn.MaxPool2d(3, stride=stride, padding=1),
SiReLU(positive=1)
),
'conv_3x3_p': lambda C, stride, affine, act_fun:
ReLUConvBN(C_in=C, C_out=C, kernel_size=3, padding=1, stride=stride, affine=affine, act_fun=act_fun, positive=1),
'conv_5x5_p': lambda C, stride, affine, act_fun:
ReLUConvBN(C_in=C, C_out=C, kernel_size=5, padding=2, stride=stride, affine=affine, act_fun=act_fun, positive=1),
'skip_connect_p': lambda C, stride, affine, act_fun:
Identity(positive=1) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun, positive=1),
'sep_conv_3x3_p': lambda C, stride, affine, act_fun:
SepConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=1),
'sep_conv_5x5_p': lambda C, stride, affine, act_fun:
SepConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=1),
'sep_conv_7x7_p': lambda C, stride, affine, act_fun:
SepConv(C, C, 7, stride, 3, affine=affine, act_fun=act_fun, positive=1),
'dil_conv_3x3_p': lambda C, stride, affine, act_fun:
DilConv(C, C, 3, stride, 2, 2, affine=affine, act_fun=act_fun, positive=1),
'dil_conv_5x5_p': lambda C, stride, affine, act_fun:
DilConv(C, C, 5, stride, 4, 2, affine=affine, act_fun=act_fun, positive=1),
'def_conv_3x3_p': lambda C, stride, affine, act_fun:
DeformConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=1),
'def_conv_5x5_p': lambda C, stride, affine, act_fun:
DeformConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=1),
'avg_pool_3x3_n': lambda C, stride, affine, act_fun: nn.Sequential(
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
SiReLU(positive=-1)
),
'max_pool_3x3_n': lambda C, stride, affine, act_fun: nn.Sequential(
nn.MaxPool2d(3, stride=stride, padding=1),
SiReLU(positive=-1)
),
'conv_3x3_n': lambda C, stride, affine, act_fun:
ReLUConvBN(C_in=C, C_out=C, kernel_size=3, padding=1, stride=stride, affine=affine, act_fun=act_fun, positive=-1),
'conv_5x5_n': lambda C, stride, affine, act_fun:
ReLUConvBN(C_in=C, C_out=C, kernel_size=5, padding=2, stride=stride, affine=affine, act_fun=act_fun, positive=-1),
'skip_connect_n': lambda C, stride, affine, act_fun:
Identity(positive=-1) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun, positive=-1),
'sep_conv_3x3_n': lambda C, stride, affine, act_fun:
SepConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=-1),
'sep_conv_5x5_n': lambda C, stride, affine, act_fun:
SepConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=-1),
'sep_conv_7x7_n': lambda C, stride, affine, act_fun:
SepConv(C, C, 7, stride, 3, affine=affine, act_fun=act_fun, positive=-1),
'dil_conv_3x3_n': lambda C, stride, affine, act_fun:
DilConv(C, C, 3, stride, 2, 2, affine=affine, act_fun=act_fun, positive=-1),
'dil_conv_5x5_n': lambda C, stride, affine, act_fun:
DilConv(C, C, 5, stride, 4, 2, affine=affine, act_fun=act_fun, positive=-1),
'def_conv_3x3_n': lambda C, stride, affine, act_fun:
DeformConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=-1),
'def_conv_5x5_n': lambda C, stride, affine, act_fun:
DeformConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=-1),
'conv_7x1_1x7': lambda C, stride, affine, act_fun: nn.Sequential(
# nn.ReLU(inplace=False),
act_fun(),
nn.Conv2d(C, C, (1, 7), stride=(1, stride),
padding=(0, 3), bias=False),
nn.Conv2d(C, C, (7, 1), stride=(stride, 1),
padding=(3, 0), bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'skip': lambda C, stride, affine, act_fun:
Zero(stride) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun, positive=1),
'transformer': lambda C, stride, affine, act_fun:
FactorizedReduce(
C, C, affine=affine, act_fun=act_fun) if stride != 1 else TransformerEncoderLayer(C),
}
class SiMLP(nn.Module):
def __init__(self, c_in, c_out, act_fun=nn.ReLU, positive=0, *args, **kwargs):
super(SiMLP, self).__init__()
self.op = nn.Sequential(
nn.Linear(c_in, c_out, bias=True),
act_fun()
)
self.positive = positive
def forward(self, x):
out = self.op(si_relu(x, self.positive))
return out
class DilConv(nn.Module):
"""
Dilation Convolution : ReLU -> DilConv -> Conv2d -> BatchNorm2d
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True, act_fun=nn.ReLU, positive=0):
super(DilConv, self).__init__()
self.op = nn.Sequential(
# nn.ReLU(inplace=False),
act_fun(),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
self.positive = positive
# if positive == -1:
# weight_init(self.op)
def forward(self, x):
out = self.op(x)
return si_relu(out, self.positive)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, act_fun=nn.ReLU, positive=0):
super(SepConv, self).__init__()
self.op = nn.Sequential(
# nn.ReLU(inplace=False),
act_fun(),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size,
stride=stride, padding=padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size,
stride=1, padding=padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
self.positive = positive
# if positive == -1:
# weight_init(self.op)
def forward(self, x):
out = self.op(x)
return si_relu(out, self.positive)
class Identity(nn.Module):
def __init__(self, positive=0):
super(Identity, self).__init__()
self.positive = positive
def forward(self, x):
return si_relu(x, self.positive)
class Zero(nn.Module):
def __init__(self, stride):
super(Zero, self).__init__()
self.stride = stride
def forward(self, x):
if self.stride == 1:
return x.mul(0.)
return x[:, :, ::self.stride, ::self.stride].mul(0.) # N * C * W * H
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=True, act_fun=nn.ReLU, positive=0):
super(FactorizedReduce, self).__init__()
assert C_out % 2 == 0
# self.relu = nn.ReLU(inplace=False)
self.activation = act_fun()
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
self.positive = positive
# if positive == -1:
# weight_init(self.op)
def forward(self, x):
# x = self.relu(x)
x = self.activation(x)
out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
out = si_relu(out, self.positive)
return out
class F0(nn.Module):
def __init__(self, C_in, C_out, affine=True, act_fun=nn.ReLU, positive=0):
super(F0, self).__init__()
assert C_out % 2 == 0
# self.relu = nn.ReLU(inplace=False)
self.activation = act_fun()
self.op=nn.Conv2d(C_out, C_out, 3, stride=2, padding=1, bias=False)
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
self.positive = positive
# if positive == -1:
# weight_init(self.op)
def forward(self, x):
# x = self.relu(x)
x = self.activation(x)
out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
out = si_relu(out, self.positive)
out=self.op(out)
return out
class ReLUConvBN(nn.Module):
"""
ReLu -> Conv2d -> BatchNorm2d
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, act_fun=nn.ReLU, positive=0):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
# nn.ReLU(inplace=False),
act_fun(),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride,
padding=padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
self.positive = positive
# if positive == -1:
# weight_init(self.op)
def forward(self, x):
out = self.op(x)
return si_relu(out, self.positive)
# class DeformConv(nn.Module):
# def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, act_fun=nn.ReLU, positive=0):
# super(DeformConv, self).__init__()
# self.op = nn.Sequential(
# # nn.ReLU(inplace=False),
# act_fun(),
# DeformConvPack(C_in, C_out, kernel_size=kernel_size,
# stride=stride, padding=padding, bias=True),
# nn.BatchNorm2d(C_out, affine=affine)
# )
# self.positive = positive
# # if positive == -1:
# # weight_init(self.op)
# def forward(self, x):
# out = self.op(x)
# return si_relu(out, self.positive)
class Attention(Module):
"""
Obtained from: github.com:rwightman/pytorch-image-models
"""
def __init__(self, dim, num_heads=4, attention_dropout=0.1, projection_dropout=0.1):
super().__init__()
self.num_heads = num_heads
head_dim = dim // self.num_heads
self.scale = head_dim ** -0.5
self.qkv = Linear(dim, dim * 3, bias=False)
self.attn_drop = Dropout(attention_dropout)
self.proj = Linear(dim, dim)
self.proj_drop = Dropout(projection_dropout)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TransformerEncoderLayer(Module):
"""
Inspired by torch.nn.TransformerEncoderLayer and
rwightman's timm package.
"""
def __init__(self, d_model, nhead=4, dim_feedforward=256, dropout=0.1,
attention_dropout=0.1, drop_path_rate=0.1):
super(TransformerEncoderLayer, self).__init__()
self.pre_norm = LayerNorm(d_model)
self.self_attn = Attention(dim=d_model, num_heads=nhead,
attention_dropout=attention_dropout, projection_dropout=dropout)
dim_feedforward = d_model
self.linear1 = Linear(d_model, dim_feedforward)
self.dropout1 = Dropout(dropout)
self.norm1 = LayerNorm(d_model)
self.linear2 = Linear(dim_feedforward, d_model)
self.dropout2 = Dropout(dropout)
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0 else Identity()
self.activation = F.gelu
def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
# print(src.shape)
c = src.shape[-1]
src = rearrange(src, 'b d r c -> b (r c) d')
# print(src.shape)
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
src = self.norm1(src)
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
src = src + self.drop_path(self.dropout2(src2))
src = rearrange(src, 'b (r c) d -> b d r c', c=c)
return src
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""
Obtained from: github.com:rwightman/pytorch-image-models
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(Module):
"""
Obtained from: github.com:rwightman/pytorch-image-models
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
================================================
FILE: examples/Structure_Evolution/MSE-NAS/readme.md
================================================
# Brain-Inspired Multi-scale Evolutionary Architectures for Spiking Neural Networks —— Based on BrainCog #
## Requirments ##
* numpy
* pytorch >= 1.12.0
* pymoo = 0.4.0
* BrainCog
## Run ##
```python evolution.py```
## Citation ##
If you find the code and dataset useful in your research, please consider citing:
```
@article{pan2024brain,
title={Brain-Inspired Multi-Scale Evolutionary Neural Architecture Search for Deep Spiking Neural Networks},
author={Pan, Wenxuan and Zhao, Feifei and Shen, Guobin and Han, Bing and Zeng, Yi},
journal={IEEE Transactions on Evolutionary Computation},
year={2024},
publisher={IEEE}
}
@article{zeng2023braincog,
title={BrainCog: A spiking neural network based, brain-inspired cognitive intelligence engine for brain-inspired AI and brain simulation},
author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},
journal={Patterns},
volume={4},
number={8},
year={2023},
publisher={Elsevier}
}
```
================================================
FILE: examples/Structure_Evolution/MSE-NAS/tm.py
================================================
import sys
sys.path.insert(0, '/home/panwenxuan/back')
import numpy as np
import argparse
import time
import obj
import timm.models
import yaml
import os
import logging
from random import choice
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from micro_encoding import ops
from braincog.base.node.node import *
from braincog.utils import *
from braincog.base.utils.criterions import *
from braincog.datasets.datasets import *
from braincog.model_zoo.resnet import *
from braincog.model_zoo.convnet import *
from braincog.utils import save_feature_map, setup_seed
from braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix
import micro_encoding
from pymop.problem import Problem
import torch
import torch.nn as nn
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from pymoo.optimize import minimize
from utils import data_transforms
from timm.data import create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
# from tn import TinyImageNet
import copy
from sklearn.metrics import confusion_matrix,roc_auc_score
from sklearn.preprocessing import label_binarize
def train_motifs(args,gen,arch_dir,genome,_logger,args_text,devices,bits):
if args.bns:
from cellmodel import NetworkCIFAR
else:
from cell123model import NetworkCIFAR
test_motifs,ids = micro_encoding.decode_motif(args.layers,bits,genome)
if gen==-1:
args.epochs=args.trainning_epochs
else:
args.epochs=args.eval_epochs
all_best=[]
try:
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
dataset=args.dataset,
step=args.step,
encode_type=args.encode,
node_type=eval(args.node_type),
threshold=args.threshold,
tau=args.tau,
sigmoid_thres=args.sigmoid_thres,
requires_thres_grad=args.requires_thres_grad,
spike_output=not args.no_spike_output,
C=args.init_channels,
layers=args.layers,
auxiliary=args.auxiliary,
motif=test_motifs,
parse_method=args.parse_method,
act_fun=args.act_fun,
temporal_flatten=args.temporal_flatten,
layer_by_layer=args.layer_by_layer,
n_groups=args.n_groups,
cell_type=genome[-1],
)
if 'dvs' in args.dataset:
args.channels = 2
# elif 'mnist' in args.dataset:
# args.channels = 1
else:
args.channels = 3
# flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)
# _logger.info('flops = %fM', flops / 1e6)
# _logger.info('param size = %fM', params / 1e6)
# _logger.info(model)
linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0
args.lr = linear_scaled_lr
_logger.info("learning rate is %f" % linear_scaled_lr)
if args.local_rank == 0:
sumpram=sum([m.numel() for m in model.parameters()])
_logger.info('Model %s created, param count: %d' %
(args.model, sumpram))
# return
# if sumpram > 15000000:
# return 0,0
num_aug_splits = 0
if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
use_amp = None
if args.amp:
# for backwards compat, `--amp` arg tries apex before native amp
if has_apex:
args.apex_amp = True
elif has_native_amp:
args.native_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
optimizer = create_optimizer(args, model)
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
loss_scaler = ApexScaler()
if args.local_rank == 0:
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
amp_autocast = torch.cuda.amp.autocast
loss_scaler = NativeScaler()
if args.local_rank == 0:
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if args.local_rank == 0:
_logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint
resume_epoch = None
if args.resume and args.eval_checkpoint == '':
args.eval_checkpoint = args.resume
if args.resume:
args.eval = True
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'], False)
resume_epoch = resume_checkpoint(
model, args.resume,
optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0)
# print(model.get_attr('mu'))
# print(model.get_attr('sigma'))
if args.num_gpu > 1:
if use_amp == 'apex':
_logger.warning(
'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
use_amp = None
model = nn.DataParallel(model, device_ids=devices).cuda()
assert not args.channels_last, "Channels last not supported with DP, use DDP."
else:
model = model.cuda()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
optimizer = create_optimizer(args, model)
if args.critical_loss or args.spike_rate:
if args.num_gpu>1:
model.module.set_requires_fp(True)
else:
model.set_requires_fp(True)
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume=args.resume)
if args.node_resume:
ckpt = torch.load(args.node_resume, map_location='cpu')
model.load_node_weight(ckpt, args.node_trainable)
model_without_ddp = model
if args.distributed:
if args.sync_bn:
assert not args.split_bn
try:
if has_apex and use_amp != 'native':
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
except Exception as e:
_logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
if has_apex and use_amp != 'native':
# Apex DDP preferred unless native amp is activated
if args.local_rank == 0:
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[args.local_rank],
find_unused_parameters=True) # can use device str in Torch >= 1.1
model_without_ddp = model.module
# NOTE: EMA model does not need to be wrapped by DDP
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
_logger.info('Scheduled epochs: {}'.format(num_epochs))
# now config only for imnet
data_config = resolve_data_config(vars(args), model=model, verbose=False)
loader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.dataset)(
batch_size=args.batch_size,
step=args.step,
args=args,
_logge=_logger,
data_config=data_config,
num_aug_splits=num_aug_splits,
size=args.event_size,
mix_up=args.mix_up,
cut_mix=args.cut_mix,
event_mix=args.event_mix,
beta=args.cutmix_beta,
prob=args.cutmix_prob,
num=args.cutmix_num,
noise=args.cutmix_noise,
num_classes=args.num_classes,
rand_aug=args.rand_aug,
randaug_n=args.randaug_n,
randaug_m=args.randaug_m,
temporal_flatten=args.temporal_flatten,
portion=args.train_portion,
_logger=_logger,
)
if args.loss_fn == 'mse':
train_loss_fn = UnilateralMse(1.)
validate_loss_fn = UnilateralMse(1.)
else:
if args.jsd:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
elif mixup_active:
# smoothing is handled with mixup target transform
train_loss_fn = SoftTargetCrossEntropy().cuda()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
if args.loss_fn == 'mix':
train_loss_fn = MixLoss(train_loss_fn)
validate_loss_fn = MixLoss(validate_loss_fn)
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
if args.eval: # evaluate the model
if args.distributed:
state_dict = torch.load(args.eval_checkpoint)['state_dict_ema']
new_state_dict = OrderedDict()
# add module prefix for DDP
for k, v in state_dict.items():
k = 'module.' + k
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
# else:
# load_checkpoint(model, args.eval_checkpoint, args.model_ema)
for i in range(1):
val_metrics = validate(start_epoch, model, loader_eval, validate_loss_fn, args,_logger,arch_dir,
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat)
print(f"Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%")
# return
saver = None
if args.local_rank == 0:
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
checkpoint_dir=arch_dir, recovery_dir=arch_dir, decreasing=decreasing)
with open(os.path.join(arch_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
try: # train the model
if args.reset_drop:
model_without_ddp.reset_drop_path(0.0)
for epoch in range(start_epoch, args.epochs):
if epoch == 0 and args.reset_drop:
model_without_ddp.reset_drop_path(args.drop_path)
# if epoch == 3 and best_metric<5:
# return 0,0
if args.distributed:
loader_train.sampler.set_epoch(epoch)
train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,_logger=_logger,
lr_scheduler=lr_scheduler, saver=saver, output_dir=arch_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
_logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(epoch, model, loader_eval, validate_loss_fn, args,_logger, arch_dir,amp_autocast=amp_autocast,
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
epoch, model_ema.ema, loader_eval, validate_loss_fn, args, _logger,arch_dir,amp_autocast=amp_autocast, log_suffix=' (EMA)',
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat)
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
update_summary(
epoch, train_metrics, eval_metrics, os.path.join(arch_dir, 'summary.csv'),
write_header=best_metric is None)
best_metric, best_epoch = eval_metrics[eval_metric],epoch
_logger.info('Test: {0} '.format(best_metric))
all_best.append(best_metric)
f=open(os.path.join(arch_dir, 'direct.txt'), 'a')
f.write(str(best_metric))
f.write('\n')
f.close()
f=open(os.path.join(arch_dir, 'direct_genome.txt'), 'a')
f.write(",".join(str(k) for k in genome))
f.write('\n')
f.close()
except KeyboardInterrupt:
pass
except MemoryError:
return -10000, all_best
except RuntimeError:
return -10000, all_best
return best_metric,all_best
def train_epoch(
epoch, model, loader, optimizer, loss_fn, args,_logger,
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
closses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.train()
end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
for batch_idx, (inputs, target) in enumerate(loader):
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
if not args.prefetcher or args.dataset != 'imnet':
inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()
if mixup_fn is not None:
inputs, target = mixup_fn(inputs, target)
if args.channels_last:
inputs = inputs.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(inputs)
loss = loss_fn(output, target)
if not (args.cut_mix | args.mix_up | args.event_mix) and args.dataset != 'imnet':
# print(output.shape, target.shape)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
# acc1, = accuracy(output, target)
else:
acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])
closs = torch.tensor([0.], device=loss.device)
if args.critical_loss:
closs = calc_critical_loss(model)
loss = loss + .1 * closs
spike_rate_avg_layer_str = ''
threshold_str = ''
if not args.distributed:
losses_m.update(loss.item(), inputs.size(0))
top1_m.update(acc1.item(), inputs.size(0))
top5_m.update(acc5.item(), inputs.size(0))
closses_m.update(closs.item(), inputs.size(0))
if args.num_gpu>1:
spike_rate_avg_layer = model.module.get_fire_rate().tolist()
spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]
threshold = model.module.get_threshold()
else:
spike_rate_avg_layer = model.get_fire_rate().tolist()
spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]
threshold = model.get_threshold()
threshold_str = ['{:.3f}'.format(i) for i in threshold]
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)
else:
loss.backward(create_graph=second_order)
if args.noisy_grad != 0.:
random_gradient(model, args.noisy_grad)
if args.clip_grad is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
if args.opt == 'lamb':
optimizer.step(epoch=epoch)
else:
optimizer.step()
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
num_updates += 1
batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
mu_str = ''
sigma_str = ''
if not args.distributed:
if 'Noise' in args.node_type:
mu, sigma = model.get_noise_param()
mu_str = ['{:.3f}'.format(i.detach()) for i in mu]
sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item(), inputs.size(0))
closses_m.update(reduced_loss.item(), inputs.size(0))
if args.local_rank == 0:
if args.distributed:
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'cLoss: {closs.val:>9.6f} ({closs.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
closs=closses_m,
top1=top1_m,
top5=top5_m,
batch_time=batch_time_m,
rate=inputs.size(0) * args.world_size / batch_time_m.val,
rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m
))
else:
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'cLoss: {closs.val:>9.6f} ({closs.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})\n'
# 'Fire_rate: {spike_rate}\n'
# 'Thres: {threshold}\n'
# 'Mu: {mu_str}\n'
# 'Sigma: {sigma_str}\n'
.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
closs=closses_m,
top1=top1_m,
top5=top5_m,
batch_time=batch_time_m,
rate=inputs.size(0) * args.world_size / batch_time_m.val,
rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m,
# spike_rate=spike_rate_avg_layer_str,
# threshold=threshold_str,
# mu_str=mu_str,
# sigma_str=sigma_str
))
if args.save_images and output_dir:
torchvision.utils.save_image(
inputs,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True)
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
end = time.time()
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.avg)])
def validate(epoch, model, loader, loss_fn, args,_logger, arch_dir,amp_autocast=suppress,
log_suffix='', visualize=False, spike_rate=False, tsne=False, conf_mat=False):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
closses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.eval()
feature_vec = []
feature_cls = []
logits_vec = []
labels_vec = []
end = time.time()
last_idx = len(loader) - 1
all_probs = np.array([]).reshape(0, args.num_classes)
all_targets = np.array([])
attack=False
with torch.no_grad():
for batch_idx, (inputs, target) in enumerate(loader):
# inputs = inputs.type(torch.float64)
last_batch = batch_idx == last_idx
if not args.prefetcher or args.dataset != 'imnet':
inputs = inputs.type(torch.FloatTensor).cuda()
target = target.cuda()
if attack:
data2 = copy.deepcopy(inputs)
inputs = pgd_attack(model, data2, target, target.device, nn.CrossEntropyLoss())
if args.channels_last:
inputs = inputs.contiguous(memory_format=torch.channels_last)
if not args.distributed:
if (visualize or spike_rate or tsne or conf_mat) and not args.critical_loss:
if args.num_gpu>1:
model.module.set_requires_fp(True)
else:
model.set_requires_fp(True)
# if not args.critical_loss:
# model.set_requires_fp(False)
with amp_autocast():
output = model(inputs)
if isinstance(output, (tuple, list)):
output = output[0]
if not args.distributed:
if visualize:
x = model.get_fp()
feature_path = os.path.join(arch_dir, 'feature_map')
if os.path.exists(feature_path) is False:
os.mkdir(feature_path)
save_feature_map(x, feature_path)
# if not args.critical_loss:
# model_config.set_requires_fp(False)
if tsne:
x = model.get_fp(temporal_info=False)[-1]
x = torch.nn.AdaptiveAvgPool2d((1, 1))(x)
x = x.reshape(x.shape[0], -1)
feature_vec.append(x)
feature_cls.append(target)
if conf_mat:
logits_vec.append(output)
labels_vec.append(target)
if spike_rate:
if args.num_gpu>1:
avg, var, spike, avg_per_step = model.module.get_spike_info()
else:
avg, var, spike, avg_per_step = model.get_spike_info()
save_spike_info(
os.path.join(arch_dir, 'spike_info.csv'),
epoch, batch_idx,
args.step, avg, var,
spike, avg_per_step)
# augmentation reduction
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor]
loss = loss_fn(output, target)
probs = output.softmax(dim=1) #
acc1, acc5 = accuracy(output, target, topk=(1, 5))
# acc1, = accuracy(output, target)
all_probs = np.vstack([all_probs, probs.detach().cpu().numpy()])
all_targets = np.concatenate([all_targets, target.detach().cpu().numpy()])
closs = torch.tensor([0.], device=loss.device)
if not args.distributed:
if args.num_gpu>1:
spike_rate_avg_layer = model.module.get_fire_rate().tolist()
threshold = model.module.get_threshold()
threshold_str = ['{:.3f}'.format(i) for i in threshold]
spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]
tot_spike = model.module.get_tot_spike()
else:
spike_rate_avg_layer = model.get_fire_rate().tolist()
threshold = model.get_threshold()
threshold_str = ['{:.3f}'.format(i) for i in threshold]
spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]
tot_spike = model.get_tot_spike()
if args.critical_loss:
closs = calc_critical_loss(model)
loss = loss + .1 * closs
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
acc1 = reduce_tensor(acc1, args.world_size)
acc5 = reduce_tensor(acc5, args.world_size)
else:
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), inputs.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
closses_m.update(closs.item(), inputs.size(0))
batch_time_m.update(time.time() - end)
end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
mu_str = ''
sigma_str = ''
if not args.distributed:
if 'Noise' in args.node_type:
mu, sigma = model.get_noise_param()
mu_str = ['{:.3f}'.format(i.detach()) for i in mu]
sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]
if args.distributed:
_logger.info(
'{0}: [{1:>4d}/{2}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'cLoss: {closs.val:>7.4f} ({closs.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
log_name,
batch_idx,
last_idx,
batch_time=batch_time_m,
loss=losses_m,
closs=closses_m,
top1=top1_m,
top5=top5_m,
))
else:
_logger.info(
'{0}: [{1:>4d}/{2}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'cLoss: {closs.val:>7.4f} ({closs.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})\n'
'Fire_rate: {spike_rate}\n'
'Tot_spike: {tot_spike}\n'
'Thres: {threshold}\n'
'Mu: {mu_str}\n'
'Sigma: {sigma_str}\n'.format(
log_name,
batch_idx,
last_idx,
batch_time=batch_time_m,
loss=losses_m,
closs=closses_m,
top1=top1_m,
top5=top5_m,
spike_rate=spike_rate_avg_layer_str,
tot_spike=tot_spike,
threshold=threshold_str,
mu_str=mu_str,
sigma_str=sigma_str
))
# metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
if not args.distributed:
if tsne:
feature_vec = torch.cat(feature_vec)
feature_cls = torch.cat(feature_cls)
plot_tsne(feature_vec, feature_cls, os.path.join(arch_dir, 't-sne-2d.eps'))
plot_tsne_3d(feature_vec, feature_cls, os.path.join(arch_dir, 't-sne-3d.eps'))
if conf_mat:
logits_vec = torch.cat(logits_vec)
labels_vec = torch.cat(labels_vec)
plot_confusion_matrix(logits_vec, labels_vec, os.path.join(arch_dir, 'confusion_matrix.eps'))
# 将真实标签二值化,为每个类别创建一个二进制标签
all_targets_binarized = label_binarize(all_targets, classes=range(args.num_classes))
# 使用roc_auc_score的multi_class和average参数来计算平均AUC
auc = roc_auc_score(all_targets_binarized, all_probs, multi_class='ovr', average='macro')
# print("Mean AUC: {:.2f}".format(auc))
return OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('auc', auc)])
================================================
FILE: examples/Structure_Evolution/MSE-NAS/utils.py
================================================
import json
import matplotlib.pyplot as plt
import os
import numpy as np
import torch
import shutil
import torchvision.transforms as transforms
from torch.autograd import Variable
from auto_augment import CIFAR10Policy
from braincog.model_zoo.darts.genotypes import PRIMITIVES
forward_edge_num = sum(1 for i in range(3) for n in range(2 + i))
backward_edge_num = sum(1 for i in range(3) for n in range(i))
num_ops = len(PRIMITIVES)
type_num = len(PRIMITIVES) // 2
# edge_num = [2, 3, 4]
edge_num = [2, 3, 4, 1, 2]
def drop_path(x, drop_prob):
if drop_prob > 0.:
keep_prob = 1. - drop_prob
mask = Variable(torch.cuda.FloatTensor(
x.size(0), 1, 1, 1).bernoulli_(keep_prob))
x.div_(keep_prob)
x.mul_(mask)
return x
class AvgrageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.avg = 0
self.sum = 0
self.cnt = 0
def update(self, val, n=1):
self.sum += val * n
self.cnt += n
self.avg = self.sum / self.cnt
def accuracy(output, target, topk=(1,)):
"""Compute the top1 and top5 accuracy
"""
maxk = max(topk)
batch_size = target.size(0)
# Return the k largest elements of the given input tensor
# along a given dimension -> N * k
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class Cutout(object):
def __init__(self, length):
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
def _data_transforms_cifar(args):
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] if args.dataset == 'cifar10' else [0.50707519, 0.48654887,
0.44091785]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] if args.dataset == 'cifar10' else [0.26733428, 0.25643846,
0.27615049]
normalize_transform = [
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD)]
random_transform = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()]
if args.auto_aug:
random_transform += [CIFAR10Policy()]
if args.cutout:
cutout_transform = [Cutout(args.cutout_length)]
else:
cutout_transform = []
train_transform = transforms.Compose(
random_transform + normalize_transform + cutout_transform
)
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
return train_transform, valid_transform
def count_parameters_in_MB(model):
return np.sum(np.prod(v.size()) for v in model.parameters()) / 1e6
def save_checkpoint(state, is_best, save):
filename = os.path.join(save, 'checkpoint.pth.tar')
torch.save(state, filename)
if is_best:
best_filename = os.path.join(save, 'model_best.pth.tar')
shutil.copyfile(filename, best_filename)
def save(model, model_path):
torch.save(model.state_dict(), model_path)
def load(model, model_path):
model.load_state_dict(torch.load(model_path))
def drop_path(x, drop_prob):
if drop_prob > 0.:
keep_prob = 1. - drop_prob
mask = Variable(torch.cuda.FloatTensor(
x.size(0), 1, 1, 1).bernoulli_(keep_prob))
x.div_(keep_prob)
x.mul_(mask)
return x
def create_exp_dir(path, scripts_to_save=None):
if not os.path.exists(path):
os.makedirs(path)
print('Experiment dir : {}'.format(path))
if scripts_to_save is not None:
os.makedirs(os.path.join(path, 'scripts'))
for script in scripts_to_save:
dst_file = os.path.join(path, 'scripts', os.path.basename(script))
shutil.copyfile(script, dst_file)
def calc_time(seconds):
m, s = divmod(seconds, 60)
h, m = divmod(m, 60)
t, h = divmod(h, 24)
return {'day': t, 'hour': h, 'minute': m, 'second': int(s)}
def save_file(recoder, path='./', back_connection=False):
size = (forward_edge_num +
backward_edge_num if back_connection else forward_edge_num, num_ops)
fig, axs = plt.subplots(*size, figsize=(36, 98))
row = 0
col = 0
for (k, v) in recoder.items():
axs[row, col].set_title(k)
axs[row, col].plot(v, 'r+')
if col == num_ops - 1:
col = 0
row += 1
else:
col += 1
if not os.path.exists(path):
os.makedirs(path)
fig.savefig(os.path.join(path, 'output.png'), bbox_inches='tight')
plt.tight_layout()
print('save history weight in {}'.format(os.path.join(path, 'output.png')))
with open(os.path.join(path, 'history_weight.json'), 'w') as outf:
json.dump(recoder, outf)
print('save history weight in {}'.format(
os.path.join(path, 'history_weight.json')))
def data_transforms(args):
if args.dataset == 'cifar10':
MEAN = [0.4913, 0.4821, 0.4465]
STD = [0.2470, 0.2434, 0.2615]
elif args.dataset == 'cifar100':
MEAN = [0.5071, 0.4867, 0.4408]
STD = [0.2673, 0.2564, 0.2762]
elif args.dataset == 'tinyimagenet':
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
if (args.dataset== 'tinyimagenet'):
train_transform = transforms.Compose([
transforms.RandomCrop(64, padding=8),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
])
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
])
else: # cifar10 or cifar100
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
])
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
])
return train_transform, valid_transform
================================================
FILE: examples/TIM/README.md
================================================
# TIM: An Efficient Temporal Interaction Module for Spiking Transformer, [IJCAI2024](https://arxiv.org/abs/2401.11687)

## Reference
```
@misc{shen2024tim,
title={TIM: An Efficient Temporal Interaction Module for Spiking Transformer},
author={Sicheng Shen and Dongcheng Zhao and Guobin Shen and Yi Zeng},
year={2024},
eprint={2401.11687},
archivePrefix={arXiv},
primaryClass={cs.NE}
}
```
Here is the official implemented code of TIM. The code is based on Pytorch and [Braincog](https://github.com/BrainCog-X/Brain-Cog)
## Requirements
### Create Braincog Virtual Environment
```
conda create -n braincog python=3.8
conda activate braincog
pip install braincog
```
### Dataset Preparation
**Datasets Needed**: CIFAR10-DVS, N-CALTECH101, UCF101DVS, NCARS, HMDB51DVS, SHD
Please unzip data to ```/data/datasets``` so that dataset.py may directly load corresponding dataset for training
## Model Training
For most DVS data, we prefer using the event-frame size of 64 but not 128 here.
Please adjust your hyper parameters here. 10 is set as the default value of time step numbers.
```
@register_model
def spikformer_dvs(pretrained=False, **kwargs):
model = Spikformer(TIM_alpha=0.5,step=10,if_UCF=False,num_classes=10,
# img_size_h=64, img_size_w=64,
# patch_size=16, embed_dims=256, num_heads=16, mlp_ratios=4,
# in_channels=2, qkv_bias=False,
# depths=2, sr_ratios=1,
**kwargs
)
model.default_cfg = _cfg()
return model
```
### Training on CIFAR10-DVS
```
python main.py --model spikformer_dvs --dataset dvsc10 --epoch 500 --batch-size 16 --event-size 64
```
### Training on N-CALTECH101
```num_classes``` should be set to 101
```
python main.py --model spikformer_dvs --dataset NCALTECH101 --epoch 500 --batch-size 16 --event-size 64 --num_classes 101
```
### Training on NCARS
```num_classes``` should be set to 2
```
python main.py --model spikformer_dvs --dataset NCARS --epoch 500 --batch-size 16 --event-size 64 --num_classes 2
```
### Training on UCF101DVS
```num_classes``` should be set to 101,```if_UCF``` should be set to ```True```
```
python main.py --model spikformer_dvs --dataset UCF101DVS --epoch 500 --batch-size 16 --event-size 64 --num_classes 101
```
### Training on HMDB51DVS
```
python main.py --model spikformer_dvs --dataset HMDBDVS --epoch 500 --batch-size 16 --event-size 64 --num_classes 51
```
### Training on SHD
```num_classes``` should be set to 20
```
python main.py --model spikformer_shd --dataset SHD --epoch 500 --batch-size 16 --num_classes 20
```
================================================
FILE: examples/TIM/main.py
================================================
import argparse
import time
import timm.models
import yaml
import os
import random as buildin_random
import logging
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from braincog.base.node.node import *
from braincog.utils import *
from braincog.base.utils.criterions import *
# from braincog.datasets.datasets import *
from utils.datasets import *
from braincog.model_zoo.resnet import *
from braincog.model_zoo.convnet import *
from braincog.model_zoo.vgg_snn import VGG_SNN, SNN5
from braincog.model_zoo.resnet19_snn import resnet19
from braincog.utils import save_feature_map, setup_seed
from braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix, plot_mem_distribution
import torch
import torch.nn as nn
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model, register_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
from torch.utils.tensorboard import SummaryWriter
from models.spikformer_braincog_DVS import spikformer_dvs
from models.spikformer_braincog_DVS import spikformer_shd
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser = argparse.ArgumentParser(description='SNN Training and Evaluating')
# Model parameters
parser.add_argument('--dataset', default='dvsc10', type=str)
parser.add_argument('--model', default='spikformer', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',
help='path to eval checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
parser.add_argument('--num-classes', type=int, default=10, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
# Dataset parameters for static datasets
parser.add_argument('--img-size', type=int, default=224, metavar='N',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='inputs image center crop percent (for validation only)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
# Dataloader parameters
parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='inputs batch size for training (default: 128)')
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
help='ratio of validation batch size to training batch size (default: 1)')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=1e-4,
help='weight decay (default: 0.01 for adamw)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=5e-3, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--epochs', type=int, default=400, metavar='N',
help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochss to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--power', type=int, default=1, help='power')
# Augmentation & regularization parameters ONLY FOR IMAGE NET
parser.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
parser.add_argument('--jsd', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "const")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix', type=float, default=0.,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=0.,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.0)')
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
parser.add_argument('--newton-maxiter', default=20, type=int,
help='max iterration in newton method')
parser.add_argument('--reset-drop', action='store_true', default=False,
help='whether to reset drop')
parser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],
help='The implementation way of gaussian kernel method, choose from "cuda" and "torch"')
# Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
parser.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
parser.add_argument('--dist-bn', type=str, default='',
help='Distribute BatchNorm stats between node after each epoch ("broadcast", "reduce", or "")')
parser.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.99996,
help='decay factor for model weights moving average (default: 0.9998)')
# Misc
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=8, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--save-images', action='store_true', default=False,
help='save images of inputs bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--output', default='/home/shensicheng/code/TIM/logs', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
parser.add_argument('--tensorboard-dir', default='./runs', type=str)
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--device', type=int, default=0)
# Spike parameters
parser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')
parser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')
parser.add_argument('--temporal-flatten', action='store_true',
help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')
parser.add_argument('--adaptive-node', action='store_true')
parser.add_argument('--critical-loss', action='store_true')
parser.add_argument('--conv-type', type=str, default='normal')
parser.add_argument('--sew-cnf', type=str, default='ADD')
parser.add_argument('--rand-step', action='store_true')
# neuron type
parser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')
parser.add_argument('--act-fun', type=str, default='QGateGrad',
help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')
parser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')
parser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')
parser.add_argument('--requires-thres-grad', action='store_true')
parser.add_argument('--sigmoid-thres', action='store_true')
parser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')
parser.add_argument('--noisy-grad', type=float, default=0.,
help='Add noise to backward, sometime will make higher accuracy (default: 0.)')
parser.add_argument('--spike-output', action='store_true', default=False,
help='Using mem output or spike output (default: False)')
parser.add_argument('--n_groups', type=int, default=1)
parser.add_argument('--n-encode-type', type=str, default='linear')
parser.add_argument('--n-preact', action='store_true')
parser.add_argument('--layer-by-layer', action='store_true',
help='forward step-by-step or layer-by-layer. '
'Larger Model with layer-by-layer will be faster (default: False)')
parser.add_argument('--tet-loss', action='store_true')
# EventData Augmentation
parser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')
parser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')
parser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')
parser.add_argument('--cutmix_beta', type=float, default=2.0, help='cutmix_beta (default: 1.)')
parser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')
parser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')
parser.add_argument('--cutmix_noise', type=float, default=0.,
help='Add Pepper noise after mix, sometimes work (default: 0.)')
parser.add_argument('--gaussian-n', type=int, default=3)
parser.add_argument('--rand-aug', action='store_true',
help='Rand Augment for Event data (default: False)')
parser.add_argument('--randaug_n', type=int, default=3,
help='Rand Augment times n (default: 3)')
parser.add_argument('--randaug_m', type=int, default=15,
help='Rand Augment times n (default: 15) (0-30)')
parser.add_argument('--train-portion', type=float, default=0.9,
help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')
parser.add_argument('--event-size', default=48, type=int,
help='Event size. Resize event data before process (default: 48)')
parser.add_argument('--node-resume', type=str, default='',
help='resume weights in node for adaptive node. (default: False)')
# visualize
parser.add_argument('--visualize', action='store_true',
help='Visualize spiking map for each layer, only for validate (default: False)')
parser.add_argument('--spike-rate', action='store_true',
help='Print spiking rate for each layer, only for validate(default: False)')
parser.add_argument('--tsne', action='store_true')
parser.add_argument('--conf-mat', action='store_true')
parser.add_argument('--mem-dist', action='store_true')
parser.add_argument('--adaptation-info', action='store_true')
parser.add_argument('--suffix', type=str, default='',
help='Add an additional suffix to the save path (default: \'\')')
try:
from apex import amp
from apex.parallel import DistributedDataParallel as ApexDDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
if args_config.config:
with open(args_config.config, 'r') as f:
cfg = yaml.safe_load(f)
parser.set_defaults(**cfg)
# The main arg parser parses the rest of the args, the usual
# defaults will have been overridden if config file specified.
args = parser.parse_args(remaining)
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
def main():
args, args_text = _parse_args()
# args.no_spike_output = args.no_spike_output | args.cut_mix
args.no_spike_output = True
output_dir = ''
if args.local_rank == 0:
output_base = args.output if args.output else './output'
exp_name = '-'.join([
args.model,
args.dataset,
args.node_type,
str(args.step),
args.suffix,
datetime.now().strftime("%Y%m%d-%H%M%S"),
# str(args.img_size)
])
output_dir = get_outdir(output_base, 'train', exp_name)
args.output_dir = output_dir
setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))
summary_writer = SummaryWriter(log_dir=os.path.join(args.tensorboard_dir, exp_name))
args.tensorboard_prefix = os.path.join(args.dataset, args.model)
else:
summary_writer = None
setup_default_logging()
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1:
_logger.warning(
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
args.num_gpu = 1
# args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
if args.distributed:
args.num_gpu = 1
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
else:
torch.cuda.set_device('cuda:%d' % args.device)
assert args.rank >= 0
if args.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
_logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
# torch.manual_seed(args.seed + args.rank)
setup_seed(args.seed + args.rank)
model = create_model(
args.model,
# pretrained=args.pretrained,
# num_classes=args.num_classes,
# dataset=args.dataset,
# step=args.step,
# encode_type=args.encode,
# node_type=eval(args.node_type),
# threshold=args.threshold,
# tau=args.tau,
# sigmoid_thres=args.sigmoid_thres,
# requires_thres_grad=args.requires_thres_grad,
# spike_output=not args.no_spike_output,
# act_fun=args.act_fun,
# temporal_flatten=args.temporal_flatten,
# layer_by_layer=args.layer_by_layer,
# n_groups=args.n_groups,
# n_encode_type=args.n_encode_type,
# n_preact=args.n_preact,
# tet_loss=args.tet_loss,
# sew_cnf=args.sew_cnf,
# conv_type=args.conv_type,
)
_logger.info('[MODEL ARCH]\n{}'.format(model))
if 'dvs' in args.dataset:
args.channels = 2
elif 'mnist' in args.dataset:
args.channels = 1
else:
args.channels = 3
# flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)
# _logger.info('flops = %fM', flops / 1e6)
# _logger.info('param size = %fM', params / 1e6)
linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0
args.lr = linear_scaled_lr
_logger.info("learning rate is %f" % linear_scaled_lr)
if args.local_rank == 0:
_logger.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
num_aug_splits = 0
if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
use_amp = None
if args.amp:
# for backwards compat, `--amp` arg tries apex before native amp
if has_apex:
args.apex_amp = True
elif has_native_amp:
args.native_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
if args.num_gpu > 1:
if use_amp == 'apex':
_logger.warning(
'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
use_amp = None
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
assert not args.channels_last, "Channels last not supported with DP, use DDP."
else:
model = model.cuda()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
optimizer = create_optimizer(args, model)
_logger.info('[OPTIMIZER]\n{}'.format(optimizer))
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
loss_scaler = ApexScaler()
if args.local_rank == 0:
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
amp_autocast = torch.cuda.amp.autocast
loss_scaler = NativeScaler()
if args.local_rank == 0:
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if args.local_rank == 0:
_logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint
resume_epoch = None
if args.resume and args.eval_checkpoint == '':
args.eval_checkpoint = args.resume
if args.resume:
args.eval = True
# checkpoint = torch.load(args.resume, map_location='cpu')
# model.load_state_dict(checkpoint['state_dict'], False)
resume_epoch = resume_checkpoint(
model, args.resume,
optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0)
# print(model.get_attr('mu'))
# print(model.get_attr('sigma'))
if hasattr(model, 'set_threshold'):
model.set_threshold(args.threshold)
if args.critical_loss or args.spike_rate:
model.set_requires_fp(True)
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume=args.resume)
if args.node_resume:
ckpt = torch.load(args.node_resume, map_location='cpu')
model.load_node_weight(ckpt, args.node_trainable)
model_without_ddp = model
if args.distributed:
if args.sync_bn:
assert not args.split_bn
try:
if has_apex and use_amp != 'native':
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
except Exception as e:
_logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
if has_apex and use_amp != 'native':
# Apex DDP preferred unless native amp is activated
if args.local_rank == 0:
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model.cuda(), device_ids=[args.local_rank],
find_unused_parameters=True) # can use device str in Torch >= 1.1
model_without_ddp = model.module
# NOTE: EMA model does not need to be wrapped by DDP
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
_logger.info('Scheduled epochs: {}'.format(num_epochs))
# now config only for imnet
data_config = resolve_data_config(vars(args), model=model, verbose=False)
loader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.dataset)(
batch_size=args.batch_size,
step=args.step,
args=args,
_logge=_logger,
data_config=data_config,
num_aug_splits=num_aug_splits,
size=args.event_size,
mix_up=args.mix_up,
cut_mix=args.cut_mix,
event_mix=args.event_mix,
beta=args.cutmix_beta,
prob=args.cutmix_prob,
gaussian_n=args.gaussian_n,
num=args.cutmix_num,
noise=args.cutmix_noise,
num_classes=args.num_classes,
rand_aug=args.rand_aug,
randaug_n=args.randaug_n,
randaug_m=args.randaug_m,
portion=args.train_portion,
_logger=_logger,
)
# _logger.info('train_loader:\n{}\nval_loader:\n{}'.format(loader_train, loader_eval))
if args.loss_fn == 'mse':
train_loss_fn = UnilateralMse(1.)
validate_loss_fn = UnilateralMse(1.)
elif args.loss_fn == 'onehot-mse':
train_loss_fn = OnehotMse(args.num_classes)
validate_loss_fn = OnehotMse(args.num_classes)
else:
if args.jsd:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
elif mixup_active:
# smoothing is handled with mixup target transform
train_loss_fn = SoftTargetCrossEntropy().cuda()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
if args.loss_fn == 'mix':
train_loss_fn = MixLoss(train_loss_fn)
validate_loss_fn = MixLoss(validate_loss_fn)
if args.tet_loss:
train_loss_fn = TetLoss(train_loss_fn)
validate_loss_fn = TetLoss(validate_loss_fn)
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
if args.eval: # evaluate the model
# if args.distributed:
# raise NotImplementedError('eval not has not been verified for distributed')
# else:
# load_checkpoint(model, args.eval_checkpoint, args.model_ema)
model.eval()
for t in range(1, args.step * 3):
# for t in range(args.step, args.step + 1):
model.set_attr('step', t)
val_metrics = validate(start_epoch, model, loader_eval, validate_loss_fn, args,
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat, summary_writer=summary_writer)
print(f"[STEP:{t}], Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%")
return
saver = None
if args.local_rank == 0:
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=3)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
try: # train the model
if args.reset_drop:
model_without_ddp.reset_drop_path(0.0)
for epoch in range(start_epoch, args.epochs):
if epoch == 0 and args.reset_drop:
model_without_ddp.reset_drop_path(args.drop_path)
if args.distributed:
loader_train.sampler.set_epoch(epoch)
train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler,
model_ema=model_ema, mixup_fn=mixup_fn, summary_writer=summary_writer
)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
_logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(epoch, model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast,
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat, summary_writer=summary_writer)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
epoch, model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)',
visualize=args.visualize, spike_rate=args.spike_rate,
tsne=args.tsne, conf_mat=args.conf_mat, summary_writer=summary_writer
)
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None)
# if saver is not None and epoch >= args.n_warm_up:
if saver is not None:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
except KeyboardInterrupt:
pass
if best_metric is not None:
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def train_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None, summary_writer=None):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
# closses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.train()
# t, k = adjust_surrogate_coeff(100, args.epochs)
# model.set_attr('t', t)
# model.set_attr('k', k)
end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
iters_per_epoch = len(loader)
for batch_idx, (inputs, target) in enumerate(loader):
last_batch = batch_idx == last_idx
if args.rand_step:
step = buildin_random.randint(1, args.step + 2)
model.set_attr('step', step)
data_time_m.update(time.time() - end)
if not args.prefetcher or args.dataset != 'imnet':
inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()
if mixup_fn is not None:
inputs, target = mixup_fn(inputs, target)
if args.channels_last:
inputs = inputs.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(inputs)
loss = loss_fn(output, target)
if args.tet_loss:
output = output.mean(0)
if not (args.cut_mix | args.mix_up | args.event_mix | (args.cutmix != 0.) | (args.mixup != 0.)):
# print(output.shape, target.shape)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
# acc1, = accuracy(output, target)
else:
acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)
else:
loss.backward(create_graph=second_order)
if args.noisy_grad != 0.:
random_gradient(model, args.noisy_grad)
if args.clip_grad is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
# if args.opt == 'lamb':
# optimizer.step(epoch=epoch)
# else:
optimizer.step()
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
num_updates += 1
batch_time_m.update(time.time() - end)
if args.local_rank == 0:
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/train/top1'), acc1.item(), epoch * iters_per_epoch + batch_idx)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/train/top5'), acc5.item(), epoch * iters_per_epoch + batch_idx)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/train/loss'), loss.item(), epoch * iters_per_epoch + batch_idx)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
if args.distributed:
loss = reduce_tensor(loss.data, args.world_size)
acc1 = reduce_tensor(acc1, args.world_size)
acc5 = reduce_tensor(acc5, args.world_size)
losses_m.update(loss.item(), inputs.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
# closses_m.update(reduced_loss.item(), inputs.size(0))
if args.local_rank == 0:
# if args.distributed:
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
top1=top1_m,
top5=top5_m,
batch_time=batch_time_m,
rate=inputs.size(0) * args.world_size / batch_time_m.val,
rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m
))
if args.save_images and output_dir:
torchvision.utils.save_image(
inputs,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True)
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
end = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
if args.local_rank == 0:
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/train/top1'), top1_m.avg, epoch)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/train/top5'), top5_m.avg, epoch)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/train/loss'), losses_m.avg, epoch)
if args.rand_step:
model.set_attr('step', args.step)
return OrderedDict([('loss', losses_m.avg)])
def validate(epoch, model, loader, loss_fn, args, amp_autocast=suppress,
log_suffix='', visualize=False, spike_rate=False, tsne=False, conf_mat=False, summary_writer=None):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
# closses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
spike_m = AverageMeter()
model.eval()
feature_vec = []
feature_cls = []
logits_vec = []
labels_vec = []
mem_vec = []
end = time.time()
last_idx = len(loader) - 1
iters_per_epoch = len(loader)
with torch.no_grad():
for batch_idx, (inputs, target) in enumerate(loader):
# inputs = inputs.type(torch.float64)
last_batch = batch_idx == last_idx
if not args.prefetcher or args.dataset != 'imnet':
inputs = inputs.type(torch.FloatTensor).cuda()
target = target.cuda()
if args.channels_last:
inputs = inputs.contiguous(memory_format=torch.channels_last)
if not args.distributed:
if (visualize or spike_rate or tsne or conf_mat or args.mem_dist) and not args.critical_loss:
model.set_requires_fp(True)
with amp_autocast():
output = model(inputs)
if isinstance(output, (tuple, list)):
output = output[0]
# augmentation reduction
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor]
# print(args.rank, output.shape, target.shape, max(target))
loss = loss_fn(output, target)
if args.tet_loss:
output = output.mean(0)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
acc1 = reduce_tensor(acc1, args.world_size)
acc5 = reduce_tensor(acc5, args.world_size)
else:
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), inputs.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
# closses_m.update(closs, inputs.size(0))
batch_time_m.update(time.time() - end)
end = time.time()
if args.local_rank == 0:
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/val/top1'), acc1.item(), epoch * iters_per_epoch + batch_idx)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/val/top5'), acc5.item(), epoch * iters_per_epoch + batch_idx)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/val/loss'), loss.item(), epoch * iters_per_epoch + batch_idx)
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
if not args.distributed and spike_rate:
spike_m.update(model.get_tot_spike() / output.size(0), output.size(0))
if not args.distributed and spike_rate:
_logger.info(
'[Spike Info]: {spike.val} ({spike.avg})'.format(
spike=spike_m
)
)
if last_batch or batch_idx % args.log_interval == 0:
_logger.info(
'Eval : {} '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
epoch,
batch_idx,
last_idx,
batch_time=batch_time_m,
loss=losses_m,
top1=top1_m,
top5=top5_m,
))
# metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg)])
if args.local_rank == 0:
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/val/top1'), top1_m.avg, epoch)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/val/top5'), top5_m.avg, epoch)
summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/val/loss'), losses_m.avg, epoch)
return metrics
if __name__ == '__main__':
main()
================================================
FILE: examples/TIM/models/TIM.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from braincog.model_zoo.base_module import BaseModule
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.strategy.surrogate import *
from utils.MyNode import *
class TIM(BaseModule):
def __init__(self,dim=256,encode_type='direct',in_channels=16,TIM_alpha=0.5):
super().__init__(step=1,encode_type=encode_type)
# channels may depends on the shape of input
self.interactor = nn.Conv1d(in_channels=in_channels,out_channels=in_channels,kernel_size=5, stride=1, padding=2, bias=True)
self.in_lif = MyNode(tau=2.0,v_threshold=0.3,layer_by_layer=False,step=1) #spike-driven
self.out_lif = MyNode(tau=2.0,v_threshold=0.5,layer_by_layer=False,step=1) #spike-driven
self.tim_alpha = TIM_alpha
# input [T, B, H, N, C/H]
def forward(self, x):
self.reset()
T, B, H, N, CoH = x.shape
output = []
x_tim = torch.empty_like(x[0])
#temporal interaction
for i in range(T):
#1st step
if i == 0 :
x_tim = x[i]
output.append(x_tim)
#other steps
else:
x_tim = self.interactor(x_tim.flatten(0,1)).reshape(B,H,N,CoH).contiguous()
x_tim = self.in_lif(x_tim) * self.tim_alpha + x[i] * (1-self.tim_alpha)
x_tim = self.out_lif(x_tim)
output.append(x_tim)
output = torch.stack(output) # T B H, N, C/H
return output # T B H, N, C/H
================================================
FILE: examples/TIM/models/spikformer_braincog_DVS.py
================================================
import torch
import torch.nn as nn
from timm.models.layers import to_2tuple, trunc_normal_, DropPath
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
import torch.nn.functional as F
from braincog.model_zoo.base_module import BaseModule
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.strategy.surrogate import *
from functools import partial
from torchvision import transforms
from utils.MyNode import *
from models.TIM import *
__all__ = ['spikformer']
class MLP(BaseModule):
def __init__(self,in_features,step=10,encode_type='direct',hidden_features=None, out_features=None, drop=0.):
super().__init__(step=10,encode_type='direct')
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1)
self.fc1_bn = nn.BatchNorm1d(hidden_features)
self.fc1_lif = MyNode(step=step,tau=2.0)
self.fc2_conv = nn.Conv1d(hidden_features, out_features, kernel_size=1, stride=1)
self.fc2_bn = nn.BatchNorm1d(out_features)
self.fc2_lif = MyNode(step=step,tau=2.0)
self.c_hidden = hidden_features
self.c_output = out_features
def forward(self, x):
self.reset()
T,B,C,N = x.shape
x = self.fc1_conv(x.flatten(0,1))
x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N ).contiguous() # T B C N
x = self.fc1_lif(x.flatten(0,1)).reshape(T, B, self.c_hidden, N).contiguous()
x = self.fc2_conv(x.flatten(0,1))
x = self.fc2_bn(x).reshape(T, B, C, N).contiguous()
x = self.fc2_lif(x.flatten(0,1)).reshape(T, B, C, N ).contiguous()
return x
class SSA(BaseModule):
def __init__(self,dim,step=10,encode_type='direct',num_heads=16,TIM_alpha=0.5,qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
super().__init__(step=10,encode_type='direct')
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
self.in_channels = dim // num_heads
self.scale = 0.25
self.q_conv = nn.Conv1d(dim, dim,kernel_size=1, stride=1,bias=False)
self.q_bn = nn.BatchNorm1d(dim)
self.q_lif = MyNode(step=step,tau=2.0)
self.k_conv = nn.Conv1d(dim, dim,kernel_size=1, stride=1,bias=False)
self.k_bn = nn.BatchNorm1d(dim)
self.k_lif = MyNode(step=step,tau=2.0)
self.v_conv = nn.Conv1d(dim, dim,kernel_size=1, stride=1,bias=False)
self.v_bn = nn.BatchNorm1d(dim)
self.v_lif = MyNode(step=step,tau=2.0)
self.attn_drop = nn.Dropout(0.2)
self.res_lif = MyNode(step=step, tau=2.0)
self.attn_lif = MyNode(step=step, tau=2.0, v_threshold=0.5,)
self.proj_conv = nn.Conv1d(dim, dim,kernel_size=1, stride=1,bias=False)
self.proj_bn = nn.BatchNorm1d(dim)
self.proj_lif = MyNode(step=step, tau=2.0,)
self.TIM = TIM(TIM_alpha=TIM_alpha,in_channels=self.in_channels)
def forward(self, x):
self.reset()
T,B,C,N = x.shape
x_for_qkv = x.flatten(0, 1)
q_conv_out = self.q_conv(x_for_qkv)
q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N).contiguous()
q_conv_out = self.q_lif(q_conv_out.flatten(0,1)).reshape(T, B, C ,N).transpose(-2,-1)
q = q_conv_out.reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
k_conv_out = self.k_conv(x_for_qkv)
k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N).contiguous()
k_conv_out= self.k_lif(k_conv_out.flatten(0,1)).reshape(T, B, C ,N).transpose(-2,-1)
k = k_conv_out.reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
v_conv_out = self.v_conv(x_for_qkv)
v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, N).contiguous()
v_conv_out = self.v_lif(v_conv_out.flatten(0,1)).reshape(T, B, C ,N).transpose(-2,-1)
v = v_conv_out.reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
#TIM on Q
q = self.TIM(q)
#SSA
attn = (q @ k.transpose(-2, -1))
x = (attn @ v) * self.scale
x = x.transpose(3,4).reshape(T, B, C, N).contiguous()
x = self.attn_lif(x.flatten(0,1))
x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T, B, C, N)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, step=10,TIM_alpha=0.5, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = SSA(dim, step=step,TIM_alpha=TIM_alpha,num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, sr_ratio=sr_ratio)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(in_features=dim,step=step, hidden_features=mlp_hidden_dim, drop=drop)
def forward(self, x):
x = x + self.attn(x)
x = x + self.mlp(x)
return x
class SPS(BaseModule):
def __init__(self, step=10, encode_type='direct', img_size_h=64, img_size_w=64, patch_size=4, in_channels=2,
embed_dims=256,if_UCF=False):
super().__init__(step=10, encode_type='direct')
self.image_size = [img_size_h, img_size_w]
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.C = in_channels
self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.if_UCF = if_UCF
self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn = nn.BatchNorm2d(embed_dims // 8)
self.proj_lif = MyNode(step=step, tau=2.0)
self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.proj_conv1 = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn1 = nn.BatchNorm2d(embed_dims // 4)
self.proj_lif1 = MyNode(step=step, tau=2.0)
self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.proj_conv2 = nn.Conv2d(embed_dims // 4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn2 = nn.BatchNorm2d(embed_dims // 2)
self.proj_lif2 = MyNode(step=step, tau=2.0)
self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.proj_conv3 = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn3 = nn.BatchNorm2d(embed_dims)
self.proj_lif3 = MyNode(step=step, tau=2.0)
self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.rpe_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
self.rpe_bn = nn.BatchNorm2d(embed_dims)
self.rpe_lif = MyNode(step=step, tau=2.0)
def forward(self, x):
self.reset()
T, B, C, H, W = x.shape
# UCF101DVS
if self.if_UCF:
x = F.adaptive_avg_pool2d(x.flatten(0,1), output_size=(64, 64)).reshape(T, B, C,64,64)
T, B, C, H, W = x.shape
x = self.proj_conv(x.flatten(0, 1)) # have some fire value
x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous()
x = self.proj_lif(x.flatten(0, 1)).contiguous()
x = self.maxpool(x)
x = self.proj_conv1(x)
x = self.proj_bn1(x).reshape(T, B, -1, H // 2, W // 2).contiguous()
x = self.proj_lif1(x.flatten(0, 1)).contiguous()
x = self.maxpool1(x)
x = self.proj_conv2(x)
x = self.proj_bn2(x).reshape(T, B, -1, H // 4, W // 4).contiguous()
x = self.proj_lif2(x.flatten(0, 1)).contiguous()
x = self.maxpool2(x)
x = self.proj_conv3(x)
x = self.proj_bn3(x).reshape(T, B, -1, H // 8, W // 8).contiguous()
x = self.proj_lif3(x.flatten(0, 1)).contiguous()
x = self.maxpool3(x)
x_rpe = self.rpe_bn(self.rpe_conv(x)).reshape(T, B, -1 , H // 16,W // 16).contiguous()
x_rpe = self.rpe_lif(x_rpe.flatten(0,1)).contiguous()
x = x + x_rpe
x = x.reshape(T, B, -1, (H//16)*(W//16)).contiguous()
return x # T B C N
class Spikformer(nn.Module):
def __init__(self, step=10,TIM_alpha=0.5,if_UCF=False,
img_size_h=64, img_size_w=64, patch_size=16, in_channels=2, num_classes=10,
embed_dims=256, num_heads=16, mlp_ratios=4, qkv_bias=False, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=2, sr_ratios=4,
):
super().__init__()
self.T = step # time step
self.num_classes = num_classes
self.depths = depths
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)] # stochastic depth decay rule
patch_embed = SPS( step=step,
if_UCF=if_UCF,
img_size_h=img_size_h,
img_size_w=img_size_w,
patch_size=patch_size,
in_channels=in_channels,
embed_dims=embed_dims)
block = nn.ModuleList([Block(step=step, TIM_alpha=TIM_alpha,
dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],
norm_layer=norm_layer, sr_ratio=sr_ratios)
for j in range(depths)])
setattr(self, f"patch_embed", patch_embed)
setattr(self, f"block", block)
# classification head
self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
@torch.jit.ignore
def _get_pos_embed(self, pos_embed, patch_embed, H, W):
if H * W == self.patch_embed1.num_patches:
return pos_embed
else:
return F.interpolate(
pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
block = getattr(self, f"block")
patch_embed = getattr(self, f"patch_embed")
x = patch_embed(x)
for blk in block:
x = blk(x)
return x.mean(3)
def forward(self, x):
x = x.permute(1, 0, 2, 3, 4)
x = self.forward_features(x)
x = self.head(x.mean(0))
return x
# Hyperparams could be adjust here
@register_model
def spikformer_dvs(pretrained=False, **kwargs):
model = Spikformer(TIM_alpha=0.5,step=10,if_UCF=False,num_classes=10,
# img_size_h=64, img_size_w=64,
# patch_size=16, embed_dims=256, num_heads=16, mlp_ratios=4,
# in_channels=2, qkv_bias=False,
# depths=2, sr_ratios=1,
**kwargs
)
model.default_cfg = _cfg()
return model
================================================
FILE: examples/TIM/models/spikformer_braincog_SHD.py
================================================
import torch
import torch.nn as nn
from timm.models.layers import to_2tuple, trunc_normal_, DropPath
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
import torch.nn.functional as F
from braincog.model_zoo.base_module import BaseModule
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.strategy.surrogate import *
from functools import partial
from torchvision import transforms
from utils.MyNode import *
from models.TIM import *
__all__ = ['spikformer']
class MLP(BaseModule):
def __init__(self,in_features,step=10,encode_type='direct',hidden_features=None, out_features=None, drop=0.):
super().__init__(step=10,encode_type='direct')
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1)
self.fc1_bn = nn.BatchNorm1d(hidden_features)
self.fc1_lif = MyNode(step=step,tau=2.0)
self.fc2_conv = nn.Conv1d(hidden_features, out_features, kernel_size=1, stride=1)
self.fc2_bn = nn.BatchNorm1d(out_features)
self.fc2_lif = MyNode(step=step,tau=2.0)
self.c_hidden = hidden_features
self.c_output = out_features
def forward(self, x):
self.reset()
T,B,C,N = x.shape
x = self.fc1_conv(x.flatten(0,1))
x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N ).contiguous() # T B C N
x = self.fc1_lif(x.flatten(0,1)).reshape(T, B, self.c_hidden, N).contiguous()
x = self.fc2_conv(x.flatten(0,1))
x = self.fc2_bn(x).reshape(T, B, C, N).contiguous()
x = self.fc2_lif(x.flatten(0,1)).reshape(T, B, C, N ).contiguous()
return x
class SSA(BaseModule):
def __init__(self,dim,step=10,encode_type='direct',num_heads=16,TIM_alpha=0.5,qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
super().__init__(step=10,encode_type='direct')
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
self.in_channels = dim // num_heads
self.scale = 0.25
self.q_conv = nn.Conv1d(dim, dim,kernel_size=1, stride=1,bias=False)
self.q_bn = nn.BatchNorm1d(dim)
self.q_lif = MyNode(step=step,tau=2.0)
self.k_conv = nn.Conv1d(dim, dim,kernel_size=1, stride=1,bias=False)
self.k_bn = nn.BatchNorm1d(dim)
self.k_lif = MyNode(step=step,tau=2.0)
self.v_conv = nn.Conv1d(dim, dim,kernel_size=1, stride=1,bias=False)
self.v_bn = nn.BatchNorm1d(dim)
self.v_lif = MyNode(step=step,tau=2.0)
self.attn_drop = nn.Dropout(0.2)
self.res_lif = MyNode(step=step, tau=2.0)
self.attn_lif = MyNode(step=step, tau=2.0, v_threshold=0.5,)
self.proj_conv = nn.Conv1d(dim, dim,kernel_size=1, stride=1,bias=False)
self.proj_bn = nn.BatchNorm1d(dim)
self.proj_lif = MyNode(step=step, tau=2.0,)
self.TIM = TIM(TIM_alpha=TIM_alpha,in_channels=self.in_channels)
def forward(self, x):
self.reset()
T,B,C,N = x.shape
x_for_qkv = x.flatten(0, 1)
q_conv_out = self.q_conv(x_for_qkv)
q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N).contiguous()
q_conv_out = self.q_lif(q_conv_out.flatten(0,1)).reshape(T, B, C ,N).transpose(-2,-1)
q = q_conv_out.reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
k_conv_out = self.k_conv(x_for_qkv)
k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N).contiguous()
k_conv_out= self.k_lif(k_conv_out.flatten(0,1)).reshape(T, B, C ,N).transpose(-2,-1)
k = k_conv_out.reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
v_conv_out = self.v_conv(x_for_qkv)
v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, N).contiguous()
v_conv_out = self.v_lif(v_conv_out.flatten(0,1)).reshape(T, B, C ,N).transpose(-2,-1)
v = v_conv_out.reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous()
#TIM on Q
q = self.TIM(q)
#SSA
attn = (q @ k.transpose(-2, -1))
x = (attn @ v) * self.scale
x = x.transpose(3,4).reshape(T, B, C, N).contiguous()
x = self.attn_lif(x.flatten(0,1))
x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T, B, C, N)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, step=10,TIM_alpha=0.5, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = SSA(dim, step=step,TIM_alpha=TIM_alpha,num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, sr_ratio=sr_ratio)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(in_features=dim,step=step, hidden_features=mlp_hidden_dim, drop=drop)
def forward(self, x):
x = x + self.attn(x)
x = x + self.mlp(x)
return x
class SPS(BaseModule):
def __init__(self, step=10, encode_type='direct', img_size_h=64, img_size_w=64, patch_size=4, in_channels=2,
embed_dims=256,if_UCF=False):
super().__init__(step=10, encode_type='direct')
self.image_size = [img_size_h, img_size_w]
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.C = in_channels
self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.if_UCF = if_UCF
self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn = nn.BatchNorm2d(embed_dims // 8)
self.proj_lif = MyNode(step=step, tau=2.0)
self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.proj_conv1 = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn1 = nn.BatchNorm2d(embed_dims // 4)
self.proj_lif1 = MyNode(step=step, tau=2.0)
self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.proj_conv2 = nn.Conv2d(embed_dims // 4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn2 = nn.BatchNorm2d(embed_dims // 2)
self.proj_lif2 = MyNode(step=step, tau=2.0)
self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.proj_conv3 = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
self.proj_bn3 = nn.BatchNorm2d(embed_dims)
self.proj_lif3 = MyNode(step=step, tau=2.0)
self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.rpe_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
self.rpe_bn = nn.BatchNorm2d(embed_dims)
self.rpe_lif = MyNode(step=step, tau=2.0)
def forward(self, x):
self.reset()
# SHD
T, B, _ = x.shape
x = x.reshape(T,B,2,-1) # T B 2 350
x = F.interpolate(x.flatten(0,1), size=256, mode='nearest').reshape(T,B,2,16,16)
T, B, C, H, W = x.shape
x = self.proj_conv(x.flatten(0, 1)) # have some fire value
x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous()
x = self.proj_lif(x.flatten(0, 1)).contiguous()
# x = self.maxpool(x)
x = self.proj_conv1(x)
x = self.proj_bn1(x).reshape(T, B, -1, H, W).contiguous()
x = self.proj_lif1(x.flatten(0, 1)).contiguous()
# x = self.maxpool1(x)
x = self.proj_conv2(x)
x = self.proj_bn2(x).reshape(T, B, -1, H, W).contiguous()
x = self.proj_lif2(x.flatten(0, 1)).contiguous()
x = self.maxpool2(x)
x = self.proj_conv3(x)
x = self.proj_bn3(x).reshape(T, B, -1, H // 2, W // 2).contiguous()
x = self.proj_lif3(x.flatten(0, 1)).contiguous()
x = self.maxpool3(x)
x_rpe = self.rpe_bn(self.rpe_conv(x)).reshape(T, B, -1 , H // 4,W // 4).contiguous()
x_rpe = self.rpe_lif(x_rpe.flatten(0,1)).contiguous()
x = x + x_rpe
x = x.reshape(T, B, -1, (H//4)*(W//4)).contiguous()
return x # T B C N
class Spikformer(nn.Module):
def __init__(self, step=10,TIM_alpha=0.5,if_UCF=False,
img_size_h=64, img_size_w=64, patch_size=16, in_channels=2, num_classes=10,
embed_dims=256, num_heads=16, mlp_ratios=4, qkv_bias=False, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=2, sr_ratios=4,
):
super().__init__()
self.T = step # time step
self.num_classes = num_classes
self.depths = depths
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)] # stochastic depth decay rule
patch_embed = SPS( step=step,
if_UCF=if_UCF,
img_size_h=img_size_h,
img_size_w=img_size_w,
patch_size=patch_size,
in_channels=in_channels,
embed_dims=embed_dims)
block = nn.ModuleList([Block(step=step, TIM_alpha=TIM_alpha,
dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],
norm_layer=norm_layer, sr_ratio=sr_ratios)
for j in range(depths)])
setattr(self, f"patch_embed", patch_embed)
setattr(self, f"block", block)
# classification head
self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
@torch.jit.ignore
def _get_pos_embed(self, pos_embed, patch_embed, H, W):
if H * W == self.patch_embed1.num_patches:
return pos_embed
else:
return F.interpolate(
pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
block = getattr(self, f"block")
patch_embed = getattr(self, f"patch_embed")
x = patch_embed(x)
for blk in block:
x = blk(x)
return x.mean(3)
def forward(self, x):
x = x.permute(1, 0, 2)
x = self.forward_features(x)
x = self.head(x.mean(0))
return x
# Hyperparams could be adjust here
@register_model
def spikformer_shd(pretrained=False, **kwargs):
model = Spikformer(TIM_alpha=0.5,step=10,if_UCF=False,num_classes=20,
# img_size_h=64, img_size_w=64,
# patch_size=16, embed_dims=256, num_heads=16, mlp_ratios=4,
# in_channels=2, qkv_bias=False,
# depths=2, sr_ratios=1,
**kwargs
)
model.default_cfg = _cfg()
return model
================================================
FILE: examples/TIM/utils/MyGrad.py
================================================
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.strategy.surrogate import *
class MyGrad(SurrogateFunctionBase):
def __init__(self, alpha=4., requires_grad=False):
super().__init__(alpha, requires_grad)
@staticmethod
def act_fun(x, alpha):
return sigmoid.apply(x, alpha)
================================================
FILE: examples/TIM/utils/MyNode.py
================================================
from braincog.base.node.node import *
from braincog.base.connection.layer import *
from braincog.base.strategy.surrogate import *
from utils.MyGrad import MyGrad
class MyBaseNode(BaseNode):
def __init__(self,threshold=0.5,step=10,layer_by_layer=False,mem_detach=False):
super().__init__(threshold=threshold,step=step,layer_by_layer=layer_by_layer,mem_detach=mem_detach)
def rearrange2node(self, inputs):
if self.groups != 1:
if len(inputs.shape) == 4:
outputs = rearrange(inputs, 'b (c t) w h -> t b c w h', t=self.step)
elif len(inputs.shape) == 2:
outputs = rearrange(inputs, 'b (c t) -> t b c', t=self.step)
else:
raise NotImplementedError
elif self.layer_by_layer:
if len(inputs.shape) == 4:
outputs = rearrange(inputs, '(t b) c w h -> t b c w h', t=self.step)
#加入适配Transformer T B N C的rearange2node分支
elif len(inputs.shape) == 3:
outputs = rearrange(inputs, '(t b) n c -> t b n c', t=self.step)
elif len(inputs.shape) == 2:
outputs = rearrange(inputs, '(t b) c -> t b c', t=self.step)
else:
raise NotImplementedError
else:
outputs = inputs
return outputs
def rearrange2op(self, inputs):
if self.groups != 1:
if len(inputs.shape) == 5:
outputs = rearrange(inputs, 't b c w h -> b (c t) w h')
elif len(inputs.shape) == 3:
outputs = rearrange(inputs, ' t b c -> b (c t)')
else:
raise NotImplementedError
elif self.layer_by_layer:
if len(inputs.shape) == 5:
outputs = rearrange(inputs, 't b c w h -> (t b) c w h')
# 加入适配Transformer T B N C的rearange2op分支
elif len(inputs.shape) == 4:
outputs = rearrange(inputs, ' t b n c -> (t b) n c')
elif len(inputs.shape) == 3:
outputs = rearrange(inputs, ' t b c -> (t b) c')
else:
raise NotImplementedError
else:
outputs = inputs
return outputs
class MyNode(MyBaseNode):
def __init__(self, threshold=1.,step=10,layer_by_layer=True,tau=2., act_fun=MyGrad, mem_detach=True,*args, **kwargs):
super().__init__(threshold=threshold,step=step, layer_by_layer=layer_by_layer,mem_detach=mem_detach)
self.tau = tau
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.act_fun = act_fun(alpha=4., requires_grad=False)
def integral(self, inputs):
self.mem = self.mem + (inputs - self.mem) / self.tau
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.threshold)
self.mem = self.mem * (1 - self.spike.detach())
================================================
FILE: examples/TIM/utils/datasets.py
================================================
import os
import warnings
import random
import torchvision.datasets
import braincog.datasets.ucf101_dvs
try:
import tonic
from tonic import DiskCachedDataset
except:
warnings.warn("tonic should be installed, 'pip install git+https://github.com/FloyedShen/tonic.git'")
import torch
import torch.nn.functional as F
import torch.utils
import torchvision.datasets as datasets
from timm.data import ImageDataset, create_loader, Mixup, FastCollateMixup, AugMixDataset
from timm.data import create_transform, distributed_sampler
from timm.data.loader import PrefetchLoader
from tonic import DiskCachedDataset
from torchvision import transforms
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from braincog.datasets.NOmniglot.nomniglot_full import NOmniglotfull
from braincog.datasets.NOmniglot.nomniglot_nw_ks import NOmniglotNWayKShot
from braincog.datasets.NOmniglot.nomniglot_pair import NOmniglotTrainSet, NOmniglotTestSet
# from braincog.base.conversion.conversion import CIFAR10Policy, Cutout
# from .cut_mix import CutMix, EventMix, MixUp
# from .rand_aug import *
# from .event_drop import event_drop
# from .utils import dvs_channel_check_expend, rescale
DVSCIFAR10_MEAN_16 = [0.3290, 0.4507]
DVSCIFAR10_STD_16 = [1.8398, 1.6549]
DATA_DIR = '/data/datasets'
DEFAULT_CROP_PCT = 0.875
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
CIFAR10_DEFAULT_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_DEFAULT_STD = (0.2023, 0.1994, 0.2010)
CIFAR100_DEFAULT_MEAN = (0.5071, 0.4867, 0.4408)
CIFAR100_DEFAULT_STD = (0.2675, 0.2565, 0.2761)
def unpack_mix_param(args):
mix_up = args['mix_up'] if 'mix_up' in args else False
cut_mix = args['cut_mix'] if 'cut_mix' in args else False
event_mix = args['event_mix'] if 'event_mix' in args else False
beta = args['beta'] if 'beta' in args else 1.
prob = args['prob'] if 'prob' in args else .5
num = args['num'] if 'num' in args else 1
num_classes = args['num_classes'] if 'num_classes' in args else 10
noise = args['noise'] if 'noise' in args else 0.
gaussian_n = args['gaussian_n'] if 'gaussian_n' in args else None
return mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n
def build_transform(is_train, img_size):
"""
构建数据增强, 适用于static data
:param is_train: 是否训练集
:param img_size: 输出的图像尺寸
:return: 数据增强策略
"""
resize_im = img_size > 32
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=img_size,
is_training=True,
color_jitter=0.4,
# auto_augment='rand-m9-mstd0.5-inc1',
interpolation='bicubic',
# re_prob=0.25,
# re_mode='pixel',
# re_count=1,
)
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(
img_size, padding=4)
return transform
t = []
if resize_im:
size = int((256 / 224) * img_size)
t.append(
# to maintain same ratio w.r.t. 224 images
transforms.Resize(size, interpolation=3),
)
t.append(transforms.CenterCrop(img_size))
t.append(transforms.ToTensor())
if img_size > 32:
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
else:
t.append(transforms.Normalize(CIFAR10_DEFAULT_MEAN, CIFAR10_DEFAULT_STD))
return transforms.Compose(t)
def build_dataset(is_train, img_size, dataset, path, same_da=False):
"""
构建带有增强策略的数据集
:param is_train: 是否训练集
:param img_size: 输出图像尺寸
:param dataset: 数据集名称
:param path: 数据集路径
:param same_da: 为训练集使用测试集的增广方法
:return: 增强后的数据集
"""
transform = build_transform(False, img_size) if same_da else build_transform(is_train, img_size)
if dataset == 'CIFAR10':
dataset = datasets.CIFAR10(
path, train=is_train, transform=transform, download=True)
nb_classes = 10
elif dataset == 'CIFAR100':
dataset = datasets.CIFAR100(
path, train=is_train, transform=transform, download=True)
nb_classes = 100
else:
raise NotImplementedError
return dataset, nb_classes
class MNISTData(object):
"""
Load MNIST datesets.
"""
def __init__(self,
data_path: str,
batch_size: int,
train_trans: Sequence[torch.nn.Module] = None,
test_trans: Sequence[torch.nn.Module] = None,
pin_memory: bool = True,
drop_last: bool = True,
shuffle: bool = True,
) -> None:
self._data_path = data_path
self._batch_size = batch_size
self._pin_memory = pin_memory
self._drop_last = drop_last
self._shuffle = shuffle
self._train_transform = transforms.Compose(train_trans) if train_trans else None
self._test_transform = transforms.Compose(test_trans) if test_trans else None
def get_data_loaders(self):
print('Batch size: ', self._batch_size)
train_datasets = datasets.MNIST(root=self._data_path, train=True, transform=self._train_transform, download=True)
test_datasets = datasets.MNIST(root=self._data_path, train=False, transform=self._test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=self._batch_size,
pin_memory=self._pin_memory, drop_last=self._drop_last, shuffle=self._shuffle
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=self._batch_size,
pin_memory=self._pin_memory, drop_last=False
)
return train_loader, test_loader
def get_standard_data(self):
MNIST_MEAN = 0.1307
MNIST_STD = 0.3081
self._train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),
transforms.ToTensor(),
transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])
self._test_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])
return self.get_data_loaders()
def get_mnist_data(batch_size, num_workers=8, same_da=False, **kwargs):
"""
获取MNIST数据
http://data.pymvpa.org/datasets/mnist/
:param batch_size: batch size
:param same_da: 为训练集使用测试集的增广方法
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
MNIST_MEAN = 0.1307
MNIST_STD = 0.3081
if 'skip_norm' in kwargs and kwargs['skip_norm'] is True:
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(rescale)
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(rescale)
])
else:
train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),
# transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])
test_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])
train_datasets = datasets.MNIST(
root=DATA_DIR, train=True, transform=test_transform if same_da else train_transform, download=True)
test_datasets = datasets.MNIST(
root=DATA_DIR, train=False, transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_fashion_data(batch_size, num_workers=8, same_da=False, **kwargs):
"""
获取fashion MNIST数据
http://arxiv.org/abs/1708.07747
:param batch_size: batch size
:param same_da: 为训练集使用测试集的增广方法
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor()])
test_transform = transforms.Compose([transforms.ToTensor()])
train_datasets = datasets.FashionMNIST(
root=DATA_DIR, train=True, transform=test_transform if same_da else train_transform, download=True)
test_datasets = datasets.FashionMNIST(
root=DATA_DIR, train=False, transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size,
pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=num_workers
)
return train_loader, test_loader, False, None
def get_cifar10_data(batch_size, num_workers=8, same_da=False, **kwargs):
# """
# 获取CIFAR10数据
# https://www.cs.toronto.edu/~kriz/cifar.html
# :param batch_size: batch size
# :param kwargs:
# :return: (train loader, test loader, mixup_active, mixup_fn)
# """
# train_datasets, _ = build_dataset(True, 32, 'CIFAR10', DATA_DIR, same_da)
# test_datasets, _ = build_dataset(False, 32, 'CIFAR10', DATA_DIR, same_da)
#
# train_loader = torch.utils.data.DataLoader(
# train_datasets, batch_size=batch_size,
# pin_memory=True, drop_last=True, shuffle=True,
# num_workers=num_workers
# )
#
# test_loader = torch.utils.data.DataLoader(
# test_datasets, batch_size=batch_size,
# pin_memory=True, drop_last=False,
# num_workers=num_workers
# )
normalize = transforms.Normalize(CIFAR10_DEFAULT_MEAN, CIFAR10_DEFAULT_STD)
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
CIFAR10Policy(),
transforms.ToTensor(),
Cutout(n_holes=1, length=16),
normalize])
transform_test = transforms.Compose([transforms.ToTensor(), normalize])
train_dataset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transform_test)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
shuffle=True, num_workers=num_workers,
pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
shuffle=False, num_workers=num_workers,
pin_memory=True
)
return train_loader, test_loader, None, None
def get_shd_data(batch_size, step, **kwargs):
"""
获取SHD数据
https://ieeexplore.ieee.org/abstract/document/9311226
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
:format: (b,t,c,len) 不同于vision, audio中c为1, 并且没有h,w; 只有len=700
"""
sensor_size = tonic.datasets.SHD.sensor_size
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),
])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),
])
train_dataset = tonic.datasets.SHD(os.path.join(DATA_DIR, 'DVS/SHD'),
transform=train_transform, train=True)
test_dataset = tonic.datasets.SHD(os.path.join(DATA_DIR, 'DVS/SHD'),
transform=test_transform, train=False)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
pin_memory=True, drop_last=True, num_workers=8,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=2,
shuffle=False,
)
return train_loader, test_loader, None, None
def get_cifar100_data(batch_size, num_workers=8, same_data=False, *args, **kwargs):
# """
# 获取CIFAR100数据
# https://www.cs.toronto.edu/~kriz/cifar.html
# :param batch_size: batch size
# :param kwargs:
# :return: (train loader, test loader, mixup_active, mixup_fn)
# """
# train_datasets, _ = build_dataset(True, 32, 'CIFAR100', DATA_DIR, same_data)
# test_datasets, _ = build_dataset(False, 32, 'CIFAR100', DATA_DIR, same_data)
#
# train_loader = torch.utils.data.DataLoader(
# train_datasets, batch_size=batch_size,
# pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers
# )
#
# test_loader = torch.utils.data.DataLoader(
# test_datasets, batch_size=batch_size,
# pin_memory=True, drop_last=False, num_workers=num_workers
# )
# return train_loader, test_loader, False, None
normalize = transforms.Normalize(CIFAR100_DEFAULT_MEAN, CIFAR100_DEFAULT_STD)
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
CIFAR10Policy(),
transforms.ToTensor(),
Cutout(n_holes=1, length=16),
normalize])
transform_test = transforms.Compose([transforms.ToTensor(), normalize])
train_dataset = datasets.CIFAR100(root=DATA_DIR, train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR100(root=DATA_DIR, train=False, download=True, transform=transform_test)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
shuffle=True, num_workers=num_workers,
pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
shuffle=False, num_workers=num_workers,
pin_memory=True
)
return train_loader, test_loader, None, None
def get_imnet_data(args, _logger, data_config, num_aug_splits, **kwargs):
"""
获取ImageNet数据集
http://arxiv.org/abs/1409.0575
:param args: 其他的参数
:param _logger: 日志路径
:param data_config: 增强策略
:param num_aug_splits: 不同增强策略的数量
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
train_dir = os.path.join(DATA_DIR, 'ILSVRC2012/train')
if not os.path.exists(train_dir):
_logger.error(
'Training folder does not exist at: {}'.format(train_dir))
exit(1)
dataset_train = ImageDataset(train_dir)
# collate_fn = None
# mixup_fn = None
# mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
# if mixup_active:
# mixup_args = dict(
# mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
# prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
# label_smoothing=args.smoothing, num_classes=args.num_classes)
# if args.prefetcher:
# # collate conflict (need to support deinterleaving in collate mixup)
# assert not num_aug_splits
# collate_fn = FastCollateMixup(**mixup_args)
# else:
# mixup_fn = Mixup(**mixup_args)
# if num_aug_splits > 1:
# dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation']
loader_train = create_loader(
dataset_train,
input_size=data_config['input_size'],
batch_size=args.batch_size,
is_training=True,
use_prefetcher=args.prefetcher,
no_aug=args.no_aug,
# re_prob=args.reprob,
# re_mode=args.remode,
# re_count=args.recount,
# re_split=args.resplit,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
# vflip=args.vflip,
# color_jitter=args.color_jitter,
# auto_augment=args.aa,
# num_aug_splits=num_aug_splits,
interpolation=train_interpolation,
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
# collate_fn=collate_fn,
pin_memory=args.pin_mem,
# use_multi_epochs_loader=args.use_multi_epochs_loader
)
eval_dir = os.path.join(DATA_DIR, 'ILSVRC2012/val')
if not os.path.isdir(eval_dir):
eval_dir = os.path.join(DATA_DIR, 'ILSVRC2012/validation')
if not os.path.isdir(eval_dir):
_logger.error(
'Validation folder does not exist at: {}'.format(eval_dir))
exit(1)
dataset_eval = ImageDataset(eval_dir)
loader_eval = create_loader(
dataset_eval,
input_size=data_config['input_size'],
batch_size=args.validation_batch_size_multiplier * args.batch_size,
is_training=False,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
)
return loader_train, loader_eval, None, None
def get_dvsg_data(batch_size, step, **kwargs):
"""
获取DVS Gesture数据
DOI: 10.1109/CVPR.2017.781
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
sensor_size = tonic.datasets.DVSGesture.sensor_size
size = kwargs['size'] if 'size' in kwargs else 48
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),
])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),
])
train_dataset = tonic.datasets.DVSGesture(os.path.join(DATA_DIR, 'DVS/DVSGesture'),
transform=train_transform, train=True)
test_dataset = tonic.datasets.DVSGesture(os.path.join(DATA_DIR, 'DVS/DVSGesture'),
transform=test_transform, train=False)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
lambda x: dvs_channel_check_expend(x),
transforms.RandomCrop(size, padding=size // 12),
# lambda x: event_drop(x),
# transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
lambda x: dvs_channel_check_expend(x),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/DVSGesture/train_cache_{}'.format(step)),
transform=train_transform, num_copies=3)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/DVSGesture/test_cache_{}'.format(step)),
transform=test_transform, num_copies=3)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
pin_memory=True, drop_last=True, num_workers=8,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
pin_memory=True, drop_last=False, num_workers=2,
shuffle=False,
)
return train_loader, test_loader, mixup_active, None
def get_dvsc10_data(batch_size, step, **kwargs):
"""
获取DVS CIFAR10数据
http://journal.frontiersin.org/article/10.3389/fnins.2017.00309/full
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
size = kwargs['size'] if 'size' in kwargs else 48
sensor_size = tonic.datasets.CIFAR10DVS.sensor_size
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=train_transform)
test_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=test_transform)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# lambda x: TemporalShift(x, .01),
# lambda x: drop(x, 0.15),
# lambda x: ShearX(x, 15),
# lambda x: ShearY(x, 15),
# lambda x: TranslateX(x, 0.225),
# lambda x: TranslateY(x, 0.225),
# lambda x: Rotate(x, 15),
# lambda x: CutoutAbs(x, 0.25),
# lambda x: CutoutTemporal(x, 0.25),
# lambda x: GaussianBlur(x, 0.5),
# lambda x: SaltAndPepperNoise(x, 0.1),
# transforms.Normalize(DVSCIFAR10_MEAN_16, DVSCIFAR10_STD_16),
transforms.RandomCrop(size, padding=size // 12),
transforms.RandomHorizontalFlip(),
# lambda x: event_drop(x),
# transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
# print('randaug', m, n)
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/DVS_Cifar10/train_cache_{}'.format(step)),
transform=train_transform)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/DVS_Cifar10/test_cache_{}'.format(step)),
transform=test_transform)
num_train = len(train_dataset)
num_per_cls = num_train // 10
indices_train, indices_test = [], []
portion = kwargs['portion'] if 'portion' in kwargs else .9
for i in range(10):
indices_train.extend(
list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))))
indices_test.extend(
list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
# print('cut_mix', beta, prob, num, num_classes)
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_train),
pin_memory=True, drop_last=True, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_test),
pin_memory=True, drop_last=False, num_workers=2
)
return train_loader, test_loader, mixup_active, None
def get_nmnist_data(batch_size, step, **kwargs):
"""
获取DVS CIFAR10数据
http://journal.frontiersin.org/article/10.3389/fnins.2017.00309/full
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
size = kwargs['size'] if 'size' in kwargs else 48
sensor_size = tonic.datasets.NMNIST.sensor_size
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/NMNIST'), transform=train_transform)
test_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/NMNIST'), transform=test_transform)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# lambda x: TemporalShift(x, .01),
# lambda x: drop(x, 0.15),
# lambda x: ShearX(x, 15),
# lambda x: ShearY(x, 15),
# lambda x: TranslateX(x, 0.225),
# lambda x: TranslateY(x, 0.225),
# lambda x: Rotate(x, 15),
# lambda x: CutoutAbs(x, 0.25),
# lambda x: CutoutTemporal(x, 0.25),
# lambda x: GaussianBlur(x, 0.5),
# lambda x: SaltAndPepperNoise(x, 0.1),
# transforms.Normalize(DVSCIFAR10_MEAN_16, DVSCIFAR10_STD_16),
transforms.RandomCrop(size, padding=size // 12),
transforms.RandomHorizontalFlip(),
# lambda x: event_drop(x),
# transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
# print('randaug', m, n)
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/NMNIST/train_cache_{}'.format(step)),
transform=train_transform)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/NMNIST/test_cache_{}'.format(step)),
transform=test_transform)
num_train = len(train_dataset)
num_per_cls = num_train // 10
indices_train, indices_test = [], []
portion = kwargs['portion'] if 'portion' in kwargs else .9
for i in range(10):
indices_train.extend(
list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))))
indices_test.extend(
list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
# print('cut_mix', beta, prob, num, num_classes)
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=indices_train,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_train),
pin_memory=True, drop_last=True, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_test),
pin_memory=True, drop_last=False, num_workers=2
)
return train_loader, test_loader, mixup_active, None
def get_NCALTECH101_data(batch_size, step, **kwargs):
"""
获取NCaltech101数据
http://journal.frontiersin.org/Article/10.3389/fnins.2015.00437/abstract
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
sensor_size = braincog.datasets.ncaltech101.NCALTECH101.sensor_size
cls_count = braincog.datasets.ncaltech101.NCALTECH101.cls_count
dataset_length = braincog.datasets.ncaltech101.NCALTECH101.length
portion = kwargs['portion'] if 'portion' in kwargs else .9
size = kwargs['size'] if 'size' in kwargs else 48
# print('portion', portion)
train_sample_weight = []
train_sample_index = []
train_count = 0
test_sample_index = []
idx_begin = 0
for count in cls_count:
sample_weight = dataset_length / count
train_sample = round(portion * count)
test_sample = count - train_sample
train_count += train_sample
train_sample_weight.extend(
[sample_weight] * train_sample
)
train_sample_weight.extend(
[0.] * test_sample
)
train_sample_index.extend(
list((range(idx_begin, idx_begin + train_sample)))
)
test_sample_index.extend(
list(range(idx_begin + train_sample, idx_begin + train_sample + test_sample))
)
idx_begin += count
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weight, train_count)
test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_sample_index)
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = braincog.datasets.ncaltech101.NCALTECH101(os.path.join(DATA_DIR, 'DVS/NCALTECH101'), transform=train_transform)
test_dataset = braincog.datasets.ncaltech101.NCALTECH101(os.path.join(DATA_DIR, 'DVS/NCALTECH101'), transform=test_transform)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: print(x.shape),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
transforms.RandomCrop(size, padding=size // 12),
# transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# lambda x: temporal_flatten(x),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/NCALTECH101/train_cache_{}'.format(step)),
transform=train_transform, num_copies=3)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'DVS/NCALTECH101/test_cache_{}'.format(step)),
transform=test_transform, num_copies=3)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
sampler=train_sampler,
pin_memory=True, drop_last=True, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
sampler=test_sampler,
pin_memory=True, drop_last=False, num_workers=2
)
return train_loader, test_loader, mixup_active, None
def get_UCF101DVS_data(batch_size, step, **kwargs):
"""
获取DVS CIFAR10数据
http://journal.frontiersin.org/article/10.3389/fnins.2017.00309/full
:param batch_size: batch size
:param step: 仿真步长
:param kwargs:
:return: (train loader, test loader, mixup_active, mixup_fn)
"""
size = kwargs['size'] if 'size' in kwargs else 48
sensor_size = braincog.datasets.ucf101_dvs.UCF101DVS.sensor_size
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = braincog.datasets.ucf101_dvs.UCF101DVS(os.path.join(DATA_DIR, 'DVS/UCF101DVS'), train=True, transform=train_transform)
test_dataset = braincog.datasets.ucf101_dvs.UCF101DVS(os.path.join(DATA_DIR, 'DVS/UCF101DVS'), train=False, transform=test_transform)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# lambda x: TemporalShift(x, .01),
# lambda x: drop(x, 0.15),
# lambda x: ShearX(x, 15),
# lambda x: ShearY(x, 15),
# lambda x: TranslateX(x, 0.225),
# lambda x: TranslateY(x, 0.225),
# lambda x: Rotate(x, 15),
# lambda x: CutoutAbs(x, 0.25),
# lambda x: CutoutTemporal(x, 0.25),
# lambda x: GaussianBlur(x, 0.5),
# lambda x: SaltAndPepperNoise(x, 0.1),
# transforms.Normalize(DVSCIFAR10_MEAN_16, DVSCIFAR10_STD_16),
# transforms.RandomCrop(size, padding=size // 12),
transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
# print('randaug', m, n)
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'UCF101DVS/train_cache_{}'.format(step)),
transform=train_transform)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'UCF101DVS/test_cache_{}'.format(step)),
transform=test_transform)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
# print('cut_mix', beta, prob, num, num_classes)
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
pin_memory=True, drop_last=True, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False,
pin_memory=True, drop_last=False, num_workers=2
)
return train_loader, test_loader, mixup_active, None
def get_HMDBDVS_data(batch_size, step, **kwargs):
sensor_size = braincog.datasets.hmdb_dvs.HMDBDVS.sensor_size
train_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.DropEvent(p=0.1),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
test_transform = transforms.Compose([
# tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
train_dataset = braincog.datasets.hmdb_dvs.HMDBDVS(os.path.join(DATA_DIR, 'HMDBDVS'), transform=train_transform)
test_dataset = braincog.datasets.hmdb_dvs.HMDBDVS(os.path.join(DATA_DIR, 'HMDBDVS'), transform=test_transform)
cls_count = train_dataset.cls_count
dataset_length = train_dataset.length
portion = .5
# portion = kwargs['portion'] if 'portion' in kwargs else .9
size = kwargs['size'] if 'size' in kwargs else 48
# print('portion', portion)
train_sample_weight = []
train_sample_index = []
train_count = 0
test_sample_index = []
idx_begin = 0
for count in cls_count:
sample_weight = dataset_length / count
train_sample = round(portion * count)
test_sample = count - train_sample
train_count += train_sample
train_sample_weight.extend(
[sample_weight] * train_sample
)
train_sample_weight.extend(
[0.] * test_sample
)
lst = list(range(idx_begin, idx_begin + train_sample + test_sample))
random.seed(0)
random.shuffle(lst)
train_sample_index.extend(
lst[:train_sample]
# list((range(idx_begin, idx_begin + train_sample)))
)
test_sample_index.extend(
lst[train_sample:train_sample + test_sample]
# list(range(idx_begin + train_sample, idx_begin + train_sample + test_sample))
)
idx_begin += count
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weight, train_count)
test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_sample_index)
train_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: print(x.shape),
# lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# transforms.RandomCrop(size, padding=size // 12),
# transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(15)
])
test_transform = transforms.Compose([
lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# lambda x: temporal_flatten(x),
])
if 'rand_aug' in kwargs.keys():
if kwargs['rand_aug'] is True:
n = kwargs['randaug_n']
m = kwargs['randaug_m']
train_transform.transforms.insert(2, RandAugment(m=m, n=n))
# if 'temporal_flatten' in kwargs.keys():
# if kwargs['temporal_flatten'] is True:
# train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
train_dataset = DiskCachedDataset(train_dataset,
cache_path=os.path.join(DATA_DIR, 'HMDBDVS/train_cache_{}'.format(step)),
transform=train_transform, num_copies=3)
test_dataset = DiskCachedDataset(test_dataset,
cache_path=os.path.join(DATA_DIR, 'HMDBDVS/test_cache_{}'.format(step)),
transform=test_transform, num_copies=3)
mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
mixup_active = cut_mix | event_mix | mix_up
if cut_mix:
train_dataset = CutMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise)
if event_mix:
train_dataset = EventMix(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise,
gaussian_n=gaussian_n)
if mix_up:
train_dataset = MixUp(train_dataset,
beta=beta,
prob=prob,
num_mix=num,
num_class=num_classes,
indices=train_sample_index,
noise=noise)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size,
sampler=train_sampler,
pin_memory=True, drop_last=True, num_workers=8
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size,
sampler=test_sampler,
pin_memory=True, drop_last=False, num_workers=2
)
return train_loader, test_loader, mixup_active, None
# def get_NCARS_data(batch_size, step, **kwargs):
# """
# 获取N-Cars数据
# https://ieeexplore.ieee.org/document/8578284/
# :param batch_size: batch size
# :param step: 仿真步长
# :param kwargs:
# :return: (train loader, test loader, mixup_active, mixup_fn)
# """
# sensor_size = tonic.datasets.NCARS.sensor_size
# size = kwargs['size'] if 'size' in kwargs else 48
#
# train_transform = transforms.Compose([
# # tonic.transforms.Denoise(filter_time=10000),
# # tonic.transforms.DropEvent(p=0.1),
# tonic.transforms.ToFrame(sensor_size=None, n_time_bins=step),
# ])
# test_transform = transforms.Compose([
# # tonic.transforms.Denoise(filter_time=10000),
# tonic.transforms.ToFrame(sensor_size=None, n_time_bins=step),
# ])
#
# train_dataset = tonic.datasets.NCARS(os.path.join(DATA_DIR, 'DVS/NCARS'), transform=train_transform, train=True)
# test_dataset = tonic.datasets.NCARS(os.path.join(DATA_DIR, 'DVS/NCARS'), transform=test_transform, train=False)
#
# train_transform = transforms.Compose([
# lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# lambda x: dvs_channel_check_expend(x),
# transforms.RandomCrop(size, padding=size // 12),
# transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(15)
# ])
# test_transform = transforms.Compose([
# lambda x: torch.tensor(x, dtype=torch.float),
# lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
# lambda x: dvs_channel_check_expend(x),
# ])
# if 'rand_aug' in kwargs.keys():
# if kwargs['rand_aug'] is True:
# n = kwargs['randaug_n']
# m = kwargs['randaug_m']
# train_transform.transforms.insert(2, RandAugment(m=m, n=n))
#
# # if 'temporal_flatten' in kwargs.keys():
# # if kwargs['temporal_flatten'] is True:
# # train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
# # test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))
#
# train_dataset = DiskCachedDataset(train_dataset,
# cache_path=os.path.join(DATA_DIR, 'DVS/NCARS/train_cache_{}'.format(step)),
# transform=train_transform, num_copies=3)
# test_dataset = DiskCachedDataset(test_dataset,
# cache_path=os.path.join(DATA_DIR, 'DVS/NCARS/test_cache_{}'.format(step)),
# transform=test_transform, num_copies=3)
#
# mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)
# mixup_active = cut_mix | event_mix | mix_up
#
# if cut_mix:
# train_dataset = CutMix(train_dataset,
# beta=beta,
# prob=prob,
# num_mix=num,
# num_class=num_classes,
# noise=noise)
#
# if event_mix:
# train_dataset = EventMix(train_dataset,
# beta=beta,
# prob=prob,
# num_mix=num,
# num_class=num_classes,
# noise=noise,
# gaussian_n=gaussian_n)
# if mix_up:
# train_dataset = MixUp(train_dataset,
# beta=beta,
# prob=prob,
# num_mix=num,
# num_class=num_classes,
# noise=noise)
#
# train_loader = torch.utils.data.DataLoader(
# train_dataset, batch_size=batch_size,
# pin_memory=True, drop_last=True, num_workers=8,
# shuffle=True,
# )
#
# test_loader = torch.utils.data.DataLoader(
# test_dataset, batch_size=batch_size,
# pin_memory=True, drop_last=False, num_workers=2,
# shuffle=False,
# )
#
# return train_loader, test_loader, mixup_active, None
def get_nomni_data(batch_size, train_portion=1., **kwargs):
"""
获取N-Omniglot数据
:param batch_size:batch的大小
:param data_mode:一共full nkks pair三种模式
:param frames_num:一个样本帧的个数
:param data_type:event frequency两种模式
"""
data_mode = kwargs["data_mode"] if "data_mode" in kwargs else "full"
frames_num = kwargs["frames_num"] if "frames_num" in kwargs else 10
data_type = kwargs["data_type"] if "data_type" in kwargs else "event"
train_transform = transforms.Compose([
transforms.Resize((64, 64))])
test_transform = transforms.Compose([
transforms.Resize((64, 64))])
if data_mode == "full":
train_datasets = NOmniglotfull(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), train=True, frames_num=frames_num,
data_type=data_type,
transform=train_transform)
test_datasets = NOmniglotfull(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), train=False, frames_num=frames_num,
data_type=data_type,
transform=test_transform)
elif data_mode == "nkks":
train_datasets = NOmniglotNWayKShot(os.path.join(DATA_DIR, 'DVS/NOmniglot'),
n_way=kwargs["n_way"],
k_shot=kwargs["k_shot"],
k_query=kwargs["k_query"],
train=True,
frames_num=frames_num,
data_type=data_type,
transform=train_transform)
test_datasets = NOmniglotNWayKShot(os.path.join(DATA_DIR, 'DVS/NOmniglot'),
n_way=kwargs["n_way"],
k_shot=kwargs["k_shot"],
k_query=kwargs["k_query"],
train=False,
frames_num=frames_num,
data_type=data_type,
transform=test_transform)
elif data_mode == "pair":
train_datasets = NOmniglotTrainSet(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), use_frame=True,
frames_num=frames_num, data_type=data_type,
use_npz=False, resize=105)
test_datasets = NOmniglotTestSet(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), time=2000, way=kwargs["n_way"],
shot=kwargs["k_shot"], use_frame=True,
frames_num=frames_num, data_type=data_type, use_npz=False, resize=105)
else:
pass
train_loader = torch.utils.data.DataLoader(
train_datasets, batch_size=batch_size, num_workers=12,
pin_memory=True, drop_last=True, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
test_datasets, batch_size=batch_size, num_workers=12,
pin_memory=True, drop_last=False
)
return train_loader, test_loader, None, None
================================================
FILE: examples/decision_making/BDM-SNN/BDM-SNN-UAV.py
================================================
import numpy as np
import torch,os,sys
from torch import nn
from torch.nn import Parameter
import abc
import math
from abc import ABC
import torch.nn.functional as F
import matplotlib.pyplot as plt
#from braincog.base.strategy.surrogate import *
from braincog.base.node.node import IFNode
from braincog.base.learningrule.STDP import STDP,MutliInputSTDP
from braincog.base.connection.CustomLinear import CustomLinear
from braincog.base.brainarea.basalganglia import basalganglia
from braincog.model_zoo.bdmsnn import BDMSNN
from robomaster import robot
import time
def chooseAct(Net,input,weight_trace_d1,weight_trace_d2):
"""
根据输入选择行为
:param Net: 输入BDM-SNN网络
:param input: 输入电流 编码状态的脉冲
:param weight_trace_d1: 不断累积保存资格迹
:param weight_trace_d2: 不断累积保存资格迹
:return: 返回选择的行为、资格迹和网络
"""
for i_train in range(500):
out, dw = Net(input)
# rstdp
weight_trace_d1 *= trace_decay
weight_trace_d1 += dw[0][0]
weight_trace_d2 *= trace_decay
weight_trace_d2 += dw[1][0]
if torch.max(out) > 0:
return torch.argmax(out),weight_trace_d1,weight_trace_d2,Net
def updateNet(Net,reward, action, state,weight_trace_d1,weight_trace_d2):
"""
更新网络
:param Net: BDM-SNN网络
:param reward: 获得的奖励
:param action: 执行的行为
:param state: 执行行为前的状态
:param weight_trace_d1: 直接通路累积的资格迹
:param weight_trace_d2: 间接通路累积的资格迹
:return: 更新后的网络
"""
r = torch.ones((num_state, num_state * num_action), dtype=torch.float)
r[state, state * num_action + action] = reward
dw_d1 = r * weight_trace_d1
dw_d2 = -1 * r * weight_trace_d2
Net.UpdateWeight(0, state,num_action,dw_d1)
Net.UpdateWeight(1, state,num_action,dw_d2)
return Net
if __name__=="__main__":
"""
定义无人机 大疆Tello Talent
定义BDM-SNN网络
用户自定义状态空间、奖励函数,调用行为选择及网络更新
"""
#define UAV
tl_drone = robot.Drone()
tl_drone.initialize()
tl_flight = tl_drone.flight
tl_flight.takeoff().wait_for_completed()
#define Net
num_state=9
num_action=2
weight_exc=1
weight_inh=-0.5
trace_decay = 0.8
DM=BDMSNN(num_state,num_action,weight_exc,weight_inh,"lif")
con_matrix1 = torch.zeros((num_state, num_state * num_action), dtype=torch.float)
for i in range(num_state):
for j in range(num_action):
con_matrix1[i, i * num_action + j] = weight_exc
weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)
weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)
iteration=0
while iteration < 200:
input = torch.zeros((num_state), dtype=torch.float)
#users define the judgestate function
state=1
input[state]=2
action,weight_trace_d1,weight_trace_d2,DM = chooseAct(DM,input,weight_trace_d1,weight_trace_d2)
#uav do action
if action==0:
tl_flight.forward(distance=20).wait_for_completed()
if action == 1:
# flying left
tl_flight.rc(a=20, b=0, c=0, d=0)
time.sleep(4)
if action == 2:
# flying right
tl_flight.rc(a=-20, b=0, c=0, d=0)
time.sleep(3)
if action == 3:
tl_flight.backward(distance=20).wait_for_completed()
#users define the reward function
reward =1
DM=updateNet(DM,reward, action, state,weight_trace_d1,weight_trace_d2)
weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)
weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)
DM.reset()
iteration += 1
================================================
FILE: examples/decision_making/BDM-SNN/BDM-SNN-hh.py
================================================
import torch,os
from random import randint
import torch
from torch import nn
from braincog.base.strategy.surrogate import *
from braincog.model_zoo.bdmsnn import BDMSNN
import pygame
from pygame.locals import *
from collections import deque
from random import randint
import numpy as np
import matplotlib.pyplot as plt
import random
#os.environ["SDL_VIDEODRIVER"] = "dummy"
def load_images():
"""
Flappy Bird中load图像
:return:load的图像
"""
def load_image(img_file_name):
file_name = os.path.join('.', 'birdimages', img_file_name)
img = pygame.image.load(file_name)
# converting all images before use speeds up blitting
img.convert()
return img
return {'background': load_image('background.png'),
'pipe-end': load_image('pipe_end.png'),
'pipe-body': load_image('pipe_body.png'),
# images for animating the flapping bird -- animated GIFs are
# not supported in pygame
'bird-wingup': load_image('bird_wing_up.png'),
'bird-wingdown': load_image('bird_wing_down.png'),}
class Bird(pygame.sprite.Sprite):
"""
Flappy Bird类
"""
WIDTH = HEIGHT = 32
SINK_SPEED = 0.2
Fail_SINk_SPEED = 0.6
CLIMB_SPEED = 0.25
CLIMB_DURATION = 333.3
REGION = CLIMB_DURATION / 3
NEAR_COLLIDE = 30
NEAR_PIPE = 0
def __init__(self, x, y, msec_to_climb, images):
super(Bird, self).__init__()
self.x, self.y = x, y
self.msec_to_climb = msec_to_climb
self._img_wingup, self._img_wingdown = images
self._mask_wingup = pygame.mask.from_surface(self._img_wingup)
self._mask_wingdown = pygame.mask.from_surface(self._img_wingdown)
def update(self, action,state,delta_frames=1):
"""
更新小鸟的位置
:param action: 输入行为
:param state:输入状态
:param delta_frames:Fault
:return:None
"""
if self.msec_to_climb > 0 and action == 1:
if state==4 or state==5 or state == 2 or state == 3:
self.y -= (2*Bird.CLIMB_SPEED * (1000.0 * delta_frames / 60))
else:
self.y -= (Bird.CLIMB_SPEED * (1000.0 * delta_frames / 60))
else:
if state == 4 or state == 5 or state == 2 or state == 3:
self.y += 2*Bird.SINK_SPEED * (1000.0 * delta_frames / 60)
else:
self.y += Bird.SINK_SPEED * (1000.0 * delta_frames / 60)
def sink(self, delta_frames=1):
self.y += Bird.Fail_SINk_SPEED * (1000.0 * delta_frames / 60)
@property
def image(self):
if pygame.time.get_ticks() % 500 >= 250:
return self._img_wingup
else:
return self._img_wingdown
@property
def mask(self):
if pygame.time.get_ticks() % 500 >= 250:
return self._mask_wingup
else:
return self._mask_wingdown
@property
def rect(self):
return Rect(self.x, self.y, Bird.WIDTH, Bird.HEIGHT)
class PipePair(pygame.sprite.Sprite):
"""
Flappy Bird 中的管子类
"""
WIDTH = 80
PIECE_HEIGHT = 32
ADD_INTERVAL = 2000
ADD_EVENT = pygame.USEREVENT + 1
ROOM_HIGHT = 2 * Bird.HEIGHT + 2 * PIECE_HEIGHT
def __init__(self, pipe_end_img, pipe_body_img):
self.x = float(WIN_WIDTH - 1)
self.score_counted = False
self.isNewPipe = True
self.image = pygame.Surface((PipePair.WIDTH, WIN_HEIGHT), SRCALPHA)
self.image.convert() # speeds up blitting
self.image.fill((0, 0, 0, 0))
total_pipe_body_pieces = int(
(WIN_HEIGHT - # fill window from top to bottom
3 * Bird.HEIGHT - # make room for bird to fit through
3 * PipePair.PIECE_HEIGHT) / # 2 end pieces + 1 body piece
PipePair.PIECE_HEIGHT # to get number of pipe pieces
)
self.bottom_pieces = randint(1, total_pipe_body_pieces)
self.top_pieces = total_pipe_body_pieces - self.bottom_pieces
# bottom pipe
for i in range(1, self.bottom_pieces + 1):
piece_pos = (0, WIN_HEIGHT - i * PipePair.PIECE_HEIGHT)
self.image.blit(pipe_body_img, piece_pos)
bottom_pipe_end_y = WIN_HEIGHT - self.bottom_height_px
bottom_end_piece_pos = (0, bottom_pipe_end_y - PipePair.PIECE_HEIGHT)
self.image.blit(pipe_end_img, bottom_end_piece_pos)
# top pipe
for i in range(self.top_pieces):
self.image.blit(pipe_body_img, (0, i * PipePair.PIECE_HEIGHT))
top_pipe_end_y = self.top_height_px
self.image.blit(pipe_end_img, (0, top_pipe_end_y))
self.center = (top_pipe_end_y + bottom_pipe_end_y) / 2
# compensate for added end pieces
self.top_pieces += 1
self.bottom_pieces += 1
# for collision detection
self.mask = pygame.mask.from_surface(self.image)
self.top_y = top_pipe_end_y
self.bottom_y = bottom_pipe_end_y
@property
def top_height_px(self):
return self.top_pieces * PipePair.PIECE_HEIGHT
@property
def bottom_height_px(self):
return self.bottom_pieces * PipePair.PIECE_HEIGHT
@property
def visible(self):
return -PipePair.WIDTH < self.x < WIN_WIDTH
@property
def rect(self):
return Rect(self.x, 0, PipePair.WIDTH, PipePair.PIECE_HEIGHT)
def update(self, delta_frames=1):
self.x -= 0.18 * 1000.0 * delta_frames /60
def collides_with(self, bird):
return pygame.sprite.collide_mask(self, bird)
def chooseAct(Net,input,weight_trace_d1,weight_trace_d2):
"""
根据输入选择行为
:param Net: 输入BDM-SNN网络
:param input: 输入电流 编码状态的脉冲
:param weight_trace_d1: 不断累积保存资格迹
:param weight_trace_d2: 不断累积保存资格迹
:return: 返回选择的行为、资格迹和网络
"""
for i_train in range(500):
out, dw = Net(input)
# rstdp
weight_trace_d1 *= trace_decay
weight_trace_d1 += dw[0][0]
weight_trace_d2 *= trace_decay
weight_trace_d2 += dw[1][0]
if torch.max(out) > 0:
return torch.argmax(out),weight_trace_d1,weight_trace_d2,Net
def judgeState(bird, pipes, collide):
"""
根据小鸟和管子之间的位置关系判断当前状态
:param bird:传入小鸟的各项属性
:param pipes:传入管子的各项属性
:param collide:是否发生碰撞
:return:状态,距离,是否是新的管子
"""
# bird's x and y coordinate in the left top of the image
dist = bird.y + Bird.HEIGHT / 2 - WIN_HEIGHT / 2
isNew = False
index = -1
state = -1
if collide:
state = 8
return state
for p in pipes:
if p.x + PipePair.WIDTH - Bird.HEIGHT / 4 < bird.x and not p.score_counted:
continue
if p.x - Bird.NEAR_PIPE <= bird.x + Bird.HEIGHT and \
p.x + PipePair.WIDTH - Bird.HEIGHT / 4 >= bird.x:
p_top_y = p.top_y + PipePair.PIECE_HEIGHT
p_bottom_y = p.bottom_y - PipePair.PIECE_HEIGHT
if p.center - bird.y - Bird.HEIGHT / 2 >= 0 and bird.y >= p_top_y + Bird.NEAR_COLLIDE / 2:
state = 0
elif bird.y - p.center + Bird.HEIGHT / 2 > 0 and bird.y + Bird.HEIGHT <= p_bottom_y - Bird.NEAR_COLLIDE / 2:
state = 1
elif bird.y < p_top_y + Bird.NEAR_COLLIDE / 2 and bird.y > p_top_y - 10:
state = 6
elif bird.y + Bird.HEIGHT > p_bottom_y - Bird.NEAR_COLLIDE / 2 and bird.y + Bird.HEIGHT < p_bottom_y + 10:
state = 7
if state > -0.5:
index = 1
elif p.x > bird.x + Bird.HEIGHT + Bird.NEAR_PIPE:
state = blankState(bird, p.center)
if p.isNewPipe:
isNew = True
p.isNewPipe = False
index = 1
if index > 0: # only judge the nearest and not passed pipe
dist = bird.y + Bird.HEIGHT / 2 - p.center
break
if index < -0.5: # no pipe left, key the bird in the middle
pos = WIN_HEIGHT / 2
dist = bird.y + Bird.HEIGHT / 2 - pos
state = blankState(bird, pos)
return state, dist, isNew
def blankState(bird, center):
"""
judgeState中调用的判断状态的函数 根据鸟的位置和管子中心的距离来判断
:param bird: 传入小鸟的各项属性
:param center:中心
:return:状态
"""
realHeight = (PipePair.ROOM_HIGHT - Bird.HEIGHT) / 2
if center - bird.y - Bird.HEIGHT / 2 >= 0 and \
center - bird.y - Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2:
state = 0
elif bird.y - center + Bird.HEIGHT / 2 >= 0 and \
bird.y - center + Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2:
state = 1
elif center - bird.y - Bird.HEIGHT / 2 >= realHeight - Bird.NEAR_COLLIDE / 2 and \
center - bird.y - Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:
state = 2
elif bird.y - center + Bird.HEIGHT / 2 >= realHeight - Bird.NEAR_COLLIDE / 2 and \
bird.y - center + Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:
state = 3
elif bird.y + Bird.HEIGHT / 2 <= center - (realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION):
state = 4
elif bird.y + Bird.HEIGHT / 2 >= center + realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:
state = 5
return state
def getReward(state,lastState,smallerError,isNewPipe):
"""
根据状态和距离的变化获得奖励
:param state: 执行行为后的当前状态
:param lastState:执行行为之前的上一状态
:param smallerError:距离是否变小
:param isNewPipe:是否是新的管子
:return:奖励
"""
if state == 0 or state == 1:
reward = 6
elif state == 2 or state == 3:
if lastState == state and not isNewPipe:
if smallerError:
reward = 3
else:
reward = -5
else:
reward = -3
elif state == 4 or state == 5:
if lastState == state and not isNewPipe:
if smallerError:
reward = 3
else:
reward = -8
else:
reward = -5
elif state == 6 or state == 7:
if lastState == state and not isNewPipe:
if smallerError:
reward = 3
else:
reward = -3
else:
reward = -3
elif state == 8: # collide
reward = -100
return reward
def updateNet(Net,reward, action, state,weight_trace_d1,weight_trace_d2):
"""
更新网络
:param Net: BDM-SNN网络
:param reward: 获得的奖励
:param action: 执行的行为
:param state: 执行行为前的状态
:param weight_trace_d1: 直接通路累积的资格迹
:param weight_trace_d2: 间接通路累积的资格迹
:return: 更新后的网络
"""
r = torch.ones((num_state, num_state * num_action), dtype=torch.float)
r[state, state * num_action + action] = reward
dw_d1 = r * weight_trace_d1
dw_d2 = -1 * r * weight_trace_d2
Net.UpdateWeight(0, state,num_action,dw_d1)
Net.UpdateWeight(1, state,num_action,dw_d2)
return Net
if __name__=="__main__":
"""
执行网络,运行Flappy Bird游戏
"""
num_state=9
num_action=2
weight_exc=50
weight_inh=-60
trace_decay = 0.8
DM=BDMSNN(num_state,num_action,weight_exc,weight_inh,"hh")
con_matrix1 = torch.zeros((num_state, num_state * num_action), dtype=torch.float)
for i in range(num_state):
for j in range(num_action):
con_matrix1[i, i * num_action + j] = weight_exc
weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)
weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)
pygame.init()
WIN_HEIGHT = 512
WIN_WIDTH = 284 * 2
heighest = 0
display_frame=0
display_surface = pygame.display.set_mode((WIN_WIDTH, WIN_HEIGHT))
pygame.display.set_caption('Flappy Bird')
images = load_images()
bird = Bird(250, int(WIN_HEIGHT / 2 - Bird.HEIGHT / 2), 2,
(images['bird-wingup'], images['bird-wingdown']))
clock = pygame.time.Clock()
score_font = pygame.font.SysFont(None, 25, bold=True)
info_font = pygame.font.SysFont(None, 50, bold=True)
collide = paused = False
frame_clock = 0
pipes = deque()
score = 0
lastDist = 0
lastState = 0 #init
state = lastState
num=0
num_reward=[]
num_score=[]
while not collide:
num=num+1
if num>30000:
break
input = torch.zeros((num_state), dtype=torch.float)
clock.tick(60)
if frame_clock %2==0 or frame_clock==1:
state, dist, isNewPipe = judgeState(bird, pipes, collide)
lastState = state
lastDist = dist
input[state]=2
print(input)
action,weight_trace_d1,weight_trace_d2,DM = chooseAct(DM,input,weight_trace_d1,weight_trace_d2)
print("state, dist:", state, dist)
print("state, action:",state,action)
if not (paused or frame_clock % (60 * PipePair.ADD_INTERVAL / 1000.0)):
pygame.event.post(pygame.event.Event(PipePair.ADD_EVENT))
for e in pygame.event.get():
if e.type == QUIT or (e.type == KEYUP and e.key == K_ESCAPE):
collide = True
elif e.type == KEYUP and e.key in (K_PAUSE, K_p):
paused = not paused
elif e.type == PipePair.ADD_EVENT:
pp = PipePair(images['pipe-end'], images['pipe-body'])
pipes.append(pp)
if paused:
continue # don't draw anything
pipe_collision = any(p.collides_with(bird) for p in pipes)
if pipe_collision or 0 >= bird.y or bird.y >= WIN_HEIGHT - Bird.HEIGHT:
collide = True
for x in (0, WIN_WIDTH / 2):
display_surface.blit(images['background'], (x, 0))
while pipes and not pipes[0].visible:
pipes.popleft()
for p in pipes:
p.update()
display_surface.blit(p.image, p.rect)
bird.update(action,state)
display_surface.blit(bird.image, bird.rect)
if frame_clock %2==0 or frame_clock==1 or collide:
dist = 0
if collide:
nextState = 8
isNewPipe = False
else:
nextState, dist, isNewPipe = judgeState(bird, pipes, collide) # judge the bird's state
print("next state:", nextState)
print("lastdist, dist:", lastDist,dist)
isSmallerError = False
if state == nextState:
isSmallerError = False
if lastDist <= 0:
if lastDist < dist:
isSmallerError = True
else:
if lastDist > dist:
isSmallerError = True
if frame_clock>0 and not collide:
reward = getReward(nextState, state, isSmallerError, isNewPipe)
print("reward:", reward)
num_reward.append(reward)
DM=updateNet(DM,reward, action, state,weight_trace_d1,weight_trace_d2)
state = nextState #going on the next state
weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)
weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)
DM.reset()
display_frame += 1
for p in pipes:
if p.x + PipePair.WIDTH < bird.x and not p.score_counted:
score += 1
p.score_counted = True
num_score.append(score)
score_surface = score_font.render('Current score: ' + str(score), True, (0, 0, 0)) # current score
score_x = WIN_WIDTH / 2 - 3 * score_surface.get_width() / 4
display_surface.blit(score_surface, (score_x, PipePair.PIECE_HEIGHT))
if heighest < score:
heighest = score
score_surface_h = score_font.render('Highest score: ' + str(heighest), True,
(0, 0, 0)) # heighest score
score_x_h = 4 * WIN_WIDTH / 5 - 1.2 * score_surface.get_width() / 3
display_surface.blit(score_surface_h, (score_x_h, PipePair.PIECE_HEIGHT))
score_surface_i = score_font.render('Attempts: 0', True, (0, 0, 0)) # heighest score
score_x_i = 10
display_surface.blit(score_surface_i, (score_x_i, PipePair.PIECE_HEIGHT))
frame_clock += 1
pygame.display.flip()
# if collide, display the fail information, for 2 frames
cct = 0
while (bird.y < WIN_HEIGHT - Bird.HEIGHT - 3):
clock.tick(60)
for x in (0, WIN_WIDTH / 2):
display_surface.blit(images['background'], (x, 0))
while pipes and not pipes[0].visible:
pipes.popleft()
for p in pipes:
display_surface.blit(p.image, p.rect)
if cct >= 6:
bird.sink()
display_surface.blit(bird.image, bird.rect)
fail_infor = info_font.render('Game over !', True, (255, 60, 30)) # current score
pos_x = WIN_WIDTH / 2 - fail_infor.get_width() / 2
pos_y = WIN_HEIGHT / 2 - 100
display_surface.blit(fail_infor, (pos_x, pos_y))
# display the score
score_surface = score_font.render('Current score: ' + str(score), True, (0, 0, 0)) # current score
score_x = WIN_WIDTH / 2 - 3 * score_surface.get_width() / 4
display_surface.blit(score_surface, (score_x, PipePair.PIECE_HEIGHT))
if heighest < score:
heighest = score
score_surface_h = score_font.render('Highest score: ' + str(heighest), True,
(0, 0, 0)) # heighest score
score_x_h = 4 * WIN_WIDTH / 5 - 1.2 * score_surface.get_width() / 3
display_surface.blit(score_surface_h, (score_x_h, PipePair.PIECE_HEIGHT))
score_surface_i = score_font.render('Attempts: 0' , True, (0, 0, 0)) # heighest score
score_x_i = 10
display_surface.blit(score_surface_i, (score_x_i, PipePair.PIECE_HEIGHT))
pygame.display.flip()
cct += 1
if heighest < score:
heighest = score
num_reward_np=np.array(num_reward)
num_score_np=np.array(num_score)
print(num_reward_np,num_score_np)
np.save('hh_reward_l.npy', num_reward_np)
np.save('hh_score_l.npy', num_score_np)
print(score)
================================================
FILE: examples/decision_making/BDM-SNN/BDM-SNN.py
================================================
import torch
import os
from braincog.model_zoo.bdmsnn import BDMSNN
import pygame
from pygame.locals import *
from collections import deque
from random import randint
import numpy as np
try:
pygame.display.init()
except:
os.environ["SDL_VIDEODRIVER"] = "dummy"
def load_images():
"""
Flappy Bird中load图像
:return:load的图像
"""
def load_image(img_file_name):
file_name = os.path.join('.', 'birdimages', img_file_name)
img = pygame.image.load(file_name)
# converting all images before use speeds up blitting
img.convert()
return img
return {'background': load_image('background.png'),
'pipe-end': load_image('pipe_end.png'),
'pipe-body': load_image('pipe_body.png'),
# images for animating the flapping bird -- animated GIFs are
# not supported in pygame
'bird-wingup': load_image('bird_wing_up.png'),
'bird-wingdown': load_image('bird_wing_down.png'), }
class Bird(pygame.sprite.Sprite):
"""
Flappy Bird类
"""
WIDTH = HEIGHT = 32
SINK_SPEED = 0.2
Fail_SINk_SPEED = 0.6
CLIMB_SPEED = 0.25
CLIMB_DURATION = 333.3
REGION = CLIMB_DURATION / 3
NEAR_COLLIDE = 30
NEAR_PIPE = 0
def __init__(self, x, y, msec_to_climb, images):
super(Bird, self).__init__()
self.x, self.y = x, y
self.msec_to_climb = msec_to_climb
self._img_wingup, self._img_wingdown = images
self._mask_wingup = pygame.mask.from_surface(self._img_wingup)
self._mask_wingdown = pygame.mask.from_surface(self._img_wingdown)
def update(self, action, state, delta_frames=1):
"""
更新小鸟的位置
:param action: 输入行为
:param state:输入状态
:param delta_frames:Fault
:return:None
"""
if self.msec_to_climb > 0 and action == 1:
if state == 4 or state == 5 or state == 2 or state == 3:
self.y -= (2 * Bird.CLIMB_SPEED * (1000.0 * delta_frames / 60))
else:
self.y -= (Bird.CLIMB_SPEED * (1000.0 * delta_frames / 60))
else:
if state == 4 or state == 5 or state == 2 or state == 3:
self.y += 2 * Bird.SINK_SPEED * (1000.0 * delta_frames / 60)
else:
self.y += Bird.SINK_SPEED * (1000.0 * delta_frames / 60)
def sink(self, delta_frames=1):
self.y += Bird.Fail_SINk_SPEED * (1000.0 * delta_frames / 60)
@property
def image(self):
if pygame.time.get_ticks() % 500 >= 250:
return self._img_wingup
else:
return self._img_wingdown
@property
def mask(self):
if pygame.time.get_ticks() % 500 >= 250:
return self._mask_wingup
else:
return self._mask_wingdown
@property
def rect(self):
return Rect(self.x, self.y, Bird.WIDTH, Bird.HEIGHT)
class PipePair(pygame.sprite.Sprite):
"""
Flappy Bird 中的管子类
"""
WIDTH = 80
PIECE_HEIGHT = 32
ADD_INTERVAL = 2000
ADD_EVENT = pygame.USEREVENT + 1
ROOM_HIGHT = 2 * Bird.HEIGHT + 2 * PIECE_HEIGHT
def __init__(self, pipe_end_img, pipe_body_img):
self.x = float(WIN_WIDTH - 1)
self.score_counted = False
self.isNewPipe = True
self.image = pygame.Surface((PipePair.WIDTH, WIN_HEIGHT), SRCALPHA)
self.image.convert() # speeds up blitting
self.image.fill((0, 0, 0, 0))
total_pipe_body_pieces = int(
(WIN_HEIGHT - # fill window from top to bottom
3 * Bird.HEIGHT - # make room for bird to fit through
3 * PipePair.PIECE_HEIGHT) / # 2 end pieces + 1 body piece
PipePair.PIECE_HEIGHT # to get number of pipe pieces
)
self.bottom_pieces = randint(1, total_pipe_body_pieces)
self.top_pieces = total_pipe_body_pieces - self.bottom_pieces
# bottom pipe
for i in range(1, self.bottom_pieces + 1):
piece_pos = (0, WIN_HEIGHT - i * PipePair.PIECE_HEIGHT)
self.image.blit(pipe_body_img, piece_pos)
bottom_pipe_end_y = WIN_HEIGHT - self.bottom_height_px
bottom_end_piece_pos = (0, bottom_pipe_end_y - PipePair.PIECE_HEIGHT)
self.image.blit(pipe_end_img, bottom_end_piece_pos)
# top pipe
for i in range(self.top_pieces):
self.image.blit(pipe_body_img, (0, i * PipePair.PIECE_HEIGHT))
top_pipe_end_y = self.top_height_px
self.image.blit(pipe_end_img, (0, top_pipe_end_y))
self.center = (top_pipe_end_y + bottom_pipe_end_y) / 2
# compensate for added end pieces
self.top_pieces += 1
self.bottom_pieces += 1
# for collision detection
self.mask = pygame.mask.from_surface(self.image)
self.top_y = top_pipe_end_y
self.bottom_y = bottom_pipe_end_y
@property
def top_height_px(self):
return self.top_pieces * PipePair.PIECE_HEIGHT
@property
def bottom_height_px(self):
return self.bottom_pieces * PipePair.PIECE_HEIGHT
@property
def visible(self):
return -PipePair.WIDTH < self.x < WIN_WIDTH
@property
def rect(self):
return Rect(self.x, 0, PipePair.WIDTH, PipePair.PIECE_HEIGHT)
def update(self, delta_frames=1):
self.x -= 0.18 * 1000.0 * delta_frames / 60
def collides_with(self, bird):
return pygame.sprite.collide_mask(self, bird)
def chooseAct(Net, input, weight_trace_d1, weight_trace_d2):
"""
根据输入选择行为
:param Net: 输入BDM-SNN网络
:param input: 输入电流 编码状态的脉冲
:param weight_trace_d1: 不断累积保存资格迹
:param weight_trace_d2: 不断累积保存资格迹
:return: 返回选择的行为、资格迹和网络
"""
for i_train in range(500):
out, dw = Net(input)
# rstdp
weight_trace_d1 *= trace_decay
weight_trace_d1 += dw[0][0]
weight_trace_d2 *= trace_decay
weight_trace_d2 += dw[1][0]
if torch.max(out) > 0:
return torch.argmax(out), weight_trace_d1, weight_trace_d2, Net
def judgeState(bird, pipes, collide):
"""
根据小鸟和管子之间的位置关系判断当前状态
:param bird:传入小鸟的各项属性
:param pipes:传入管子的各项属性
:param collide:是否发生碰撞
:return:状态,距离,是否是新的管子
"""
# bird's x and y coordinate in the left top of the image
dist = bird.y + Bird.HEIGHT / 2 - WIN_HEIGHT / 2
isNew = False
index = -1
state = -1
if collide:
state = 8
return state
for p in pipes:
if p.x + PipePair.WIDTH - Bird.HEIGHT / 4 < bird.x and not p.score_counted:
continue
if p.x - Bird.NEAR_PIPE <= bird.x + Bird.HEIGHT and \
p.x + PipePair.WIDTH - Bird.HEIGHT / 4 >= bird.x:
p_top_y = p.top_y + PipePair.PIECE_HEIGHT
p_bottom_y = p.bottom_y - PipePair.PIECE_HEIGHT
if p.center - bird.y - Bird.HEIGHT / 2 >= 0 and bird.y >= p_top_y + Bird.NEAR_COLLIDE / 2:
state = 0
elif bird.y - p.center + Bird.HEIGHT / 2 > 0 and bird.y + Bird.HEIGHT <= p_bottom_y - Bird.NEAR_COLLIDE / 2:
state = 1
elif bird.y < p_top_y + Bird.NEAR_COLLIDE / 2 and bird.y > p_top_y - 10:
state = 6
elif bird.y + Bird.HEIGHT > p_bottom_y - Bird.NEAR_COLLIDE / 2 and bird.y + Bird.HEIGHT < p_bottom_y + 10:
state = 7
if state > -0.5:
index = 1
elif p.x > bird.x + Bird.HEIGHT + Bird.NEAR_PIPE:
state = blankState(bird, p.center)
if p.isNewPipe:
isNew = True
p.isNewPipe = False
index = 1
if index > 0: # only judge the nearest and not passed pipe
dist = bird.y + Bird.HEIGHT / 2 - p.center
break
if index < -0.5: # no pipe left, key the bird in the middle
pos = WIN_HEIGHT / 2
dist = bird.y + Bird.HEIGHT / 2 - pos
state = blankState(bird, pos)
return state, dist, isNew
def blankState(bird, center):
"""
judgeState中调用的判断状态的函数 根据鸟的位置和管子中心的距离来判断
:param bird: 传入小鸟的各项属性
:param center:中心
:return:状态
"""
realHeight = (PipePair.ROOM_HIGHT - Bird.HEIGHT) / 2
if center - bird.y - Bird.HEIGHT / 2 >= 0 and \
center - bird.y - Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2:
state = 0
elif bird.y - center + Bird.HEIGHT / 2 >= 0 and \
bird.y - center + Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2:
state = 1
elif center - bird.y - Bird.HEIGHT / 2 >= realHeight - Bird.NEAR_COLLIDE / 2 and \
center - bird.y - Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:
state = 2
elif bird.y - center + Bird.HEIGHT / 2 >= realHeight - Bird.NEAR_COLLIDE / 2 and \
bird.y - center + Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:
state = 3
elif bird.y + Bird.HEIGHT / 2 <= center - (realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION):
state = 4
elif bird.y + Bird.HEIGHT / 2 >= center + realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:
state = 5
return state
def getReward(state, lastState, smallerError, isNewPipe):
"""
根据状态和距离的变化获得奖励
:param state: 执行行为后的当前状态
:param lastState:执行行为之前的上一状态
:param smallerError:距离是否变小
:param isNewPipe:是否是新的管子
:return:奖励
"""
if state == 0 or state == 1:
reward = 6
elif state == 2 or state == 3:
if lastState == state and not isNewPipe:
if smallerError:
reward = 3
else:
reward = -5
else:
reward = -3
elif state == 4 or state == 5:
if lastState == state and not isNewPipe:
if smallerError:
reward = 3
else:
reward = -8
else:
reward = -5
elif state == 6 or state == 7:
if lastState == state and not isNewPipe:
if smallerError:
reward = 3
else:
reward = -3
else:
reward = -3
elif state == 8: # collide
reward = -100
return reward
def updateNet(Net, reward, action, state, weight_trace_d1, weight_trace_d2):
"""
更新网络
:param Net: BDM-SNN网络
:param reward: 获得的奖励
:param action: 执行的行为
:param state: 执行行为前的状态
:param weight_trace_d1: 直接通路累积的资格迹
:param weight_trace_d2: 间接通路累积的资格迹
:return: 更新后的网络
"""
r = torch.ones((num_state, num_state * num_action), dtype=torch.float)
r[state, state * num_action + action] = reward
dw_d1 = r * weight_trace_d1
dw_d2 = -1 * r * weight_trace_d2
Net.UpdateWeight(0, state, num_action, dw_d1)
Net.UpdateWeight(1, state, num_action, dw_d2)
return Net
if __name__ == "__main__":
"""
执行网络,运行Flappy Bird游戏
"""
num_state = 9
num_action = 2
weight_exc = 1
weight_inh = -0.5
trace_decay = 0.8
DM = BDMSNN(num_state, num_action, weight_exc, weight_inh, "lif")
con_matrix1 = torch.zeros((num_state, num_state * num_action), dtype=torch.float)
for i in range(num_state):
for j in range(num_action):
con_matrix1[i, i * num_action + j] = weight_exc
weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)
weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)
pygame.init()
WIN_HEIGHT = 512
WIN_WIDTH = 284 * 2
heighest = 0
contTime = 0
display_frame = 0
display_surface = pygame.display.set_mode((WIN_WIDTH, WIN_HEIGHT))
pygame.display.set_caption('Flappy Bird')
images = load_images()
bird = Bird(250, int(WIN_HEIGHT / 2 - Bird.HEIGHT / 2), 2,
(images['bird-wingup'], images['bird-wingdown']))
clock = pygame.time.Clock()
score_font = pygame.font.SysFont(None, 25, bold=True)
info_font = pygame.font.SysFont(None, 50, bold=True)
collide = paused = False
frame_clock = 0
pipes = deque()
score = 0
lastDist = 0
lastState = 0 # init
state = lastState
num = 0
num_reward = []
num_score = []
while not collide:
num = num + 1
if num > 30000:
break
input = torch.zeros((num_state), dtype=torch.float)
clock.tick(60)
if frame_clock % 2 == 0 or frame_clock == 1:
state, dist, isNewPipe = judgeState(bird, pipes, collide)
lastState = state
lastDist = dist
input[state] = 2
action, weight_trace_d1, weight_trace_d2, DM = chooseAct(DM, input, weight_trace_d1, weight_trace_d2)
print("state, dist:", state, dist)
print("state, action:", state, action)
if not (paused or frame_clock % (60 * PipePair.ADD_INTERVAL / 1000.0)):
pygame.event.post(pygame.event.Event(PipePair.ADD_EVENT))
for e in pygame.event.get():
if e.type == QUIT or (e.type == KEYUP and e.key == K_ESCAPE):
collide = True
elif e.type == KEYUP and e.key in (K_PAUSE, K_p):
paused = not paused
elif e.type == PipePair.ADD_EVENT:
pp = PipePair(images['pipe-end'], images['pipe-body'])
pipes.append(pp)
if paused:
continue # don't draw anything
pipe_collision = any(p.collides_with(bird) for p in pipes)
if pipe_collision or 0 >= bird.y or bird.y >= WIN_HEIGHT - Bird.HEIGHT:
collide = True
for x in (0, WIN_WIDTH / 2):
display_surface.blit(images['background'], (x, 0))
while pipes and not pipes[0].visible:
pipes.popleft()
for p in pipes:
p.update()
display_surface.blit(p.image, p.rect)
bird.update(action, state)
display_surface.blit(bird.image, bird.rect)
if frame_clock % 2 == 0 or frame_clock == 1 or collide:
dist = 0
if collide:
nextState = 8
isNewPipe = False
else:
nextState, dist, isNewPipe = judgeState(bird, pipes, collide) # judge the bird's state
print("next state:", nextState)
print("lastdist, dist:", lastDist, dist)
isSmallerError = False
if state == nextState:
isSmallerError = False
if lastDist <= 0:
if lastDist < dist:
isSmallerError = True
else:
if lastDist > dist:
isSmallerError = True
if frame_clock > 0 and not collide:
reward = getReward(nextState, state, isSmallerError, isNewPipe)
print("reward:", reward)
num_reward.append(reward)
DM = updateNet(DM, reward, action, state, weight_trace_d1, weight_trace_d2)
state = nextState # going on the next state
weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)
weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)
DM.reset()
display_frame += 1
for p in pipes:
if p.x + PipePair.WIDTH < bird.x and not p.score_counted:
score += 1
p.score_counted = True
num_score.append(score)
score_surface = score_font.render('Current score: ' + str(score), True, (0, 0, 0)) # current score
score_x = WIN_WIDTH / 2 - 3 * score_surface.get_width() / 4
display_surface.blit(score_surface, (score_x, PipePair.PIECE_HEIGHT))
if heighest < score:
heighest = score
score_surface_h = score_font.render('Highest score: ' + str(heighest), True,
(0, 0, 0)) # heighest score
score_x_h = 4 * WIN_WIDTH / 5 - 1.2 * score_surface.get_width() / 3
display_surface.blit(score_surface_h, (score_x_h, PipePair.PIECE_HEIGHT))
score_surface_i = score_font.render('Attempts: 0', True, (0, 0, 0)) # heighest score
score_x_i = 10
display_surface.blit(score_surface_i, (score_x_i, PipePair.PIECE_HEIGHT))
frame_clock += 1
pygame.display.flip()
# if collide, display the fail information, for 2 frames
cct = 0
while (bird.y < WIN_HEIGHT - Bird.HEIGHT - 3):
clock.tick(60)
for x in (0, WIN_WIDTH / 2):
display_surface.blit(images['background'], (x, 0))
while pipes and not pipes[0].visible:
pipes.popleft()
for p in pipes:
display_surface.blit(p.image, p.rect)
if cct >= 6:
bird.sink()
display_surface.blit(bird.image, bird.rect)
fail_infor = info_font.render('Game over !', True, (255, 60, 30)) # current score
pos_x = WIN_WIDTH / 2 - fail_infor.get_width() / 2
pos_y = WIN_HEIGHT / 2 - 100
display_surface.blit(fail_infor, (pos_x, pos_y))
# display the score
score_surface = score_font.render('Current score: ' + str(score), True, (0, 0, 0)) # current score
score_x = WIN_WIDTH / 2 - 3 * score_surface.get_width() / 4
display_surface.blit(score_surface, (score_x, PipePair.PIECE_HEIGHT))
if heighest < score:
heighest = score
score_surface_h = score_font.render('Highest score: ' + str(heighest), True,
(0, 0, 0)) # heighest score
score_x_h = 4 * WIN_WIDTH / 5 - 1.2 * score_surface.get_width() / 3
display_surface.blit(score_surface_h, (score_x_h, PipePair.PIECE_HEIGHT))
score_surface_i = score_font.render('Attempts: 0', True, (0, 0, 0)) # heighest score
score_x_i = 10
display_surface.blit(score_surface_i, (score_x_i, PipePair.PIECE_HEIGHT))
pygame.display.flip()
cct += 1
if heighest < score:
heighest = score
contTime += 1
num_reward_np = np.array(num_reward)
num_score_np = np.array(num_score)
print(num_reward_np, num_score_np)
np.save('lif_reward_l.npy', num_reward_np)
np.save('lif_score_l.npy', num_score_np)
print(score)
================================================
FILE: examples/decision_making/BDM-SNN/README.md
================================================
# Brain-inspired Decision-Making SNN
## Requirements
"decisionmaking.py", "BDM-SNN.py","BDM-SNN-hh.py":pygame
"BDM-SNN-UAV.py":robomaster
## Run
The decisionmaking.py and BDM-SNN.py implements the core code of the brain-inspired decision-making spiking neural network in paper entitled "A brain-inspired decision-making spiking neural network and its application in unmanned aerial vehicle".
"decisionmaking.py, BDM-SNN.py" includes the multi-brain regions coordinated decision-making spiking neural network with LIF neurons.
"BDM-SNN-hh.py" includes the BDM-SNN with simplified HH neurons.
"BDM-SNN-UAV.py" includes the BDM-SNN applied to the UAV (DJI Tello talent), users need to define the reinforcement learning task.
```shell
python decisionmaking.py
python BDM-SNN.py
python BDM-SNN-hh.py
python BDM-SNN-UAV.py
```
## Results
"decisionmaking.py", "BDM-SNN.py" and "BDM-SNN-hh.py" have been verified on Flappy Bird game. BDM-SNN could stably pass the pipeline on the first try.

Differences from the original article: an improved reward-modulated STDP learning rule.
## Citation
If you find this package helpful, please consider citing the following papers:
```BibTex
@article{zhao2018brain,
title={A brain-inspired decision-making spiking neural network and its application in unmanned aerial vehicle},
author={Zhao, Feifei and Zeng, Yi and Xu, Bo},
journal={Frontiers in neurorobotics},
volume={12},
pages={56},
year={2018},
publisher={Frontiers Media SA}
}
@misc{https://doi.org/10.48550/arxiv.2207.08533,
doi = {10.48550/ARXIV.2207.08533},
url = {https://arxiv.org/abs/2207.08533},
author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},
title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},
publisher = {arXiv},
year = {2022},
}
```
================================================
FILE: examples/decision_making/BDM-SNN/decisionmaking.py
================================================
import numpy as np
import torch,os,sys
from torch import nn
from torch.nn import Parameter
import abc
import math
from abc import ABC
import torch.nn.functional as F
import matplotlib.pyplot as plt
#from BrainCog.base.strategy.surrogate import *
from braincog.base.node.node import IFNode, SimHHNode
from braincog.base.learningrule.STDP import STDP, MutliInputSTDP
from braincog.base.connection.CustomLinear import CustomLinear
from braincog.base.brainarea.basalganglia import basalganglia
#from braincog.model_zoo.bdmsnn import BDMSNN
import pygame
from pygame.locals import *
from collections import deque
from random import randint
#os.environ["SDL_VIDEODRIVER"] = "dummy"
class BDMSNN(nn.Module):
def __init__(self, num_state, num_action, weight_exc, weight_inh, node_type):
"""
定义BDM-SNN网络
:param num_state: 状态个数
:param num_action: 动作个数
:param weight_exc: 兴奋性连接权重
:param weight_inh: 抑制性连接权重
"""
super().__init__()
# parameters
BG = basalganglia(num_state, num_action, weight_exc, weight_inh, node_type)
dm_connection = BG.getweight()
dm_mask = BG.getmask()
# input-dlpfc
con_matrix9 = torch.eye((num_state), dtype=torch.float)
dm_connection.append(CustomLinear(weight_exc * con_matrix9, con_matrix9))
dm_mask.append(con_matrix9)
# gpi-th
con_matrix10 = torch.eye((num_action), dtype=torch.float)
dm_mask.append(con_matrix10)
dm_connection.append(CustomLinear(weight_inh * con_matrix10, con_matrix10))
# th-pm
dm_mask.append(con_matrix10)
dm_connection.append(CustomLinear(weight_exc * con_matrix10, con_matrix10))
# dlpfc-th
con_matrix11 = torch.ones((num_state, num_action), dtype=torch.float)
dm_mask.append(con_matrix11)
dm_connection.append(CustomLinear(0.2 * weight_exc * con_matrix11, con_matrix11))
# pm-pm
con_matrix3 = torch.ones((num_action, num_action), dtype=torch.float)
con_matrix4 = torch.eye((num_action), dtype=torch.float)
con_matrix5 = con_matrix3 - con_matrix4
con_matrix5 = con_matrix5
dm_mask.append(con_matrix5)
dm_connection.append(CustomLinear(5 * weight_inh * con_matrix5, con_matrix5))
# dlpfc thalamus pm +bg
self.weight_exc = weight_exc
self.num_subDM = 8
self.connection = dm_connection
self.mask = dm_mask
self.node = BG.node
self.node_type = node_type
if self.node_type == "hh":
self.node.extend([SimHHNode() for i in range(self.num_subDM - BG.num_subBG)])
self.node[6].g_Na = torch.tensor(12)
self.node[6].g_K = torch.tensor(3.6)
self.node[6].g_L = torch.tensor(0.03)
if self.node_type == "lif":
self.node.extend([IFNode() for i in range(self.num_subDM - BG.num_subBG)])
self.learning_rule = BG.learning_rule
self.learning_rule.append(MutliInputSTDP(self.node[5], [self.connection[10], self.connection[12]])) # gpi-丘脑
self.learning_rule.append(MutliInputSTDP(self.node[6], [self.connection[11], self.connection[13]])) # pm
self.learning_rule.append(STDP(self.node[7], self.connection[9]))
out_shape=[self.connection[0].weight.shape[1],self.connection[1].weight.shape[1],self.connection[2].weight.shape[1],self.connection[4].weight.shape[1],self.connection[3].weight.shape[1],self.connection[10].weight.shape[1],self.connection[11].weight.shape[1],self.connection[9].weight.shape[1]]
self.out = []
self.dw = []
for i in range(self.num_subDM):
self.out.append(torch.zeros((out_shape[i]), dtype=torch.float))
self.dw.append(torch.zeros((out_shape[i]), dtype=torch.float))
def forward(self, input):
"""
根据输入得到网络的输出
:param input: 输入
:return: 网络的输出
"""
self.out[7] = self.node[7](self.connection[9](input))
self.out[0], self.dw[0] = self.learning_rule[0](self.out[7])
self.out[1], self.dw[1] = self.learning_rule[1](self.out[7])
self.out[2], self.dw[2] = self.learning_rule[2](self.out[7], self.out[3])
self.out[3], self.dw[3] = self.learning_rule[3](self.out[1], self.out[2])
self.out[4], self.dw[4] = self.learning_rule[4](self.out[0], self.out[3], self.out[2])
self.out[5], self.dw[5] = self.learning_rule[5](self.out[4], self.out[7])
self.out[6], self.dw[6] = self.learning_rule[6](self.out[5], self.out[6])
br = ["StrD1", "StrD2", "STN", "Gpe", "Gpi", "thalamus", "PM", "DLPFC"]
for i in range(self.num_subDM):
if torch.max(self.out[i]) > 0 and self.node_type == "hh":
self.node[i].n_reset()
print("every areas:", br[i], self.out[i])
return self.out[6], self.dw
def UpdateWeight(self, i, s, num_action, dw):
"""
更新网络中第i组连接的权重
:param i:要更新的连接组索引
:param s:传入状态
:param dw:更新权重的量
:return:
"""
if self.node_type == "hh":
self.connection[i].update(0.2 * self.weight_exc * dw)
self.connection[i].weight.data[s, [s * num_action, s * num_action + 1]] /= (self.connection[i].weight.data[s, [s * num_action, s * num_action + 1]].float().max() + 1e-12)
self.connection[i].weight.data[s, :] = self.connection[i].weight.data[s, :] * self.weight_exc
if self.node_type == "lif":
dw_mean = dw[s, [s * num_action, s * num_action + 1]].mean()
dw_std = dw[s, [s * num_action, s * num_action + 1]].std()
dw[s, [s * num_action, s * num_action + 1]] = (dw[s, [s * num_action,s * num_action + 1]] - dw_mean) / dw_std
dw[s, :] = dw[s, :] * self.mask[i][s, :]
self.connection[i].update(dw)
self.connection[i].weight.data[s, [s * num_action, s * num_action + 1]] /= (self.connection[i].weight.data[s, [s * num_action, s * num_action + 1]].float().max() + 1e-12)
if i in [0, 1, 2, 6, 7, 11, 12]:
self.connection[i].weight.data = torch.clamp(self.connection[i].weight.data, 0, None)
if i in [3, 4, 5, 8, 10]:
self.connection[i].weight.data = torch.clamp(self.connection[i].weight.data, None, 0)
def reset(self):
"""
reset神经元或学习法则的中间量
:return: None
"""
for i in range(self.num_subDM):
self.node[i].n_reset()
for i in range(len(self.learning_rule)):
self.learning_rule[i].reset()
def getweight(self):
"""
获取网络的连接(包括权值等)
:return: 网络的连接
"""
return self.connection
def load_images():
"""Load all images required by the game and return a dict of them.
The returned dict has the following keys:
background: The game's background image.
bird-wingup: An image of the bird with its wing pointing upward.
Use this and bird-wingdown to create a flapping bird.
bird-wingdown: An image of the bird with its wing pointing downward.
Use this and bird-wingup to create a flapping bird.
pipe-end: An image of a pipe's end piece (the slightly wider bit).
Use this and pipe-body to make pipes.
pipe-body: An image of a slice of a pipe's body. Use this and
pipe-body to make pipes.
"""
def load_image(img_file_name):
"""Return the loaded pygame image with the specified file name.
This function looks for images in the game's images folder
(./images/). All images are converted before being returned to
speed up blitting.
Arguments:
img_file_name: The file name (including its extension, e.g.
'.png') of the required image, without a file path.
"""
file_name = os.path.join('.', 'birdimages', img_file_name)
img = pygame.image.load(file_name)
# converting all images before use speeds up blitting
img.convert()
return img
return {'background': load_image('background.png'),
'pipe-end': load_image('pipe_end.png'),
'pipe-body': load_image('pipe_body.png'),
# images for animating the flapping bird -- animated GIFs are
# not supported in pygame
'bird-wingup': load_image('bird_wing_up.png'),
'bird-wingdown': load_image('bird_wing_down.png'),}
class Bird(pygame.sprite.Sprite):
WIDTH = HEIGHT = 32
SINK_SPEED = 0.2
Fail_SINk_SPEED = 0.6
CLIMB_SPEED = 0.25
CLIMB_DURATION = 333.3
REGION = CLIMB_DURATION / 3 # when far from the pipe, the bird can fluctuate in a certain region,wgx
NEAR_COLLIDE = 30 # when inside the pipe, near collide distance, this define another state
NEAR_PIPE = 0 # at what distance does the bird near the pipe
def __init__(self, x, y, msec_to_climb, images):
super(Bird, self).__init__()
self.x, self.y = x, y
self.msec_to_climb = msec_to_climb
self._img_wingup, self._img_wingdown = images
self._mask_wingup = pygame.mask.from_surface(self._img_wingup)
self._mask_wingdown = pygame.mask.from_surface(self._img_wingdown)
def update(self, action,state,delta_frames=1):
if self.msec_to_climb > 0 and action == 1:
if state==4 or state==5 or state == 2 or state == 3:
self.y -= (2*Bird.CLIMB_SPEED * (1000.0 * delta_frames / 60))
else:
self.y -= (Bird.CLIMB_SPEED * (1000.0 * delta_frames / 60))
else:
if state == 4 or state == 5 or state == 2 or state == 3:
self.y += 2*Bird.SINK_SPEED * (1000.0 * delta_frames / 60)
else:
self.y += Bird.SINK_SPEED * (1000.0 * delta_frames / 60)
# if the bird fails, sink the bird till it hit the bottom
def sink(self, delta_frames=1):
self.y += Bird.Fail_SINk_SPEED * (1000.0 * delta_frames / 60)
@property
def image(self):
if pygame.time.get_ticks() % 500 >= 250:
return self._img_wingup
else:
return self._img_wingdown
@property
def mask(self):
if pygame.time.get_ticks() % 500 >= 250:
return self._mask_wingup
else:
return self._mask_wingdown
@property
def rect(self):
return Rect(self.x, self.y, Bird.WIDTH, Bird.HEIGHT)
class PipePair(pygame.sprite.Sprite):
WIDTH = 80
PIECE_HEIGHT = 32
ADD_INTERVAL = 2000
ADD_EVENT = pygame.USEREVENT + 1
ROOM_HIGHT = 2 * Bird.HEIGHT + 2 * PIECE_HEIGHT
def __init__(self, pipe_end_img, pipe_body_img):
"""Initialises a new random PipePair.
The new PipePair will automatically be assigned an x attribute of
float(WIN_WIDTH - 1).
Arguments:
pipe_end_img: The image to use to represent a pipe's end piece.
pipe_body_img: The image to use to represent one horizontal slice
of a pipe's body.
"""
self.x = float(WIN_WIDTH - 1)
self.score_counted = False
self.isNewPipe = True
self.image = pygame.Surface((PipePair.WIDTH, WIN_HEIGHT), SRCALPHA)
self.image.convert() # speeds up blitting
self.image.fill((0, 0, 0, 0))
total_pipe_body_pieces = int(
(WIN_HEIGHT - # fill window from top to bottom
3 * Bird.HEIGHT - # make room for bird to fit through
3 * PipePair.PIECE_HEIGHT) / # 2 end pieces + 1 body piece
PipePair.PIECE_HEIGHT # to get number of pipe pieces
)
self.bottom_pieces = randint(1, total_pipe_body_pieces)
self.top_pieces = total_pipe_body_pieces - self.bottom_pieces
# bottom pipe
for i in range(1, self.bottom_pieces + 1):
piece_pos = (0, WIN_HEIGHT - i * PipePair.PIECE_HEIGHT)
self.image.blit(pipe_body_img, piece_pos)
bottom_pipe_end_y = WIN_HEIGHT - self.bottom_height_px
bottom_end_piece_pos = (0, bottom_pipe_end_y - PipePair.PIECE_HEIGHT)
self.image.blit(pipe_end_img, bottom_end_piece_pos)
# top pipe
for i in range(self.top_pieces):
self.image.blit(pipe_body_img, (0, i * PipePair.PIECE_HEIGHT))
top_pipe_end_y = self.top_height_px
self.image.blit(pipe_end_img, (0, top_pipe_end_y))
self.center = (top_pipe_end_y + bottom_pipe_end_y) / 2 # center of pipe-room,wgx
# compensate for added end pieces
self.top_pieces += 1
self.bottom_pieces += 1
# for collision detection
self.mask = pygame.mask.from_surface(self.image)
self.top_y = top_pipe_end_y
self.bottom_y = bottom_pipe_end_y
@property
def top_height_px(self):
"""Get the top pipe's height, in pixels."""
return self.top_pieces * PipePair.PIECE_HEIGHT
@property
def bottom_height_px(self):
"""Get the bottom pipe's height, in pixels."""
return self.bottom_pieces * PipePair.PIECE_HEIGHT
@property
def visible(self):
"""Get whether this PipePair on screen, visible to the player."""
return -PipePair.WIDTH < self.x < WIN_WIDTH
@property
def rect(self):
"""Get the Rect which contains this PipePair."""
return Rect(self.x, 0, PipePair.WIDTH, PipePair.PIECE_HEIGHT)
def update(self, delta_frames=1):
"""Update the PipePair's position.
Attributes:
delta_frames: The number of frames elapsed since this method was
last called.
"""
self.x -= 0.18 * 1000.0 * delta_frames /60
def collides_with(self, bird):
"""Get whether the bird collides with a pipe in this PipePair.
Arguments:
bird: The Bird which should be tested for collision with this
PipePair.
"""
return pygame.sprite.collide_mask(self, bird)
def chooseAct(Net,s,input,weight_trace_d1,weight_trace_d2):
for i_train in range(500):
out, dw = Net(input)
# 更新权重
# Net.UpdateWeight(10, dw[5][0])
# Net.UpdateWeight(12, dw[5][1])
# Net.UpdateWeight(11, dw[6][0])
# rstdp
weight_trace_d1 *= trace_decay
weight_trace_d1 += dw[0][0]
weight_trace_d2 *= trace_decay
weight_trace_d2 += dw[1][0]
if torch.max(out) > 0:
return torch.argmax(out),weight_trace_d1,weight_trace_d2,Net
def judgeState(bird, pipes, collide):
# bird's x and y coordinate in the left top of the image
dist = bird.y + Bird.HEIGHT / 2 - WIN_HEIGHT / 2
isNew = False
index = -1
state = -1
if collide:
state = 8
return state
for p in pipes:
if p.x + PipePair.WIDTH - Bird.HEIGHT / 4 < bird.x and not p.score_counted:
continue
if p.x - Bird.NEAR_PIPE <= bird.x + Bird.HEIGHT and \
p.x + PipePair.WIDTH - Bird.HEIGHT / 4 >= bird.x:
p_top_y = p.top_y + PipePair.PIECE_HEIGHT
p_bottom_y = p.bottom_y - PipePair.PIECE_HEIGHT
if p.center - bird.y - Bird.HEIGHT / 2 >= 0 and bird.y >= p_top_y + Bird.NEAR_COLLIDE / 2:
state = 0
elif bird.y - p.center + Bird.HEIGHT / 2 > 0 and bird.y + Bird.HEIGHT <= p_bottom_y - Bird.NEAR_COLLIDE / 2:
state = 1
elif bird.y < p_top_y + Bird.NEAR_COLLIDE / 2 and bird.y > p_top_y - 10:
state = 6
elif bird.y + Bird.HEIGHT > p_bottom_y - Bird.NEAR_COLLIDE / 2 and bird.y + Bird.HEIGHT < p_bottom_y + 10:
state = 7
if state > -0.5:
index = 1
elif p.x > bird.x + Bird.HEIGHT + Bird.NEAR_PIPE:
state = blankState(bird, p.center)
if p.isNewPipe:
isNew = True
p.isNewPipe = False
index = 1
if index > 0: # only judge the nearest and not passed pipe
dist = bird.y + Bird.HEIGHT / 2 - p.center
break
if index < -0.5: # no pipe left, key the bird in the middle
pos = WIN_HEIGHT / 2
dist = bird.y + Bird.HEIGHT / 2 - pos
state = blankState(bird, pos)
return state, dist, isNew
def blankState(bird, center): # judge the state before passing the pipe
realHeight = (PipePair.ROOM_HIGHT - Bird.HEIGHT) / 2
if center - bird.y - Bird.HEIGHT / 2 >= 0 and \
center - bird.y - Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2:
state = 0
elif bird.y - center + Bird.HEIGHT / 2 >= 0 and \
bird.y - center + Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2:
state = 1
elif center - bird.y - Bird.HEIGHT / 2 >= realHeight - Bird.NEAR_COLLIDE / 2 and \
center - bird.y - Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:
state = 2
elif bird.y - center + Bird.HEIGHT / 2 >= realHeight - Bird.NEAR_COLLIDE / 2 and \
bird.y - center + Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:
state = 3
elif bird.y + Bird.HEIGHT / 2 <= center - (realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION):
state = 4
elif bird.y + Bird.HEIGHT / 2 >= center + realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:
state = 5
return state
def getReward(state,lastState,smallerError,isNewPipe):
if state == 0 or state == 1:
reward = 6
elif state == 2 or state == 3:
if lastState == state and not isNewPipe:
if smallerError:
reward = 3
else:
reward = -5
else:
reward = -3
elif state == 4 or state == 5:
if lastState == state and not isNewPipe:
if smallerError:
reward = 3
else:
reward = -8
else:
reward = -5
elif state == 6 or state == 7:
if lastState == state and not isNewPipe:
if smallerError:
reward = 3
else:
reward = -3
else:
reward = -3
elif state == 8: # collide
reward = -100
return reward
def updateNet(Net,reward, action, state,weight_trace_d1,weight_trace_d2):
r = torch.ones((num_state, num_state * num_action), dtype=torch.float)
r[state, state * num_action + action] = reward
dw_d1 = r * weight_trace_d1
dw_d2 = -1 * r * weight_trace_d2
Net.UpdateWeight(0, state, num_action, dw_d1)
Net.UpdateWeight(1, state, num_action, dw_d2)
return Net
if __name__=="__main__":
#定义网络
num_state=9
num_action=2
weight_exc=1
weight_inh=-0.5
trace_decay = 0.8
DM = BDMSNN(num_state, num_action, weight_exc, weight_inh, "lif")
con_matrix1 = torch.zeros((num_state, num_state * num_action), dtype=torch.float)
for i in range(num_state):
for j in range(num_action):
con_matrix1[i, i * num_action + j] = weight_exc
weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)
weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)
#定义游戏场景
pygame.init()
WIN_HEIGHT = 512
WIN_WIDTH = 284 * 2 # image size: 284x512 px; tiled twice
heighest = 0
iteration=0
contTime = 0 # number of times to restart
display_frame=0
while iteration < 20: # restart the game for reinforcement learning, wgx
display_surface = pygame.display.set_mode((WIN_WIDTH, WIN_HEIGHT))
pygame.display.set_caption('Flappy Bird')
images = load_images()
bird = Bird(250, int(WIN_HEIGHT / 2 - Bird.HEIGHT / 2), 2,
(images['bird-wingup'], images['bird-wingdown']))
clock = pygame.time.Clock()
score_font = pygame.font.SysFont(None, 25, bold=True) # default font
info_font = pygame.font.SysFont(None, 50, bold=True)
collide = paused = False
frame_clock = 0
pipes = deque()
score = 0
lastDist = 0
lastState = 0 #init
state = lastState
while not collide:
# 输入
input = torch.zeros((num_state), dtype=torch.float)
clock.tick(60)
if frame_clock %2==0 or frame_clock==1:
state, dist, isNewPipe = judgeState(bird, pipes, collide) # judge the bird's state
lastState = state
lastDist = dist
input[state]=2
action,weight_trace_d1,weight_trace_d2,DM = chooseAct(DM,state,input,weight_trace_d1,weight_trace_d2)
print("state, dist:", state, dist)
print("state, action:",state,action)
if not (paused or frame_clock % (60 * PipePair.ADD_INTERVAL / 1000.0)):
pygame.event.post(pygame.event.Event(PipePair.ADD_EVENT))
for e in pygame.event.get():
if e.type == QUIT or (e.type == KEYUP and e.key == K_ESCAPE):
collide = True
elif e.type == KEYUP and e.key in (K_PAUSE, K_p):
paused = not paused
elif e.type == PipePair.ADD_EVENT:
pp = PipePair(images['pipe-end'], images['pipe-body'])
pipes.append(pp)
if paused:
continue # don't draw anything
# check for collisions
pipe_collision = any(p.collides_with(bird) for p in pipes)
if pipe_collision or 0 >= bird.y or bird.y >= WIN_HEIGHT - Bird.HEIGHT:
collide = True
for x in (0, WIN_WIDTH / 2):
display_surface.blit(images['background'], (x, 0))
while pipes and not pipes[0].visible:
pipes.popleft()
for p in pipes:
p.update()
display_surface.blit(p.image, p.rect)
bird.update(action,state)
display_surface.blit(bird.image, bird.rect)
if frame_clock %2==0 or frame_clock==1 or collide:
# judge the state and update the value function
dist = 0
if collide:
nextState = 8
isNewPipe = False
else:
nextState, dist, isNewPipe = judgeState(bird, pipes, collide) # judge the bird's state
print("next state:", nextState)
print("lastdist, dist:", lastDist,dist)
isSmallerError = False
if state == nextState:
isSmallerError = False
if lastDist <= 0:
if lastDist < dist:
isSmallerError = True
else:
if lastDist > dist:
isSmallerError = True
if frame_clock>0 and not collide:
reward = getReward(nextState, state, isSmallerError, isNewPipe)
print("reward:", reward)
DM=updateNet(DM,reward, action, state,weight_trace_d1,weight_trace_d2)
state = nextState #going on the next state
weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)
weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)
DM.reset()
display_frame += 1
# update and display score
for p in pipes:
if p.x + PipePair.WIDTH < bird.x and not p.score_counted:
score += 1
p.score_counted = True
score_surface = score_font.render('Current score: ' + str(score), True, (0, 0, 0)) # current score
score_x = WIN_WIDTH / 2 - 3 * score_surface.get_width() / 4
display_surface.blit(score_surface, (score_x, PipePair.PIECE_HEIGHT))
if heighest < score:
heighest = score
score_surface_h = score_font.render('Highest score: ' + str(heighest), True,
(0, 0, 0)) # heighest score
score_x_h = 4 * WIN_WIDTH / 5 - 1.2 * score_surface.get_width() / 3
display_surface.blit(score_surface_h, (score_x_h, PipePair.PIECE_HEIGHT))
score_surface_i = score_font.render('Attempts: ' + str(iteration), True, (0, 0, 0)) # heighest score
score_x_i = 10
display_surface.blit(score_surface_i, (score_x_i, PipePair.PIECE_HEIGHT))
frame_clock += 1
pygame.display.flip()
# if collide, display the fail information, for 2 frames
cct = 0
while (bird.y < WIN_HEIGHT - Bird.HEIGHT - 3):
clock.tick(60)
for x in (0, WIN_WIDTH / 2):
display_surface.blit(images['background'], (x, 0))
while pipes and not pipes[0].visible:
pipes.popleft()
for p in pipes:
display_surface.blit(p.image, p.rect)
if cct >= 6:
bird.sink()
display_surface.blit(bird.image, bird.rect)
fail_infor = info_font.render('Game over !', True, (255, 60, 30)) # current score
pos_x = WIN_WIDTH / 2 - fail_infor.get_width() / 2
pos_y = WIN_HEIGHT / 2 - 100
display_surface.blit(fail_infor, (pos_x, pos_y))
# display the score
score_surface = score_font.render('Current score: ' + str(score), True, (0, 0, 0)) # current score
score_x = WIN_WIDTH / 2 - 3 * score_surface.get_width() / 4
display_surface.blit(score_surface, (score_x, PipePair.PIECE_HEIGHT))
if heighest < score:
heighest = score
score_surface_h = score_font.render('Highest score: ' + str(heighest), True,
(0, 0, 0)) # heighest score
score_x_h = 4 * WIN_WIDTH / 5 - 1.2 * score_surface.get_width() / 3
display_surface.blit(score_surface_h, (score_x_h, PipePair.PIECE_HEIGHT))
score_surface_i = score_font.render('Attempts: ' + str(iteration), True, (0, 0, 0)) # heighest score
score_x_i = 10
display_surface.blit(score_surface_i, (score_x_i, PipePair.PIECE_HEIGHT))
pygame.display.flip()
cct += 1
if heighest < score:
heighest = score
contTime += 1
iteration += 1
================================================
FILE: examples/decision_making/RL/README.md
================================================
# PL-SDQN
This repository contains code from our paper [**Solving the Spike Feature Information Vanishing Problem in Spiking Deep Q Network with Potential Based Normalization**]. If you use our code or refer to this project, please cite this paper.
To run the PL-SDQN model, please install 'tianshou' framework first https://github.com/thu-ml/tianshou
## Requirments
* numpy
* scipy
* pytorch >= 1.7.0
* torchvision
* gym
* atari-py
* opencv-python
* tianshou
## Train
```shell
python ./sdqn/main.py
```
## Citation
If you find this package helpful, please consider citing the following papers:
```BibTex
@ARTICLE{sun2022,
AUTHOR={Sun, Yinqian and Zeng, Yi and Li, Yang},
TITLE={Solving the spike feature information vanishing problem in spiking deep Q network with potential based normalization},
JOURNAL={Frontiers in Neuroscience},
VOLUME={16},
YEAR={2022},
URL={https://www.frontiersin.org/articles/10.3389/fnins.2022.953368},
DOI={10.3389/fnins.2022.953368},
ISSN={1662-453X},
}
@misc{https://doi.org/10.48550/arxiv.2207.08533,
doi = {10.48550/ARXIV.2207.08533},
url = {https://arxiv.org/abs/2207.08533},
author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},
title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},
publisher = {arXiv},
year = {2022},
}
```
================================================
FILE: examples/decision_making/RL/atari/__init__.py
================================================
================================================
FILE: examples/decision_making/RL/atari/atari_wrapper.py
================================================
# Borrow a lot from openai baselines:
# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
from collections import deque
import cv2
import gym
import numpy as np
class NoopResetEnv(gym.Wrapper):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
:param gym.Env env: the environment to wrap.
:param int noop_max: the maximum value of no-ops to run.
"""
def __init__(self, env, noop_max=30):
super().__init__(env)
self.noop_max = noop_max
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
def reset(self):
self.env.reset()
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
for _ in range(noops):
obs, _, done, _ = self.env.step(self.noop_action)
if done:
obs = self.env.reset()
return obs
class MaxAndSkipEnv(gym.Wrapper):
"""Return only every `skip`-th frame (frameskipping) using most recent raw
observations (for max pooling across time steps)
:param gym.Env env: the environment to wrap.
:param int skip: number of `skip`-th frame.
"""
def __init__(self, env, skip=4):
super().__init__(env)
self._skip = skip
def step(self, action):
"""Step the environment with the given action. Repeat action, sum
reward, and max over last observations.
"""
obs_list, total_reward, done = [], 0., False
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
obs_list.append(obs)
total_reward += reward
if done:
break
max_frame = np.max(obs_list[-2:], axis=0)
return max_frame, total_reward, done, info
class EpisodicLifeEnv(gym.Wrapper):
"""Make end-of-life == end-of-episode, but only reset on true game over. It
helps the value estimation.
:param gym.Env env: the environment to wrap.
"""
def __init__(self, env):
super().__init__(env)
self.lives = 0
self.was_real_done = True
def step(self, action):
obs, reward, done, info = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal, then update lives to
# handle bonus lives
lives = self.env.unwrapped.ale.lives()
if 0 < lives < self.lives:
# for Qbert sometimes we stay in lives == 0 condition for a few
# frames, so its important to keep lives > 0, so that we only reset
# once the environment is actually done.
done = True
self.lives = lives
return obs, reward, done, info
def reset(self):
"""Calls the Gym environment reset, only when lives are exhausted. This
way all states are still reachable even though lives are episodic, and
the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset()
else:
# no-op step to advance from terminal/lost life state
obs = self.env.step(0)[0]
self.lives = self.env.unwrapped.ale.lives()
return obs
class FireResetEnv(gym.Wrapper):
"""Take action on reset for environments that are fixed until firing.
Related discussion: https://github.com/openai/baselines/issues/240
:param gym.Env env: the environment to wrap.
"""
def __init__(self, env):
super().__init__(env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def reset(self):
self.env.reset()
return self.env.step(1)[0]
class WarpFrame(gym.ObservationWrapper):
"""Warp frames to 84x84 as done in the Nature paper and later work.
:param gym.Env env: the environment to wrap.
"""
def __init__(self, env):
super().__init__(env)
self.size = 84
self.observation_space = gym.spaces.Box(
low=np.min(env.observation_space.low),
high=np.max(env.observation_space.high),
shape=(self.size, self.size),
dtype=env.observation_space.dtype
)
def observation(self, frame):
"""returns the current observation from a frame"""
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
return cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA)
class ScaledFloatFrame(gym.ObservationWrapper):
"""Normalize observations to 0~1.
:param gym.Env env: the environment to wrap.
"""
def __init__(self, env):
super().__init__(env)
low = np.min(env.observation_space.low)
high = np.max(env.observation_space.high)
self.bias = low
self.scale = high - low
self.observation_space = gym.spaces.Box(
low=0., high=1., shape=env.observation_space.shape, dtype=np.float32
)
def observation(self, observation):
return (observation - self.bias) / self.scale
class ClipRewardEnv(gym.RewardWrapper):
"""clips the reward to {+1, 0, -1} by its sign.
:param gym.Env env: the environment to wrap.
"""
def __init__(self, env):
super().__init__(env)
self.reward_range = (-1, 1)
def reward(self, reward):
"""Bin reward to {+1, 0, -1} by its sign. Note: np.sign(0) == 0."""
return np.sign(reward)
class FrameStack(gym.Wrapper):
"""Stack n_frames last frames.
:param gym.Env env: the environment to wrap.
:param int n_frames: the number of frames to stack.
"""
def __init__(self, env, n_frames):
super().__init__(env)
self.n_frames = n_frames
self.frames = deque([], maxlen=n_frames)
shape = (n_frames, ) + env.observation_space.shape
self.observation_space = gym.spaces.Box(
low=np.min(env.observation_space.low),
high=np.max(env.observation_space.high),
shape=shape,
dtype=env.observation_space.dtype
)
def reset(self):
obs = self.env.reset()
for _ in range(self.n_frames):
self.frames.append(obs)
return self._get_ob()
def step(self, action):
obs, reward, done, info = self.env.step(action)
self.frames.append(obs)
return self._get_ob(), reward, done, info
def _get_ob(self):
# the original wrapper use `LazyFrames` but since we use np buffer,
# it has no effect
return np.stack(self.frames, axis=0)
def wrap_deepmind(
env_id,
episode_life=True,
clip_rewards=True,
frame_stack=4,
scale=False,
warp_frame=True
):
"""Configure environment for DeepMind-style Atari. The observation is
channel-first: (c, h, w) instead of (h, w, c).
:param str env_id: the atari environment id.
:param bool episode_life: wrap the episode life wrapper.
:param bool clip_rewards: wrap the reward clipping wrapper.
:param int frame_stack: wrap the frame stacking wrapper.
:param bool scale: wrap the scaling observation wrapper.
:param bool warp_frame: wrap the grayscale + resize observation wrapper.
:return: the wrapped atari environment.
"""
assert 'NoFrameskip' in env_id
env = gym.make(env_id)
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if episode_life:
env = EpisodicLifeEnv(env)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
if warp_frame:
env = WarpFrame(env)
if scale:
env = ScaledFloatFrame(env)
if clip_rewards:
env = ClipRewardEnv(env)
if frame_stack:
env = FrameStack(env, frame_stack)
return env
================================================
FILE: examples/decision_making/RL/mcs-fqf/discrete.py
================================================
from audioop import bias
from time import time
from typing import Any, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from tianshou.data import Batch
from braincog.base.node.node import LIFNode, ThreeCompNode
class SpikePopEncodingNetwork(nn.Module):
"""Cosine embedding network for IQN. Convert a scalar in [0, 1] to a list \
of n-dim vectors.
:param num_cosines: the number of cosines used for the embedding.
:param embedding_dim: the dimension of the embedding/output.
.. note::
From https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master
/fqf_iqn_qrdqn/network.py .
"""
def __init__(self, num_cosines: int, embedding_dim: int, device, time_window: int=8) -> None:
super().__init__()
self._threshold = 0.5
# self._decay = 0.2
self._decay = 0.5
self.r_max = 0.5
# self.mus = torch.from_numpy(np.arange(num_cosines) / num_cosines)
self.sigma = 0.05
# self.sigma = 0.01
self._node = LIFNode
self.net = nn.Sequential(
nn.Linear(num_cosines, embedding_dim),
self._node()
# self._node(threshold=self._threshold, decay=self._decay)
)
self.num_cosines = num_cosines
self.embedding_dim = embedding_dim
self.mus = torch.arange(0, num_cosines, device=device).view(1, 1, self.num_cosines) / num_cosines
def reset(self):
for mod in self.modules():
if hasattr(mod, 'n_reset'):
mod.n_reset()
def forward(self, taus: torch.Tensor, time_window: int) -> torch.Tensor:
batch_size = taus.shape[0]
N = taus.shape[1]
self.reset()
taus_lam = self.r_max * torch.exp(-(taus.unsqueeze(-1) - self.mus)**2/2/self.sigma**2).view(batch_size*N, self.num_cosines)
taus_repeat = taus_lam.unsqueeze(0).repeat(time_window, 1, 1)
taus_emb = torch.poisson(taus_repeat)
tau_embeddings = []
for i in range(time_window):
t_e = self.net(taus_emb[i])
tau_embeddings.append(t_e)
return tau_embeddings
class SpikeFractionProposalNetwork(nn.Module):
"""Fraction proposal network for FQF.
:param num_fractions: the number of factions to propose.
:param embedding_dim: the dimension of the embedding/input.
.. note::
Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master
/fqf_iqn_qrdqn/network.py .
"""
def __init__(self, num_fractions: int, embedding_dim: int) -> None:
super().__init__()
self.net = nn.Linear(embedding_dim, num_fractions)
torch.nn.init.xavier_uniform_(self.net.weight, gain=0.01)
torch.nn.init.constant_(self.net.bias, 0)
self.num_fractions = num_fractions
self.embedding_dim = embedding_dim
def forward(
self, state_embeddings: list
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
state_embeddings = torch.stack(state_embeddings).detach()
time_window = state_embeddings.shape[0]
batch_size = state_embeddings.shape[1]
logits = self.net(state_embeddings.view(time_window*batch_size, -1))
logits = logits.view(time_window, batch_size, -1)
m = torch.distributions.Categorical(logits=logits.mean(0))
taus_1_N = torch.cumsum(m.probs, dim=1)
taus = F.pad(taus_1_N, (1, 0))
tau_hats = (taus[:, :-1] + taus[:, 1:]).detach() / 2.0
entropies = m.entropy()
return taus, tau_hats, entropies
class MCQuantiles(nn.Module):
def __init__(self,
state_embedings_shape: int,
tau_embeddings_shape: int,
hidden_size: int,
last_size: int,
fusion_size : int=512,
tau_s : int = 2.0):
super().__init__()
self.basal_w = nn.Linear(state_embedings_shape, fusion_size, bias=False)
self.apical_w = nn.Linear(tau_embeddings_shape, fusion_size, bias=False)
self._node = LIFNode
self.mc_node = ThreeCompNode()
self._last = nn.Sequential(
nn.Linear(fusion_size, hidden_size),
self._node(),
nn.Linear(hidden_size, last_size),
)
def reset(self):
for mod in self.modules():
if hasattr(mod, 'n_reset'):
mod.n_reset()
def forward(self, state_embedding, tau_embedding):
"""
state_embedding: list
tau_embedding: torch.Tensor
"""
self.reset()
assert isinstance(state_embedding, type(tau_embedding))
if isinstance(state_embedding, list):
time_window = len(state_embedding)
elif isinstance(state_embedding, torch.Tensor):
time_window = state_embedding.shape[0]
else:
raise TypeError('Not support data type.')
batch_size = state_embedding[0].shape[0]
sample_size = tau_embedding[0].shape[0] // batch_size
quantiles = []
for step in range(time_window):
basal_psp = self.basal_w(state_embedding[step]).unsqueeze(1)
apical_psp = self.apical_w(tau_embedding[step]).view(batch_size, sample_size, -1)
embeddings = self.mc_node({'basal_inputs': basal_psp, 'apical_inputs': apical_psp}).view(batch_size*sample_size, -1)
out = self._last(embeddings)
quantiles.append(out)
quantiles = sum(quantiles) / time_window
return quantiles.view(batch_size, sample_size, -1).transpose(1, 2)
class SpikeFullQuantileFunction(nn.Module):
"""Full(y parameterized) Quantile Function.
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param int action_dim: the dimension of action space.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param int num_cosines: the number of cosines to use for cosine embedding.
Default to 64.
:param int preprocess_net_output_dim: the output dimension of
preprocess_net.
.. note::
The first return value is a tuple of (quantiles, fractions, quantiles_tau),
where fractions is a Batch(taus, tau_hats, entropies).
"""
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
hidden_sizes: Sequence[int] = (),
num_cosines: int = 64,
preprocess_net_output_dim: Optional[int] = None,
device: Union[str, int, torch.device] = "cpu",
) -> None:
super().__init__()
self.device = device
self.last_size = np.prod(action_shape)
self.preprocess = preprocess_net
self.input_dim = getattr(
self.preprocess, "output_dim", preprocess_net_output_dim
)
self.embed_model = SpikePopEncodingNetwork(num_cosines,
self.input_dim, device=device).to(device)
self.mcquantiles = MCQuantiles(self.input_dim, self.input_dim, hidden_size=np.prod(hidden_sizes),
last_size=action_shape).to(device)
def forward( # type: ignore
self, s: Union[np.ndarray, torch.Tensor],
propose_model: SpikeFractionProposalNetwork,
fractions: Optional[Batch] = None,
**kwargs: Any
) -> Tuple[Any, torch.Tensor]:
r"""Mapping: s -> Q(s, \*)."""
logits, h = self.preprocess(s, state=kwargs.get("state", None))
# Propose fractions
if fractions is None:
taus, tau_hats, entropies = propose_model(logits)
fractions = Batch(taus=taus, tau_hats=tau_hats, entropies=entropies)
else:
taus, tau_hats = fractions.taus, fractions.tau_hats
time_window = len(logits)
tau_hats_emb = self.embed_model(tau_hats, time_window)
quantiles = self.mcquantiles(logits, tau_hats_emb)
quantiles_tau = None
if self.training:
with torch.no_grad():
tau_emb = self.embed_model(taus[:, 1:-1], time_window)
quantiles_tau = self.mcquantiles(logits, tau_emb)
return (quantiles, fractions, quantiles_tau), h
================================================
FILE: examples/decision_making/RL/mcs-fqf/main.py
================================================
import argparse
import os
import pprint
import numpy as np
import torch
from network import SpikingDQN
from ..atari.atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger, SequenceLogger
from discrete import SpikeFractionProposalNetwork, SpikeFullQuantileFunction
from policy import FQFPolicy
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='MsPacmanNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=3128)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--fraction-lr', type=float, default=2.5e-9)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--num-fractions', type=int, default=32)
parser.add_argument('--num-cosines', type=int, default=64)
parser.add_argument('--ent-coef', type=float, default=10.)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=200)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10)
# parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--resume-id', type=str, default=None)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
)
parser.add_argument('--save-buffer-name', type=str, default=None)
parser.add_argument('--time-window', type=int, default=8)
parser.add_argument('--prefix', type=str, default='')
parser.add_argument('--save-interval', type=int, default=10)
return parser.parse_args()
def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
def make_atari_env_watch(args):
return wrap_deepmind(
args.task,
frame_stack=args.frames_stack,
episode_life=False,
clip_rewards=False
)
def main(args=get_args()):
print('Setting: ', args)
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print('update_per_step: ', args.update_per_step)
print('lr: ', args.lr)
# make environments
train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)]
)
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# define model
feature_net = SpikingDQN(
*args.state_shape, args.action_shape, args.device, time_window=args.time_window, features_only=True
)
net = SpikeFullQuantileFunction(
feature_net,
args.action_shape,
args.hidden_sizes,
args.num_cosines,
device=args.device,
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
fraction_net = SpikeFractionProposalNetwork(args.num_fractions, net.input_dim)
fraction_optim = torch.optim.RMSprop(
fraction_net.parameters(), lr=args.fraction_lr
)
# define policy
policy = FQFPolicy(
net,
optim,
fraction_net,
fraction_optim,
args.gamma,
args.num_fractions,
args.ent_coef,
args.n_step,
target_update_freq=args.target_update_freq
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = VectorReplayBuffer(
args.buffer_size,
buffer_num=len(train_envs),
ignore_obs_next=True,
save_only_last_obs=True,
stack_num=args.frames_stack
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(args.logdir, args.task, 'spike_fqf', args.prefix)
model_log_path = os.path.join(log_path, 'models')
if not os.path.exists(model_log_path):
os.makedirs(model_log_path)
print('log_path: ', log_path)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer, save_interval=args.save_interval)
result_logger = SequenceLogger(log_path)
def save_checkpoint_fn(epoch, env_step, gradient_step, epoch_round=True):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
if epoch_round:
ckpt_path = os.path.join(model_log_path, 'checkpoint_epoch{}.pth'.format(epoch))
else:
ckpt_path = os.path.join(model_log_path, 'checkpoint.pth')
ckpt = {
'epoch': epoch,
'env_step': env_step,
'gradient_step': gradient_step,
'model': policy.state_dict()
}
torch.save(ckpt, ckpt_path)
return ckpt_path
setting_path = os.path.join(log_path, 'settings.txt')
argsDict = args.__dict__
with open(setting_path, 'w') as f:
f.writelines('------------------ start ------------------' + '\n')
for eachArg, value in argsDict.items():
f.writelines(eachArg + ' : ' + str(value) + '\n')
f.writelines('------------------- end -------------------')
def save_fn(policy, is_best=False):
if is_best:
torch.save(policy.state_dict(), os.path.join(log_path, 'best_policy.pth'))
else:
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return mean_rewards >= 20
else:
return False
def train_fn(epoch, env_step):
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * \
(args.eps_train - args.eps_train_final)
else:
eps = args.eps_train_final
policy.set_eps(eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
# watch agent's performance
def watch():
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = VectorReplayBuffer(
args.buffer_size,
buffer_num=len(test_envs),
ignore_obs_next=True,
save_only_last_obs=True,
stack_num=args.frames_stack
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(
n_episode=args.test_num, render=args.render
)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
if args.watch:
watch()
exit(0)
# test train_collector and start filling replay buffer
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = offpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.step_per_collect,
args.test_num,
args.batch_size,
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False,
resume_from_log=args.resume_id is not None,
save_checkpoint_fn=save_checkpoint_fn,
result_logger=result_logger
)
pprint.pprint(result)
watch()
if __name__ == '__main__':
main(get_args())
================================================
FILE: examples/decision_making/RL/mcs-fqf/network.py
================================================
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from torch import nn
from braincog.base.node.node import LIFNode
from ..utils.normalization import PopNorm
class SpikingDQN(nn.Module):
"""Reference: Human-level control through deep reinforcement learning.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
device: Union[str, int, torch.device] = "cpu",
time_window: int = 16,
features_only: bool = False,
) -> None:
super().__init__()
self._node = LIFNode
# self._node = ReLUNode
self.features_only = features_only
self.device = device
# self._threshold = 0.5
self._threshold = 1.0
self.v_reset = 0.0
# self._decay = 0.2
self._decay = 0.5
self._time_window = time_window
init_layer = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), gain=1)
self.p_count = 0
self.net = nn.Sequential(
# nn.utils.weight_norm(nn.Conv2d(c, 32, kernel_size=8, stride=4)),
nn.Conv2d(c, 32, kernel_size=8, stride=4),
# nn.BatchNorm2d(32),
# self._node(threshold=self._threshold, decay=self._decay),
PopNorm([32, 20, 20], threshold=self._threshold, v_reset=self.v_reset),
self._node(threshold=self._threshold, v_reset=self.v_reset),
# nn.utils.weight_norm(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
# nn.BatchNorm2d(64),
# self._node(threshold=self._threshold, decay=self._decay),
PopNorm([64, 9, 9], threshold=self._threshold, v_reset=self.v_reset),
self._node(threshold=self._threshold, v_reset=self.v_reset),
# nn.utils.weight_norm(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
# nn.BatchNorm2d(64),
PopNorm([64, 7, 7], threshold=self._threshold, v_reset=self.v_reset),
# self._node(threshold=self._threshold, decay=self._decay),
self._node(threshold=self._threshold, v_reset=self.v_reset),
nn.Flatten()
)
with torch.no_grad():
self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])
if not features_only:
self.net = nn.Sequential(
self.net, nn.Linear(self.output_dim, 512),
# self.net, nn.Linear(self.output_dim, 512),
# self._node(threshold=self._threshold, decay=self._decay),
self._node(threshold=self._threshold, v_reset=self.v_reset),
# nn.Linear(512, np.prod(action_shape))
nn.Linear(512, np.prod(action_shape), bias=False)
)
self.output_dim = np.prod(action_shape)
def reset(self):
for mod in self.modules():
if hasattr(mod, 'n_reset'):
mod.n_reset()
def forward(
self,
x: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Q(x, \*)."""
self.reset()
# obs = torch.as_tensor(x, device=self.device, dtype=torch.float32)
x = torch.as_tensor(x, device=self.device, dtype=torch.float32) / 255.0
qs = []
for i in range(self._time_window):
value = self.net(x)
qs.append(value)
if self.features_only:
return qs, state
else:
q_values = sum(qs) / self._time_window
return q_values, state
================================================
FILE: examples/decision_making/RL/mcs-fqf/policy.py
================================================
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from tianshou.data import Batch, ReplayBuffer, to_numpy
from tianshou.policy import DQNPolicy, QRDQNPolicy
from discrete import SpikeFractionProposalNetwork, SpikeFullQuantileFunction
class FQFPolicy(QRDQNPolicy):
"""Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140.
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param FractionProposalNetwork fraction_model: a FractionProposalNetwork for
proposing fractions/quantiles given state.
:param torch.optim.Optimizer fraction_optim: a torch.optim for optimizing
the fraction model above.
:param float discount_factor: in [0, 1].
:param int num_fractions: the number of fractions to use. Default to 32.
:param float ent_coef: the coefficient for entropy loss. Default to 0.
:param int estimation_step: the number of steps to look ahead. Default to 1.
:param int target_update_freq: the target network update frequency (0 if
you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
.. seealso::
Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed
explanation.
"""
def __init__(
self,
model: SpikeFullQuantileFunction,
optim: torch.optim.Optimizer,
fraction_model: SpikeFractionProposalNetwork,
fraction_optim: torch.optim.Optimizer,
discount_factor: float = 0.99,
num_fractions: int = 32,
ent_coef: float = 0.0,
estimation_step: int = 1,
target_update_freq: int = 0,
reward_normalization: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
model, optim, discount_factor, num_fractions, estimation_step,
target_update_freq, reward_normalization, **kwargs
)
self.propose_model = fraction_model
self._ent_coef = ent_coef
self._fraction_optim = fraction_optim
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs_next: s_{t+n}
if self._target:
result = self(batch, input="obs_next")
a, fractions = result.act, result.fractions
next_dist = self(
batch, model="model_old", input="obs_next", fractions=fractions
).logits
else:
next_b = self(batch, input="obs_next")
a = next_b.act
next_dist = next_b.logits
next_dist = next_dist[np.arange(len(a)), a, :]
return next_dist # shape: [bsz, num_quantiles]
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = "model",
input: str = "obs",
fractions: Optional[Batch] = None,
**kwargs: Any,
) -> Batch:
model = getattr(self, model)
obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs
# print('fractions: ', fractions)
if fractions is None:
(logits, fractions, quantiles_tau), h = model(
obs_, propose_model=self.propose_model, state=state, info=batch.info
)
else:
(logits, _, quantiles_tau), h = model(
obs_,
propose_model=self.propose_model,
fractions=fractions,
state=state,
info=batch.info
)
# print('fractions.taus shape : ', fractions.taus.shape)
# print('logits shape: ', logits.shape)
weighted_logits = (fractions.taus[:, 1:] -
fractions.taus[:, :-1]).unsqueeze(1) * logits
# print('weighted_logits shape: ', weighted_logits.shape)
q = DQNPolicy.compute_q_value(
self, weighted_logits.sum(2), getattr(obs, "mask", None)
)
if not hasattr(self, "max_action_num"):
self.max_action_num = q.shape[1]
act = to_numpy(q.max(dim=1)[1])
return Batch(
logits=logits,
act=act,
state=h,
fractions=fractions,
quantiles_tau=quantiles_tau
)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._iter % self._freq == 0:
self.sync_weight()
weight = batch.pop("weight", 1.0)
out = self(batch)
curr_dist_orig = out.logits
taus, tau_hats = out.fractions.taus, out.fractions.tau_hats
act = batch.act
curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2)
target_dist = batch.returns.unsqueeze(1)
# calculate each element's difference between curr_dist and target_dist
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
huber_loss = (
u * (
tau_hats.unsqueeze(2) -
(target_dist - curr_dist).detach().le(0.).float()
).abs()
).sum(-1).mean(1)
quantile_loss = (huber_loss * weight).mean()
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer
# calculate fraction loss
with torch.no_grad():
sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :]
sa_quantiles = out.quantiles_tau[np.arange(len(act)), act, :]
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169
values_1 = sa_quantiles - sa_quantile_hats[:, :-1]
signs_1 = sa_quantiles > torch.cat(
[sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1
)
values_2 = sa_quantiles - sa_quantile_hats[:, 1:]
signs_2 = sa_quantiles < torch.cat(
[sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1
)
gradient_of_taus = (
torch.where(signs_1, values_1, -values_1) +
torch.where(signs_2, values_2, -values_2)
)
fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean()
# calculate entropy loss
entropy_loss = out.fractions.entropies.mean()
fraction_entropy_loss = fraction_loss - self._ent_coef * entropy_loss
self._fraction_optim.zero_grad()
fraction_entropy_loss.backward(retain_graph=True)
self._fraction_optim.step()
self.optim.zero_grad()
quantile_loss.backward()
self.optim.step()
self._iter += 1
return {
"loss": quantile_loss.item() + fraction_entropy_loss.item(),
"loss/quantile": quantile_loss.item(),
"loss/fraction": fraction_loss.item(),
"loss/entropy": entropy_loss.item()
}
================================================
FILE: examples/decision_making/RL/requirements.txt
================================================
gym
atari-py
opencv-python
tianshou
================================================
FILE: examples/decision_making/RL/sdqn/main.py
================================================
import argparse
import os
import pprint
import numpy as np
import torch
try:
import tianshou
except:
raise ImportError('Need install "tianshou" lib at https://github.com/thu-ml/tianshou !')
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger, WandbLogger
import random
from network import SpikingDQN
from ..atari.atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=40)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--resume-id', type=str, default=None)
parser.add_argument('--time-window', type=int, default=16)
parser.add_argument(
'--spike',
default=False,
action='store_true',
help='execute spike dqn'
)
parser.add_argument(
'--logger',
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
)
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()
def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
def make_atari_env_watch(args):
return wrap_deepmind(
args.task,
frame_stack=args.frames_stack,
episode_life=False,
clip_rewards=False
)
def main(args=get_args()):
print('n_step: ', args.n_step)
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print('logdir', args.logdir)
print('Spiking', args.spike)
# make environments
train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)]
)
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
os.environ['PYTHONHASHSEED'] = str(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
net = SpikingDQN(*args.state_shape, args.action_shape, args.device, args.time_window).to(args.device)
# optim = torch.optim.Adam(net.parameters(), lr=args.lr)
optim = torch.optim.AdamW(net.parameters(), lr=args.lr)
# define policy
policy = DQNPolicy(
net,
optim,
args.gamma,
args.n_step,
target_update_freq=args.target_update_freq
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = VectorReplayBuffer(
args.buffer_size,
buffer_num=len(train_envs),
ignore_obs_next=True,
save_only_last_obs=True,
stack_num=args.frames_stack
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(args.logdir, args.task, 'csdqn')
if args.logger == "tensorboard":
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
else:
logger = WandbLogger(
save_interval=1,
project=args.task,
name='dqn',
run_id=args.resume_id,
config=args,
)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return mean_rewards >= 20
else:
return False
def train_fn(epoch, env_step):
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * \
(args.eps_train - args.eps_train_final)
else:
eps = args.eps_train_final
policy.set_eps(eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
torch.save({'model': policy.state_dict()}, ckpt_path)
return ckpt_path
# watch agent's performance
def watch():
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = VectorReplayBuffer(
args.buffer_size,
buffer_num=len(test_envs),
ignore_obs_next=True,
save_only_last_obs=True,
stack_num=args.frames_stack
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(
n_episode=args.test_num, render=args.render
)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
if args.watch:
watch()
exit(0)
# test train_collector and start filling replay buffer
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = offpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.step_per_collect,
args.test_num,
args.batch_size,
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False,
resume_from_log=args.resume_id is not None,
save_checkpoint_fn=save_checkpoint_fn,
)
pprint.pprint(result)
watch()
if __name__ == '__main__':
main(get_args())
================================================
FILE: examples/decision_making/RL/sdqn/network.py
================================================
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from torch import nn
from braincog.base.node.node import LIFNode
from ..utils.normalization import PopNorm
class SpikingDQN(nn.Module):
"""Reference: Human-level control through deep reinforcement learning.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
device: Union[str, int, torch.device] = "cpu",
time_window: int = 16,
features_only: bool = False,
) -> None:
super().__init__()
self._node = LIFNode
self.features_only = features_only
self.device = device
self._threshold = 1.0
self.v_reset = 0.0
self._decay = 0.5
self._time_window = time_window
self.p_count = 0
self.net = nn.Sequential(
nn.Conv2d(c, 32, kernel_size=8, stride=4),
PopNorm([32, 20, 20], threshold=self._threshold, v_reset=self.v_reset),
self._node(threshold=self._threshold, v_reset=self.v_reset),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
PopNorm([64, 9, 9], threshold=self._threshold, v_reset=self.v_reset),
self._node(threshold=self._threshold, v_reset=self.v_reset),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
PopNorm([64, 7, 7], threshold=self._threshold, v_reset=self.v_reset),
self._node(threshold=self._threshold, v_reset=self.v_reset),
nn.Flatten()
)
with torch.no_grad():
self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])
if not features_only:
self.net = nn.Sequential(
self.net, nn.Linear(self.output_dim, 512),
self._node(threshold=self._threshold, v_reset=self.v_reset),
nn.Linear(512, np.prod(action_shape), bias=False)
)
self.output_dim = np.prod(action_shape)
def reset(self):
for mod in self.modules():
if hasattr(mod, 'n_reset'):
mod.n_reset()
def forward(
self,
x: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Q(x, \*)."""
self.reset()
x = torch.as_tensor(x, device=self.device, dtype=torch.float32) / 255.0
qs = []
for i in range(self._time_window):
value = self.net(x)
qs.append(value)
if self.features_only:
return qs, state
else:
q_values = sum(qs) / self._time_window
return q_values, state
================================================
FILE: examples/decision_making/RL/utils/__init__.py
================================================
__all__ = ['normalization']
from . import (
normalization,
)
================================================
FILE: examples/decision_making/RL/utils/normalization.py
================================================
from typing import Optional, Any
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer
import torch.nn.functional as F
from torch import Tensor, Size
from typing import Union, List
import numbers
from torch.nn import Module
_shape_t = Union[int, List[int], Size]
class PopNorm(Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization `__
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
The standard-deviation is calculated via the biased estimator, equivalent to
`torch.var(input, unbiased=False)`.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
\times \ldots \times \text{normalized\_shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = nn.LayerNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = nn.LayerNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = nn.LayerNorm(10)
>>> # Activating the module
>>> output = m(input)
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
normalized_shape: _shape_t
eps: float
elementwise_affine: bool
def __init__(self, normalized_shape: _shape_t, threshold: float, v_reset: float, eps: float = 1e-5, affine: bool = True) -> None:
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = tuple(normalized_shape)
self.threshold = threshold
self.v_reset = v_reset
self.eps = eps
self.affine = affine
if self.affine:
# self.weight = Parameter(torch.Tensor(*normalized_shape))
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self) -> None:
if self.affine:
# nn.init.ones_(self.weight)
# nn.init.zeros_(self.bias)
nn.init.constant_(self.weight, self.threshold-self.v_reset)
nn.init.constant_(self.bias, self.v_reset)
def forward(self, input: Tensor) -> Tensor:
out = F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps)
# out = F.layer_norm(
# input, self.normalized_shape, None, None, self.eps)
# if self.affine:
# out = self.weight * out + self.bias
return out
def extra_repr(self) -> Tensor:
return '{normalized_shape}, eps={eps}, ' \
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
class _NormBase(nn.Module):
"""Common base of _InstanceNorm and _BatchNorm"""
_version = 2
__constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
num_features: int
eps: float
momentum: float
affine: bool
track_running_stats: bool
# WARNING: weight and bias purposely not defined here.
# See https://github.com/pytorch/pytorch/issues/39670
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
mean: float = 0.2,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(_NormBase, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
self.mean = mean
if self.affine:
self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
# self.bias = Parameter(torch.empty(num_features, **factory_kwargs), requires_grad=False)
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
self.running_mean: Optional[Tensor]
self.running_var: Optional[Tensor]
self.register_buffer('num_batches_tracked',
torch.tensor(0, dtype=torch.long,
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
else:
self.register_buffer("running_mean", None)
self.register_buffer("running_var", None)
self.register_buffer("num_batches_tracked", None)
self.reset_parameters()
def reset_running_stats(self) -> None:
if self.track_running_stats:
# running_mean/running_var/num_batches... are registered at runtime depending
# if self.track_running_stats is on
self.running_mean.zero_() # type: ignore[union-attr]
self.running_var.fill_(1) # type: ignore[union-attr]
self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
def reset_parameters(self) -> None:
self.reset_running_stats()
if self.affine:
nn.init.ones_(self.weight)
# nn.init.zeros_(self.bias)
nn.init.constant_(self.bias, self.mean)
def _check_input_dim(self, input):
raise NotImplementedError
def extra_repr(self):
return (
"{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
"track_running_stats={track_running_stats}".format(**self.__dict__)
)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
if (version is None or version < 2) and self.track_running_stats:
# at version 2: added num_batches_tracked buffer
# this should have a default value of 0
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key not in state_dict:
state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)
super(_NormBase, self)._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
class _BatchNorm(_NormBase):
def __init__(
self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
mean=0.2,
device=None,
dtype=None
):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(
num_features, eps, momentum, affine, track_running_stats, mean, **factory_kwargs
)
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)
# exponential_average_factor is set to self.momentum
# (when it is available) only so that it gets updated
# in ONNX graph when this node is exported to ONNX.
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None: # type: ignore[has-type]
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type]
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
r"""
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
"""
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)
r"""
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
"""
return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean
if not self.training or self.track_running_stats
else None,
self.running_var if not self.training or self.track_running_stats else None,
self.weight,
self.bias,
bn_training,
exponential_average_factor,
self.eps,
)
class PDBatchNorm2d(_BatchNorm):
r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
with additional channel dimension) as described in the paper
`Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift `__ .
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated per-dimension over
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
to 1 and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated
via the biased estimator, equivalent to `torch.var(input, unbiased=False)`.
Also by default, during training this layer keeps running estimates of its
computed mean and variance, which are then used for normalization during
evaluation. The running estimates are kept with a default :attr:`momentum`
of 0.1.
If :attr:`track_running_stats` is set to ``False``, this layer then does not
keep running estimates, and batch statistics are instead used during
evaluation time as well.
.. note::
This :attr:`momentum` argument is different from one used in optimizer
classes and the conventional notion of momentum. Mathematically, the
update rule for running statistics here is
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
new observed value.
Because the Batch Normalization is done over the `C` dimension, computing statistics
on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, H, W)`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``
Shape:
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Examples::
>>> # With Learnable Parameters
>>> m = nn.BatchNorm2d(100)
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm2d(100, affine=False)
>>> input = torch.randn(20, 100, 35, 45)
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError("expected 4D input (got {}D input)".format(input.dim()))
================================================
FILE: examples/decision_making/swarm/Collision-Avoidance.py
================================================
import torch,os
from braincog.model_zoo.rsnn import RSNN
from random import randint
import math
import random
import matplotlib
# matplotlib.use("TkAgg")
import numpy as np
import random
import matplotlib.pyplot as plt
import matplotlib.animation as animation
#os.environ["SDL_VIDEODRIVER"] = "dummy"
#parameters
N =10
WORLD_WIDTH = 500
COLLISION_THRE =25 #60 65 70
WALL_COLLISION_LIMIT=10
VISIBLE_THRE=75 #3=75/COLLISION_THRE #3*COLLISION_THRE
#eight velocity
vel_space=[[0,1],[1,0],[0,-1],[-1,0],[1,1],[1,-1],[-1,-1],[-1,1]]
vel_x_small=[[0,1],[1,0],[0,-1],[1,1],[1,-1]]
vel_x_large=[[0,1],[0,-1],[-1,0],[-1,-1],[-1,1]]
vel_y_small=[[0,1],[1,0],[-1,0],[1,1],[-1,1]]
vel_y_large=[[1,0],[0,-1],[-1,0],[1,-1],[-1,-1]]
N_action=len(vel_space)
col_robot=[i for i in range(N)]
# parameters for rl+snn
C = 50
runtime = 100 # Runtime in ms for choosing action
# parameters for snn
tau = 10 # time constant of STDP
stdpwin = 10 # STDP windows in ms
Apos = 0.925
Aneg = 0.1
vr = 0 # Reset Potential
vt = 0.1 # Judge if the neurons fire or not
tau_m = 20
Rm = 0.5
tau_e = 5
# inhibition weight between output population
s_in = np.random.rand(N_action * C, N_action * C)
for i in range(N_action):
for j in range(C):
for k in range(C):
s_in[i * C + j][i * C + k] = 0
#init boids with no collision
global boids
boids = np.zeros(N, dtype=[('pos', int, 2), ('vel', int, 2),('nn',RSNN)])
list_rand=[i for i in range(16)]
rand_int=random.sample(list_rand,N)
for i in range(len(rand_int)):
boids['pos'][i,0]=np.random.uniform(int(rand_int[i]%4)*125,(int(rand_int[i]%4)+1)*125+1,1)
boids['pos'][i,1] = np.random.uniform(int(rand_int[i]/4) * 125, (int(rand_int[i]/4) + 1) * 125 + 1, 1)
boids['vel'] = np.random.uniform(-1, 2, (N, 2))
for i_vel in range(len(boids['vel'])):
boids['nn'][i_vel] = RSNN(N_action*2,N_action*C).cuda()
while(boids['vel'][i_vel][0]==0 and boids['vel'][i_vel][1]==0):
boids['vel'][i_vel] = np.random.uniform(-1, 2, (1, 2))
#update boids parameters
do_update=np.zeros(N)
distance_pre=np.zeros((N,N))
tmp_min_robot=[i for i in range(N)]
tmp_input=[i for i in range(N)]
sum_deta_tmp=np.zeros(N)
sum_deta_new=np.zeros(N)
trace_decay = 0.8
def chooseAct(Net,input,explore):
count_group = np.zeros(N_action)
count_output = np.zeros(N_action * C)
if explore==-1:
pass
else:
pass
for i_train in range(runtime):
out, dw = Net(input[:,i_train])
# rstdp
Net.weight_trace *= trace_decay
Net.weight_trace += dw[0][0]
count_output=count_output+np.array(out)
for i in range(N_action):
count_group[i]=count_output[i*C:(i+1)*C].sum()
if count_group.max()>C/2:
action=count_group.argmax()
return action,Net
# if t==runtime-2 and len(np.where(self.count_group==0)[0])!=len(self.count_group):
# self.action=self.count_group.argmax()
def update_boids(xs, ys, xvs, yvs,frame):
global distance_pre,col_c
# Matrix off position difference and distance
xdiff = np.add.outer(xs, -xs)
ydiff = np.add.outer(ys, -ys)
distance = np.sqrt(xdiff ** 2 + ydiff ** 2)
# Calculate the boids that are visible to every other boid -pi/2 to pi/2
visible = np.zeros((N, N))
dir = np.zeros((N, N))
col_c = WORLD_WIDTH * np.ones((N, 4))
dir_c = np.zeros((N, 4))
angle_towards = np.arctan2(-ydiff, -xdiff)
angle_vel = np.arctan2(yvs, xvs)
for i in range(N):
for j in range(N):
if (xvs[i] == 1 and yvs[i] == 0) or (xvs[i] == 1 and yvs[i] == 1) or (xvs[i] == 0 and yvs[i] == 1) or (
xvs[i] == 0 and yvs[i] == -1) or (xvs[i] == 1 and yvs[i] == -1):
if angle_towards[i][j] < angle_vel[i] + np.pi / 2 and angle_towards[i][j] > angle_vel[i] - np.pi / 2:
visible[i][j] = True
if angle_towards[i][j] > angle_vel[i] - np.pi / 2 and angle_towards[i][j] < angle_vel[i]:
dir[i][j]=1#right
if angle_towards[i][j] < angle_vel[i] + np.pi / 2 and angle_towards[i][j] >= angle_vel[i]:
dir[i][j] = 2#left
if xvs[i] == -1 and yvs[i] == 1:
if (angle_towards[i][j] > angle_vel[i] - np.pi / 2 and angle_towards[i][j] < np.pi) or (
angle_towards[i][j] > -np.pi and angle_towards[i][j] < angle_vel[i] - 1.5 * np.pi):
visible[i][j] = True
if angle_towards[i][j] > angle_vel[i] - np.pi / 2 and angle_towards[i][j] < angle_vel[i]:
dir[i][j] = 1
if (angle_towards[i][j] < np.pi and angle_towards[i][j] >= angle_vel[i]) or (
angle_towards[i][j] > -np.pi and angle_towards[i][j] < angle_vel[i] - 1.5 * np.pi):
dir[i][j] = 2
if xvs[i] == -1 and yvs[i] == 0:
if (angle_towards[i][j] > np.pi / 2 and angle_towards[i][j] < np.pi) or (
angle_towards[i][j] > -np.pi and angle_towards[i][j] < -np.pi / 2):
visible[i][j] = True
if angle_towards[i][j] > np.pi / 2 and angle_towards[i][j] < np.pi:
dir[i][j] = 1
if angle_towards[i][j] >= -np.pi and angle_towards[i][j] < -np.pi / 2:
dir[i][j] = 2
if xvs[i] == -1 and yvs[i] == -1:
if (angle_towards[i][j] > -np.pi and angle_towards[i][j] < -np.pi / 4) or (
angle_towards[i][j] > 0.75 * np.pi and angle_towards[i][j] < np.pi):
visible[i][j] = True
if (angle_towards[i][j] > 0.75 * np.pi and angle_towards[i][j] < np.pi) or (
angle_towards[i][j] > -np.pi and angle_towards[i][j] < angle_vel[i]):
dir[i][j] = 1
if angle_towards[i][j] >= angle_vel[i] and angle_towards[i][j] < -np.pi / 4:
dir[i][j] = 2
v_tmp = np.diag(np.diag(visible))
visible = visible - v_tmp
# the danger of collision, considering dis=6*collision
collision = np.clip(VISIBLE_THRE/COLLISION_THRE - distance / COLLISION_THRE, 0,VISIBLE_THRE/COLLISION_THRE) * visible # visible and in some distance 3*collision_thre
c_tmp = np.diag(np.diag(collision))
collision = collision - c_tmp
if len(np.where(yvs[np.where(ys < (VISIBLE_THRE/COLLISION_THRE)*WALL_COLLISION_LIMIT)] == -1)[0])>0:
wall_tmp=np.where(ys < (VISIBLE_THRE/COLLISION_THRE)*WALL_COLLISION_LIMIT)[0]
for i_wall in range(len(wall_tmp)):
if yvs[wall_tmp[i_wall]] == -1:
col_c[wall_tmp[i_wall], 0] = ys[wall_tmp[i_wall]]
if xvs[wall_tmp[i_wall]] >= 0:
dir_c[wall_tmp[i_wall], 0] = 1
else:
dir_c[wall_tmp[i_wall], 1] = 2
if len(np.where(xvs[np.where(xs < (VISIBLE_THRE/COLLISION_THRE)*WALL_COLLISION_LIMIT)]==-1)[0])>0:
wall_tmp = np.where(xs < (VISIBLE_THRE/COLLISION_THRE)*WALL_COLLISION_LIMIT)[0]
for i_wall in range(len(wall_tmp)):
if xvs[wall_tmp[i_wall]] == -1:
col_c[wall_tmp[i_wall], 1] = xs[wall_tmp[i_wall]]
if yvs[wall_tmp[i_wall]] >= 0:
dir_c[wall_tmp[i_wall], 1] = 2
else:
dir_c[wall_tmp[i_wall], 1] = 1
if len(np.where(yvs[np.where((WORLD_WIDTH - ys) < (VISIBLE_THRE/COLLISION_THRE) * WALL_COLLISION_LIMIT)] == 1)[0]) > 0:
wall_tmp = np.where((WORLD_WIDTH - ys) < (VISIBLE_THRE/COLLISION_THRE) * WALL_COLLISION_LIMIT)[0]
for i_wall in range(len(wall_tmp)):
if yvs[wall_tmp[i_wall]]==1:
col_c[wall_tmp[i_wall],2] =WORLD_WIDTH - ys[wall_tmp[i_wall]]
if xvs[wall_tmp[i_wall]]>=0:
dir_c[wall_tmp[i_wall],2]=2
else:
dir_c[wall_tmp[i_wall], 2] = 1
if len(np.where(xvs[np.where((WORLD_WIDTH - xs) < (VISIBLE_THRE/COLLISION_THRE)*WALL_COLLISION_LIMIT)] ==1)[0])>0:
wall_tmp=np.where((WORLD_WIDTH - xs) < (VISIBLE_THRE/COLLISION_THRE) * WALL_COLLISION_LIMIT)[0]
for i_wall in range(len(wall_tmp)):
if xvs[wall_tmp[i_wall]]==1:
col_c[wall_tmp[i_wall],3] =WORLD_WIDTH - xs[wall_tmp[i_wall]]
if yvs[wall_tmp[i_wall]]>=0:
dir_c[wall_tmp[i_wall],3]=1
else:
dir_c[wall_tmp[i_wall], 3] = 2
# print(col_c)
col_c_tmp = np.clip(VISIBLE_THRE/COLLISION_THRE - col_c / WALL_COLLISION_LIMIT, 0, VISIBLE_THRE/COLLISION_THRE)
deta_dis_tmp = distance - distance_pre
deta_dis = deta_dis_tmp * collision # <0 and small is the obstacle
collision=np.c_[collision, col_c_tmp]
deta_dis=np.c_[deta_dis, -col_c_tmp]
dir=np.c_[dir,dir_c]
# print(collision,deta_dis)
#for every agent, choose the approaching agent as input
for i in range(N):
if frame>1 and do_update[i]>0:
sum_deta_new[i] = (tmp_input[i] * collision[i][tmp_min_robot[i]]).sum()
# print(sum_deta_new[i] ,sum_deta_tmp[i] )
if sum_deta_new[i] < sum_deta_tmp[i] :
r=10*(sum_deta_tmp[i]-sum_deta_new[i])
else:
r=-10*(sum_deta_new[i]-sum_deta_tmp[i])
boids['nn'][i].UpdateWeight(r)
if frame > 0:
do_update[i] =0
if len(np.where(deta_dis[i] < 0)[0]) > 0:
do_update[i] += 1
# then get the velocity direction of objects and the distance between them as the network input
appro_index = np.where(deta_dis[i] < 0)[0] # the input is the approching directions and distances
# print(appro_index)
input = []
for j in range(len(appro_index)):
if appro_index[j]<=N-1:
xvs_input = xvs[appro_index[j]]
yvs_input = yvs[appro_index[j]]
input.append(vel_space.index([xvs_input, yvs_input]))
else:
vel_tmp=int(appro_index[j]%N)
input.append(vel_tmp)
dis_tmp=np.c_[distance,col_c]
weight = -1 * dis_tmp[i][np.where(deta_dis[i] < 0)]
# input=input[np.argmin(weight)]
if weight.max() - weight.min() == 0:
weight = np.random.randint(1, 5, weight.shape)
weight[0] = 4
else:
k = (4 - 1) / (weight.max() - weight.min())
weight = 1 + k * (weight - weight.min())
# print(input,weight)
I = np.zeros((N_action*2, runtime))
for j in range(len(input)):
# print(appro_index,input,appro_index[j],dir[i][appro_index[j]],input[j]*dir[i][appro_index[j]])
I[int(input[j]+N_action*(dir[i][appro_index[j]]-1))][0:runtime] = max(I[int(input[j]+N_action*(dir[i][appro_index[j]]-1))][0], weight[j])
if random.random()<0.7:
action_index,boids['nn'][i] = chooseAct(boids['nn'][i],I,-1)#exploitation
else:
action_index,boids['nn'][i] = chooseAct(boids['nn'][i],I, 1) #exploration
xvs[i] = vel_space[action_index][0]
yvs[i] = vel_space[action_index][1]
tmp_min_robot[i] = np.where(deta_dis[i] < 0)[0]
tmp_input[i] = weight
sum_deta_tmp[i] = (tmp_input[i] * collision[i][tmp_min_robot[i]]).sum()
xs+=xvs
ys+=yvs
xs=np.clip(xs,0,WORLD_WIDTH)
ys = np.clip(ys, 0, WORLD_WIDTH)
distance_pre = distance
if frame>=10000:
for i in range(N):
for j in range(N_action*2):
I = np.zeros((N_action * 2, runtime))
I[j][0:runtime]=4
a=chooseAct(boids['nn'][i],I,-1)
# print(a)
aaa=1
def animate(frame):
update_boids(boids['pos'][:, 0], boids['pos'][:, 1], boids['vel'][:, 0], boids['vel'][:, 1],frame)
scatter.set_offsets(boids['pos'])
scatter1.set_offsets(boids['pos'])
#build background
fig = plt.figure(figsize=(8, 8))
ax1 = fig.add_subplot(111)
ax1.set_title('Scatter Plot')
plt.xlim(-20,520)
plt.ylim(-20,520)
plt.grid(ls='--',c='gray')
plt.xlabel('X')
plt.ylabel('Y')
# Use a scatter plot to visualize the boids
color_list=['r','b','g','y','m','c','deeppink','tomato','gold','crimson','cornsilk','darkred','greenyellow','lightcoral','mintcream',
'rosybrown']
colors=color_list[0:N]
#colors=random.sample(color_list,N)
lines=np.zeros(N)+5
scatter = ax1.scatter(boids['pos'][:, 0], boids['pos'][:, 1],s=500,alpha=0.5,linewidths=lines)
scatter1 = ax1.scatter(boids['pos'][:, 0], boids['pos'][:, 1],s=2500,c=colors,alpha=0.5)
boids_newp=boids['pos']+boids['vel']*10
for i in range(N):
boids_linex=np.hstack((boids['pos'][i, 0],boids_newp[i,0]))
boids_liney=np.hstack((boids['pos'][i, 1],boids_newp[i,1]))
#line,=plt.plot(boids_linex,boids_liney,linewidth=5)
#lines = [ax1.plot(np.hstack((boids['pos'][i, 0],boids_newp[i,0])), np.hstack((boids['pos'][i, 1],boids_newp[i,1])),linewidth=5) for i in range(N)]
animation = animation.FuncAnimation(fig, animate,interval=0.001)
plt.show()
================================================
FILE: examples/decision_making/swarm/README.md
================================================
# Reward-modulated Spiking Neural Network for Self-organizing Collision Avoidance of Drone Swarm
This repository contains code from our paper [*Nature-inspired Self-organizing Collision Avoidance for Drones Swarm Based on Reward-modulated Spiking Neural Network*] published in Cell Patterns.
https://www.cell.com/patterns/fulltext/S2666-3899(22)00236-7
We also provide the BrainCog-based version: https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/decision_making/swarm
If you use our code or refer to this project, please cite this paper:
Feifei Zhao,Yi Zeng, Bing Han, Hongjian Fang, and Zhuoya Zhao. Nature-inspired Self-organizing Collision Avoidance for Drones Swarm Based on Reward-modulated Spiking Neural Network. Patterns, DOI:https://doi.org/10.1016/j.patter.2022.100611
## Paper Introduction
The collaborative interaction mechanisms of biological swarms in nature are of great importance to inspire the study of swarm intelligence. This paper proposed a self-organizing obstacle avoidance model by drawing on the decentralized, self-organizing properties of intelligent behavior of biological swarms. Each individual independently adopts brain-inspired reward-modulated spiking neural network (RSNN) to achieve online learning and makes decentralized decisions based on local observations. The following picture shows the decision-making process of our model.
We validated the proposed model on swarm collision avoidance tasks (a swarm of unmanned aerial vehicles without central control) in a bounded space, carrying out simulation and real-world experiments. The drone swarm emerges with safe flight behavior, as shown in the following videos. Compared with artificial neural network-based online learning methods, our proposed method exhibits superior performance and better stability.


## Run
* "reward-modulated snn on swarm collision avoidance.py" includes the self-organized collision avoidance implemented by RSNN for simulation scenarios.
* "flytestfive.py" includes five UAVs swarm collision avoidance validation in real bounded scenario.
## Requirments
* "reward-modulated snn on swarm collision avoidance.py": python==3.7, numpy>=1.21.6
* "flytestfive.py": multi_robomaster
================================================
FILE: requirements.txt
================================================
numpy
scipy
h5py
torch
torchvision
torchaudio
timm == 0.6.13
scikit-learn
einops
thop
pyyaml
matplotlib
seaborn
pygame
dv
tensorboard
tonic
================================================
FILE: setup.py
================================================
from setuptools import find_packages
from setuptools import setup
with open("./requirements.txt", "r", encoding="utf-8") as fh:
install_requires = fh.read()
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
setup(
install_requires=install_requires,
packages=find_packages(),
description="BrainCog is an open source spiking neural network based brain-inspired cognitive intelligence engine for Brain-inspired Artificial Intelligence and brain simulation. More information on braincog can be found on its homepage http://www.brain-cog.network/",
long_description=long_description,
long_description_content_type="text/markdown",
classifiers=[
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"License :: Other/Proprietary License",
"Operating System :: OS Independent",
],
name='braincog',
version='0.2.7.19',
author='braincog',
python_requires='>=3.6'
)