Showing preview only (479K chars total). Download the full file or copy to clipboard to get everything.
Repository: talmo/leap
Branch: master
Commit: c39e07b647da
Files: 157
Total size: 36.8 MB
Directory structure:
gitextract_ix4wtaz2/
├── .gitignore
├── LICENSE
├── analysis/
│ └── gait_analysis/
│ ├── Cluster_Velocity_Distributions.mat
│ ├── GaitVectors3.mat
│ ├── Gait_Densities.mat
│ ├── Gait_Speed_Distributions.mat
│ ├── Swing_Velocity_Over_Time.mat
│ ├── Swing_and_Stance_versus_Velocity.mat
│ ├── TetrapodExample.mat
│ ├── TripodExample.mat
│ ├── compute_gait_densities.m
│ ├── gait_analysis_computation.m
│ ├── gait_analysis_plotting.m
│ └── plot_gait_densities.m
├── data/
│ └── readme.md
├── examples/
│ ├── batch_process_video.ipynb
│ ├── hdf5tovid.m
│ └── vidtohdf5.m
├── install_leap.m
├── leap/
│ ├── __init__.py
│ ├── compute_errors.m
│ ├── confmaps2pts.m
│ ├── generate_training_set.m
│ ├── graph2paf.m
│ ├── guis/
│ │ ├── cluster_sample.mlapp
│ │ ├── create_skeleton.mlapp
│ │ └── label_joints.m
│ ├── hpc/
│ │ └── python_gpu.sh
│ ├── image_augmentation.py
│ ├── layers.py
│ ├── models.py
│ ├── plot_joints_single.m
│ ├── predict_box.m
│ ├── predict_box.py
│ ├── pts2confmaps.m
│ ├── test_leap.m
│ ├── toolbox/
│ │ ├── aliases/
│ │ │ ├── alims.m
│ │ │ ├── ff.m
│ │ │ ├── h5file.m
│ │ │ ├── imgsc.m
│ │ │ └── repext.m
│ │ ├── graphics/
│ │ │ ├── FEX-settingsdlg/
│ │ │ │ └── settingsdlg.m
│ │ │ ├── GUI Layout Toolbox/
│ │ │ │ └── layout/
│ │ │ │ └── +uix/
│ │ │ │ ├── +mixin/
│ │ │ │ │ ├── Container.m
│ │ │ │ │ ├── Flex.m
│ │ │ │ │ └── Panel.m
│ │ │ │ ├── Box.m
│ │ │ │ ├── ChildEvent.m
│ │ │ │ ├── ChildObserver.m
│ │ │ │ ├── Container.m
│ │ │ │ ├── Divider.m
│ │ │ │ ├── FigureData.m
│ │ │ │ ├── FigureObserver.m
│ │ │ │ ├── Grid.m
│ │ │ │ ├── GridFlex.m
│ │ │ │ ├── HBox.m
│ │ │ │ ├── Node.m
│ │ │ │ ├── Panel.m
│ │ │ │ ├── PointerManager.m
│ │ │ │ ├── SelectionData.m
│ │ │ │ ├── VBox.m
│ │ │ │ ├── calcPixelSizes.m
│ │ │ │ ├── setPosition.m
│ │ │ │ └── tracking.m
│ │ │ ├── distributionPlot/
│ │ │ │ └── colorCode2rgb.m
│ │ │ ├── draggable/
│ │ │ │ └── draggable.m
│ │ │ ├── figclosekey.m
│ │ │ ├── figsize.m
│ │ │ ├── fontsize.m
│ │ │ ├── hline.m
│ │ │ ├── isax.m
│ │ │ ├── isfig.m
│ │ │ ├── noticks.m
│ │ │ ├── pareto2.m
│ │ │ ├── plotExplainedVar.m
│ │ │ ├── plotpts.m
│ │ │ ├── redblue.m
│ │ │ ├── sc/
│ │ │ │ ├── gray.m
│ │ │ │ ├── private/
│ │ │ │ │ ├── colormap_helper.m
│ │ │ │ │ └── rescale.m
│ │ │ │ └── real2rgb.m
│ │ │ └── shortticks.m
│ │ ├── hdf5/
│ │ │ ├── h5att2struct.m
│ │ │ ├── h5getdatasets.m
│ │ │ ├── h5readframes.m
│ │ │ ├── h5readgroup.m
│ │ │ ├── h5save.m
│ │ │ ├── h5savegroup.m
│ │ │ ├── h5size.m
│ │ │ ├── h5struct2att.m
│ │ │ └── hdf5prop/
│ │ │ ├── h5datacreate.m
│ │ │ └── hdf5prop.m
│ │ ├── imageproc/
│ │ │ └── ind2im.m
│ │ ├── inputParsing/
│ │ │ ├── get_caller_name.m
│ │ │ ├── nameval2struct.m
│ │ │ ├── parse_params.m
│ │ │ └── struct2nameval.m
│ │ ├── io/
│ │ │ ├── GetFullPath.m
│ │ │ ├── dir_ext.m
│ │ │ ├── dir_paths.m
│ │ │ ├── dir_regex.m
│ │ │ ├── exists.m
│ │ │ ├── ext2filter_spec.m
│ │ │ ├── extrep.m
│ │ │ ├── funpath.m
│ │ │ ├── get_ext.m
│ │ │ ├── get_filename.m
│ │ │ ├── get_filesize.m
│ │ │ ├── get_new_filename.m
│ │ │ ├── lastdir.m
│ │ │ ├── mkdirto.m
│ │ │ └── uibrowse.m
│ │ ├── ml/
│ │ │ └── ezpca.m
│ │ ├── strings/
│ │ │ ├── bytes2str.m
│ │ │ ├── instr.m
│ │ │ ├── printf.m
│ │ │ ├── secs2hms.m
│ │ │ └── secsf.m
│ │ ├── utilities/
│ │ │ ├── af.m
│ │ │ ├── arange.m
│ │ │ ├── areempty.m
│ │ │ ├── argmin.m
│ │ │ ├── arr2cell.m
│ │ │ ├── cell1.m
│ │ │ ├── cellcat.m
│ │ │ ├── cf.m
│ │ │ ├── clip.m
│ │ │ ├── functional_programming/
│ │ │ │ └── wrap.m
│ │ │ ├── get_new_string.m
│ │ │ ├── grp2cell.m
│ │ │ ├── horz.m
│ │ │ ├── iseven.m
│ │ │ ├── loadvar.m
│ │ │ ├── nunique.m
│ │ │ ├── rownorm.m
│ │ │ ├── stacks/
│ │ │ │ ├── imtile.m
│ │ │ │ ├── stack2cell.m
│ │ │ │ ├── stack2vecs.m
│ │ │ │ └── vecs2stack.m
│ │ │ ├── swap.m
│ │ │ ├── time/
│ │ │ │ ├── GetSystemTimePreciseAsFileTime.m
│ │ │ │ ├── GetSystemTimePreciseAsFileTime.mexw64
│ │ │ │ ├── stic.m
│ │ │ │ ├── stoc.m
│ │ │ │ ├── stocf.m
│ │ │ │ └── systime.m
│ │ │ ├── varsize.m
│ │ │ ├── varstruct.m
│ │ │ ├── vert.m
│ │ │ └── vplay.m
│ │ └── video/
│ │ ├── validate_stack.m
│ │ └── vplayer.m
│ ├── training.py
│ ├── utils.py
│ └── viz.py
├── readme.md
├── setup.py
└── uninstall_leap.m
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
# MATLAB
*.asv
# LEAP-specific
data/*
!data/readme.md
models/*
leap/toolbox/io/.lastdir
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: analysis/gait_analysis/GaitVectors3.mat
================================================
[File too large to display: 25.2 MB]
================================================
FILE: analysis/gait_analysis/Gait_Densities.mat
================================================
[File too large to display: 11.2 MB]
================================================
FILE: analysis/gait_analysis/compute_gait_densities.m
================================================
%% Pathing
embed_path = 'Z:\code\2018-05-05_joints_tsne_FlyAging_talmo-labels\results\FlyAging-DiegoCNN_v1.0_filters=64_rot=15_lrfactor=0.1_lrmindelta=1e-05_03.mat';
density_path = 'Z:\code\2018-05-05_joints_tsne_FlyAging_talmo-labels\viz\FlyAging-DiegoCNN_v1.0_filters=64_rot=15_lrfactor=0.1_lrmindelta=1e-05_03\density.mat';
density = load(density_path);
embed = load(embed_path);
Y = embed.Y;
Ld = density.Ld;
gait_path = 'C:\code\murthylab\JointTracker\LabelPostProcessing\GaitVectors3.mat';
gait = load(gait_path);
%% Compute density and segment
sigma = 20/30;
numGridPoints = 500;
gridRange = [-20 20];
% Setup grid
gv = linspace(gridRange(1), gridRange(2), numGridPoints);
xv = gv; yv = gv;
D_tripod = getDensity(Y(gait.hmm.most_likely_seq == 3 & gait.moving_forward',:),sigma, numGridPoints, gridRange);
D_tetrapod = getDensity(Y(gait.hmm.most_likely_seq == 4 & gait.moving_forward',:),sigma, numGridPoints, gridRange);
D_NC = getDensity(Y(gait.hmm.most_likely_seq == 5 & gait.moving_forward',:),sigma, numGridPoints, gridRange);
gait_density = zeros([size(D_tripod),3]);
gait_density(:,:,1) = D_tripod;
gait_density(:,:,2) = D_tetrapod;
gait_density(:,:,3) = D_NC;
% Saving
save_path = 'Gait_Densities';
save(save_path,'gait_density','D_tripod','D_tetrapod','D_NC','Ld');
================================================
FILE: analysis/gait_analysis/gait_analysis_computation.m
================================================
clear all;
%% Pathing
% addpath(genpath('deps'));
joints_dir = 'Z:\data\JointTracker\2018-02_FlyAging_boxes\expts\preds\talmo-labels\FlyAging-DiegoCNN_v1.0_filters=64_rot=15_lrfactor=0.1_lrmindelta=1e-05_03';
data_dir = 'Z:\data\Fly_Aging\male_data';
% Get the paths from the directories
data_fns = dir([data_dir,'/*_*']);
exptnames = {data_fns(:).name};
joint_fns = dir(strcat(joints_dir,'\*.h5'));
joint_exptnames = {joint_fns(:).name};
joint_exptnames = cf(@(x) x(1:end-3),joint_exptnames);
joint_expt_2_data = zeros(size(joint_exptnames));
for i = 1:numel(joint_exptnames)
for j = 1:numel(exptnames)
if strcmp(exptnames{j},joint_exptnames{i})
joint_expt_2_data(i) = j;
end
end
end
joint_paths = cell(size(joint_fns));
for i = 1:numel(joint_fns)
joint_paths{i} = fullfile(joint_fns(i).folder,joint_fns(i).name);
end
data_paths = cell(size(data_fns));
for i = 1:numel(data_fns)
data_paths{i} = [fullfile(data_fns(i).folder,data_fns(i).name),'/Positions.dat'];
end
%% Get all of the positions for all of the videos.
pos = cell(size(joint_paths));
% Loads the joint positions and adds the thorax back in to the fifth
% feature as all zeros.
parfor i = 1:numel(joint_paths)
joints = h5read(joint_paths{i},'/positions_pred');
joints = joints - joints(5,:,:);
joints = reshape(joints,[],size(joints,3));
pos{i} = joints;
end
pos = cat(2,pos{:});
%% Get the ids of all of the walking bouts according to the speed of centroids
moving_forward = cell(size(data_paths));
forward_velocity = cell(size(data_paths));
speed = cell(size(data_paths));
smoothing_window = 5;
% Velocity thresholds
reconversion_constant = (1/(24.40/1088))./35; % This corrects a previous measurement error
conversion_to_mm = 31.0857/1088;
forward_motion_thresh = .02; % 2 mm/s
% For each video, get all velocity and speed stats;
for i = 1:numel(data_paths)
% Load centers from ellipse data and smooth
frames = h5read(['Z:\data\JointTracker\2018-02_FlyAging_boxes\expts\' exptnames{i} '.h5'],'/framesIdx');
% ell = h5read(data_paths{i},'/ell');
[X,Y,numLines] = positionReader(data_paths{i});
X = X.*reconversion_constant;
Y = Y.*reconversion_constant;% reconversion constants.
ctr = [X(frames),Y(frames)];
ctr = smoothdata(ctr,1,'movmean',smoothing_window);
ell = h5read(['Z:\data\JointTracker\2018-02_FlyAging_boxes\expts\' exptnames{i} '.h5'],'/ell');
% Get the velocity, direction of motion, and orientation of the fly
vel_ctr = diffpad(ctr);
% vel_ctr = smoothdata(vel_ctr,1,'movmean',smoothing_window*10);
direction_of_motion = smoothdata(mod(unwrap(atan2(vel_ctr(:,2),vel_ctr(:,1))),2*pi),'movmean',smoothing_window);
orientation = mod(unwrap((ell(frames,5)*2*pi/360)),2*pi);
difference_dir = abs(direction_of_motion-orientation);
% Get the component of the velocity in the forward direction
speed_ctr = sqrt(sum(vel_ctr.^2,2));
speed{i} = speed_ctr;
forward_velocity{i} = cos(difference_dir).*speed{i};
moving_forward{i} = forward_velocity{i} > forward_motion_thresh;
end
lengths = cellfun(@(x) numel(x),moving_forward);
speed = cat(1,speed{:});
moving_forward = cat(1,moving_forward{:});
forward_velocity = cat(1,forward_velocity{:});
fv = forward_velocity;
%% Get rasters of when legs are moving in the forward direction
% This is meant to replicate "Quantification of
% gait parameters in freely walking wild type and sensory deprived
% Drosophila melanogaster" Figure 4
Fs = 100;
% We want to look at only the leg tips
tips = [22 26 30 10 14 18]; % (The order matches the paper)
pos_tips = reshape(pos,[],2,size(pos,2));
pos_tips = pos_tips(tips,:,:);
dim = 1; % Only look in the x direction
traj = squeeze(pos_tips(:,dim,:));
% Get the velocity relative to the center of the fly (egocentric vel)
vel = diffpad(traj,2);
vel = smoothdata(vel,2,'gauss',smoothing_window);
% Define stance to be when the legs move in the negative x direction.
stance = vel<0;
%% Get the contiguous bouts that will be used for HMM fitting
seq = sum(stance)+1;
TRGuess = rand(3);
EMITGuess = rand(3,7);
mf = find(moving_forward);
% Get the frames of contiguous bouts
endFrames = cumsum(lengths);
startFrames = [1 endFrames(1:end-1)'+1];
samples = cell1(numel(lengths));
samples_per_video = 3000;
duration_thresh = 50;
for i = 1:numel(lengths)
bw = bwconncomp(moving_forward(startFrames(i):endFrames(i)));
duration = cellfun(@(x) numel(x),bw.PixelIdxList);
bw.PixelIdxList(duration < duration_thresh) = [];
ids = cat(1,bw.PixelIdxList{:});
ids = ids(1:(min(samples_per_video,numel(ids))));
samples{i} = startFrames(i) + ids - 1;
end
num_samples = cellfun(@(x) numel(x),samples);
samples = cat(1,samples{num_samples>1});
%% Train HMM and get most likely sequence
tic;[ESTTR,ESTEMIT] = hmmtrain(seq(samples),TRGuess,EMITGuess);toc
mls = hmmviterbi(seq,ESTTR,ESTEMIT);
%% Here you should look at the emission probabilities
figure;
imagesc(ESTEMIT);
%% Reorder the labels to 3 = tripod, 4 = tetrapod, 5 = non-canonical
% Note: this can be automated, but since initialization is random
% and it doesn't take that long to train, I prefer checking manually.
mls(mls == 3) = 4;
mls(mls == 2) = 5;
mls(mls == 1) = 3;
%% Save results into a structure
hmm.TRGUESS = TRGuess;
hmm.EMITGUESS = EMITGuess;
hmm.TrainingSamples = uint8(seq(samples));
hmm.ESTTR = ESTTR;
hmm.ESTEMIT = ESTEMIT;
hmm.most_likely_seq = uint8(mls);
% Saving
% save_path = 'GaitVector3';
% save(save_path,'hmm');
%% Look at particular section
% Limit the window to a particular region
% Tripod = 17230751;
% Tetrapod = 4781582;
% Tetrapod = 8472800
start = 8472800;
ids = start:start+100;
% Save example gait vectors
example_vel = vel(:,ids);
example_stance = stance(:,ids);
example_fv = fv(ids);
% Saving
% save_path = 'TripodExample';
% save_path = 'TetrapodExample';
% save(save_path,'example_vel','example_stance','example_fv','Fs');
%% Calculate the distribution of speeds for each hidden state and save
mls = hmm.most_likely_seq;
speed_lim = [2 35]; % mm/s
tri_ids = moving_forward & (mls == 3)' & (fv.*Fs < speed_lim(2)) & (fv.*Fs > speed_lim(1));
tetra_ids = moving_forward & (mls == 4)' & (fv.*Fs < speed_lim(2)) & (fv.*Fs > speed_lim(1));
NC_ids = moving_forward & (mls == 5)' & (fv.*Fs < speed_lim(2)) & (fv.*Fs > speed_lim(1));
[N1,edges1] = histcounts(fv(tri_ids).*Fs,166);
[N2,edges2] = histcounts(fv(tetra_ids).*Fs,166);
[N3,edges3] = histcounts(fv(NC_ids).*Fs,166);
% Saving
% save_path = 'Gait_Speed_Distributions';
% save(save_path,'N1','N2','N3','edges1','edges2','edges3','speed_lim')
%% Look at the Velocity statistics per Tsne locomotor state
density_path = 'Z:\code\2018-05-05_joints_tsne_FlyAging_talmo-labels\viz\FlyAging-DiegoCNN_v1.0_filters=64_rot=15_lrfactor=0.1_lrmindelta=1e-05_03\density.mat';
density = load(density_path);
embedding_ordered_locomotor_states = [7 11 13 10 8 9];
num_states = numel(embedding_ordered_locomotor_states);
%% Calculate the velocity distributions.
speed_ids = fv > 0 & fv < .4;
h = cell1(num_states);
N = cell1(num_states);
edges = cell1(num_states);
for i = 1:num_states
state_ids = density.YL == embedding_ordered_locomotor_states(i);
[N{i},edges{i}] = histcounts(fv(state_ids & speed_ids)*100,'Normalization','pdf');
end
% Saving
% save_path = 'Cluster_Velocity_Distributions';
% save(save_path,'N','edges','num_states');
%% Calculate the velocity of legs during swing as you bin velocities differently
win = -1:5;
speed_levels = [2 5:5:45];
leg_vel_at_speed = zeros(numel(speed_levels)-1,numel(win));
leg_vel_std_at_speed = zeros(numel(speed_levels)-1,numel(win));
% For each body speed level get the velocity of the swings over the window
for s = 1:numel(speed_levels)-1
swings = cell([size(vel,1) 1]);
inSpeed = (speed'*Fs >= speed_levels(s) & speed'*Fs < speed_levels(s+1));
% Get the swings in which the fly was at the speed level and moving
% forward.
parfor i = 1:size(vel,1)
% First get the swing starts
bw = bwconncomp(~stance(i,:));
ids = cell2mat(cf(@(x) x(1),bw.PixelIdxList));
swing_start = false(size(stance(i,:)));
swing_start(ids) = true;
% Then get the swing starts where the fly is moving forward within
% the speed threshold.
bw = bwconncomp(swing_start & inSpeed);
ids = cell2mat(cf(@(x) x(1),bw.PixelIdxList));
ids = ids' + win;
vel_i = vel(i,:);
swings{i} = indpad(vel_i,ids);
end
swings = cat(1,swings{:});
leg_vel_at_speed(s,:) = nanmean(swings,1);
leg_vel_std_at_speed(s,:) = nanstd(swings,1);
end
% Saving
% save_path = 'Swing_Velocity_Over_Time';
% save(save_path,'leg_vel_at_speed','leg_vel_std_at_speed','speed_levels','win');
%% Stance / Swing Duration
swing_durations = cell([1,size(stance,1)]);
stance_durations = cell([1,size(stance,1)]);
swing_body_velocities = cell([1,size(stance,1)]);
stance_body_velocities = cell([1,size(stance,1)]);
period = cell([1,size(stance,1)]);
period_velocities = cell([1,size(stance,1)]);
% Get the swing, stance, and period duration and velocities
parfor i = 1:size(stance,1)
% Swing
bw = bwconncomp(~stance(i,:));
swing_durations{i} = cellfun(@(x) numel(x),bw.PixelIdxList);
swing_body_velocities{i} = cellfun(@(x) mean(fv(x)),bw.PixelIdxList);
% Stance
bw = bwconncomp(stance(i,:));
stance_durations{i} = cellfun(@(x) numel(x),bw.PixelIdxList);
stance_body_velocities{i} = cellfun(@(x) mean(fv(x)),bw.PixelIdxList);
% Period
% bw = bwconncomp(stance(i,:));
period{i} = cellfun(@(x,y) numel(x(1):y(end)),{bw.PixelIdxList{1:end-1}},{bw.PixelIdxList{2:end}});
period_velocities{i} = cellfun(@(x,y) mean(fv(x(1):y(end))),{bw.PixelIdxList{1:end-1}},{bw.PixelIdxList{2:end}});
end
swing_durations = cat(2,swing_durations{:});
stance_durations = cat(2,stance_durations{:});
swing_body_velocities = cat(2,swing_body_velocities{:});
stance_body_velocities = cat(2,stance_body_velocities{:});
period = cat(2,period{:});
period_velocities = cat(2,period_velocities{:});
%% Plot as a line plot
stance_vel_thresh = 7.2; % This number is taken from Mendes et al
ids = stance_body_velocities*Fs > stance_vel_thresh;
xranges = [stance_vel_thresh 50];
yranges = [0 prctile(stance_durations(ids),99)];
stance_edges = xranges(1):1:xranges(2);
% Stance Duration vs Velocity
X1 = stance_body_velocities(ids)*Fs;
Y1 = stance_durations(ids);
[X1ids] = discretize(X1,stance_edges);
[stance_dur_mu, stance_dur_std] = grpstats(Y1, categorical(X1ids),{'mean','std'});
swing_vel_thresh = 7.2; % This number is taken from Mendes et al
ids = swing_body_velocities*Fs > swing_vel_thresh;
xranges = [swing_vel_thresh 50];
swing_edges = xranges(1):1:xranges(2);
% Stance Duration vs Velocity
X2 = swing_body_velocities(ids)*Fs;
Y2 = swing_durations(ids);
[X2ids] = discretize(X2,swing_edges);
[swing_dur_mu, swing_dur_std] = grpstats(Y2, categorical(X2ids),{'mean','std'});
% Saving
% save_path = 'Swing_and_Stance_versus_Velocity';
% save(save_path,'stance_dur_mu','swing_dur_mu','stance_dur_std','swing_dur_std','stance_edges','swing_edges')
================================================
FILE: analysis/gait_analysis/gait_analysis_plotting.m
================================================
clear all;
%% Look at particular section
% Pick the gait example to observe
% load('TetrapodExample');
load('TripodExample');
% Plot the velocity of the leg tips
figure('pos',[153, 427, 560, 420]); hold on; axis tight; set(gcf,'color','w'); fontsize(16)
imagesc(example_vel);
xlabel('Time (seconds)')
xticklabels(xticks/Fs);
ylabel('Leg Tip')
yticks([1:6])
yticklabels({'RF','RM','RH','LF','LM','LH'})
ax1 = gca;
caxis([-10 10]);
% Plot the rasters
figure('pos',[850, 634, 848, 334]); hold on; axis tight; set(gcf,'color','w'); fontsize(16)
imagesc(example_stance); colormap('gray');
xlabel('Time (seconds)')
xticklabels(xticks/Fs);
ylabel('Leg Tip')
yticks([1:6])
yticklabels({'LF','LM','LH','RF','RM','RH'})
ax2 = gca;
% plot the forward velocity of the fly
figure('pos',[850, 359, 854, 186]); hold on; axis tight; set(gcf,'color','w'); fontsize(16)
plot(example_fv.*Fs)
xlabel('Time (seconds)')
xticklabels(xticks/Fs);
ylabel('Forward Velocity (mm/s)')
ax3 = gca;
ylim([0 40]);
% plot the raster of tripod, tetrapod, or non-canonical
figure('pos',[849, 114, 931, 150]); hold on; axis tight; set(gcf,'color','w'); fontsize(16)
example_gait = sum(example_stance,1);
example_gait(~(example_gait == 3 | example_gait == 4)) = 5;
imagesc(example_gait);colormap('jet');h = colorbar;
xlabel('Time (seconds)')
xticklabels(xticks/Fs);
yticks([])
ylabel(h,'Number of legs in stance')
ax4 = gca;
linkaxes([ax1,ax2,ax3,ax4],'x')
%% Plot the emission probabilities for each hidden states
load('GaitVectors3.mat');
emissions = hmm.ESTEMIT;
temp = emissions(2,:);
emissions(2,:) = emissions(3,:);
emissions(3,:) = temp;
figure; hold on;
imagesc(emissions);
for i = 1:size(emissions,1)
for j = 1:size(emissions,2)
caption = sprintf('%.2f',emissions(i,j));
text(j,i,caption,'Fontsize',10,'FontWeight','bold','HorizontalAlignment','center','Color',[0 0 0]);
end
end
axis ij;
axis tight
xlabel('Number of Legs in Stance');
yticks([1 2 3]);
yticklabels({'Tripod','Tetrapod','Non-canonical'})
xticklabels({'0','1','2','3','4','5','6'})
fontsize(16)
figure; hold on;
plot(emissions','LineWidth',3);
xlabel('Number of Legs in Stance');
xticklabels({'0','1','2','3','4','5','6'})
legend({'Tripod','Tetrapod','Non-canonical'})
ylabel('Emission Probability')
%% Plot the distribution of speeds
load('Gait_Speed_Distributions');
figure; hold on;
plot(edges1(1:end-1),N1,'LineWidth',3)
plot(edges2(1:end-1),N2,'LineWidth',3)
plot(edges3(1:end-1),N3,'LineWidth',3)
xlabel('Forward Velocity (ms)')
ylabel('Count')
xlim(speed_lim)
legend({'Tripod','Tetrapod','Non-canonical'})
fontsize(16)
%% Plot the velocity Distributions
load('Cluster_Velocity_Distributions');
figure; hold on;
cmap = spring(num_states);
for i = 1:num_states
plot(edges{i}(1:end-1),N{i},'Color',cmap(num_states + 1 -i,:),'LineWidth',3);
end
grid on;
xlabel('Forward Velocity')
ylabel('Probability')
axis tight;
ylim([0 .2])
%% Plotting the mean and std with bounded lines
load('Swing_Velocity_Over_Time');
cmap = parula(numel(speed_levels));
p_lines = cell1(numel(speed_levels)-1);
figure('pos',[568, 186, 1036, 798]); figclosekey; set(gcf,'color','w'); hold on;
for i = 1:size(leg_vel_at_speed,1)
yci = zeros(2,size(leg_vel_at_speed,2));
yci(1,:) = leg_vel_std_at_speed(i,:);
yci(2,:) = leg_vel_std_at_speed(i,:);
[p_lines{i},~] = boundedline(win*10,leg_vel_at_speed(i,:),yci','alpha','cmap',cmap(i,:));
p_lines{i}.LineWidth = 3;
end
% Legend
leg = cell([1 numel(speed_levels)-1]);
for i = 1:numel(speed_levels)-1
leg{i} = sprintf('%d - %d mm/s',speed_levels(i),speed_levels(i+1));
end
l = legend([p_lines{:}],leg);
l.Position = [0.7503 0.6488 0.1573 0.3239];
fontsize(16)
xlabel('Time from swing onset (ms)')
ylabel('Swing velocity (mm/s)')
% export_fig('figs/Swing_Velocity_vs_Time_Confidences.png','-r300')
%% Plot the Swing_and_Stance_versus_Velocity
load('Swing_and_Stance_versus_Velocity')
figure; figclosekey, hold on;
yci = zeros(2,numel(stance_dur_std));
yci(1,:) = stance_dur_std;
yci(2,:) = stance_dur_std;
[bl1,~] = boundedline(stance_edges(2:end),stance_dur_mu',yci','alpha');
yci = zeros(2,numel(swing_dur_std));
yci(1,:) = swing_dur_std;
yci(2,:) = swing_dur_std;
[bl2,~] = boundedline(swing_edges(2:end),swing_dur_mu',yci','alpha','r');
yticklabels(round(yticks*10))
ylabel('Durations (ms)');
xlabel('Average Body Speed (mm/s)');
axis tight
legend([bl1,bl2],{'Stance','Swing'})
fontsize(16);
================================================
FILE: analysis/gait_analysis/plot_gait_densities.m
================================================
clear all;
%% Plot all three densities overlayed ontop of one another
load('Gait_Densities');
figure; hold on;
h = imagesc(gait_density./(max(max(max(gait_density)))));
axis equal; axis xy; axis off
%% Plot the distributions for each mode in the same scale
Locomotor_states = [7 11 13 9 10 8];
[cropped,~,crop_mask] = bwcrop(ismember(Ld,Locomotor_states));
figure; hold on; axis xy; axis equal; axis off;
cropped_density = reshape(D_tripod(crop_mask),size(cropped,1),[]);
h = imagesc(cropped_density);
c1 = colorbar;
peak = max(c1.Limits);
colormap('viridis')
figure; hold on; axis xy; axis equal; axis off;
cropped_density = reshape(D_tetrapod(crop_mask),size(cropped,1),[]);
h = imagesc(cropped_density);
% h.AlphaData = 1 .* ~reshape(density.Lbnds(crop_mask),size(cropped,1),[]);
c2 = colorbar;
colormap('viridis')
caxis([0 peak]);
figure; hold on; axis xy; axis equal; axis off;
cropped_density = reshape(D_NC(crop_mask),size(cropped,1),[]);
h = imagesc(cropped_density);
% h.AlphaData = 1 .* ~reshape(density.Lbnds(crop_mask),size(cropped,1),[]);
colormap('viridis')
c3 = colorbar;
caxis([0 peak]);
================================================
FILE: data/readme.md
================================================
The full fly dataset and all trained networks used in our paper can be downloaded from: [`http://arks.princeton.edu/ark:/88435/dsp01pz50gz79z`](http://arks.princeton.edu/ark:/88435/dsp01pz50gz79z)
Additional datasets:
## BermanFlies
_We recommend using this for testing._
- [Cluster sampled images](https://1drv.ms/u/s!AnmpIqqfwz3zgbUg2M8Sa_0NcLhrMg) (**68.7 MiB**)
- [Labels for n = 1500 images](https://1drv.ms/u/s!AnmpIqqfwz3zgbUeIuYYDj_A1pCr9Q) (**291 KiB**)
- [Trained network model](https://1drv.ms/u/s!AnmpIqqfwz3zgcwUvbqPUn7mXIMLLg) (**53.0 MiB**)
- [Full dataset](http://arks.princeton.edu/ark:/88435/dsp01pz50gz79z) (**~168 GiB**)
## CatNect
This is the dataset used for the tutorial. It contains a clip of a cat chasing a laser pointer recorded with a Kinect.
- [Full clip](https://1drv.ms/u/s!AnmpIqqfwz3zgcwS_3gAJFU0sANBcA) (**105 MiB**)
- [Cluster sampled images](https://1drv.ms/u/s!AnmpIqqfwz3zgcwR_9mmNJz8ALW3Hw) (**42.2 MiB**)
- [Labels for n = 40 images](https://1drv.ms/u/s!AnmpIqqfwz3zgcwQlKSXXDy9KvIPVg) (**82.3 KiB**)
- [Trained network model](https://1drv.ms/u/s!AnmpIqqfwz3zgc0H-v6qbnc3vak2Lw) (**64.7 MiB**)
================================================
FILE: examples/batch_process_video.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example: Predict body part positions from an MP4 file\n",
"This notebook presents an example pipeline for applying a trained LEAP network to ~360k frames read from an MP4 video.\n",
"\n",
"You can download the data to reproduce the benchmarking results below.\n",
"\n",
"**Input data:** [072212_163153.mp4](https://1drv.ms/v/s!AnmpIqqfwz3zgcgekCxNp-MN76p1UQ) (254 MiB)\n",
"\n",
"**Output data:** [072212_163153.preds.h5](https://1drv.ms/u/s!AnmpIqqfwz3zgcgdDhQrKRsBaxvCXQ) (46.9 MiB)\n",
"\n",
"The trained network can be [downloaded here](https://1drv.ms/u/s!AnmpIqqfwz3zgdpIOzsqojhEmr0J0w) or from the links in the [repository data folder](https://github.com/talmo/leap/tree/master/data)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Anaconda3\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n",
"Using TensorFlow backend.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Platform: Windows-10-10.0.16299-SP0\n",
"h5py:\n",
"Summary of the h5py configuration\n",
"---------------------------------\n",
"\n",
"h5py 2.7.1\n",
"HDF5 1.10.1\n",
"Python 3.6.4 |Anaconda, Inc.| (default, Jan 16 2018, 10:22:32) [MSC v.1900 64 bit (AMD64)]\n",
"sys.platform win32\n",
"sys.maxsize 9223372036854775807\n",
"numpy 1.14.1\n",
"\n",
"Keras: 2.2.0\n",
"Tensorflow: 1.5.0\n",
"Devices:\n",
"[name: \"/device:CPU:0\"\n",
"device_type: \"CPU\"\n",
"memory_limit: 268435456\n",
"locality {\n",
"}\n",
"incarnation: 12947954769568633288\n",
", name: \"/device:GPU:0\"\n",
"device_type: \"GPU\"\n",
"memory_limit: 9143884186\n",
"locality {\n",
" bus_id: 1\n",
"}\n",
"incarnation: 16757352010506659625\n",
"physical_device_desc: \"device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:01:00.0, compute capability: 6.1\"\n",
", name: \"/device:GPU:1\"\n",
"device_type: \"GPU\"\n",
"memory_limit: 9143884186\n",
"locality {\n",
" bus_id: 1\n",
"}\n",
"incarnation: 16188733693954377295\n",
"physical_device_desc: \"device: 1, name: GeForce GTX 1080 Ti, pci bus id: 0000:02:00.0, compute capability: 6.1\"\n",
"]\n"
]
}
],
"source": [
"import os\n",
"import numpy as np\n",
"import cv2\n",
"import h5py\n",
"from time import time\n",
"\n",
"import keras\n",
"import keras.models\n",
"from leap.predict_box import convert_to_peak_outputs\n",
"from leap.utils import versions\n",
"\n",
"versions(list_devices=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Parameters"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Media file path\n",
"video_path = \"D:/tmp/072212_163153.mp4\"\n",
"\n",
"# Trained network path\n",
"model_path = \"D:/OneDrive/code/leap/data/BermanFlies/models/180615_025354-n=1500/final_model.h5\"\n",
"\n",
"# Predictions output path\n",
"save_path = \"D:/tmp/072212_163153.preds.h5\"\n",
"\n",
"# Number of frames to read before predicting (higher = faster, but limited by RAM)\n",
"chunk_size = 10000\n",
"\n",
"# Number of frames to evaluate at once on the GPU (higher = faster, but limited by GPU memory)\n",
"batch_size = 64"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Processing"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: D:/OneDrive/code/leap/data/BermanFlies/models/180615_025354-n=1500/final_model.h5\n",
" Input: (None, 192, 192, 1)\n",
" Output: (None, 3, 32)\n",
"Predicted: 10000/361000 frames | Elapsed: 0.4 min / 391.1 FPS / ETA: 15.0 min\n",
"Predicted: 20000/361000 frames | Elapsed: 0.8 min / 435.0 FPS / ETA: 13.1 min\n",
"Predicted: 30000/361000 frames | Elapsed: 1.1 min / 451.0 FPS / ETA: 12.2 min\n",
"Predicted: 40000/361000 frames | Elapsed: 1.5 min / 459.0 FPS / ETA: 11.7 min\n",
"Predicted: 50000/361000 frames | Elapsed: 1.8 min / 464.7 FPS / ETA: 11.2 min\n",
"Predicted: 60000/361000 frames | Elapsed: 2.1 min / 468.7 FPS / ETA: 10.7 min\n",
"Predicted: 70000/361000 frames | Elapsed: 2.5 min / 471.4 FPS / ETA: 10.3 min\n",
"Predicted: 80000/361000 frames | Elapsed: 2.8 min / 473.5 FPS / ETA: 9.9 min\n",
"Predicted: 90000/361000 frames | Elapsed: 3.2 min / 474.8 FPS / ETA: 9.5 min\n",
"Predicted: 100000/361000 frames | Elapsed: 3.5 min / 476.0 FPS / ETA: 9.1 min\n",
"Predicted: 110000/361000 frames | Elapsed: 3.8 min / 477.4 FPS / ETA: 8.8 min\n",
"Predicted: 120000/361000 frames | Elapsed: 4.2 min / 478.4 FPS / ETA: 8.4 min\n",
"Predicted: 130000/361000 frames | Elapsed: 4.5 min / 479.5 FPS / ETA: 8.0 min\n",
"Predicted: 140000/361000 frames | Elapsed: 4.8 min / 481.4 FPS / ETA: 7.7 min\n",
"Predicted: 150000/361000 frames | Elapsed: 5.2 min / 483.0 FPS / ETA: 7.3 min\n",
"Predicted: 160000/361000 frames | Elapsed: 5.5 min / 483.7 FPS / ETA: 6.9 min\n",
"Predicted: 170000/361000 frames | Elapsed: 5.9 min / 484.0 FPS / ETA: 6.6 min\n",
"Predicted: 180000/361000 frames | Elapsed: 6.2 min / 484.1 FPS / ETA: 6.2 min\n",
"Predicted: 190000/361000 frames | Elapsed: 6.5 min / 484.3 FPS / ETA: 5.9 min\n",
"Predicted: 200000/361000 frames | Elapsed: 6.9 min / 484.6 FPS / ETA: 5.5 min\n",
"Predicted: 210000/361000 frames | Elapsed: 7.2 min / 484.7 FPS / ETA: 5.2 min\n",
"Predicted: 220000/361000 frames | Elapsed: 7.6 min / 484.6 FPS / ETA: 4.8 min\n",
"Predicted: 230000/361000 frames | Elapsed: 7.9 min / 484.7 FPS / ETA: 4.5 min\n",
"Predicted: 240000/361000 frames | Elapsed: 8.3 min / 484.7 FPS / ETA: 4.2 min\n",
"Predicted: 250000/361000 frames | Elapsed: 8.6 min / 484.9 FPS / ETA: 3.8 min\n",
"Predicted: 260000/361000 frames | Elapsed: 8.9 min / 485.0 FPS / ETA: 3.5 min\n",
"Predicted: 270000/361000 frames | Elapsed: 9.3 min / 485.2 FPS / ETA: 3.1 min\n",
"Predicted: 280000/361000 frames | Elapsed: 9.6 min / 485.3 FPS / ETA: 2.8 min\n",
"Predicted: 290000/361000 frames | Elapsed: 10.0 min / 485.1 FPS / ETA: 2.4 min\n",
"Predicted: 300000/361000 frames | Elapsed: 10.3 min / 485.1 FPS / ETA: 2.1 min\n",
"Predicted: 310000/361000 frames | Elapsed: 10.7 min / 485.1 FPS / ETA: 1.8 min\n",
"Predicted: 320000/361000 frames | Elapsed: 11.0 min / 485.3 FPS / ETA: 1.4 min\n",
"Predicted: 330000/361000 frames | Elapsed: 11.3 min / 485.1 FPS / ETA: 1.1 min\n",
"Predicted: 340000/361000 frames | Elapsed: 11.7 min / 485.2 FPS / ETA: 0.7 min\n",
"Predicted: 350000/361000 frames | Elapsed: 12.0 min / 485.4 FPS / ETA: 0.4 min\n",
"Predicted: 360000/361000 frames | Elapsed: 12.4 min / 485.6 FPS / ETA: 0.0 min\n",
"Predicted: 361000/361000 frames | Elapsed: 12.4 min / 485.4 FPS / ETA: 0.0 min\n",
"Finished predicting 361000 frames.\n",
" Prediction | Runtime: 11.55 min / 520.921 FPS\n",
" Reading | Runtime: 0.78 min / 7726.516 FPS\n",
"Saved: D:/tmp/072212_163153.preds.h5\n",
"Total runtime: 12.4 mins\n",
"Total performance: 483.585 FPS\n"
]
}
],
"source": [
"t0_all = time()\n",
"\n",
"# Load model and convert to peak-coordinate output\n",
"model = convert_to_peak_outputs(keras.models.load_model(model_path))\n",
"print(\"Model:\", model_path)\n",
"print(\" Input:\", str(model.input_shape))\n",
"print(\" Output:\", str(model.output_shape))\n",
"\n",
"# model = keras.utils.multi_gpu_model(model, gpus=2)\n",
"\n",
"# Open video for reading\n",
"reader = cv2.VideoCapture(video_path)\n",
"num_samples = int(reader.get(cv2.CAP_PROP_FRAME_COUNT))\n",
"\n",
"# Initialize\n",
"positions_pred = []\n",
"conf_pred = []\n",
"buffer = []\n",
"samples_predicted = 0\n",
"reading_runtime = 0\n",
"prediction_runtime = 0\n",
"done = False\n",
"\n",
"# Process video chunk-by-chunk\n",
"while not done:\n",
" t0_reading = time()\n",
" # Read and finish if no frame was retrieved\n",
" returned_frame, I = reader.read()\n",
" done = not returned_frame\n",
" reading_runtime += time() - t0_reading\n",
" \n",
" # Add current frame to buffer\n",
" if not done:\n",
" buffer.append(I[...,0])\n",
" \n",
" # Do we have anything to predict?\n",
" if len(buffer) >= chunk_size or (done and len(buffer) > 0):\n",
" t0_prediction = time()\n",
" \n",
" # Predict on buffer\n",
" Y = model.predict(np.stack(buffer, axis=0)[...,None], batch_size=batch_size)\n",
" \n",
" # Save\n",
" positions_pred.append(Y[:,:2,:].astype(\"int32\"))\n",
" conf_pred.append(Y[:,2,:].squeeze())\n",
" \n",
" # Empty out buffer container\n",
" buffer = []\n",
" \n",
" # Performance stats\n",
" samples_predicted += len(Y)\n",
" prediction_runtime += time() - t0_prediction\n",
" elapsed = time() - t0_all\n",
" fps = samples_predicted / elapsed\n",
" print(\"Predicted: %d/%d frames | Elapsed: %.1f min / %.1f FPS / ETA: %.1f min\" %\n",
" (samples_predicted, num_samples, elapsed / 60, fps, (num_samples - samples_predicted) / fps / 60))\n",
" \n",
"# Close video reader\n",
"reader.release()\n",
"\n",
"# Merge arrays\n",
"positions_pred = np.concatenate(positions_pred, axis=0)\n",
"conf_pred = np.concatenate(conf_pred, axis=0)\n",
"\n",
"# Report performance stats\n",
"print(\"Finished predicting %d frames.\" % samples_predicted)\n",
"print(\" Prediction | Runtime: %.2f min / %.3f FPS\" % (prediction_runtime / 60, samples_predicted / prediction_runtime))\n",
"print(\" Reading | Runtime: %.2f min / %.3f FPS\" % (reading_runtime / 60, samples_predicted / reading_runtime))\n",
"\n",
"# Save\n",
"if os.path.exists(save_path):\n",
" os.remove(save_path)\n",
"with h5py.File(save_path, \"w\") as f:\n",
" f.attrs[\"num_samples\"] = num_samples\n",
" f.attrs[\"video_path\"] = video_path\n",
" f.attrs[\"model_path\"] = model_path\n",
"\n",
" ds_pos = f.create_dataset(\"positions_pred\", data=positions_pred, compression=\"gzip\", compression_opts=1)\n",
" ds_pos.attrs[\"description\"] = \"coordinate of peak at each sample\"\n",
" ds_pos.attrs[\"dims\"] = \"(sample, [x, y], joint) === (sample, [column, row], joint)\"\n",
"\n",
" ds_conf = f.create_dataset(\"conf_pred\", data=conf_pred, compression=\"gzip\", compression_opts=1)\n",
" ds_conf.attrs[\"description\"] = \"confidence map value in [0, 1.0] at peak\"\n",
" ds_conf.attrs[\"dims\"] = \"(sample, joint)\"\n",
"\n",
" total_runtime = time() - t0_all\n",
" f.attrs[\"reading_runtime_secs\"] = reading_runtime\n",
" f.attrs[\"prediction_runtime_secs\"] = prediction_runtime\n",
" f.attrs[\"total_runtime_secs\"] = total_runtime\n",
" \n",
" \n",
"print(\"Saved:\", save_path)\n",
"\n",
"print(\"Total runtime: %.1f mins\" % (total_runtime / 60))\n",
"print(\"Total performance: %.3f FPS\" % (samples_predicted / total_runtime))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"toc_cell": false,
"toc_position": {},
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: examples/hdf5tovid.m
================================================
% Clean start!
clear all, clc
%% Parameters
% Path to input file
dataPath = '../data/examples/072212_163153.clip.h5';
% Dataset name
dset = '/box';
% Path to output file
savePath = 'C:\tmp\072212_163153.clip.mp4';
% Frames to convert at a time (lower this if your memory is limited)
chunkSize = 1000;
% Framerate for playback of the video file
fps = 25;
%% Initialize
% Get dataset info
info = h5info(dataPath, dset);
shape = info.Dataspace.Size;
numFrames = shape(end);
% Check if file already exists
if exist(savePath,'file') > 0
warning(['Overwriting existing video file: ' savePath])
delete(savePath)
end
% Open video for writing
writer = VideoWriter(savePath,'MPEG-4'); % use this for MP4s
% writer = VideoWriter(savePath,'Motion JPEG AVI'); % use this for AVIs
% Set compression quality (higher = bigger file, better quality)
writer.Quality = 100;
% Set playback speed in frames/second
writer.FrameRate = fps;
% Open file for writing
writer.open();
%% Save
framesWritten = 0;
done = false;
t0 = tic;
while ~done
% Check how many frames to read
chunkFrames = min(chunkSize, numFrames - framesWritten);
% Read chunk
chunk = h5read(dataPath,dset,[1 1 1 framesWritten+1], [inf inf inf chunkFrames]);
% Check for datatype/range concordance (floats must be in [0,1])
if isfloat(chunk) && max(chunk(:)) > 1
chunk = chunk / 255;
end
% Write frames
writer.writeVideo(chunk);
% Increment frames written counter
framesWritten = framesWritten + size(chunk,4);
% Check if we're done
done = framesWritten >= numFrames;
end
elapsed = toc(t0);
fprintf('Finished writing %d frames in %.2f mins:\n\t%s\n', framesWritten, elapsed/60, savePath)
% Close file
writer.close();
================================================
FILE: examples/vidtohdf5.m
================================================
% Clean start!
clear all, clc
%% Parameters
% Path to input file
videoPath = '..\..\leap\data\examples\072212_163153.clip.mp4';
% Path to output file
% savePath = '..\..\leap\data\examples\072212_163153.clip.h5';
savePath = 'C:\tmp\072212_163153.clip.h5';
% Frames to convert at a time (lower this if your memory is limited)
chunkSize = 1000;
% Convert frames to single channel grayscale images (instead of 3 channel RGB)
grayscale = true;
%% Initialize
% Open video for reading
vr = VideoReader(videoPath);
% Check size from first frame
I0 = vr.readFrame();
if grayscale; I0 = rgb2gray(I0); end
frameSize = size(I0);
if numel(frameSize) == 2; frameSize = [frameSize 1]; end
% Reset VideoReader
delete(vr);
vr = VideoReader(videoPath);
% Check if file already exists
if exist(savePath,'file') > 0
warning(['Overwriting existing HDF5 file: ' savePath])
delete(savePath)
end
% Create HDF5 file with infinite number of frames and GZIP compression
h5create(savePath,'/box',[frameSize inf],'ChunkSize',[frameSize 1],'Deflate',1,'Datatype','uint8')
%% Save
buffer = cell(chunkSize,1);
done = false;
framesRead = 0;
framesWritten = 0;
t0 = tic;
while ~done
% Read next frame
I = vr.readFrame();
if grayscale; I = rgb2gray(I); end
% Check if there are any frames left
done = ~vr.hasFrame();
% Increment frames read counter and add to the write buffer
framesRead = framesRead + 1;
buffer{mod(framesRead-1, chunkSize)+1} = I;
% Have we filled the buffer or are there no frames left?
if mod(framesRead, chunkSize) == 0 || done
% Concatenate the buffer into an array
chunk = cat(4, buffer{:});
% Extend the dataset and save to disk
h5write(savePath, '/box', chunk, [1 1 1 framesWritten+1], size(chunk))
% Increment frames written counter
framesWritten = framesWritten + size(chunk,4);
end
end
elapsed = toc(t0);
fprintf('Finished writing %d frames in %.2f mins.\n', framesWritten, elapsed/60)
h5disp(savePath)
================================================
FILE: install_leap.m
================================================
function install_leap()
%INSTALL_LEAP Installs the Python package and adds MATLAB scripts to path.
% Usage:
% install_leap
%
% See also: uninstall_leap, test_leap
% Find base repository path (where this file is contained)
basePath = fileparts(which('install_leap'));
% Add to MATLAB path
addpath(genpath(fullfile(basePath,'leap')));
% Check if Python package is importable
canImportPython = test_leap();
if ~canImportPython
[status,msg] = system(['pip install -e "' basePath '"']);
disp(msg)
end
% Check again
test_leap;
end
================================================
FILE: leap/__init__.py
================================================
from . import image_augmentation
from . import layers
from . import models
from . import predict_box
from . import training
from . import utils
from . import viz
================================================
FILE: leap/compute_errors.m
================================================
function err = compute_errors(pos_pred, pos_gt)
%COMPUTE_ERRORS Computes error rates given predicted and ground truth positions.
% Usage:
% err = compute_errors(pos_pred, pos_gt)
%
% Args:
% pos_pred: predicted positions (J x 2 x N)
% pos_gt: ground truth predictions (J x 2 x N)
%
% Returns:
% err: struct with error metrics
%
% See also:
if isstruct(pos_pred)
pos_pred = pos_pred.positions_pred;
end
if isstruct(pos_gt)
pos_gt = pos_gt.positions_pred;
end
% Find the difference between predicted and ground truth
delta = pos_pred - pos_gt;
% Find Euclidean distance between each pair of points
euclidean = squeeze(sqrt(sum(delta .^ 2, 2)))';
% Compute metrics overall
mae_all = mean(abs(delta(:)));
mse_all = mean(delta(:) .^ 2);
rmse_all = sqrt(mse_all);
% Compute metrics per joint
delta_rows = reshape(permute(delta,[2 3 1]),[],size(delta,1));
mae = mean(abs(delta_rows));
mse = mean(delta_rows .^ 2);
rmse = sqrt(mse);
% Return everything
err = varstruct(delta, euclidean, ...
mae_all, mse_all, rmse_all, ...
delta_rows, mae, mse, rmse);
end
================================================
FILE: leap/confmaps2pts.m
================================================
function [pts, confvals] = confmaps2pts(C)
%CONFMAPS2PTS Convert a set of confidence maps into a set of points.
% Usage:
% [pts, confvals] = confmaps2pts(C)
%
% See also: pts2confmaps
numChannels = size(C,3);
pts = zeros(numChannels,2,'single');
confvals = zeros(numChannels,1,'like',C);
for i = 1:numChannels
[confvals(i), ind] = max(vert(C(:,:,i)));
[r,c] = ind2sub(size(C(:,:,i)),ind);
pts(i,:) = [c r];
end
end
================================================
FILE: leap/generate_training_set.m
================================================
function savePath = generate_training_set(boxPath, varargin)
%GENERATE_TRAINING_SET Creates a dataset for training.
% Usage: generate_training_set(boxPath, ...)
t0_all = stic;
%% Setup
defaults = struct();
defaults.savePath = [];
defaults.scale = 1;
defaults.mirroring = true; % flip images and adjust confidence maps to augment dataset
defaults.horizontalOrientation = true; % animal is facing right/left if true (for mirroring)
defaults.sigma = 5; % kernel size for confidence maps
defaults.normalizeConfmaps = true; % scale maps to [0,1] range
defaults.postShuffle = true; % shuffle data before saving (useful for reproducible dataset order)
defaults.testFraction = 0; % separate these data from training and validation sets
defaults.compress = false; % use GZIP compression to save the outputs
params = parse_params(varargin,defaults);
% Paths
labelsPath = repext(boxPath,'labels.mat');
% Output
savePath = params.savePath;
if isempty(savePath)
savePath = ff(fileparts(boxPath), 'training', [get_filename(boxPath,true) '.h5']);
savePath = get_new_filename(savePath,true);
end
mkdirto(savePath)
%% Labels
labels = load(labelsPath);
% Check for complete frames
labeledIdx = find(squeeze(all(all(~isnan(labels.positions),2),1)));
numFrames = numel(labeledIdx);
printf('Found %d/%d labeled frames.', numFrames, size(labels.positions,3))
% Pull out label data
joints = labels.positions(:,:,labeledIdx);
joints = joints * params.scale;
numJoints = size(joints,1);
% Pull out other info
jointNames = labels.skeleton.nodes;
skeleton = struct();
skeleton.edges = labels.skeleton.edges;
skeleton.pos = labels.skeleton.pos;
%% Load images
stic;
box = h5readframes(boxPath,'/box',labeledIdx);
if params.scale ~= 1; box = imresize(box,params.scale); end
boxSize = size(box(:,:,:,1));
stocf('Loaded %d images', size(box,4))
% Load metadata
try exptID = h5read(boxPath, '/exptID'); exptID = exptID(labeledIdx); catch; end
try framesIdx = h5read(boxPath, '/framesIdx'); framesIdx = framesIdx(labeledIdx); catch; end
try idxs = h5read(boxPath, '/idxs'); idxs = idxs(labeledIdx); catch; end
try L = h5read(boxPath, '/L'); L = L(labeledIdx); catch; end
try box_no_seg = imresize(h5readframes(boxPath,'/box_no_seg',labeledIdx),params.scale); catch; end
try box_raw = imresize(h5readframes(boxPath,'/box_raw',labeledIdx),params.scale); catch; end
attrs = h5att2struct(boxPath);
%% Generate confidence maps
stic;
confmaps = NaN([boxSize(1:2), numJoints, numFrames],'single');
parfor i = 1:numFrames
pts = joints(:,:,i);
confmaps(:,:,:,i) = pts2confmaps(pts,boxSize(1:2),params.sigma,params.normalizeConfmaps);
end
stocf('Generated confidence maps') % 15 sec for 192x192x32x500
varsize(confmaps)
%% Augment by mirroring
if params.mirroring
% Flip images
if params.horizontalOrientation
box_flip = flipud(box);
try box_no_seg_flip = flipud(box_no_seg); catch; end
try box_raw_flip = flipud(box_raw); catch; end
confmaps_flip = flipud(confmaps);
joints_flip = joints; joints_flip(:,2,:) = size(box,1) - joints_flip(:,2,:);
else
box_flip = fliplr(box);
try box_no_seg_flip = fliplr(box_no_seg); catch; end
try box_raw_flip = fliplr(box_raw); catch; end
confmaps_flip = fliplr(confmaps);
joints_flip = joints; joints_flip(:,1,:) = size(box,2) - joints_flip(:,1,:);
end
% Check for *L/*R naming pattern (e.g., {{'wingL','wingR'}, {'legR1','legL1'}})
swap_names = {};
baseNames = regexp(jointNames,'(.*)L([0-9]*)$','tokens');
isSymmetric = ~cellfun(@isempty,baseNames);
for i = horz(find(isSymmetric))
nameR = [baseNames{i}{1}{1} 'R' baseNames{i}{1}{2}];
if ismember(nameR,jointNames)
swap_names{end+1} = {jointNames{i}, nameR};
end
end
% Swap channels accordingly
printf('Symmetric channels:')
for i = 1:numel(swap_names)
[~,swap_idx] = ismember(swap_names{i}, jointNames);
if any(swap_idx == 0); continue; end
printf(' %s (%d) <-> %s (%d)', jointNames{swap_idx(1)}, swap_idx(1), ...
jointNames{swap_idx(2)}, swap_idx(2))
joints_flip(swap_idx,:,:) = joints_flip(fliplr(horz(swap_idx)),:,:);
confmaps_flip(:,:,swap_idx,:) = confmaps_flip(:,:,fliplr(horz(swap_idx)),:);
end
% Merge
[box,flipped] = cellcat({box,box_flip},4);
joints = cat(3, joints, joints_flip);
try box_raw = cat(4,box_raw,box_raw_flip); catch; end
try box_no_seg = cat(4,box_no_seg,box_no_seg_flip); catch; end
confmaps = cat(4, confmaps, confmaps_flip);
labeledIdx = [labeledIdx(:); labeledIdx(:)];
try exptID = [exptID(:); exptID(:)]; catch; end
try framesIdx = [framesIdx(:); framesIdx(:)]; catch; end
try idxs = [idxs(:); idxs(:)]; catch; end
% Update frame count
numFrames = size(box,4);
end
%% Post-shuffle
shuffleIdx = vert(1:numFrames);
if params.postShuffle
shuffleIdx = randperm(numFrames);
box = box(:,:,:,shuffleIdx);
labeledIdx = labeledIdx(shuffleIdx);
try box_no_seg = box_no_seg(:,:,:,shuffleIdx); catch; end
try box_raw = box_raw(:,:,:,shuffleIdx); catch; end
try exptID = exptID(shuffleIdx); catch; end
try framesIdx = framesIdx(shuffleIdx); catch; end
joints = joints(:,:,shuffleIdx);
confmaps = confmaps(:,:,:,shuffleIdx);
end
%% Separate testing set
numTestFrames = round(numel(shuffleIdx) * params.testFraction);
if numTestFrames > 0
testIdx = randperm(numel(shuffleIdx),numTestFrames);
trainIdx = setdiff(shuffleIdx, testIdx);
% Test set
testing = struct();
testing.shuffleIdx = shuffleIdx(testIdx);
testing.box = box(:,:,:,testIdx);
testing.labeledIdx = labeledIdx(testIdx);
try testing.box_no_seg = box_no_seg(:,:,:,testIdx); catch; end
try testing.box_raw = box_raw(:,:,:,testIdx); catch; end
try testing.exptID = exptID(testIdx); catch; end
try testing.framesIdx = framesIdx(testIdx); catch; end
testing.joints = joints(:,:,testIdx);
testing.confmaps = confmaps(:,:,:,testIdx);
testing.testIdx = testIdx;
% Training set
shuffleIdx = shuffleIdx(trainIdx);
box = box(:,:,:,trainIdx);
labeledIdx = labeledIdx(trainIdx);
try box_no_seg = box_no_seg(:,:,:,trainIdx); catch; end
try box_raw = box_raw(:,:,:,trainIdx); catch; end
try exptID = exptID(trainIdx); catch; end
try framesIdx = framesIdx(trainIdx); catch; end
joints = joints(:,:,trainIdx);
confmaps = confmaps(:,:,:,trainIdx);
end
%% Save
% Augment metadata
attrs.createdOn = datestr(now);
attrs.boxPath = boxPath;
attrs.labelsPath = labelsPath;
attrs.scale = params.scale;
attrs.postShuffle = uint8(params.postShuffle);
attrs.horizontalOrientation = uint8(params.horizontalOrientation);
% Write
stic;
if exists(savePath); delete(savePath); end
% Training data
h5save(savePath,box,[],'compress',params.compress)
h5save(savePath,labeledIdx)
h5save(savePath,shuffleIdx)
try h5save(savePath,box_no_seg,[],'compress',params.compress); catch; end
try h5save(savePath,box_raw,[],'compress',params.compress); catch; end
try h5save(savePath,exptID); catch; end
try h5save(savePath,framesIdx); catch; end
h5save(savePath,joints,[],'compress',params.compress)
h5save(savePath,confmaps,[],'compress',params.compress)
% Testing data
if numTestFrames > 0
h5save(savePath,trainIdx)
h5savegroup(savePath,testing,'/testing','compress',params.compress)
end
% Metadata
h5writeatt(savePath,'/confmaps','sigma',params.sigma)
h5writeatt(savePath,'/confmaps','normalize',uint8(params.normalizeConfmaps))
h5struct2att(savePath,'/',attrs)
h5savegroup(savePath,skeleton,'/skeleton')
h5writeatt(savePath,'/skeleton','jointNames',strjoin(jointNames,'\n'))
stocf('Saved:\n%s', savePath)
get_filesize(savePath)
stocf(t0_all, 'Finished generating training set.');
end
================================================
FILE: leap/graph2paf.m
================================================
function paf = graph2paf(nodes, edges, sz, channelsOnly, sigma)
%GRAPH2PAF Converts a set of edges into part affinity fields.
% Usage:
% graph2paf(nodes, edges, sz, sigma)
%
% Args:
% nodes: set of points (N x 2)
% edges: indices into nodes defining directed edges (E x 2)
% sz: grid/image size (1 x 2)
% channelsOnly: stack all PAFs along channels (dim 3) instead of dim 4 (default: true)
% sigma: maximum distance from edge to keep (default: 5)
%
% Returns:
% paf: part affinity fields (sz(1) x sz(2) x 2E) or (sz(1) x sz(2) x 2 x E)
%
% See also: pts2confmaps
if nargin < 4 || isempty(channelsOnly); channelsOnly = true; end
if nargin < 5 || isempty(sigma); sigma = 5; end
% Create image coordinate grid
[XX,YY] = meshgrid(1:sz(2), 1:sz(1));
% Create PAFs for each edge
E = size(edges,1);
paf = cell(E,1);
for i = 1:E
% Pull out edge points
src = nodes(edges(i,2),:);
dst = nodes(edges(i,1),:);
% Edge length
L = norm(dst - src, 2);
% Unit vectors
V = (dst - src) ./ L; % pointing along edge
Vp = [-V(:,2), V(:,1)]; % perpendicular
% Signed distance along edge
D1 = sum(V .* ([XX(:) YY(:)] - src),2);
% Absolute distance orthogonal to edge
D2 = abs(sum(Vp .* ([XX(:) YY(:)] - src),2));
% Vector field mask
paf_mask = reshape(D1 >= 0 & D1 <= L & D2 <= sigma, sz);
% Create vector field along channels (X and Y)
paf{i} = paf_mask .* permute(V, [1 3 2]);
end
% Merge all edge PAFs
if channelsOnly
paf = cat(3, paf{:});
else
paf = cat(4, paf{:});
end
end
================================================
FILE: leap/guis/label_joints.m
================================================
function label_joints(boxPath, skeletonPath)
%LABEL_JOINTS GUI to click on images to yield a graph.
% Usage:
% label_joints(boxPath)
% label_joints(boxPath, skeletonPath)
%
% See also: make_template_skeleton
%% Startup
% addpath(genpath('deps'))
% Ask for path to data file
if nargin < 1 || isempty(boxPath); boxPath = uibrowse('*.h5',[],'Select box HDF5 file'); end
% Params
if nargin < 2; skeletonPath = []; end
recreate_labels = nargin > 1; % force recreate labels file
% Settings (saved in *.labels.mat file)
global config;
config = struct();
config.dsetName = '/box';
config.nodeSize = 10; % size of draggable markers
config.defaultNodeColor = [1 0 0]; % default color of movable nodes
config.initializedNodeColor = [1 1 0]; % color of initialized nodes
config.labeledNodeColor = [0 1 0]; % color of movable nodes with user input
config.initialFrame = 1; % first frame displayed
config.shuffleFrames = false; % shuffle frame order
config.autoSave = true; % save before going to a new frame
config.clickNearest = false; % true = click moves nearest node; false = selected node
config.draggable = true; % false = cannot drag joint markers
config.altArrowsToMoveNodes = true; % false = arrow keys move nodes, alt+arrows changes frames
config.zoomBoxFrames = [-250, 250]; % number of frames in the status zoomed in box (pre, post)
config.imgFigPos = [835 341 709 709]; % main labeling figure window
config.ctrlFigPos = [1545 342 374 708]; % control/reference window
config.statusFigPos = [836 33 1081 277]; % status bars and settings window
%%
% Initialize labeling session
box = [];
numNodes = [];
numFrames = [];
numLabeled = [];
global labels;
% Loads or creates *.labels.mat and populates config
initializeLabels();
% Pre-shuffle frames for shuffle mode
shuffleIdx = randperm(numFrames);
% Set status colormap colors
statusCmap = {
config.defaultNodeColor
config.initializedNodeColor
config.labeledNodeColor
};
for k = 1:numel(statusCmap)
if ischar(statusCmap{k}); statusCmap{k} = colorCode2rgb(statusCmap{k}); end
end
statusCmap = cellcat(statusCmap,1);
% Zoom box convenience (compute window)
if isscalar(config.zoomBoxFrames); config.zoomBoxFrames = round([-0.5 0.5] .* config.zoomBoxFrames); end
zoomBoxWindow = config.zoomBoxFrames(1):config.zoomBoxFrames(2);
function initializeLabels()
labels = struct();
% Metadata
labels.boxPath = boxPath;
labels.savePath = repext(boxPath, '.labels.mat');
% Ask for path to skeleton file
if isequal(skeletonPath,true) || ~exists(labels.savePath)
skeletonPath = uibrowse('*.mat',[],'Select skeleton MAT file');
end
% Open box file
box = h5file(boxPath, config.dsetName);
numFrames = size(box,4);
stic;
if ~exists(labels.savePath) || recreate_labels
% Load template skeleton
labels.skeletonPath = skeletonPath;
labels.skeleton = load(skeletonPath);
% Initialize custom defaults container
labels.initialization = NaN(numel(labels.skeleton.nodes), 2, numFrames, 'single');
% Try using initialization built into the HDF5 file
try
labels.initialization = h5read(boxPath, '/initialization');
labels.initialization_metadata = h5att2struct(boxPath, '/initialization');
printf('Using pre-initialized joint predictions.')
catch
end
% Initialize user labels
labels.positions = NaN(numel(labels.skeleton.nodes), 2, numFrames, 'single');
% Settings
labels.config = config;
% Timestamps
labels.createdOn = datestr(now);
labels.lastModified = datestr(now);
% Initialize history
labels.session = 1;
addToHistory("Created labels file.");
% Create labels file
save(labels.savePath, '-struct', 'labels', '-v7.3')
stocf('Created labels file: %s', labels.savePath)
else
% Load
labels = load(labels.savePath);
% Update paths
labels.boxPath = boxPath;
labels.savePath = repext(boxPath, '.labels.mat');
% Update config
if isfield(labels,'config')
config = parse_params(labels.config,config);
else
labels.config = config;
end
if ~isfield(labels,'session')
labels.session = 1;
else
labels.session = labels.session + 1;
end
stocf('Loaded existing labels file: %s', labels.savePath)
end
addToHistory('Started session.')
% Convenience
numNodes = numel(labels.skeleton.nodes);
end
function addToHistory(message)
% Utility for adding a timestamped message to the history log
session = labels.session;
timestamp = datetime();
message = string(message);
historyItem = table(session, timestamp, message);
disp(historyItem)
if ~isfield(labels,'history') || isempty(labels.history)
labels.history = historyItem;
else
labels.history = [labels.history; historyItem];
end
end
%% GUI
% Build GUI
global ui;
initializeGUI();
function initializeGUI()
ui = struct();
% %%%% Controls figure %%%%
ui.ctrl = struct();
ui.ctrl.fig = figure('NumberTitle','off','MenuBar','none', ...
'Name','LEAP Label GUI', 'WindowKeyPressFcn', @keyPress, 'DeleteFcn', @quit, ...
'Position', config.ctrlFigPos);
ui.ctrl.hbox = uix.HBox('Parent', ui.ctrl.fig);
% Joints panel
ui.ctrl.jointsPanel = uix.Panel('Parent',ui.ctrl.hbox, 'Title', 'Joints', 'Padding',5);
ui.ctrl.jointsList = uicontrol(ui.ctrl.jointsPanel, 'Style', 'listbox', 'String', labels.skeleton.joints.name, ...
'Callback',@(h,~,~)selectNode(h.Value));
% Reference image
ui.ctrl.refPanel = uix.Panel('Parent',ui.ctrl.hbox, 'Title', 'Reference', 'Padding',5);
ui.ctrl.refAx = axes(uicontainer('Parent',ui.ctrl.refPanel));
ui.ctrl.refImg = imagesc(labels.skeleton.refI);
% Style
ui.ctrl.refAx.Units = 'normalized';
ui.ctrl.refAx.Position = [0 0 1 1];
axis(ui.ctrl.refAx,'equal','tight','ij')
colormap(ui.ctrl.refAx,'gray')
noticks(ui.ctrl.refAx)
hold(ui.ctrl.refAx,'on')
% Plot reference skeleton
for i = 1:size(labels.skeleton.segments,1)
% Find default position of each nodes in the segment
pos = labels.skeleton.pos(labels.skeleton.segments.joints_idx{i},:);
% Plot
plot(ui.ctrl.refAx, pos(:,1), pos(:,2), '.-', 'Color',labels.skeleton.segments.color{i}, 'LineWidth', 1);
end
% Draw each joint node
ui.ctrl.refNodes = gobjects(height(labels.skeleton.joints),1);
for i = 1:numel(ui.ctrl.refNodes)
pos = labels.skeleton.joints.pos(i,:);
ui.ctrl.refNodes(i) = plot(ui.ctrl.refAx, pos(1),pos(2),'o', 'Color','r');
end
% Set box widths
ui.ctrl.hbox.Widths = [-1 -3];
%%%%
% %%%% Image figure %%%%
ui.img = struct();
ui.img.fig = figure('NumberTitle','off','MenuBar','none','ToolBar','none', ...
'Name',sprintf('Frame %d/%d', config.initialFrame, numFrames), 'WindowKeyPressFcn', @keyPress, 'DeleteFcn', @quit, ...
'Position', config.imgFigPos);
ui.img.ax = axes(ui.img.fig);
ui.img.img = imagesc(ui.img.ax, box(:,:,:,1));
ui.img.img.ButtonDownFcn = @(~,~) clickImage();
% Full figure image axes
ui.img.ax.Units = 'normalized';
ui.img.ax.Position = [0 0 1 1];
% Style
axis(ui.img.ax,'equal','tight','ij')
colormap(ui.img.ax,'gray')
noticks(ui.img.ax)
hold(ui.img.ax,'on')
% Initialize skeleton drawing container
ui.skel = struct();
ui.skel.segs = [];
ui.skel.nodes = [];
%%%%
% %%%% Status figure %%%%
% Initialize status container
ui.status = struct();
ui.status.selectedNode = [];
ui.status.movedNodes = false(numNodes,1);
ui.status.currentFrame = config.initialFrame;
ui.status.unsavedChanges = false(numFrames,1);
ui.status.initialPos = [];
% Get full status indicators for all frames
status = getStatus();
numInitialized = sum(all(status == 1,1));
numLabeled = sum(all(status == 2,1));
% Create figure window
ui.status.fig = figure('NumberTitle','off','MenuBar','none','ToolBar','none', ...
'Name',sprintf('Status: %d/%d (%.2f%%) labeled', numLabeled, numFrames, numLabeled/numFrames*100), ...
'WindowKeyPressFcn', @keyPress, 'DeleteFcn', @quit, ...
'Position', config.statusFigPos);
ui.status.hbox = uix.HBox('Parent', ui.status.fig, 'Padding',3);
% Status panel (left)
ui.status.statusPanel = uix.Panel('Parent',ui.status.hbox, 'Title','Status', 'Padding',5);
ui.status.statusBoxes = uix.VBox('Parent', ui.status.statusPanel);
% Status text
ui.status.stats = uix.VBox('Parent',ui.status.statusBoxes);
ui.status.framesInitialized = uicontrol(ui.status.stats,'Style','text','HorizontalAlignment','left',...
'String',sprintf('Initialized: %d/%d (%.3f%%)', numInitialized, numFrames, numInitialized/numFrames*100));
ui.status.framesLabeled = uicontrol(ui.status.stats,'Style','text','HorizontalAlignment','left',...
'String',sprintf('Labeled: %d/%d (%.3f%%)', numLabeled, numFrames, numLabeled/numFrames*100));
ui.status.stats.Heights = ones(1, numel(ui.status.stats.Children)) * 15;
% Status bars
ui.status.fullAx = axes(uicontainer('Parent',ui.status.statusBoxes));
ui.status.fullImg = imagesc(ui.status.fullAx, 1:numFrames, 1:numNodes, status, 'ButtonDownFcn', @clickStatusbar);
axis(ui.status.fullAx,'tight','ij')
hold(ui.status.fullAx,'on');
zoomBoxIdx = zoomBoxWindow + ui.status.currentFrame;
zoomBoxPts = [
zoomBoxIdx(1) 0
zoomBoxIdx(end) 0
zoomBoxIdx(end) numNodes
zoomBoxIdx(1) numNodes
zoomBoxIdx(1) 0
];
ui.status.fullZoomBox = patch(ui.status.fullAx, zoomBoxPts(:,1),zoomBoxPts(:,2),'w','PickableParts','none');
ui.status.fullZoomBox.FaceAlpha = 0.25;
ui.status.fullZoomBox.EdgeColor = 'w';
colormap(ui.status.fullAx, statusCmap)
caxis(ui.status.fullAx,[0 2])
ui.status.fullAx.XLim = [-0.5 0.5] + [1 numFrames];
ui.status.fullAx.YLim = [-0.5 0.5] + [1 numNodes];
% ui.status.fullAx.YTick = 1:numNodes;
% ui.status.fullAx.YTickLabel = labels.skeleton.nodes;
% ui.status.fullAx.YAxis.TickLabelInterpreter = 'none';
% Status bars (zoomed)
ui.status.zoomAx = axes(uicontainer('Parent',ui.status.statusBoxes));
ui.status.zoomImg = imagesc(ui.status.zoomAx, zoomBoxIdx, 1:numNodes, zeros(numNodes,numel(zoomBoxIdx)), 'ButtonDownFcn', @clickStatusbar);
axis(ui.status.zoomAx,'tight','ij')
colormap(ui.status.zoomAx, statusCmap)
caxis(ui.status.zoomAx,[0 2])
ui.status.zoomAx.YLim = [-0.5 0.5] + [1 numNodes];
% ui.status.zoomAx.YTick = 1:numNodes;
% ui.status.zoomAx.YTickLabel = labels.skeleton.nodes;
% ui.status.zoomAx.YAxis.TickLabelInterpreter = 'none';
% Set UI heights
ui.status.statusBoxes.Heights = [sum(ui.status.stats.Heights)+5 -1 -1];
% Settings panel (right)
ui.status.configPanel = uix.Panel('Parent',ui.status.hbox, 'Title','Settings','Padding',5);
ui.status.configButtons = uix.VBox('Parent',ui.status.configPanel);
% Auto-save
uicontrol(ui.status.configButtons,'Style','checkbox','Value',config.autoSave, ...
'Callback',@(h,~)setConfig('autoSave',h.Value), ...
'String','Autosave labels','TooltipString','Automatically saves changes to disk when changing frames or exiting.');
% Shuffle frame order
uicontrol(ui.status.configButtons,'Style','checkbox','Value',config.shuffleFrames, ...
'Callback',@(h,~)setConfig('shuffleFrames',h.Value), ...
'String','Shuffle frame order','TooltipString','Shuffled order is fixed within this session. Uncheck to use file ordering.');
% Click nearest
uicontrol(ui.status.configButtons,'Style','checkbox','Value',config.clickNearest, ...
'Callback',@(h,~)setConfig('clickNearest',h.Value), ...
'String','Click to move nearest joint','TooltipString','If unchecked, clicking on the image moves the currently selected joint.');
% Draggable markers
uicontrol(ui.status.configButtons,'Style','checkbox','Value',config.draggable, ...
'Callback', @(h,~)toggleDraggableMarkers(h.Value), ...
'String','Draggable markers','TooltipString','If unchecked, joint markers can only be moved by clicking or keyboard.');
% Alt + arrows to move nodes
uicontrol(ui.status.configButtons,'Style','checkbox','Value',config.altArrowsToMoveNodes, ...
'Callback', @(h,~)setConfig('altArrowsToMoveNodes',h.Value), ...
'String','Alt + arrow keys move markers','TooltipString','If unchecked, move markers with Alt + arrow keys, change frames with arrow keys.');
% Export confidence maps
uicontrol(ui.status.configButtons,'Style','pushbutton', ...
'Callback', @(h,~)generateTrainingSet(), ...
'String','Generate training set','TooltipString','Creates a test set with confidence maps for training a network.');
% Fast training
uicontrol(ui.status.configButtons,'Style','pushbutton', ...
'Callback', @(h,~)fastTrain(), ...
'String','Fast train network','TooltipString','Trains a network for initialization using fast presets.');
% Initialization from predictions
uicontrol(ui.status.configButtons,'Style','pushbutton', ...
'Callback', @(h,~)predictInitializations(), ...
'String','Initialize with trained model','TooltipString','Generates predictions for all frames and uses it as initialization.');
% Set UI sizes
ui.status.configButtons.Heights = ones(1, numel(ui.status.configButtons.Children)) * 25;
ui.status.hbox.Widths = [-1 175];
% Give focus back to main image window
figure(ui.img.fig);
end
function toggleDraggableMarkers(TF)
% Sets whether joint markers are draggable using the mouse
if TF
set(ui.skel.nodes,'PickableParts','visible');
draggable(ui.skel.nodes, @nodesMoved, 'endFcn', @nodesMoveEnd);
else
draggable(ui.skel.nodes, 'off');
set(ui.skel.nodes,'PickableParts','none');
end
setConfig('draggable',TF);
end
function setConfig(configField, value)
% Helper to set config fields to specified value
config.(configField) = value;
end
function quit(h,~)
% Quit callback to close all windows simultaneously
% Log to history
addToHistory("Finished session.")
% Save
if config.autoSave && isequal(h, ui.img.fig)
saveLabels();
end
% Delete figs
delete(ui.img.fig)
delete(ui.ctrl.fig)
delete(ui.status.fig)
end
function keyPress(~,evt)
% Hotkeys
% exclusive modifier flags:
noModifier = isempty(evt.Modifier);
shiftOnly = isequal(evt.Modifier, {'shift'});
ctrlOnly = isequal(evt.Modifier, {'control'});
altOnly = isequal(evt.Modifier, {'alt'});
% non-exclusive:
altPressed = ismember({'alt'}, evt.Modifier);
ctrlPressed = ismember({'control'}, evt.Modifier);
shiftPressed = ismember({'shift'}, evt.Modifier);
switch evt.Key
case 'q'
delete(ui.img.fig)
case 's'
saveLabels()
case 'r'
if noModifier % current node
resetNodes(ui.status.selectedNode);
elseif shiftOnly % all nodes
resetNodes();
end
case 'd'
if noModifier % current node
setNodesToDefault(ui.status.selectedNode);
elseif shiftOnly % all nodes
setNodesToDefault();
end
case 'tab'
if noModifier
selectNode(mod(ui.status.selectedNode-1+1, numNodes) + 1);
elseif shiftOnly
selectNode(mod(ui.status.selectedNode-1-1, numNodes) + 1);
end
case 'downarrow'
dXY = [0 1];
if (config.altArrowsToMoveNodes && altPressed) || ~config.altArrowsToMoveNodes
if noModifier
nudgeNode(dXY)
elseif shiftOnly
nudgeNode(dXY * 5)
elseif ctrlOnly
nudgeSegment(dXY)
end
end
case 'uparrow'
dXY = [0 -1];
if (config.altArrowsToMoveNodes && altPressed) || ~config.altArrowsToMoveNodes
if noModifier
nudgeNode(dXY)
elseif shiftOnly
nudgeNode(dXY * 5)
elseif ctrlOnly
nudgeSegment(dXY)
end
end
case 'leftarrow'
if (config.altArrowsToMoveNodes && altPressed) || (~config.altArrowsToMoveNodes && ~altPressed)
dXY = [-1 0] - (shiftPressed * 4);
if ctrlPressed; nudgeSegment(dXY);
else; nudgeNode(dXY); end
else
dt = -1 - (shiftPressed * 4);
if config.shuffleFrames
idx = find(shuffleIdx == ui.status.currentFrame);
goToFrame(shuffleIdx(mod(idx-1+dt, numFrames) + 1))
else
goToFrame(mod(ui.status.currentFrame-1+dt, numFrames) + 1)
end
end
case 'rightarrow'
if (config.altArrowsToMoveNodes && altPressed) || (~config.altArrowsToMoveNodes && ~altPressed)
dXY = [1 0] + (shiftPressed * 4);
if ctrlPressed; nudgeSegment(dXY);
else; nudgeNode(dXY); end
else
dt = 1 + (shiftPressed * 4);
if config.shuffleFrames
idx = find(shuffleIdx == ui.status.currentFrame);
goToFrame(shuffleIdx(mod(idx-1+dt, numFrames) + 1))
else
goToFrame(mod(ui.status.currentFrame-1+dt, numFrames) + 1)
end
end
case 'space'
% Get labeling status for all frames
labeled = getStatus() == 2;
% Consider current joint only if shift is pressed
if shiftPressed; labeled = labeled(ui.status.selectedNode,:);
else; labeled = all(labeled,1); end
% Find unlabeled frames excluding current frame
unlabeledIdxs = setdiff(find(labeled), ui.status.currentFrame);
if ~isempty(unlabeledIdxs)
if ctrlPressed
% Go to random unlabeled frame
goToFrame(datasample(unlabeledIdxs,1));
else
% Go to first unlabeled frame
goToFrame(unlabeledIdxs(1))
end
end
case 'g'
% go to frame dialog
% if ctrlOnly
% TODO:
% - custom dialog box that starts focused on the textbox
% and returns after pressing Enter/Esc
answer = inputdlg('Skip to frame index:','Skip to frame',1,{num2str(ui.status.currentFrame)});
try
idx = round(str2double(answer));
if idx >= 1 && idx <= numFrames
goToFrame(idx);
end
catch
end
% end
case 'f'
markAllCorrect();
otherwise
% evt
end
end
function clickImage()
% Callback to image clicks (but not on nodes)
% Pull out clicked point coordinate
pt = ui.img.ax.CurrentPoint(1,1:2);
% Get current node positions
pos = getNodePositions();
if config.clickNearest
% Find nearest node location
i = argmin(rownorm(pos - pt));
else
% Use current selection
i = ui.status.selectedNode;
end
% Update node position
pos(i,:) = pt;
updateSkeleton(pos);
end
function clickStatusbar(h,evt)
% Callback for seeking via mouse-click on the status bars
if evt.Button == 1
idx = clip(round(evt.IntersectionPoint(1)),[1 numFrames]);
goToFrame(idx);
end
end
function status = getStatus(idx)
% Utility function that checks labels for completeness status
% Returns [numJoints x numel(idx)] matrix with values:
% 0: default
% 1: initialized
% 2: labeled
% Get status for all frames by default
if nargin < 1; idx = 1:numFrames; end
% Initialize as default (0)
status = zeros(numNodes, numel(idx));
% Check for initialization
isInitialized = squeeze(all(~isnan(labels.initialization(:,:,idx)),2));
status(isInitialized) = 1;
% Check for user labels
isLabeled = squeeze(all(~isnan(labels.positions(:,:,idx)),2));
status(isLabeled) = 2;
end
%% Training and dataset generation
function predictInitializations(modelPath)
% Generates predictions for the entire dataset and uses those for
% initialization of unlabeled frames.
if nargin < 1 || isempty(modelPath)
modelPath = uibrowse([],[],'Select model folder...', 'dir');
if isempty(modelPath) || ~exists(modelPath); return; end
end
% TODO: better system for choosing final vs best validation model
% if exists(ff(modelPath, 'final_model.h5'))
% numValidationSamples = numel(loadvar(ff(modelPath,'training_info.mat'),'val_idx'));
% % numWeights = numel(dir_files(ff(modelPath,'weights')));
% if numValidationSamples < 500
% modelPath = ff(modelPath,'final_model.h5');
% end
% end
numValidationSamples = numel(loadvar(ff(modelPath,'training_info.mat'),'val_idx'));
if exists(ff(modelPath, 'best_model.h5')) && numValidationSamples > 500
modelPath = ff(modelPath, 'best_model.h5');
else
modelPath = ff(modelPath, 'final_model.h5');
end
% Predict
preds = predict_box(boxPath, modelPath, false);
% Save
labels.initialization = preds.positions_pred;
saveLabels();
% Update status
isInitialized = squeeze(all(~isnan(labels.initialization),2));
numInitialized = sum(all(isInitialized,1));
ui.status.framesInitialized.String = sprintf('Initialized: %d/%d (%.2f%%)', numInitialized, numFrames, numInitialized/numFrames*100);
% Update status bars
status = getStatus();
ui.status.fullImg.CData = status;
zoom_idx = ui.status.zoomImg.XData > 0 & ui.status.zoomImg.XData <= size(status,2);
ui.status.zoomImg.CData(:,zoom_idx) = status(:,ui.status.zoomImg.XData(zoom_idx));
% Log event
addToHistory(['Initialized with model: ' modelPath])
% Calculate error rate on labels
labeled = all(getStatus() == 2,1);
pos_gt = labels.positions(:,:,labeled);
pos_pred = labels.initialization(:,:,labeled);
pred_metrics = compute_errors(pos_pred,pos_gt);
% Display errors
printf('Error: mean = %.2f, s.d. = %.2f', mean(pred_metrics.euclidean(:)), std(pred_metrics.euclidean(:)))
prcs = [50 75 90];
prc_errs = prctile(pred_metrics.euclidean(:), prcs);
for i = 1:numel(prcs)
printf(' %d%% = %.3f', prcs(i), prc_errs(i))
end
% Replot
goToFrame(ui.status.currentFrame);
end
function generateTrainingSet()
% Default save path
defaultSavePath = ff(fileparts(boxPath), 'training', [get_filename(boxPath,true) '.h5']);
defaultSavePath = get_new_filename(defaultSavePath,true);
% Create dialog with parameters
[params, buttonPressed] = settingsdlg(...
'WindowWidth', 400,...
'title','Generate a training set', ...
'Description','Export a dataset for training based on the current labels.',...
'separator','General options',...
{'Save path';'savePath'}, defaultSavePath, ...
{'Scale - for resizing images';'scale'}, 1, ...
{'Sigma - kernel size for confidence maps';'sigma'}, 5, ...
{'Test set fraction - held out frames';'testFraction'},0.1,...
{'Shuffle - randomize saved dataset order';'postShuffle'}, true, ...
{'Compress - reduce file size, but slower to load';'compress'}, true, ...
'separator','Data mirroring',...
{'Mirror images - augment by flipping along the body axis';'mirroring'}, [true, false], ...
{'Animal orientation';'animalOrientation'}, {'left/right','top/bottom'} ...
);
% Cancel if OK was not pressed (cancel or window closed)
if ~strcmpi(buttonPressed,'ok'); return; end
% Convert listbox input to boolean for orientation
params.horizontalOrientation = strcmpi(params.animalOrientation,'left/right');
% Check for existing save path
if exists(params.savePath)
answer = questdlg('Save path already exists, overwrite existing file?', 'Overwrite file', 'Overwrite', 'Cancel', 'Overwrite');
if ~strcmpi(answer, 'Overwrite'); return; end
end
% Run!
generate_training_set(boxPath,params);
% Log action
addToHistory('Generated training set.')
end
function fastTrain()
% Generate a training set for fast training from current labels
% Build default output path
runName = sprintf('%s-n=%d', datestr(now,'yymmdd_HHMMSS'), numLabeled);
defaultModelsFolder = ff(fileparts(boxPath), 'models');
% Create dialog with parameters
[params, buttonPressed] = settingsdlg(...
'WindowWidth', 500,...
'title','Fast training', ...
'Description','Quickly train a model using current labels and predict on remaining frames as initialization.',...
'separator','Dataset',...
{'Scale - for resizing images';'scale'}, 1, ...
{'Sigma - kernel size for confidence maps';'sigma'}, 5, ...
'separator','Data mirroring',...
{'Mirror images - augment by flipping along the body axis';'mirroring'}, [true, false], ...
{'Animal orientation';'animalOrientation'}, {'left/right','top/bottom'}, ...
'separator','Model',...
{'Network architecture';'netName'},{'leap_cnn','hourglass','stacked_hourglass'},...
{'Filters - base number of filters for model';'filters'},32,...
{'Upsampling layers - use bilinear upsampling instead of transposed conv';'upsamplingLayers'},true,...
'separator','Training',...
{'Model path - folder to save run data to';'modelsFolder'},defaultModelsFolder,...
{'Rotate angle - augment data via random rotations';'rotateAngle'},5,...
{'Validation set fraction - frames used for validation';'valSize'},0.1,...
{'Epochs - number of rounds of training';'epochs'},15,...
{'Batch size - number of samples per batch';'batchSize'},50,...
{'Batches per epoch - number of batches of samples per round';'batchesPerEpoch'},50,...
{'Validation batches per epoch - number of batches to use for validation';'valBatchesPerEpoch'},10,...
{'Save every epoch - save weights from every epoch instead of just best+final';'saveEveryEpoch'},false,...
'separator','Training (advanced)',...
{'Reduce LR factor - drop learning rate when loss plateaus';'reduceLRFactor'},0.1,...
{'Reduce LR patience - wait after loss plateaus before reducing LR';'reduceLRPatience'},2,...
{'Reduce LR cooldown - wait after reducing LR before detecting plateau';'reduceLRCooldown'},0,...
{'Reduce LR min delta - minimum change in loss to not plateau';'reduceLRMinDelta'},1e-5,...
{'Reduce LR min LR - minimum LR to not drop below';'reduceLRMinLR'},1e-10,...
{'AMSGrad - optimizer variant for more emphasis on rare data';'amsgrad'},true ...
);
% Cancel if OK was not pressed (cancel or window closed)
if ~strcmpi(buttonPressed,'ok'); return; end
% Convert listbox input to boolean for orientation
params.horizontalOrientation = strcmpi(params.animalOrientation,'left/right');
% Generate temporary training set file
dataPath = [tempname '.h5'];
dataPath = generate_training_set(boxPath,'savePath',dataPath,...
'scale',params.scale,...
'mirroring',params.mirroring,...
'horizontalOrientation',params.horizontalOrientation,...
'sigma',params.sigma, ...
'normalizeConfmaps',true,...
'postShuffle',true, ...
'testFraction',0);
% Log action
addToHistory(sprintf('Fast training (n = %d)', numLabeled))
% Create CLI command for training
basePath = fileparts(funpath(true));
cmd = {
'python'
['"' ff(basePath, 'training.py') '"']
['"' dataPath '"']
['--base-output-path="' params.modelsFolder '"']
['--run-name="' runName '"']
['--net-name="' params.netName '"']
sprintf('--filters=%d',params.filters)
sprintf('--rotate-angle=%d', params.rotateAngle)
sprintf('--val-size=%.5f', params.valSize)
sprintf('--epochs=%d', params.epochs)
sprintf('--batch-size=%d', params.batchSize)
sprintf('--batches-per-epoch=%d', params.batchesPerEpoch)
sprintf('--val-batches-per-epoch=%d', params.valBatchesPerEpoch)
sprintf('--reduce-lr-factor=%.10f', params.reduceLRFactor)
sprintf('--reduce-lr-patience=%d', params.reduceLRPatience)
sprintf('--reduce-lr-cooldown=%d', params.reduceLRCooldown)
sprintf('--reduce-lr-min-delta=%.10f', params.reduceLRMinDelta)
sprintf('--reduce-lr-min-lr=%.10f', params.reduceLRMinLR)
};
if params.upsamplingLayers; cmd{end+1} = '--upsampling-layers'; end
if params.saveEveryEpoch; cmd{end+1} = '--save-every-epoch'; end
if params.amsgrad; cmd{end+1} = '--amsgrad'; end
cmd = strjoin(cmd);
disp(cmd)
% Train!
try
exit_code = system(cmd);
% [exit_code,cmd_output] = system(cmd);
catch ME
delete(dataPath)
rethrow(ME)
end
delete(dataPath)
% TODO: parse this out from python output?
modelPath = ff(params.modelsFolder, runName);
% Run trained model on data to initialize labels
if exists(ff(modelPath, 'final_model.h5'))
predictInitializations(modelPath)
end
end
%% Ploting
initializeSkeleton();
function initializeSkeleton()
% Creates graphics objects representing the interactive skeleton
% Draw each line segment
if ~isempty(ui.skel.segs); delete(ui.skel.segs); end
ui.skel.segs = gobjects(size(labels.skeleton.segments,1),1);
for i = 1:numel(ui.skel.segs)
% Find default position of each nodes in the segment
pos = labels.skeleton.pos(labels.skeleton.segments.joints_idx{i},:);
% Plot
ui.skel.segs(i) = plot(ui.img.ax, pos(:,1), pos(:,2), '.-', ...
'Color',labels.skeleton.segments.color{i});
% Add metadata
ui.skel.segs(i).UserData.seg_idx = i;
ui.skel.segs(i).UserData.seg_joints_idx = labels.skeleton.segments.joints_idx{i};
end
% Clicks on the skeleton edges should pass through to the image
set(ui.skel.segs, 'PickableParts', 'none');
% Draw each joint node
if ~isempty(ui.skel.nodes); delete(ui.skel.nodes); end
% status = getStatus(ui.status.currentFrame); statusCmap(status(i)+1,:)
ui.skel.nodes = gobjects(height(labels.skeleton.joints),1);
for i = 1:numel(ui.skel.nodes)
ui.skel.nodes(i) = plot(ui.img.ax,labels.skeleton.joints.pos(i,1),labels.skeleton.joints.pos(i,2),'o',...
'Color','w', 'LineWidth', 1, 'PickableParts','none');
ui.skel.nodes(i).UserData.node_idx = i;
end
% Make movable and add callbacks
if config.draggable
set(ui.skel.nodes, 'PickableParts','visible');
draggable(ui.skel.nodes, @nodesMoved, 'endFcn', @nodesMoveEnd);
end
end
function pos = getNodePositions()
% Utility function that returns node positions from the corresponding graphics objects
pos = NaN(numel(ui.skel.nodes),2);
for i = 1:numel(ui.skel.nodes)
pos(i,:) = [ui.skel.nodes(i).XData ui.skel.nodes(i).YData];
end
end
function updateSkeleton(pos)
% Updates pre-initialiazed skeleton graphics objects
if nargin < 1
% Get current node positions from graphics
pos = getNodePositions();
else
% Update node positions
for i = 1:size(pos,1)
% Check for modification to graphics positions
old_pos = [ui.skel.nodes(i).XData ui.skel.nodes(i).YData];
if ~isequal(pos(i,:), old_pos)
% Update graphics
ui.skel.nodes(i).XData = pos(i,1);
ui.skel.nodes(i).YData = pos(i,2);
end
end
end
% Check for changes
for i = 1:numNodes
if ~isequal(pos(i,:), ui.status.initialPos(i,:))
% Mark node as moved
ui.status.movedNodes(i) = true;
% Denote unsaved changes
ui.status.unsavedChanges(ui.status.currentFrame) = true;
end
end
% Set defaults
set(ui.skel.nodes, 'Marker', 'o'); % Default marker (no changes)
set(ui.skel.nodes, 'MarkerSize', config.nodeSize); % Default size (unselected)
set(ui.ctrl.refNodes, 'MarkerSize', config.nodeSize); % Default size (unselected)
% Update node colors based on status
status = getStatus(ui.status.currentFrame);
for i = 1:numNodes
% Set status color
ui.skel.nodes(i).Color = statusCmap(status(i)+1,:);
% Uncommitted changes
if ui.status.movedNodes(i); ui.skel.nodes(i).Marker = 's'; end
% Selected node
if ui.status.selectedNode == i
ui.skel.nodes(i).MarkerSize = 9;
ui.ctrl.refNodes(i).MarkerSize = 9;
end
end
% Update edges
for i = 1:numel(ui.skel.segs)
ui.skel.segs(i).XData(:) = pos(ui.skel.segs(i).UserData.seg_joints_idx,1);
ui.skel.segs(i).YData(:) = pos(ui.skel.segs(i).UserData.seg_joints_idx,2);
end
drawnow;
end
function nodesMoved(h)
% Called while node is being moved to update skeleton
% Get node index
node_idx = h.UserData.node_idx;
% Set selected node
if ui.status.selectedNode ~= node_idx
selectNode(node_idx)
end
% Update
updateSkeleton()
end
function nodesMoveEnd(h)
% Called when the node is released after moving
% Get node index
node_idx = h.UserData.node_idx;
% Set selected node
if ui.status.selectedNode ~= node_idx
selectNode(node_idx)
end
% Update
updateSkeleton()
end
function selectNode(i)
% Utility function that sets the selected node across the entire GUI
% Check for changes
previousSelection = ui.status.selectedNode;
if ~isequal(previousSelection, i)
% Set selected node
ui.status.selectedNode = i;
% Update listbox
ui.ctrl.jointsList.Value = i;
% Update graphics
updateSkeleton();
end
end
function nudgeNode(dXY, i)
% Utility function for moving a node by a delta amount
if nargin < 2; i = ui.status.selectedNode; end
% Get and update node position
pos = getNodePositions();
pos(i,:) = pos(i,:) + dXY;
% Update
updateSkeleton(pos);
end
function nudgeSegment(dXY, i)
% Utility function for moving all segments with a node by a delta amount
if nargin < 2; i = ui.status.selectedNode; end
% Find each segment with the current node and pull out all nodes
seg_nodes = {};
for j = 1:height(labels.skeleton.segments)
idx = labels.skeleton.segments.joints_idx{j};
if any(idx == i)
seg_nodes{end+1} = idx;
end
end
% Get the union of the set to make sure we don't double move any nodes
seg_nodes = unique(cellcat(seg_nodes));
% Get current positions
pos = getNodePositions();
% Move all nodes
for j = 1:numel(seg_nodes)
pos(j,:) = pos(j,:) + dXY;
end
% Update edges
updateSkeleton(pos);
end
function setNodesToDefault(node_idx)
% Utility function to reset nodes to default position from the skeleton template
if nargin < 1; node_idx = 1:numNodes; end
% Get current positions
pos = getNodePositions();
% Get default positions
default_pos = labels.skeleton.joints.pos;
% Update with defaults
% pos(node_idx,:) = default_pos(node_idx,:);
labels.positions(node_idx,:,ui.status.currentFrame) = NaN;
ui.status.movedNodes(node_idx) = false;
% Update
updateSkeleton();
end
function pos = getInitialPos(idx)
% Utility to compute the initial node positions for a single frame
if nargin < 1; idx = ui.status.currentFrame; end
% Start off with defaults
pos = labels.skeleton.joints.pos;
% Update with initialized positions
init_pos = labels.initialization(:,:,idx);
init_nodes = find(all(~isnan(init_pos),2));
pos(init_nodes,:) = init_pos(init_nodes,:);
% Update with user-labeled positions
label_pos = labels.positions(:,:,idx);
label_nodes = find(all(~isnan(label_pos),2));
pos(label_nodes,:) = label_pos(label_nodes,:);
end
function resetNodes(node_idx)
% Utility function to reset nodes to their initial positions when the frame was drawn
if nargin < 1; node_idx = 1:numNodes; end
% Start off with what we have now
pos = getNodePositions();
% Get initial postions
init_pos = getInitialPos(ui.status.currentFrame);
% Set positions for specified nodes
pos(node_idx,:) = init_pos(node_idx);
% Update
updateSkeleton(pos);
end
%% Frame update and saving
function markAllCorrect()
% Helper for setting all nodes in the current frame as correct
ui.status.movedNodes(:) = true;
commitChanges();
end
function commitChanges()
% Utility function for committing changes to node positions in the
% current frame to the labels structure (but does not save to disk)
% Get current positions
pos = getNodePositions();
% Check current status
status = getStatus(ui.status.currentFrame);
isLabeled = all(status == 2);
for i = horz(find(ui.status.movedNodes))
% Commit to labels
labels.positions(i,:,ui.status.currentFrame) = pos(i,:);
% Reset moved state
ui.status.movedNodes(i) = false;
% Mark unsaved changes
ui.status.unsavedChanges(ui.status.currentFrame) = true;
end
% Update status
status = getStatus(ui.status.currentFrame);
% Update skeleton display
%updateSkeleton();
% Update stats if changed
if ~all(isLabeled) && all(status == 2)
addToHistory(sprintf('Labeled frame %d', ui.status.currentFrame));
numLabeled = numLabeled + 1;
ui.status.framesLabeled.String = sprintf('Labeled: %d/%d (%.2f%%)', numLabeled, numFrames, numLabeled/numFrames*100);
end
% Update full status data
ui.status.fullImg.CData(:,ui.status.currentFrame) = status;
% Update zoomed status bar data
zoomBoxIdx = ui.status.zoomImg.XData;
if any(zoomBoxIdx == ui.status.currentFrame)
ui.status.zoomImg.CData(:,zoomBoxIdx == ui.status.currentFrame) = status;
end
% Update status fig title
savedStatus = '';
if any(ui.status.unsavedChanges); savedStatus = ' [unsaved]'; end
ui.status.fig.Name = sprintf('Status: %d/%d (%.2f%%) labeled%s', numLabeled, numFrames, numLabeled/numFrames*100, savedStatus);
end
function goToFrame(idx)
% Utility function for seeking to another frame
% Commit changes to labels
commitChanges();
% Autosave before anything
if config.autoSave && ~isempty(ui.status.currentFrame) && ui.status.unsavedChanges(ui.status.currentFrame)
saveLabels();
end
% Update image
ui.img.img.CData = box(:,:,:,idx);
% Update status
ui.status.currentFrame = idx;
ui.img.fig.Name = sprintf('%d/%d', ui.status.currentFrame, numFrames);
% Get initial positions
ui.status.initialPos = getInitialPos(idx);
% Update with initial positions
updateSkeleton(ui.status.initialPos);
% Update status zoom box position
zoomBoxIdx = zoomBoxWindow + ui.status.currentFrame;
zoomBoxPts = [
zoomBoxIdx(1) 0
zoomBoxIdx(end) 0
zoomBoxIdx(end) numNodes
zoomBoxIdx(1) numNodes
zoomBoxIdx(1) 0
];
ui.status.fullZoomBox.XData = zoomBoxPts(:,1);
ui.status.fullZoomBox.YData = zoomBoxPts(:,2);
% Update zoomed status bar data
ui.status.zoomImg.XData = zoomBoxIdx;
ui.status.zoomImg.CData(:) = 0; % reset
isValidIdx = zoomBoxIdx > 0 & zoomBoxIdx <= numFrames;
ui.status.zoomImg.CData(:,isValidIdx) = ui.status.fullImg.CData(:,zoomBoxIdx(isValidIdx));
ui.status.zoomAx.XLim = zoomBoxIdx([1 end]) + [-0.5 0.5];
end
function saveLabels()
% Saves everything in the labels structure to disk
stic;
% Commit unsaved changes to labels
commitChanges();
% Update if there were any changes
if any(ui.status.unsavedChanges)
% Update last modified timestamp
labels.lastModified = datestr(now);
end
% Save current frame so we pick up where we left off
config.initialFrame = ui.status.currentFrame;
% Save figure positions
config.imgFigPos = ui.img.fig.Position;
config.ctrlFigPos = ui.ctrl.fig.Position;
config.statusFigPos = ui.status.fig.Position;
% Save config
labels.config = config;
% Save to labels file
save(labels.savePath, '-struct', 'labels')
% Clear modified flags
ui.status.unsavedChanges(:) = false;
commitChanges();
stocf('Saved labels: %s', labels.savePath)
end
%% Start!
goToFrame(config.initialFrame);
selectNode(1);
updateSkeleton();
end
================================================
FILE: leap/hpc/python_gpu.sh
================================================
#!/bin/bash
#SBATCH --time=4:00:00
#SBATCH --mem=128000
#SBATCH -N 1
#SBATCH --cpus-per-task=4
#SBATCH --ntasks-per-node=1
#SBATCH --ntasks-per-socket=1
#SBATCH --gres=gpu:1
echo "args: ${@:1}"
python ${@:1}
================================================
FILE: leap/image_augmentation.py
================================================
import cv2
import numpy as np
import keras
from keras.utils import Sequence
def transform_imgs(X, theta=(-180,180), scale=1.0):
""" Transforms sets of images with the same random transformation across each channel. """
# Sample random rotation and scale if range specified
if not np.isscalar(theta):
theta = np.ptp(theta) * np.random.rand() + np.min(theta)
if not np.isscalar(scale):
scale = np.ptp(scale) * np.random.rand() + np.min(scale)
# Standardize X input to a list
single_img = type(X) == np.ndarray
if single_img:
X = [X,]
# Find image parameters
img_size = X[0].shape[:2]
ctr = (img_size[0] / 2, img_size[0] / 2)
# Compute affine transformation matrix
T = cv2.getRotationMatrix2D(ctr, theta, scale)
# Make sure we don't overwrite the inputs
X = [x.copy() for x in X]
# Apply to each image
for i in range(len(X)):
if X[i].ndim == 2:
# Single channel image
X[i] = cv2.warpAffine(X[i], T, img_size[::-1])
else:
# Multi-channel image
for c in range(X[i].shape[-1]):
X[i][...,c] = cv2.warpAffine(X[i][...,c], T, img_size[::-1])
# Pull the single image back out of the list
if single_img:
X = X[0]
return X
class PairedImageAugmenter(Sequence):
def __init__(self, X, Y, batch_size=32, shuffle=False, theta=(-180,180), scale=1.0):
self.X = X
self.Y = Y
self.batch_size = batch_size
self.theta = theta
self.scale = scale
self.num_samples = len(X)
all_idx = np.arange(self.num_samples)
if shuffle:
np.random.shuffle(all_idx)
self.batches = np.array_split(all_idx, np.ceil(self.num_samples / self.batch_size))
def __len__(self):
return len(self.batches)
def __getitem__(self, batch_idx):
idx = self.batches[batch_idx]
X = self.X[idx]
Y = self.Y[idx]
for i in range(len(X)):
X[i], Y[i] = transform_imgs((X[i],Y[i]), theta=self.theta, scale=self.scale)
return X, Y
class MultiInputOutputPairedImageAugmenter(PairedImageAugmenter):
def __init__(self, input_names, output_names, *args, **kwargs):
if type(input_names) != list:
input_names = [input_names,]
if type(output_names) != list:
output_names = [output_names,]
self.input_names = input_names
self.output_names = output_names
super().__init__(*args, **kwargs)
def __getitem__(self, batch_idx):
X,Y = super().__getitem__(batch_idx)
return ({k: X for k in self.input_names}, {k: Y for k in self.output_names})
================================================
FILE: leap/layers.py
================================================
# -*- coding: utf-8 -*-
"""
Copyright 2018 Jacob M. Graving <jgraving@gmail.com>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Modified from code written by François Chollet
All contributions by François Chollet:
Copyright (c) 2015 - 2018, François Chollet.
All rights reserved.
"""
import numpy as np
import keras
import keras.backend as K
from keras.legacy import interfaces
from keras.engine import Layer
from keras.engine import InputSpec
from keras.utils import conv_utils
from keras.backend import int_shape, permute_dimensions
from keras.backend import tf
from keras.layers import Conv2D, Add
from packaging.version import parse as parse_version
__all__ = ['UpSampling2D', 'Maxima2D']
def resize_images(x, height_factor, width_factor, interpolation, data_format):
"""Resizes the images contained in a 4D tensor.
# Arguments
x: Tensor or variable to resize.
height_factor: Positive integer.
width_factor: Positive integer.
interpolation: string, "nearest", "bilinear" or "bicubic"
data_format: string, `"channels_last"` or `"channels_first"`.
# Returns
A tensor.
# Raises
ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`.
"""
if interpolation == 'nearest':
tf_resize = tf.image.resize_nearest_neighbor
elif interpolation == 'bilinear':
tf_resize = tf.image.resize_bilinear
elif interpolation == 'bicubic':
tf_resize = tf.image.resize_bicubic
else:
raise ValueError('Invalid interpolation method:', interpolation)
if data_format == 'channels_first':
original_shape = int_shape(x)
new_shape = tf.shape(x)[2:]
new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32'))
x = permute_dimensions(x, [0, 2, 3, 1])
x = tf_resize(x, new_shape, align_corners=True)
x = permute_dimensions(x, [0, 3, 1, 2])
x.set_shape((None, None, original_shape[2] * height_factor if original_shape[2] is not None else None,
original_shape[3] * width_factor if original_shape[3] is not None else None))
return x
elif data_format == 'channels_last':
original_shape = int_shape(x)
new_shape = tf.shape(x)[1:3]
new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32'))
x = tf_resize(x, new_shape, align_corners=True)
x.set_shape((None, original_shape[1] * height_factor if original_shape[1] is not None else None,
original_shape[2] * width_factor if original_shape[2] is not None else None, None))
return x
else:
raise ValueError('Invalid data_format:', data_format)
class UpSampling2D(Layer):
"""Upsampling layer for 2D inputs.
Repeats the rows and columns of the data
by size[0] and size[1] respectively with bilinear interpolation.
# Arguments
size: int, or tuple of 2 integers.
The upsampling factors for rows and columns.
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
interpolation: A string,
one of 'nearest' (default), 'bilinear', or 'bicubic'
# Input shape
4D tensor with shape:
- If `data_format` is `"channels_last"`:
`(batch, rows, cols, channels)`
- If `data_format` is `"channels_first"`:
`(batch, channels, rows, cols)`
# Output shape
4D tensor with shape:
- If `data_format` is `"channels_last"`:
`(batch, upsampled_rows, upsampled_cols, channels)`
- If `data_format` is `"channels_first"`:
`(batch, channels, upsampled_rows, upsampled_cols)`
"""
@interfaces.legacy_upsampling2d_support
def __init__(self, size=(2, 2), data_format=None, interpolation='nearest', **kwargs):
super(UpSampling2D, self).__init__(**kwargs)
# Update to K.normalize_data_format after keras 2.2.0
if parse_version(keras.__version__) > parse_version("2.2.0"):
self.data_format = K.normalize_data_format(data_format)
else:
self.data_format = conv_utils.normalize_data_format(data_format)
self.interpolation = interpolation
self.size = conv_utils.normalize_tuple(size, 2, 'size')
self.input_spec = InputSpec(ndim=4)
def compute_output_shape(self, input_shape):
if self.data_format == 'channels_first':
height = self.size[0] * input_shape[2] if input_shape[2] is not None else None
width = self.size[1] * input_shape[3] if input_shape[3] is not None else None
return (input_shape[0],
input_shape[1],
height,
width)
elif self.data_format == 'channels_last':
height = self.size[0] * input_shape[1] if input_shape[1] is not None else None
width = self.size[1] * input_shape[2] if input_shape[2] is not None else None
return (input_shape[0],
height,
width,
input_shape[3])
def call(self, inputs):
return resize_images(inputs, self.size[0], self.size[1],
self.interpolation, self.data_format)
def get_config(self):
config = {'size': self.size,
'data_format': self.data_format}
base_config = super(UpSampling2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def _find_maxima(x):
x = K.cast(x, K.floatx())
col_max = K.max(x, axis=1)
row_max = K.max(x, axis=2)
maxima = K.max(col_max, 1)
maxima = K.expand_dims(maxima, -2)
cols = K.cast(K.argmax(col_max, -2), K.floatx())
rows = K.cast(K.argmax(row_max, -2), K.floatx())
cols = K.expand_dims(cols, -2)
rows = K.expand_dims(rows, -2)
# maxima = K.concatenate([rows, cols, maxima], -2) # y, x, val
maxima = K.concatenate([cols, rows, maxima], -2) # x, y, val
return maxima
def find_maxima(x, data_format):
"""Finds the 2D maxima contained in a 4D tensor.
# Arguments
x: Tensor or variable.
data_format: string, `"channels_last"` or `"channels_first"`.
# Returns
A tensor.
# Raises
ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`.
"""
if data_format == 'channels_first':
x = permute_dimensions(x, [0, 2, 3, 1])
x = _find_maxima(x)
x = permute_dimensions(x, [0, 2, 1])
return x
elif data_format == 'channels_last':
x = _find_maxima(x)
return x
else:
raise ValueError('Invalid data_format:', data_format)
class Maxima2D(Layer):
"""Maxima layer for 2D inputs.
Finds the maxima and 2D indices
for the channels in the input.
The output is ordered as [row, col, maximum].
# Arguments
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
# Input shape
4D tensor with shape:
- If `data_format` is `"channels_last"`:
`(batch, rows, cols, channels)`
- If `data_format` is `"channels_first"`:
`(batch, channels, rows, cols)`
# Output shape
3D tensor with shape:
- If `data_format` is `"channels_last"`:
`(batch, 3, channels)`
- If `data_format` is `"channels_first"`:
`(batch, channels, 3)`
"""
def __init__(self, data_format=None, **kwargs):
super(Maxima2D, self).__init__(**kwargs)
# Update to K.normalize_data_format after keras 2.2.0
if parse_version(keras.__version__) > parse_version("2.2.0"):
self.data_format = K.normalize_data_format(data_format)
else:
self.data_format = conv_utils.normalize_data_format(data_format)
self.input_spec = InputSpec(ndim=4)
def compute_output_shape(self, input_shape):
if self.data_format == 'channels_first':
return (input_shape[0],
input_shape[1],
3)
elif self.data_format == 'channels_last':
return (input_shape[0],
3,
input_shape[3])
def call(self, inputs):
return find_maxima(inputs, self.data_format)
def get_config(self):
config = {'data_format': self.data_format}
base_config = super(Maxima2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def residual_bottleneck_module(x_in, output_filters=32, bottleneck_factor=2, prefix="res", activation="relu", initializer="glorot_normal"):
# Get input shape and channels
in_shape = K.int_shape(x_in)
input_filters = in_shape[3]
# Bottleneck filters are proportional to the output filters
bottleneck_filters = output_filters // bottleneck_factor
# Bottleneck block
x = Conv2D(filters=bottleneck_filters, kernel_size=1, padding="same", activation=activation, kernel_initializer=initializer, name=prefix + "_Conv1")(x_in)
x = Conv2D(filters=bottleneck_filters, kernel_size=3, padding="same", activation=activation, kernel_initializer=initializer, name=prefix + "_Conv2")(x)
x = Conv2D(filters=output_filters, kernel_size=1, padding="same", activation=activation, kernel_initializer=initializer, name=prefix + "_Conv3")(x)
# 1x1 conv if input channels are different from output channels
if output_filters != input_filters:
x_in = Conv2D(filters=output_filters, kernel_size=1, padding="same", activation=activation, kernel_initializer=initializer, name=prefix + "_ConvSkip")(x_in)
# Residual connection
x = Add(name=prefix + "_AddRes")([x_in, x])
return x
================================================
FILE: leap/models.py
================================================
import keras
from keras import backend as K
from keras.models import Model
from keras.layers import Input, Conv2D, Conv2DTranspose, Add, MaxPooling2D
from keras.optimizers import Adam
from leap.layers import residual_bottleneck_module, UpSampling2D
def leap_cnn(img_size, output_channels, filters=64, upsampling_layers=False, amsgrad=False, summary=False):
"""
Creates and compiles network model.
:param img_size: shape of a single image, optionally including channels
:param output_channels: number of output channels (joints being predicted)
:param filters: number of baseline filters to use (more filters will be used in intermediate layers)
:param summary: prints network summary after compiling
"""
if len(img_size) == 2:
img_size = img_size + (1,)
x_in = Input(img_size)
x1 = Conv2D(filters, kernel_size=3, padding="same", activation="relu")(x_in)
x1 = Conv2D(filters, kernel_size=3, padding="same", activation="relu")(x1)
x1 = Conv2D(filters, kernel_size=3, padding="same", activation="relu")(x1)
x1_pool = MaxPooling2D(pool_size=2, strides=2, padding="same")(x1)
x2 = Conv2D(filters*2, kernel_size=3, padding="same", activation="relu")(x1_pool)
x2 = Conv2D(filters*2, kernel_size=3, padding="same", activation="relu")(x2)
x2 = Conv2D(filters*2, kernel_size=3, padding="same", activation="relu")(x2)
x2_pool = MaxPooling2D(pool_size=2, strides=2, padding="same")(x2)
x3 = Conv2D(filters*4, kernel_size=3, padding="same", activation="relu")(x2_pool)
x3 = Conv2D(filters*4, kernel_size=3, padding="same", activation="relu")(x3)
x3 = Conv2D(filters*4, kernel_size=3, padding="same", activation="relu")(x3)
if upsampling_layers:
x4 = UpSampling2D(interpolation="bilinear")(x3)
else:
x4 = Conv2DTranspose(filters*2, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal")(x3)
x4 = Conv2D(filters*2, kernel_size=3, padding="same", activation="relu")(x4)
x4 = Conv2D(filters*2, kernel_size=3, padding="same", activation="relu")(x4)
if upsampling_layers:
x_out = UpSampling2D(interpolation="bilinear")(x4)
x_out = Conv2D(output_channels, kernel_size=3, padding="same", activation="linear")(x_out)
else:
x_out = Conv2DTranspose(output_channels, kernel_size=3, strides=2, padding="same", activation="linear", kernel_initializer="glorot_normal")(x4)
# Compile
net = Model(inputs=x_in, outputs=x_out, name="LeapCNN")
net.compile(optimizer=Adam(amsgrad=amsgrad), loss="mean_squared_error")
if summary:
net.summary()
return net
def hourglass(img_size, output_channels, filters=64, upsampling_layers=False, amsgrad=False, summary=False):
"""
Creates and compiles network model.
:param img_size: shape of a single image, optionally including channels
:param output_channels: number of output channels (joints being predicted)
:param filters: number of baseline filters to use (more filters will be used in intermediate layers)
:param summary: prints network summary after compiling
"""
if len(img_size) == 2:
img_size = img_size + (1,)
x_in = Input(img_size, name="x_in")
x1_pre = residual_bottleneck_module(x_in, prefix="x1", output_filters=filters)
x1 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x1_pool")(x1_pre)
x2_pre = residual_bottleneck_module(x1, prefix="x2", output_filters=filters)
x2 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x2_pool")(x2_pre)
x3_pre = residual_bottleneck_module(x2, prefix="x3", output_filters=filters)
x3 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x3_pool")(x3_pre)
x4_pre = residual_bottleneck_module(x3, prefix="x4", output_filters=filters)
x4 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x4_pool")(x4_pre)
x5 = residual_bottleneck_module(x4, prefix="x5", output_filters=filters)
if upsampling_layers:
x6_pre = UpSampling2D(interpolation="bilinear", name="x6_Upsample")(x5)
else:
x6_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x6_ConvT")(x5)
x6_add = Add(name="x6_Add")([x4_pre, x6_pre])
x6 = residual_bottleneck_module(x6_add, prefix="x6", output_filters=filters)
if upsampling_layers:
x7_pre = UpSampling2D(interpolation="bilinear", name="x7_Upsample")(x6)
else:
x7_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x7_ConvT")(x6)
x7_add = Add(name="x7_Add")([x3_pre, x7_pre])
x7 = residual_bottleneck_module(x7_add, prefix="x7", output_filters=filters)
if upsampling_layers:
x8_pre = UpSampling2D(interpolation="bilinear", name="x8_Upsample")(x7)
else:
x8_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x8_ConvT")(x7)
x8_add = Add(name="x8_Add")([x2_pre, x8_pre])
x8 = residual_bottleneck_module(x8_add, prefix="x8", output_filters=filters)
if upsampling_layers:
x9_pre = UpSampling2D(interpolation="bilinear", name="x9_Upsample")(x8)
else:
x9_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x9_ConvT")(x8)
x9_add = Add(name="x9_Add")([x1_pre, x9_pre])
x9 = residual_bottleneck_module(x9_add, prefix="x9", output_filters=filters)
x_out = Conv2D(filters=output_channels, kernel_size=3, strides=1, padding="same", activation="linear", name="x_out")(x9)
# Compile
model = Model(inputs=x_in, outputs=x_out, name="hourglass")
model.compile(optimizer=Adam(amsgrad=amsgrad), loss="mean_squared_error")
if summary:
model.summary()
return model
def stacked_hourglass(img_size, output_channels, filters=64, upsampling_layers=False, amsgrad=False, summary=False):
"""
Creates and compiles network model.
:param img_size: shape of a single image, optionally including channels
:param output_channels: number of output channels (joints being predicted)
:param filters: number of baseline filters to use (more filters will be used in intermediate layers)
:param summary: prints network summary after compiling
"""
if len(img_size) == 2:
img_size = img_size + (1,)
x_in = Input(img_size, name="x_in")
x1_1_pre = residual_bottleneck_module(x_in, prefix="x1_1", output_filters=filters)
x1_1 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x1_1_pool")(x1_1_pre)
x1_2_pre = residual_bottleneck_module(x1_1, prefix="x1_2", output_filters=filters)
x1_2 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x1_2_pool")(x1_2_pre)
x1_3_pre = residual_bottleneck_module(x1_2, prefix="x1_3", output_filters=filters)
x1_3 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x1_3_pool")(x1_3_pre)
x1_4_pre = residual_bottleneck_module(x1_3, prefix="x1_4", output_filters=filters)
x1_4 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x1_4_pool")(x1_4_pre)
x1_5 = residual_bottleneck_module(x1_4, prefix="x1_5", output_filters=filters)
if upsampling_layers:
x1_6_pre = UpSampling2D(interpolation="bilinear", name="x1_6_Upsample")(x1_5)
else:
x1_6_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x1_6_ConvT")(x1_5)
x1_6_add = Add(name="x1_6_Add")([x1_4_pre, x1_6_pre])
x1_6 = residual_bottleneck_module(x1_6_add, prefix="x1_6", output_filters=filters)
if upsampling_layers:
x1_7_pre = UpSampling2D(interpolation="bilinear", name="x1_7_Upsample")(x1_6)
else:
x1_7_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x1_7_ConvT")(x1_6)
x1_7_add = Add(name="x1_7_Add")([x1_3_pre, x1_7_pre])
x1_7 = residual_bottleneck_module(x1_7_add, prefix="x1_7", output_filters=filters)
if upsampling_layers:
x1_8_pre = UpSampling2D(interpolation="bilinear", name="x1_8_Upsample")(x1_7)
else:
x1_8_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x1_8_ConvT")(x1_7)
x1_8_add = Add(name="x1_8_Add")([x1_2_pre, x1_8_pre])
x1_8 = residual_bottleneck_module(x1_8_add, prefix="x1_8", output_filters=filters)
if upsampling_layers:
x1_9_pre = UpSampling2D(interpolation="bilinear", name="x1_9_Upsample")(x1_8)
else:
x1_9_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x1_9_ConvT")(x1_8)
x1_9_add = Add(name="x1_9_Add")([x1_1_pre, x1_9_pre])
x1_9 = residual_bottleneck_module(x1_9_add, prefix="x1_9", output_filters=filters)
#############
x2_1_pre = residual_bottleneck_module(x1_9, prefix="x2_1", output_filters=filters)
x2_1 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x2_1_pool")(x2_1_pre)
x2_2_pre = residual_bottleneck_module(x2_1, prefix="x2_2", output_filters=filters)
x2_2 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x2_2_pool")(x2_2_pre)
x2_3_pre = residual_bottleneck_module(x2_2, prefix="x2_3", output_filters=filters)
x2_3 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x2_3_pool")(x2_3_pre)
x2_4_pre = residual_bottleneck_module(x2_3, prefix="x2_4", output_filters=filters)
x2_4 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x2_4_pool")(x2_4_pre)
x2_5 = residual_bottleneck_module(x2_4, prefix="x2_5", output_filters=filters)
if upsampling_layers:
x2_6_pre = UpSampling2D(interpolation="bilinear", name="x2_6_Upsample")(x2_5)
else:
x2_6_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x2_6_ConvT")(x2_5)
x2_6_add = Add(name="x2_6_Add")([x2_4_pre, x2_6_pre])
x2_6 = residual_bottleneck_module(x2_6_add, prefix="x2_6", output_filters=filters)
if upsampling_layers:
x2_7_pre = UpSampling2D(interpolation="bilinear", name="x2_7_Upsample")(x2_6)
else:
x2_7_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x2_7_ConvT")(x2_6)
x2_7_add = Add(name="x2_7_Add")([x2_3_pre, x2_7_pre])
x2_7 = residual_bottleneck_module(x2_7_add, prefix="x2_7", output_filters=filters)
if upsampling_layers:
x2_8_pre = UpSampling2D(interpolation="bilinear", name="x2_8_Upsample")(x2_7)
else:
x2_8_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x2_8_ConvT")(x2_7)
x2_8_add = Add(name="x2_8_Add")([x2_2_pre, x2_8_pre])
x2_8 = residual_bottleneck_module(x2_8_add, prefix="x2_8", output_filters=filters)
if upsampling_layers:
x2_9_pre = UpSampling2D(interpolation="bilinear", name="x2_9_Upsample")(x2_8)
else:
x2_9_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x2_9_ConvT")(x2_8)
x2_9_add = Add(name="x2_9_Add")([x2_1_pre, x2_9_pre])
x2_9 = residual_bottleneck_module(x2_9_add, prefix="x2_9", output_filters=filters)
#############
x_out1 = residual_bottleneck_module(x1_9, output_filters=output_channels, bottleneck_factor=1, prefix="x_out1", activation="linear")
x_out2 = residual_bottleneck_module(x2_9, output_filters=output_channels, bottleneck_factor=1, prefix="x_out2", activation="linear")
# Compile
model = Model(inputs=x_in, outputs=[x_out1, x_out2], name="StackedHourglass")
model.compile(optimizer=Adam(amsgrad=amsgrad), loss="mean_squared_error")
if summary:
model.summary()
return model
================================================
FILE: leap/plot_joints_single.m
================================================
function h = plot_joints_single(pts, segments, markerSize, lineWidth)
%PLOT_JOINTS_SINGLE Plot joints on a single frame.
% Usage:
% plot_joints_single(pts, segments)
% plot_joints_single(pts, segments, markerSize, lineWidth)
%
% Args:
% pts:
% segments (or skeleton): table with line segments or struct containing it
% markerSize: default: 20
% lineWidth: default: 2
%
% See also:
if isfield(segments,'segments'); segments = segments.segments; end
if nargin < 3 || isempty(markerSize); markerSize = 20; end
if nargin < 4 || isempty(lineWidth); lineWidth = 2; end
h = gobjects(numel(segments.joints_idx),1);
for i = 1:numel(segments.joints_idx)
h(i) = plotpts(pts(segments.joints_idx{i},:),'.-', ...
'Color',segments.color{i},'MarkerSize',markerSize,'LineWidth',lineWidth);
end
if nargout < 1; clear h; end
end
================================================
FILE: leap/predict_box.m
================================================
function preds = predict_box(box, modelPath, saveConfmaps)
%PREDICT_BOX Evaluates model predictions on a stack of frames. Wrapper for predict_box.py.
% Usage:
% preds = predict_box(box, modelPath)
% preds = predict_box(box, modelPath, saveConfmaps)
%
% Args:
% box: 4-D array or path to HDF5 file with '/box' dataset
% modelPath: path to model weights
% saveConfmaps: if true, returns full confidence maps in addition to
% peaks (default: false). Very slow and memory intensive!
%
% Returns:
% preds: struct containing results from model prediction
% .positions_pred: 3-D array of (parts x [X Y] x frames) indicating
% peak positions for each confidence map in image coordinates
% .conf_pred: 2-D array of (parts x frames) with the confidence map
% value at the peak pixel for detecting bad predictions
% .confmaps: 4-D array of confidence maps returned if saveConfmaps
% is set true
if nargin < 3 || isempty(saveConfmaps); saveConfmaps = false; end
% Process args
delete_box = false;
is_singleton = false;
if ischar(box)
boxPath = box;
else
boxPath = [tempname '.h5'];
if numel(size(box)) < 4
box = repmat(box,[1 1 1 2]);
is_singleton = true;
end
h5save(boxPath, box)
delete_box = true;
end
% Generate temporary output filename
outPath = [tempname '.h5'];
% Build command line args
cmd = {
'python'
['"' ff(funpath(true), 'predict_box.py') '"']
['"' boxPath '"']
['"' modelPath '"']
['"' outPath '"']
};
if saveConfmaps
cmd{end+1} = '--save-confmaps';
end
disp(strjoin(cmd))
% Predict
try
exit_code = system(strjoin(cmd));
catch ME
if delete_box && exists(boxPath); delete(boxPath); end
rethrow(ME)
end
if delete_box && exists(boxPath); delete(boxPath); end
% Read data back in
try
preds = h5readgroup(outPath);
catch ME
if exists(outPath); delete(outPath); end
rethrow(ME)
end
if exists(outPath); delete(outPath); end
% Adjust for 0-based indexing
preds.positions_pred = single(preds.positions_pred) + 1;
% Rescale confidence maps to correct range prior to quantization
if saveConfmaps && isfield(preds.Attributes, 'confmaps')
c_min = preds.Attributes.confmaps.range_min;
c_max = preds.Attributes.confmaps.range_max;
preds.confmaps = rescale(single(preds.confmaps) / 255, c_min, c_max);
end
% Adjust for singleton input
if is_singleton
preds.positions_pred = preds.positions_pred(:,:,1);
preds.conf_pred = preds.conf_pred(:,1);
if isfield(preds,'confmaps')
preds.confmaps = preds.confmaps(:,:,:,1);
end
end
end
================================================
FILE: leap/predict_box.py
================================================
import h5py
import numpy as np
import os
from time import time
import keras
import keras.models
from keras.layers import Lambda
import tensorflow as tf
import re
from clize import run
from leap.utils import find_weights, find_best_weights, preprocess
from leap.layers import Maxima2D
def tf_find_peaks(x):
""" Finds the maximum value in each channel and returns the location and value.
Args:
x: rank-4 tensor (samples, height, width, channels)
Returns:
peaks: rank-3 tensor (samples, [x, y, val], channels)
"""
# Store input shape
in_shape = tf.shape(x)
# Flatten height/width dims
flattened = tf.reshape(x, [in_shape[0], -1, in_shape[-1]])
# Find peaks in linear indices
idx = tf.argmax(flattened, axis=1)
# Convert linear indices to subscripts
rows = tf.floor_div(tf.cast(idx,tf.int32), in_shape[1])
cols = tf.floormod(tf.cast(idx,tf.int32), in_shape[1])
# Dumb way to get actual values without indexing
vals = tf.reduce_max(flattened, axis=1)
# Return N x 3 x C tensor
return tf.stack([
tf.cast(cols, tf.float32),
tf.cast(rows, tf.float32),
vals
], axis=1)
def convert_to_peak_outputs(model, include_confmaps=False):
""" Creates a new Keras model with a wrapper to yield channel peaks from rank-4 tensors. """
if type(model.output) == list:
confmaps = model.output[-1]
else:
confmaps = model.output
if include_confmaps:
return keras.Model(model.input, [Lambda(tf_find_peaks)(confmaps), confmaps])
else:
# return keras.Model(model.input, Lambda(tf_find_peaks)(confmaps))
return keras.Model(model.input, Maxima2D()(confmaps))
def predict_box(box_path, model_path, out_path, *, box_dset="/box", epoch=None, verbose=True, overwrite=False, save_confmaps=False, batch_size=32):
"""
Predict and save peak coordinates for a box.
:param box_path: path to HDF5 file with box dataset
:param model_path: path to Keras weights file or run folder with weights subfolder
:param out_path: path to HDF5 file to save results to
:param box_dset: name of HDF5 dataset containing box images
:param epoch: epoch to use if run folder provided instead of Keras weights file
:param verbose: if True, prints some info and statistics during procesing
:param overwrite: if True and out_path exists, file will be overwritten
:param save_confmaps: if True, saves the full confidence maps as additional datasets in the output file (very slow)
:param batch_size: number of samples to evaluate at once per batch (see keras.Model API)
"""
if verbose:
print("model_path:", model_path)
# Find model weights
model_name = None
weights_path = model_path
if os.path.isdir(model_path):
model_name = os.path.basename(model_path)
weights_paths, epochs, val_losses = find_weights(model_path)
if epoch == None and len(val_losses) > 0:
weights_path = weights_paths[np.argmin(val_losses)]
elif epoch == "final" or (epoch == None and len(val_losses) == 0):
weights_path = os.path.join(model_path, "final_model.h5")
else:
weights_path = weights_paths[epoch]
# Input data
box = h5py.File(box_path,"r")[box_dset]
num_samples = box.shape[0]
if verbose:
print("Input:", box_path)
print("box.shape:", box.shape)
# Create output path
if out_path[-3:] != ".h5":
if model_name == None:
out_path = os.path.join(out_path, os.path.basename(box_path))
else:
out_path = os.path.join(out_path, model_name, os.path.basename(box_path))
os.makedirs(os.path.dirname(out_path), exist_ok=True)
model_name = os.path.basename(model_path)
if verbose:
print("Output:", out_path)
t0_all = time()
if os.path.exists(out_path):
if overwrite:
os.remove(out_path)
print("Deleted existing output.")
else:
print("Error: Output path already exists.")
return
# Load and prepare model
model = keras.models.load_model(weights_path)
model_peaks = convert_to_peak_outputs(model, include_confmaps=save_confmaps)
if verbose:
print("weights_path:", weights_path)
print("Loaded model: %d layers, %d params" % (len(model.layers), model.count_params()))
# Load data and preprocess (normalize)
t0 = time()
X = preprocess(box[:])
if verbose:
print("Loaded [%.1fs]" % (time() - t0))
# Evaluate
t0 = time()
if save_confmaps:
Ypk, confmaps = model_peaks.predict(X, batch_size=batch_size)
# Quantize
confmaps_min = confmaps.min()
confmaps_max = confmaps.max()
confmaps = (confmaps - confmaps_min) / (confmaps_max - confmaps_min)
confmaps = (confmaps * 255).astype('uint8')
# Reshape
confmaps = np.transpose(confmaps, (0, 3, 2, 1))
else:
Ypk = model_peaks.predict(X, batch_size=batch_size)
prediction_runtime = time() - t0
if verbose:
print("Predicted [%.1fs]" % prediction_runtime)
print("Prediction performance: %.3f FPS" % (num_samples / prediction_runtime))
# Save
t0 = time()
with h5py.File(out_path, "w") as f:
f.attrs["num_samples"] = num_samples
f.attrs["img_size"] = X.shape[1:]
f.attrs["box_path"] = box_path
f.attrs["box_dset"] = box_dset
f.attrs["model_path"] = model_path
f.attrs["weights_path"] = weights_path
f.attrs["model_name"] = model_name
ds_pos = f.create_dataset("positions_pred", data=Ypk[:,:2,:].astype("int32"), compression="gzip", compression_opts=1)
ds_pos.attrs["description"] = "coordinate of peak at each sample"
ds_pos.attrs["dims"] = "(sample, [x, y], joint) === (sample, [column, row], joint)"
ds_conf = f.create_dataset("conf_pred", data=Ypk[:,2,:].squeeze(), compression="gzip", compression_opts=1)
ds_conf.attrs["description"] = "confidence map value in [0, 1.0] at peak"
ds_conf.attrs["dims"] = "(sample, joint)"
if save_confmaps:
ds_confmaps = f.create_dataset("confmaps", data=confmaps, compression="gzip", compression_opts=1)
ds_confmaps.attrs["description"] = "confidence maps"
ds_confmaps.attrs["dims"] = "(sample, channel, width, height)"
ds_confmaps.attrs["range_min"] = confmaps_min
ds_confmaps.attrs["range_max"] = confmaps_max
total_runtime = time() - t0_all
f.attrs["total_runtime_secs"] = total_runtime
f.attrs["prediction_runtime_secs"] = prediction_runtime
if verbose:
print("Saved [%.1fs]" % (time() - t0))
print("Total runtime: %.1f mins" % (total_runtime / 60))
print("Total performance: %.3f FPS" % (num_samples / total_runtime))
if __name__ == "__main__":
run(predict_box)
================================================
FILE: leap/pts2confmaps.m
================================================
function confmaps = pts2confmaps(pts, sz, sigma, normalize)
%PTS2CONFMAPS Generate confidence maps centered at specified points.
% Usage:
% confmaps = pts2confmaps(pts, sz, sigma)
%
% Args:
% pts: N x 2 or cell array of {N1 x 2, N2 x 2, ...}, where each cell will
% correspond to a single channel to create multipoint confidence maps
% sz: [rows cols]
% sigma: filter size (default: 5)
% normalize: outputs maps in [0, 1] rather than PDF (default: true)
%
% See also: label_joints
if ~iscell(pts); pts = arr2cell(pts,1); end
if nargin < 3 || isempty(sigma); sigma = 5; end
if nargin < 4 || isempty(normalize); normalize = true; end
confmaps = NaN(sz(1), sz(2), numel(pts));
xv = 1:sz(2); yv = 1:sz(1);
[XX,YY] = meshgrid(xv,yv);
for i = 1:numel(pts)
x = permute(pts{i}(:,1),[2 3 1]);
y = permute(pts{i}(:,2),[2 3 1]);
% confmaps(:,:,i) = sum(exp(-((YY-y).^2 + (XX-x).^2)./(2*sigma^2)),3);
confmaps(:,:,i) = max(exp(-((YY-y).^2 + (XX-x).^2)./(2*sigma^2)),[],3);
end
if ~normalize
confmaps = confmaps ./ (sigma * sqrt(2*pi));
end
end
================================================
FILE: leap/test_leap.m
================================================
function works = test_leap()
%TEST_LEAP Checks whether LEAP is properly installed.
% Usage:
% test_leap
% works = test_leap % returns true/false
%
% See also: install_leap
works = true;
% Check if we can import the LEAP package from anywhere
cdCmd = ['cd "' matlabroot '"'];
if ispc(); cdCmd = ['cd /D "' matlabroot '"']; end
[status,msg] = system([cdCmd ' && python -c "import leap"']);
if status ~= 0
works = false;
end
if nargout == 0
if works
printf('Test LEAP successful!')
else
disp('Unable to import LEAP python package. Make sure LEAP and its dependencies are installed.')
disp('Go to the base LEAP directory containing setup.py and run from MATLAB:')
disp(' !pip install -e .')
disp('Or try the MATLAB installer:')
disp(' install_leap')
end
clear works
end
end
================================================
FILE: leap/toolbox/aliases/alims.m
================================================
function varargout = alims(X)
%ALIMS Alias for arange.
% Usage:
% R = alims(X)
% [min_val, max_val] = alims(X)
%
% See also: arange
varargout = wrap(@() arange(X), 1:max(1, nargout));
end
================================================
FILE: leap/toolbox/aliases/ff.m
================================================
function varargout = ff(varargin)
N = max(nargout,1);
varargout{1:N} = fullfile(varargin{:});
end
================================================
FILE: leap/toolbox/aliases/h5file.m
================================================
function varargout = h5file(varargin)
varargout{1:nargout} = hdf5prop(varargin{:});
end
================================================
FILE: leap/toolbox/aliases/imgsc.m
================================================
function h = imgsc(I, varargin)
%IMGSC Alias for imagesc for images.
% Usage:
% imgsc(I)
% imgsc(I, ...)
%
% See also: imagesc, sc
% figure('KeyPressFcn',@KeyPressFcn_cb)
figclosekey
h = imagesc(I, varargin{:});
H = size(I,1); W = size(I,2);
if max([H W]) / min([H W]) < 2 || all([H W] < 25)
axis image
end
colorbar
if nargout < 1; clear h; end
end
================================================
FILE: leap/toolbox/aliases/repext.m
================================================
function varargout = repext(varargin)
varargout{1:max(nargout,1)} = extrep(varargin{:});
end
================================================
FILE: leap/toolbox/graphics/FEX-settingsdlg/settingsdlg.m
================================================
function [settings, button] = settingsdlg(varargin)
% SETTINGSDLG Default dialog to produce a settings-structure
%
% settings = SETTINGSDLG('fieldname', default_value, ...) creates a modal
% dialog box that returns a structure formed according to user input. The
% input should be given in the form of 'fieldname', default_value - pairs,
% where 'fieldname' is the fieldname in the structure [settings], and
% default_value the initial value displayed in the dialog box.
%
% SETTINGSDLG uses UIWAIT to suspend execution until the user responds.
%
% settings = SETTINGSDLG(settings) uses the structure [settings] to form
% the various input fields. This is the most basic (and limited) usage of
% SETTINGSDLG.
%
% [settings, button] = SETTINGSDLG(settings) returns which button was
% pressed, in addition to the (modified) structure [settings]. Either 'ok',
% 'cancel' or [] are possible values. The empty output means that the
% dialog was closed before either Cancel or OK were pressed.
%
% SETTINGSDLG('title', 'window_title') uses 'window_title' as the dialog's
% title. The default is 'Adjust settings'.
%
% SETTINGSDLG('description', 'brief_description',...) starts the dialog box
% with 'brief_description', followed by the input fields.
%
% SETTINGSDLG('windowposition', P, ...) positions the dialog box according to
% the string or vector [P]; see movegui() for valid values.
%
% SETTINGSDLG( {'display_string', 'fieldname'}, default_value,...) uses the
% 'display_string' in the dialog box, while assigning the corresponding
% user-input to fieldname 'fieldname'.
%
% SETTINGSDLG(..., 'checkbox_string', true, ...) displays a checkbox in
% stead of the default edit box, and SETTINGSDLG('fieldname', {'string1',
% 'string2'},... ) displays a popup box with the strings given in
% the second cell-array.
%
% Additionally, you can put [..., 'separator', 'seperator_string',...]
% anywhere in the argument list, which will divide all the arguments into
% sections, with section headings 'seperator_string'.
%
% You can also modify the display behavior in the case of checkboxes. When
% defining checkboxes with a 2-element logical array, the second boolean
% determines whether all fields below that checkbox are initially disabled
% (true) or not (false).
%
% Example:
%
% [settings, button] = settingsdlg(...
% 'Description', 'This dialog will set the parameters used by FMINCON()',...
% 'title' , 'FMINCON() options',...
% 'separator' , 'Unconstrained/General',...
% {'This is a checkbox'; 'Check'}, [true true],...
% {'Tolerance X';'TolX'}, 1e-6,...
% {'Tolerance on Function';'TolFun'}, 1e-6,...
% 'Algorithm' , {'active-set','interior-point'},...
% 'separator' , 'Constrained',...
% {'Tolerance on Constraints';'TolCon'}, 1e-6)
%
% See also inputdlg, dialog, errordlg, helpdlg, listdlg, msgbox, questdlg, textwrap,
% uiwait, warndlg.
% Please report bugs and inquiries to:
%
% Name : Rody P.S. Oldenhuis
% E-mail : oldenhuis@gmail.com
% Licence: 2-clause BSD (See Licence.txt)
% If you find this work useful, please consider a donation:
% https://www.paypal.me/RodyO/3.5
%% Initialize
% errortraps
narg = nargin;
if verLessThan('MATLAB', '8.6')
error(nargchk(1, inf, narg, 'struct')); %#ok<NCHKN>
else
narginchk(1, inf);
end
% parse input (+errortrap)
have_settings = 0;
if isstruct(varargin{1})
settings = varargin{1}; have_settings = 1; end
if (narg == 1)
if isstruct(varargin{1})
parameters = fieldnames(settings);
values = cellfun(@(x)settings.(x), parameters, 'UniformOutput', false);
else
error('settingsdlg:incorrect_input',...
'When passing a single argument, that argument must be a structure.')
end
else
parameters = varargin(1+have_settings : 2 : end);
values = varargin(2+have_settings : 2 : end);
end
% Initialize data
button = [];
fields = cell(numel(parameters),1);
tags = fields;
% Fill [settings] with default values & collect data
for ii = 1:numel(parameters)
% Extract fields & tags
if iscell(parameters{ii})
tags{ii} = parameters{ii}{1};
fields{ii} = parameters{ii}{2};
else
% More errortraps
if ~ischar(parameters{ii})
error('settingsdlg:nonstring_parameter',...
'Arguments should be given as [''parameter'', value,...] pairs.')
end
tags{ii} = parameters{ii};
fields{ii} = parameters{ii};
end
% More errortraps
if ~ischar(fields{ii})
error('settingsdlg:fieldname_not_char',...
'Fieldname should be a string.')
end
if ~ischar(tags{ii})
error('settingsdlg:tag_not_char',...
'Display name should be a string.')
end
% NOTE: 'Separator' is now in 'fields' even though
% it will not be used as a fieldname
% Make sure all fieldnames are properly formatted
% (alternating capitals, no whitespace)
if ~strcmpi(fields{ii}, {'Separator';'Title';'Description'})
whitespace = isspace(fields{ii});
capitalize = circshift(whitespace,[0,1]);
fields{ii}(capitalize) = upper(fields{ii}(capitalize));
fields{ii} = fields{ii}(~whitespace);
% insert associated value in output
if iscell(values{ii})
settings.(fields{ii}) = values{ii}{1};
elseif (length(values{ii}) > 1)
settings.(fields{ii}) = values{ii}(1);
else
settings.(fields{ii}) = values{ii};
end
end
end
% Avoid (some) confusion
clear parameters
% Use default colorscheme from the OS
bgcolor = get(0, 'defaultUicontrolBackgroundColor');
% Default fontsize
fontsize = get(0, 'defaultuicontrolfontsize');
% Edit-bgcolor is platform-dependent.
% MS/Windows: white.
% UNIX: same as figure bgcolor
% if ispc, edit_bgcolor = 'White';
% else edit_bgcolor = bgcolor;
% end
% TODO: not really applicable since defaultUicontrolBackgroundColor
% doesn't really seem to work on Unix...
edit_bgcolor = 'White';
% Get basic window properties
title = getValue('Adjust settings', 'Title');
description = getValue( [], 'Description');
total_width = getValue(325, 'WindowWidth');
control_width = getValue(100, 'ControlWidth');
% Window positioning:
% Put the window in the center of the screen by default.
% This will usually work fine, except on some multi-monitor setups.
scz = get(0, 'ScreenSize');
scxy = round(scz(3:4)/2-control_width/2);
scx = min(scz(3),max(1,scxy(1)));
scy = min(scz(4),max(1,scxy(2)));
% String to pass on to movegui
window_position = getValue('center', 'WindowPosition');
% Calculate best height for all uicontrol()
control_height = max(18, (fontsize+6));
% Calculate figure height (will be adjusted later according to description)
total_height = numel(fields)*1.25*control_height + ... % to fit all controls
1.5*control_height + 20; % to fit "OK" and "Cancel" buttons
% Total number of separators
num_separators = nnz(strcmpi(fields,'Separator'));
% Draw figure in background
fighandle = figure(...
'integerhandle' , 'off',... % use non-integers for the handle (prevents accidental plots from going to the dialog)
'Handlevisibility', 'off',... % only visible from within this function
'position' , [scx, scy, total_width, total_height],...% figure position
'visible' , 'off',... % hide the dialog while it is being constructed
'backingstore' , 'off',... % DON'T save a copy in the background
'resize' , 'off', ... % but just keep it resizable
'renderer' , 'zbuffer', ... % best choice for speed vs. compatibility
'WindowStyle' ,'modal',... % window is modal
'units' , 'pixels',... % better for drawing
'DockControls' , 'off',... % force it to be non-dockable
'name' , title,... % dialog title
'menubar' ,'none', ... % no menubar of course
'toolbar' ,'none', ... % no toolbar
'NumberTitle' , 'off',... % "Figure 1.4728...:" just looks corny
'color' , bgcolor); % use default colorscheme
%% Draw all required uicontrols(), and unhide window
% Define X-offsets (different when separators are used)
separator_offset_X = 2;
if num_separators > 0
text_offset_X = 20;
text_width = (total_width-control_width-text_offset_X);
else
text_offset_X = separator_offset_X;
text_width = (total_width-control_width);
end
% Handle description
description_offset = 0;
if ~isempty(description)
% create textfield (negligible height initially)
description_panel = uicontrol(...
'parent' , fighandle,...
'style' , 'text',...
'Horizontalalignment', 'left',...
'position', [separator_offset_X,...
total_height,total_width,1]);
% wrap the description
description = textwrap(description_panel, {description});
% adjust the height of the figure
textheight = size(description,1)*(fontsize+6);
description_offset = textheight + 20;
total_height = total_height + description_offset;
set(fighandle,...
'position', [scx, scy, total_width, total_height])
% adjust the position of the textfield and insert the description
set(description_panel, ...
'string' , description,...
'position', [separator_offset_X, total_height-textheight, ...
total_width, textheight]);
end
% Define Y-offsets (different when descriptions are used)
control_offset_Y = total_height-control_height-description_offset;
% initialize loop
controls = zeros(numel(tags)-num_separators,1);
ii = 1; sep_ind = 1;
enable = 'on'; separators = zeros(num_separators,1);
% loop through the controls
if numel(tags) > 0
while true
% Should we draw a separator?
if strcmpi(tags{ii}, 'Separator')
% Print separator
uicontrol(...
'style' , 'text',...
'parent' , fighandle,...
'string' , values{ii},...
'horizontalalignment', 'left',...
'fontweight', 'bold',...
'position', [separator_offset_X,control_offset_Y-4, ...
total_width, control_height]);
% remove separator, but save its position
fields(ii) = [];
tags(ii) = []; separators(sep_ind) = ii;
values(ii) = []; sep_ind = sep_ind + 1;
% reset enable (when neccessary)
if strcmpi(enable, 'off')
enable = 'on'; end
% NOTE: DON'T increase loop index
% ... or a setting?
else
% logicals: use checkbox
if islogical(values{ii})
% First draw control
controls(ii) = uicontrol(...
'style' , 'checkbox',...
'parent' , fighandle,...
'enable' , enable,...
'string' , tags{ii},...
'value' , values{ii}(1),...
'position', [text_offset_X,control_offset_Y-4, ...
total_width, control_height]);
% Should everything below here be OFF?
if (length(values{ii})>1)
% turn next controls off when asked for
if values{ii}(2)
enable = 'off'; end
% Turn on callback function
set(controls(ii),...
'Callback', @(varargin) EnableDisable(ii,varargin{:}));
end
% doubles : use edit box
% cells : use popup
% cell-of-cells: use table
else
% First print parameter
uicontrol(...
'style' , 'text',...
'parent' , fighandle,...
'string' , [tags{ii}, ':'],...
'horizontalalignment', 'left',...
'position', [text_offset_X,control_offset_Y-4, ...
text_width, control_height]);
% Popup, edit box or table?
style = 'edit';
draw_table = false;
if iscell(values{ii})
style = 'popup';
if all(cellfun('isclass', values{ii}, 'cell'))
draw_table = true; end
end
% Draw appropriate control
if ~draw_table
controls(ii) = uicontrol(...
'enable' , enable,...
'style' , style,...
'Background', edit_bgcolor,...
'parent' , fighandle,...
'string' , values{ii},...
'position', [text_width,control_offset_Y,...
control_width, control_height]);
else
% TODO
% ...table? ...radio buttons? How to do this?
warning(...
'settingsdlg:not_yet_implemented',...
'Treatment of cells is not yet implemented.');
end
end
% increase loop index
ii = ii + 1;
end
% end loop?
if ii > numel(tags)
break, end
% Decrease offset
control_offset_Y = control_offset_Y - 1.25*control_height;
end
end
% Draw cancel button
uicontrol(...
'style' , 'pushbutton',...
'parent' , fighandle,...
'string' , 'Cancel',...
'position', [separator_offset_X,2, total_width/2.5,control_height*1.5],...
'Callback', @Cancel)
% Draw OK button
uicontrol(...
'style' , 'pushbutton',...
'parent' , fighandle,...
'string' , 'OK',...
'position', [total_width*(1-1/2.5)-separator_offset_X,2, ...
total_width/2.5,control_height*1.5],...
'Callback', @OK)
% move to center of screen and make visible
movegui(fighandle, window_position);
set(fighandle, 'Visible', 'on');
% WAIT until OK/Cancel is pressed
uiwait(fighandle);
%% Helper funcitons
% Get a value from the values array:
% - if it does not exist, return the default value
% - if it exists, assign it and delete the appropriate entries from the
% data arrays
function val = getValue(default, tag)
index = strcmpi(fields, tag);
if any(index)
val = values{index};
values(index) = [];
fields(index) = [];
tags(index) = [];
else
val = default;
end
end
%% callback functions
% Enable/disable controls associated with (some) checkboxes
function EnableDisable(which, varargin) %#ok<VANUS>
% find proper range of controls to switch
if (num_separators > 1)
last_control = separators(find(separators > which,1)) - 1;
if isempty(last_control); last_control = numel(controls); end
range = (which+1):(last_control);
else
range = (which+1):numel(controls);
end
% enable/disable these controls
if strcmpi(get(controls(range(1)), 'enable'), 'off')
set(controls(range), 'enable', 'on')
else
set(controls(range), 'enable', 'off')
end
end
% OK button:
% - update fields in [settings]
% - assign [button] output argument ('ok')
% - kill window
function OK(varargin) %#ok<VANUS>
% button pressed
button = 'OK';
% fill settings
for i = 1:numel(controls)
% extract current control's string, value & type
str = get(controls(i), 'string');
val = get(controls(i), 'value');
style = get(controls(i), 'style');
% popups/edits
if ~strcmpi(style, 'checkbox')
% extract correct string (popups only)
if strcmpi(style, 'popupmenu'), str = str{val}; end
% try to convert string to double
val = str2double(str);
% insert this double in [settings]. If it was not a
% double, insert string instead
if ~isnan(val), settings.(fields{i}) = val;
else settings.(fields{i}) = str;
end
% checkboxes
else
% we can insert value immediately
settings.(fields{i}) = val;
end
end
% kill window
delete(fighandle);
end
% Cancel button:
% - assign [button] output argument ('cancel')
% - delete figure (so: return default settings)
function Cancel(varargin) %#ok<VANUS>
button = 'cancel';
delete(fighandle);
end
end
================================================
FILE: leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/+mixin/Container.m
================================================
classdef Container < handle
%uix.mixin.Container Container mixin
%
% uix.mixin.Container is a mixin class used by containers to provide
% various properties and template methods.
% Copyright 2009-2016 The MathWorks, Inc.
% $Revision: 1358 $ $Date: 2016-09-14 11:34:17 +0100 (Wed, 14 Sep 2016) $
properties( Dependent, Access = public )
Contents % contents in layout order
end
properties( Access = public, Dependent, AbortSet )
Padding % space around contents, in pixels
end
properties( Access = protected )
Contents_ = gobjects( [0 1] ) % backing for Contents
Padding_ = 0 % backing for Padding
end
properties( Dependent, Access = protected )
Dirty % needs redraw
end
properties( Access = private )
Dirty_ = false % backing for Dirty
FigureObserver % observer
FigureListener % listener
ChildObserver % observer
ChildAddedListener % listener
ChildRemovedListener % listener
SizeChangedListener % listener
ActivePositionPropertyListeners = cell( [0 1] ) % listeners
end
methods
function obj = Container()
%uix.mixin.Container Initialize
%
% c@uix.mixin.Container() initializes the container c.
% Create observers and listeners
figureObserver = uix.FigureObserver( obj );
figureListener = event.listener( figureObserver, ...
'FigureChanged', @obj.onFigureChanged );
childObserver = uix.ChildObserver( obj );
childAddedListener = event.listener( ...
childObserver, 'ChildAdded', @obj.onChildAdded );
childRemovedListener = event.listener( ...
childObserver, 'ChildRemoved', @obj.onChildRemoved );
sizeChangedListener = event.listener( ...
obj, 'SizeChanged', @obj.onSizeChanged );
% Store observers and listeners
obj.FigureObserver = figureObserver;
obj.FigureListener = figureListener;
obj.ChildObserver = childObserver;
obj.ChildAddedListener = childAddedListener;
obj.ChildRemovedListener = childRemovedListener;
obj.SizeChangedListener = sizeChangedListener;
% Track usage
obj.track()
end % constructor
end % structors
methods
function value = get.Contents( obj )
value = obj.Contents_;
end % get.Contents
function set.Contents( obj, value )
% For those who can't tell a column from a row...
if isrow( value )
value = transpose( value );
end
% Check
[tf, indices] = ismember( value, obj.Contents_ );
assert( isequal( size( obj.Contents_ ), size( value ) ) && ...
numel( value ) == numel( unique( value ) ) && all( tf ), ...
'uix:InvalidOperation', ...
'Property ''Contents'' may only be set to a permutation of itself.' )
% Call reorder
obj.reorder( indices )
end % set.Contents
function value = get.Padding( obj )
value = obj.Padding_;
end % get.Padding
function set.Padding( obj, value )
% Check
assert( isa( value, 'double' ) && isscalar( value ) && ...
isreal( value ) && ~isinf( value ) && ...
~isnan( value ) && value >= 0, ...
'uix:InvalidPropertyValue', ...
'Property ''Padding'' must be a non-negative scalar.' )
% Set
obj.Padding_ = value;
% Mark as dirty
obj.Dirty = true;
end % set.Padding
function value = get.Dirty( obj )
value = obj.Dirty_;
end % get.Dirty
function set.Dirty( obj, value )
if value
if obj.isDrawable() % drawable
obj.redraw() % redraw now
else % not drawable
obj.Dirty_ = true; % flag for future redraw
end
end
end % set.Dirty
end % accessors
methods( Access = private, Sealed )
function onFigureChanged( obj, ~, eventData )
%onFigureChanged Event handler
% Call template method
obj.reparent( eventData.OldFigure, eventData.NewFigure )
% Redraw if possible and if dirty
if obj.Dirty_ && obj.isDrawable()
obj.redraw()
obj.Dirty_ = false;
end
end % onFigureChanged
function onChildAdded( obj, ~, eventData )
%onChildAdded Event handler
% Call template method
obj.addChild( eventData.Child )
end % onChildAdded
function onChildRemoved( obj, ~, eventData )
%onChildRemoved Event handler
% Do nothing if container is being deleted
if strcmp( obj.BeingDeleted, 'on' ), return, end
% Call template method
obj.removeChild( eventData.Child )
end % onChildRemoved
function onSizeChanged( obj, ~, ~ )
%onSizeChanged Event handler
% Mark as dirty
obj.Dirty = true;
end % onSizeChanged
function onActivePositionPropertyChanged( obj, ~, ~ )
%onActivePositionPropertyChanged Event handler
% Mark as dirty
obj.Dirty = true;
end % onActivePositionPropertyChanged
end % event handlers
methods( Abstract, Access = protected )
redraw( obj )
end % abstract template methods
methods( Access = protected )
function addChild( obj, child )
%addChild Add child
%
% c.addChild(d) adds the child d to the container c.
% Add to contents
obj.Contents_(end+1,:) = child;
% Add listeners
if isa( child, 'matlab.graphics.axis.Axes' )
obj.ActivePositionPropertyListeners{end+1,:} = ...
event.proplistener( child, ...
findprop( child, 'ActivePositionProperty' ), ...
'PostSet', @obj.onActivePositionPropertyChanged );
else
obj.ActivePositionPropertyListeners{end+1,:} = [];
end
% Mark as dirty
obj.Dirty = true;
end % addChild
function removeChild( obj, child )
%removeChild Remove child
%
% c.removeChild(d) removes the child d from the container c.
% Remove from contents
contents = obj.Contents_;
tf = contents == child;
obj.Contents_(tf,:) = [];
% Remove listeners
obj.ActivePositionPropertyListeners(tf,:) = [];
% Mark as dirty
obj.Dirty = true;
end % removeChild
function reparent( obj, oldFigure, newFigure ) %#ok<INUSD>
%reparent Reparent container
%
% c.reparent(a,b) reparents the container c from the figure a
% to the figure b.
end % reparent
function reorder( obj, indices )
%reorder Reorder contents
%
% c.reorder(i) reorders the container contents using indices
% i, c.Contents = c.Contents(i).
% Reorder contents
obj.Contents_ = obj.Contents_(indices,:);
% Reorder listeners
obj.ActivePositionPropertyListeners = ...
obj.ActivePositionPropertyListeners(indices,:);
% Mark as dirty
obj.Dirty = true;
end % reorder
function tf = isDrawable( obj )
%isDrawable Test for drawability
%
% c.isDrawable() is true if the container c is drawable, and
% false otherwise. To be drawable, a container must be
% rooted.
tf = ~isempty( obj.FigureObserver.Figure );
end % isDrawable
function track( obj )
%track Track usage
persistent TRACKED % single shot
if isempty( TRACKED )
v = ver( 'layout' );
try %#ok<TRYNC>
uix.tracking( 'UA-82270656-2', v(1).Version, class( obj ) )
end
TRACKED = true;
end
end % track
end % template methods
end % classdef
================================================
FILE: leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/+mixin/Flex.m
================================================
classdef Flex < handle
%uix.mixin.Flex Flex mixin
%
% uix.mixin.Flex is a mixin class used by flex containers to provide
% various properties and helper methods.
% Copyright 2016 The MathWorks, Inc.
% $Revision: 1435 $ $Date: 2016-11-17 17:50:34 +0000 (Thu, 17 Nov 2016) $
properties( GetAccess = protected, SetAccess = private )
Pointer = 'unset' % mouse pointer
end
properties( Access = private )
Figure = gobjects( 0 ); % mouse pointer figure
Token = -1 % mouse pointer token
end
methods
function delete( obj )
%delete Destructor
% Clean up
if ~strcmp( obj.Pointer, 'unset' )
obj.unsetPointer()
end
end % destructor
end % structors
methods( Access = protected )
function setPointer( obj, figure, pointer )
%setPointer Set pointer
%
% c.setPointer(f,p) sets the pointer for the figure f to p.
% If set, unset
if obj.Token ~= -1
obj.unsetPointer()
end
% Set
obj.Token = uix.PointerManager.setPointer( figure, pointer );
obj.Figure = figure;
obj.Pointer = pointer;
end % setPointer
function unsetPointer( obj )
%unsetPointer Unset pointer
%
% c.unsetPointer() undoes the previous pointer set.
% Check
assert( obj.Token ~= -1, 'uix:InvalidOperation', ...
'Pointer is already unset.' )
% Unset
uix.PointerManager.unsetPointer( obj.Figure, obj.Token );
obj.Figure = gobjects( 0 );
obj.Pointer = 'unset';
obj.Token = -1;
end % unsetPointer
end % helper methods
end % classdef
================================================
FILE: leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/+mixin/Panel.m
================================================
classdef Panel < uix.mixin.Container
%uix.mixin.Panel Panel mixin
%
% uix.mixin.Panel is a mixin class used by panels to provide various
% properties and template methods.
% Copyright 2009-2015 The MathWorks, Inc.
% $Revision: 1435 $ $Date: 2016-11-17 17:50:34 +0000 (Thu, 17 Nov 2016) $
properties( Access = public, Dependent, AbortSet )
Selection % selected contents
end
properties( Access = protected )
Selection_ = 0 % backing for Selection
end
properties( Access = protected )
G1218142 = false % bug flag
end
events( NotifyAccess = protected )
SelectionChanged % selection changed
end
methods
function value = get.Selection( obj )
value = obj.Selection_;
end % get.Selection
function set.Selection( obj, value )
% Check
assert( isa( value, 'double' ), 'uix:InvalidPropertyValue', ...
'Property ''Selection'' must be of type double.' )
assert( isequal( size( value ), [1 1] ), ...
'uix:InvalidPropertyValue', ...
'Property ''Selection'' must be scalar.' )
assert( isreal( value ) && rem( value, 1 ) == 0, ...
'uix:InvalidPropertyValue', ...
'Property ''Selection'' must be an integer.' )
n = numel( obj.Contents_ );
if n == 0
assert( value == 0, 'uix:InvalidPropertyValue', ...
'Property ''Selection'' must be 0 for a container with no children.' )
else
assert( value >= 1 && value <= n, 'uix:InvalidPropertyValue', ...
'Property ''Selection'' must be between 1 and the number of children.' )
end
% Set
oldSelection = obj.Selection_;
newSelection = value;
obj.Selection_ = newSelection;
% Show selected child
obj.showSelection()
% Mark as dirty
obj.Dirty = true;
% Raise event
notify( obj, 'SelectionChanged', ...
uix.SelectionData( oldSelection, newSelection ) )
end % set.Selection
end % accessors
methods( Access = protected )
function addChild( obj, child )
% Check for bug
if verLessThan( 'MATLAB', '8.5' ) && strcmp( child.Visible, 'off' )
obj.G1218142 = true;
end
% Select new content
oldSelection = obj.Selection_;
newSelection = numel( obj.Contents_ ) + 1;
obj.Selection_ = newSelection;
% Call superclass method
addChild@uix.mixin.Container( obj, child )
% Show selected child
obj.showSelection()
% Notify selection change
obj.notify( 'SelectionChanged', ...
uix.SelectionData( oldSelection, newSelection ) )
end % addChild
function removeChild( obj, child )
% Adjust selection if required
contents = obj.Contents_;
index = find( contents == child );
oldSelection = obj.Selection_;
if index < oldSelection
newSelection = oldSelection - 1;
elseif index == oldSelection
newSelection = min( oldSelection, numel( contents ) - 1 );
else % index > oldSelection
newSelection = oldSelection;
end
obj.Selection_ = newSelection;
% Call superclass method
removeChild@uix.mixin.Container( obj, child )
% Show selected child
obj.showSelection()
% Notify selection change
if oldSelection ~= newSelection
obj.notify( 'SelectionChanged', ...
uix.SelectionData( oldSelection, newSelection ) )
end
end % removeChild
function reorder( obj, indices )
%reorder Reorder contents
%
% c.reorder(i) reorders the container contents using indices
% i, c.Contents = c.Contents(i).
% Reorder
selection = obj.Selection_;
if selection ~= 0
obj.Selection_ = find( indices == selection );
end
% Call superclass method
reorder@uix.mixin.Container( obj, indices )
end % reorder
function showSelection( obj )
%showSelection Show selected child, hide the others
%
% c.showSelection() shows the selected child of the container
% c, and hides the others.
% Set positions and visibility
selection = obj.Selection_;
children = obj.Contents_;
for ii = 1:numel( children )
child = children(ii);
if ii == selection
if obj.G1218142
warning( 'uix:G1218142', ...
'Selected child of %s is not visible due to bug G1218142. The child will become visible at the next redraw.', ...
class( obj ) )
obj.G1218142 = false;
else
child.Visible = 'on';
end
if isa( child, 'matlab.graphics.axis.Axes' )
child.ContentsVisible = 'on';
end
else
child.Visible = 'off';
if isa( child, 'matlab.graphics.axis.Axes' )
child.ContentsVisible = 'off';
end
% As a remedy for g1100294, move off-screen too
margin = 1000;
if isa( child, 'matlab.graphics.axis.Axes' ) ...
&& strcmp(child.ActivePositionProperty, 'outerposition' )
child.OuterPosition(1) = -child.OuterPosition(3)-margin;
else
child.Position(1) = -child.Position(3)-margin;
end
end
end
end % showSelection
end % template methods
end % classdef
================================================
FILE: leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/Box.m
================================================
classdef Box < uix.Container & uix.mixin.Container
%uix.Box Box and grid base class
%
% uix.Box is a base class for containers with spacing between
% contents.
% Copyright 2009-2015 The MathWorks, Inc.
% $Revision: 1165 $ $Date: 2015-12-06 03:09:17 -0500 (Sun, 06 Dec 2015) $
properties( Access = public, Dependent, AbortSet )
Spacing = 0 % space between contents, in pixels
end
properties( Access = protected )
Spacing_ = 0 % backing for Spacing
end
methods
function value = get.Spacing( obj )
value = obj.Spacing_;
end % get.Spacing
function set.Spacing( obj, value )
% Check
assert( isa( value, 'double' ) && isscalar( value ) && ...
isreal( value ) && ~isinf( value ) && ...
~isnan( value ) && value >= 0, ...
'uix:InvalidPropertyValue', ...
'Property ''Spacing'' must be a non-negative scalar.' )
% Set
obj.Spacing_ = value;
% Mark as dirty
obj.Dirty = true;
end % set.Spacing
end % accessors
end % classdef
================================================
FILE: leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/ChildEvent.m
================================================
classdef( Hidden, Sealed ) ChildEvent < event.EventData
%uix.ChildEvent Event data for child event
%
% e = uix.ChildEvent(c) creates event data including the child c.
% Copyright 2009-2015 The MathWorks, Inc.
% $Revision: 1165 $ $Date: 2015-12-06 03:09:17 -0500 (Sun, 06 Dec 2015) $
properties( SetAccess = private )
Child % child
end
methods
function obj = ChildEvent( child )
%uix.ChildEvent Event data for child event
%
% e = uix.ChildEvent(c) creates event data including the child
% c.
% Set properties
obj.Child = child;
end % constructor
end % structors
end % classdef
================================================
FILE: leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/ChildObserver.m
================================================
classdef ( Hidden, Sealed ) ChildObserver < handle
%uix.ChildObserver Child observer
%
% co = uix.ChildObserver(o) creates a child observer for the graphics
% object o. A child observer raises events when objects are added to
% and removed from the property Children of o.
%
% See also: uix.Node
% Copyright 2009-2016 The MathWorks, Inc.
% $Revision: 1436 $ $Date: 2016-11-17 17:53:29 +0000 (Thu, 17 Nov 2016) $
properties( Access = private )
Root % root node
end
events( NotifyAccess = private )
ChildAdded % child added
ChildRemoved % child removed
end
methods
function obj = ChildObserver( oRoot )
%uix.ChildObserver Child observer
%
% co = uix.ChildObserver(o) creates a child observer for the
% graphics object o. A child observer raises events when
% objects are added to and removed from the property Children
% of o.
% Check
assert( iscontent( oRoot ) && ...
isequal( size( oRoot ), [1 1] ), 'uix.InvalidArgument', ...
'Object must be a graphics object.' )
% Create root node
nRoot = uix.Node( oRoot );
childAddedListener = event.listener( oRoot, ...
'ObjectChildAdded', ...
@(~,e)obj.addChild(nRoot,e.Child) );
childAddedListener.Recursive = true;
nRoot.addprop( 'ChildAddedListener' );
nRoot.ChildAddedListener = childAddedListener;
childRemovedListener = event.listener( oRoot, ...
'ObjectChildRemoved', ...
@(~,e)obj.removeChild(nRoot,e.Child) );
childRemovedListener.Recursive = true;
nRoot.addprop( 'ChildRemovedListener' );
nRoot.ChildRemovedListener = childRemovedListener;
% Add children
oChildren = hgGetTrueChildren( oRoot );
for ii = 1:numel( oChildren )
obj.addChild( nRoot, oChildren(ii) )
end
% Store properties
obj.Root = nRoot;
end % constructor
end % structors
methods( Access = private )
function addChild( obj, nParent, oChild )
%addChild Add child object to parent node
%
% co.addChild(np,oc) adds the child object oc to the parent
% node np, either as part of construction of the child
% observer co, or in response to an ObjectChildAdded event on
% an object of interest to co. This may lead to ChildAdded
% events being raised on co.
% Create child node
nChild = uix.Node( oChild );
nParent.addChild( nChild )
if iscontent( oChild )
% Add Internal PreSet property listener
internalPreSetListener = event.proplistener( oChild, ...
findprop( oChild, 'Internal' ), 'PreSet', ...
@(~,~)obj.preSetInternal(nChild) );
nChild.addprop( 'InternalPreSetListener' );
nChild.InternalPreSetListener = internalPreSetListener;
% Add Internal PostSet property listener
internalPostSetListener = event.proplistener( oChild, ...
findprop( oChild, 'Internal' ), 'PostSet', ...
@(~,~)obj.postSetInternal(nChild) );
nChild.addprop( 'InternalPostSetListener' );
nChild.InternalPostSetListener = internalPostSetListener;
else
% Add ObjectChildAdded listener
childAddedListener = event.listener( oChild, ...
'ObjectChildAdded', ...
@(~,e)obj.addChild(nChild,e.Child) );
nChild.addprop( 'ChildAddedListener' );
nChild.ChildAddedListener = childAddedListener;
% Add ObjectChildRemoved listener
childRemovedListener = event.listener( oChild, ...
'ObjectChildRemoved', ...
@(~,e)obj.removeChild(nChild,e.Child) );
nChild.addprop( 'ChildRemovedListener' );
nChild.ChildRemovedListener = childRemovedListener;
end
% Raise ChildAdded event
if iscontent( oChild ) && oChild.Internal == false
notify( obj, 'ChildAdded', uix.ChildEvent( oChild ) )
end
% Add grandchildren
if ~iscontent( oChild )
oGrandchildren = hgGetTrueChildren( oChild );
for ii = 1:numel( oGrandchildren )
obj.addChild( nChild, oGrandchildren(ii) )
end
end
end % addChild
function removeChild( obj, nParent, oChild )
%removeChild Remove child object from parent node
%
% co.removeChild(np,oc) removes the child object oc from the
% parent node np, in response to an ObjectChildRemoved event
% on an object of interest to co. This may lead to
% ChildRemoved events being raised on co.
% Get child node
nChildren = nParent.Children;
tf = oChild == [nChildren.Object];
nChild = nChildren(tf);
% Raise ChildRemoved event(s)
notifyChildRemoved( nChild )
% Delete child node
delete( nChild )
function notifyChildRemoved( nc )
% Process child nodes
ngc = nc.Children;
for ii = 1:numel( ngc )
notifyChildRemoved( ngc(ii) )
end
% Process this node
oc = nc.Object;
if iscontent( oc ) && oc.Internal == false
notify( obj, 'ChildRemoved', uix.ChildEvent( oc ) )
end
end % notifyChildRemoved
end % removeChild
function preSetInternal( ~, nChild )
%preSetInternal Perform property PreSet tasks
%
% co.preSetInternal(n) caches the previous value of the
% property Internal of the object referenced by the node n, to
% enable PostSet tasks to identify whether the value changed.
% This is necessary since Internal AbortSet is false.
oldInternal = nChild.Object.Internal;
nChild.addprop( 'OldInternal' );
nChild.OldInternal = oldInternal;
end % preSetInternal
function postSetInternal( obj, nChild )
%postSetInternal Perform property PostSet tasks
%
% co.postSetInternal(n) raises a ChildAdded or ChildRemoved
% event on the child observer co in response to a change of
% the value of the property Internal of the object referenced
% by the node n.
% Retrieve old and new values
oChild = nChild.Object;
newInternal = oChild.Internal;
oldInternal = nChild.OldInternal;
% Clean up node
delete( findprop( nChild, 'OldInternal' ) )
% Raise event
switch newInternal
case oldInternal % no change
% no event
case true % false to true
notify( obj, 'ChildRemoved', uix.ChildEvent( oChild ) )
case false % true to false
notify( obj, 'ChildAdded', uix.ChildEvent( oChild ) )
end
end % postSetInternal
end % event handlers
end % classdef
function tf = iscontent( o )
%iscontent True for graphics that can be Contents (and can be Children)
%
% uix.ChildObserver needs to determine which objects can be Contents,
% which is equivalent to can be Children if HandleVisibility is 'on' and
% Internal is false. Prior to R2016a, this condition could be checked
% using isgraphics. From R2016a, isgraphics returns true for a wider
% range of objects, including some that can never by Contents, e.g.,
% JavaCanvas. Therefore this function checks whether an object is of type
% matlab.graphics.internal.GraphicsBaseFunctions, which is what isgraphics
% did prior to R2016a.
tf = isa( o, 'matlab.graphics.internal.GraphicsBaseFunctions' ) &&...
isprop( o, 'Position' );
end % iscontent
================================================
FILE: leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/Container.m
================================================
classdef Container < matlab.ui.container.internal.UIContainer
%uix.Container Container base class
%
% uix.Container is base class for containers that extend uicontainer.
% Copyright 2009-2015 The MathWorks, Inc.
% $Revision: 1137 $ $Date: 2015-05-29 21:48:21 +0100 (Fri, 29 May 2015) $
end % classdef
================================================
FILE: leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/Divider.m
================================================
classdef Divider < matlab.mixin.SetGet
%uix.Divider Draggable divider
%
% d = uix.Divider() creates a divider.
%
% d = uix.Divider(p1,v1,p2,v2,...) creates a divider and sets
% specified property p1 to value v1, etc.
% Copyright 2009-2016 The MathWorks, Inc.
% $Revision: 1436 $ $Date: 2016-11-17 17:53:29 +0000 (Thu, 17 Nov 2016) $
properties( Dependent )
Parent % parent
Units % units [inches|centimeters|characters|normalized|points|pixels]
Position % position
Visible % visible [on|off]
BackgroundColor % background color [RGB]
HighlightColor % border highlight color [RGB]
ShadowColor % border shadow color [RGB]
Orientation % orientation [vertical|horizontal]
Markings % markings [pixels]
end
properties( Access = private )
Control % uicontrol
BackgroundColor_ = get( 0, 'DefaultUicontrolBackgroundColor' ) % backing for BackgroundColor
HighlightColor_ = [1 1 1] % backing for HighlightColor
ShadowColor_ = [0.7 0.7 0.7] % backing for ShadowColor
Orientation_ = 'vertical' % backing for Orientation
Markings_ = zeros( [0 1] ) % backing for Markings
SizeChangedListener % listener
end
methods
function obj = Divider( varargin )
%uix.Divider Draggable divider
gitextract_ix4wtaz2/ ├── .gitignore ├── LICENSE ├── analysis/ │ └── gait_analysis/ │ ├── Cluster_Velocity_Distributions.mat │ ├── GaitVectors3.mat │ ├── Gait_Densities.mat │ ├── Gait_Speed_Distributions.mat │ ├── Swing_Velocity_Over_Time.mat │ ├── Swing_and_Stance_versus_Velocity.mat │ ├── TetrapodExample.mat │ ├── TripodExample.mat │ ├── compute_gait_densities.m │ ├── gait_analysis_computation.m │ ├── gait_analysis_plotting.m │ └── plot_gait_densities.m ├── data/ │ └── readme.md ├── examples/ │ ├── batch_process_video.ipynb │ ├── hdf5tovid.m │ └── vidtohdf5.m ├── install_leap.m ├── leap/ │ ├── __init__.py │ ├── compute_errors.m │ ├── confmaps2pts.m │ ├── generate_training_set.m │ ├── graph2paf.m │ ├── guis/ │ │ ├── cluster_sample.mlapp │ │ ├── create_skeleton.mlapp │ │ └── label_joints.m │ ├── hpc/ │ │ └── python_gpu.sh │ ├── image_augmentation.py │ ├── layers.py │ ├── models.py │ ├── plot_joints_single.m │ ├── predict_box.m │ ├── predict_box.py │ ├── pts2confmaps.m │ ├── test_leap.m │ ├── toolbox/ │ │ ├── aliases/ │ │ │ ├── alims.m │ │ │ ├── ff.m │ │ │ ├── h5file.m │ │ │ ├── imgsc.m │ │ │ └── repext.m │ │ ├── graphics/ │ │ │ ├── FEX-settingsdlg/ │ │ │ │ └── settingsdlg.m │ │ │ ├── GUI Layout Toolbox/ │ │ │ │ └── layout/ │ │ │ │ └── +uix/ │ │ │ │ ├── +mixin/ │ │ │ │ │ ├── Container.m │ │ │ │ │ ├── Flex.m │ │ │ │ │ └── Panel.m │ │ │ │ ├── Box.m │ │ │ │ ├── ChildEvent.m │ │ │ │ ├── ChildObserver.m │ │ │ │ ├── Container.m │ │ │ │ ├── Divider.m │ │ │ │ ├── FigureData.m │ │ │ │ ├── FigureObserver.m │ │ │ │ ├── Grid.m │ │ │ │ ├── GridFlex.m │ │ │ │ ├── HBox.m │ │ │ │ ├── Node.m │ │ │ │ ├── Panel.m │ │ │ │ ├── PointerManager.m │ │ │ │ ├── SelectionData.m │ │ │ │ ├── VBox.m │ │ │ │ ├── calcPixelSizes.m │ │ │ │ ├── setPosition.m │ │ │ │ └── tracking.m │ │ │ ├── distributionPlot/ │ │ │ │ └── colorCode2rgb.m │ │ │ ├── draggable/ │ │ │ │ └── draggable.m │ │ │ ├── figclosekey.m │ │ │ ├── figsize.m │ │ │ ├── fontsize.m │ │ │ ├── hline.m │ │ │ ├── isax.m │ │ │ ├── isfig.m │ │ │ ├── noticks.m │ │ │ ├── pareto2.m │ │ │ ├── plotExplainedVar.m │ │ │ ├── plotpts.m │ │ │ ├── redblue.m │ │ │ ├── sc/ │ │ │ │ ├── gray.m │ │ │ │ ├── private/ │ │ │ │ │ ├── colormap_helper.m │ │ │ │ │ └── rescale.m │ │ │ │ └── real2rgb.m │ │ │ └── shortticks.m │ │ ├── hdf5/ │ │ │ ├── h5att2struct.m │ │ │ ├── h5getdatasets.m │ │ │ ├── h5readframes.m │ │ │ ├── h5readgroup.m │ │ │ ├── h5save.m │ │ │ ├── h5savegroup.m │ │ │ ├── h5size.m │ │ │ ├── h5struct2att.m │ │ │ └── hdf5prop/ │ │ │ ├── h5datacreate.m │ │ │ └── hdf5prop.m │ │ ├── imageproc/ │ │ │ └── ind2im.m │ │ ├── inputParsing/ │ │ │ ├── get_caller_name.m │ │ │ ├── nameval2struct.m │ │ │ ├── parse_params.m │ │ │ └── struct2nameval.m │ │ ├── io/ │ │ │ ├── GetFullPath.m │ │ │ ├── dir_ext.m │ │ │ ├── dir_paths.m │ │ │ ├── dir_regex.m │ │ │ ├── exists.m │ │ │ ├── ext2filter_spec.m │ │ │ ├── extrep.m │ │ │ ├── funpath.m │ │ │ ├── get_ext.m │ │ │ ├── get_filename.m │ │ │ ├── get_filesize.m │ │ │ ├── get_new_filename.m │ │ │ ├── lastdir.m │ │ │ ├── mkdirto.m │ │ │ └── uibrowse.m │ │ ├── ml/ │ │ │ └── ezpca.m │ │ ├── strings/ │ │ │ ├── bytes2str.m │ │ │ ├── instr.m │ │ │ ├── printf.m │ │ │ ├── secs2hms.m │ │ │ └── secsf.m │ │ ├── utilities/ │ │ │ ├── af.m │ │ │ ├── arange.m │ │ │ ├── areempty.m │ │ │ ├── argmin.m │ │ │ ├── arr2cell.m │ │ │ ├── cell1.m │ │ │ ├── cellcat.m │ │ │ ├── cf.m │ │ │ ├── clip.m │ │ │ ├── functional_programming/ │ │ │ │ └── wrap.m │ │ │ ├── get_new_string.m │ │ │ ├── grp2cell.m │ │ │ ├── horz.m │ │ │ ├── iseven.m │ │ │ ├── loadvar.m │ │ │ ├── nunique.m │ │ │ ├── rownorm.m │ │ │ ├── stacks/ │ │ │ │ ├── imtile.m │ │ │ │ ├── stack2cell.m │ │ │ │ ├── stack2vecs.m │ │ │ │ └── vecs2stack.m │ │ │ ├── swap.m │ │ │ ├── time/ │ │ │ │ ├── GetSystemTimePreciseAsFileTime.m │ │ │ │ ├── GetSystemTimePreciseAsFileTime.mexw64 │ │ │ │ ├── stic.m │ │ │ │ ├── stoc.m │ │ │ │ ├── stocf.m │ │ │ │ └── systime.m │ │ │ ├── varsize.m │ │ │ ├── varstruct.m │ │ │ ├── vert.m │ │ │ └── vplay.m │ │ └── video/ │ │ ├── validate_stack.m │ │ └── vplayer.m │ ├── training.py │ ├── utils.py │ └── viz.py ├── readme.md ├── setup.py └── uninstall_leap.m
SYMBOL INDEX (45 symbols across 7 files)
FILE: leap/image_augmentation.py
function transform_imgs (line 6) | def transform_imgs(X, theta=(-180,180), scale=1.0):
class PairedImageAugmenter (line 47) | class PairedImageAugmenter(Sequence):
method __init__ (line 48) | def __init__(self, X, Y, batch_size=32, shuffle=False, theta=(-180,180...
method __len__ (line 62) | def __len__(self):
method __getitem__ (line 65) | def __getitem__(self, batch_idx):
class MultiInputOutputPairedImageAugmenter (line 75) | class MultiInputOutputPairedImageAugmenter(PairedImageAugmenter):
method __init__ (line 76) | def __init__(self, input_names, output_names, *args, **kwargs):
method __getitem__ (line 85) | def __getitem__(self, batch_idx):
FILE: leap/layers.py
function resize_images (line 43) | def resize_images(x, height_factor, width_factor, interpolation, data_fo...
class UpSampling2D (line 86) | class UpSampling2D(Layer):
method __init__ (line 120) | def __init__(self, size=(2, 2), data_format=None, interpolation='neare...
method compute_output_shape (line 132) | def compute_output_shape(self, input_shape):
method call (line 148) | def call(self, inputs):
method get_config (line 152) | def get_config(self):
function _find_maxima (line 159) | def _find_maxima(x):
function find_maxima (line 180) | def find_maxima(x, data_format):
class Maxima2D (line 202) | class Maxima2D(Layer):
method __init__ (line 232) | def __init__(self, data_format=None, **kwargs):
method compute_output_shape (line 241) | def compute_output_shape(self, input_shape):
method call (line 251) | def call(self, inputs):
method get_config (line 254) | def get_config(self):
function residual_bottleneck_module (line 260) | def residual_bottleneck_module(x_in, output_filters=32, bottleneck_facto...
FILE: leap/models.py
function leap_cnn (line 9) | def leap_cnn(img_size, output_channels, filters=64, upsampling_layers=Fa...
function hourglass (line 59) | def hourglass(img_size, output_channels, filters=64, upsampling_layers=F...
function stacked_hourglass (line 131) | def stacked_hourglass(img_size, output_channels, filters=64, upsampling_...
FILE: leap/predict_box.py
function tf_find_peaks (line 15) | def tf_find_peaks(x):
function convert_to_peak_outputs (line 48) | def convert_to_peak_outputs(model, include_confmaps=False):
function predict_box (line 62) | def predict_box(box_path, model_path, out_path, *, box_dset="/box", epoc...
FILE: leap/training.py
function train_val_split (line 19) | def train_val_split(X, Y, val_size=0.15, shuffle=True):
function create_run_folders (line 35) | def create_run_folders(run_name, base_path="models", clean=False):
class LossHistory (line 65) | class LossHistory(keras.callbacks.Callback):
method __init__ (line 66) | def __init__(self, run_path):
method on_train_begin (line 70) | def on_train_begin(self, logs={}):
method on_epoch_end (line 73) | def on_epoch_end(self, batch, logs={}):
function create_model (line 85) | def create_model(net_name, img_size, output_channels, **kwargs):
function train (line 99) | def train(data_path, *,
FILE: leap/utils.py
function versions (line 7) | def versions(list_devices=False):
function find_weights (line 25) | def find_weights(model_path):
function find_best_weights (line 38) | def find_best_weights(model_path):
function load_dataset (line 48) | def load_dataset(data_path, X_dset="box", Y_dset="confmaps", permute=(0,...
function preprocess (line 66) | def preprocess(X, permute=(0,3,2,1)):
FILE: leap/viz.py
function show_pred (line 6) | def show_pred(net, X, Y, joint_idx=0, alpha_pred=0.7, save_path=None, sh...
function gallery (line 78) | def gallery(array, ncols=4):
function show_confmap_grid (line 94) | def show_confmap_grid(net, X, Y, plot=True, save_path=None, show_figure=...
function plot_history (line 133) | def plot_history(history, save_path=None, show_figure=False):
Condensed preview — 157 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (475K chars).
[
{
"path": ".gitignore",
"chars": 1294,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "analysis/gait_analysis/compute_gait_densities.m",
"chars": 1286,
"preview": "%% Pathing\nembed_path = 'Z:\\code\\2018-05-05_joints_tsne_FlyAging_talmo-labels\\results\\FlyAging-DiegoCNN_v1.0_filters=64_"
},
{
"path": "analysis/gait_analysis/gait_analysis_computation.m",
"chars": 11182,
"preview": "clear all;\n%% Pathing\n% addpath(genpath('deps'));\njoints_dir = 'Z:\\data\\JointTracker\\2018-02_FlyAging_boxes\\expts\\preds\\"
},
{
"path": "analysis/gait_analysis/gait_analysis_plotting.m",
"chars": 4418,
"preview": "clear all;\n%% Look at particular section\n% Pick the gait example to observe\n% load('TetrapodExample');\nload('TripodExamp"
},
{
"path": "analysis/gait_analysis/plot_gait_densities.m",
"chars": 1109,
"preview": "clear all;\n%% Plot all three densities overlayed ontop of one another\nload('Gait_Densities');\nfigure; hold on;\nh = image"
},
{
"path": "data/readme.md",
"chars": 1142,
"preview": "The full fly dataset and all trained networks used in our paper can be downloaded from: [`http://arks.princeton.edu/ark:"
},
{
"path": "examples/batch_process_video.ipynb",
"chars": 12990,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"# Example: Predict body part positi"
},
{
"path": "examples/hdf5tovid.m",
"chars": 1763,
"preview": "% Clean start!\nclear all, clc\n\n%% Parameters\n% Path to input file\ndataPath = '../data/examples/072212_163153.clip.h5';\n\n"
},
{
"path": "examples/vidtohdf5.m",
"chars": 2042,
"preview": "% Clean start!\nclear all, clc\n\n%% Parameters\n% Path to input file\nvideoPath = '..\\..\\leap\\data\\examples\\072212_163153.cl"
},
{
"path": "install_leap.m",
"chars": 540,
"preview": "function install_leap()\n%INSTALL_LEAP Installs the Python package and adds MATLAB scripts to path.\n% Usage:\n% install_"
},
{
"path": "leap/__init__.py",
"chars": 161,
"preview": "from . import image_augmentation\nfrom . import layers\nfrom . import models\nfrom . import predict_box\nfrom . import train"
},
{
"path": "leap/compute_errors.m",
"chars": 1086,
"preview": "function err = compute_errors(pos_pred, pos_gt)\n%COMPUTE_ERRORS Computes error rates given predicted and ground truth po"
},
{
"path": "leap/confmaps2pts.m",
"chars": 432,
"preview": "function [pts, confvals] = confmaps2pts(C)\n%CONFMAPS2PTS Convert a set of confidence maps into a set of points.\n% Usage:"
},
{
"path": "leap/generate_training_set.m",
"chars": 7827,
"preview": "function savePath = generate_training_set(boxPath, varargin)\n%GENERATE_TRAINING_SET Creates a dataset for training.\n% Us"
},
{
"path": "leap/graph2paf.m",
"chars": 1558,
"preview": "function paf = graph2paf(nodes, edges, sz, channelsOnly, sigma)\n%GRAPH2PAF Converts a set of edges into part affinity fi"
},
{
"path": "leap/guis/label_joints.m",
"chars": 45151,
"preview": "function label_joints(boxPath, skeletonPath)\n%LABEL_JOINTS GUI to click on images to yield a graph.\n% Usage:\n% label_j"
},
{
"path": "leap/hpc/python_gpu.sh",
"chars": 210,
"preview": "#!/bin/bash\n#SBATCH --time=4:00:00\n#SBATCH --mem=128000\n#SBATCH -N 1\n#SBATCH --cpus-per-task=4\n#SBATCH --ntasks-per-node"
},
{
"path": "leap/image_augmentation.py",
"chars": 2795,
"preview": "import cv2\nimport numpy as np\nimport keras\nfrom keras.utils import Sequence\n\ndef transform_imgs(X, theta=(-180,180), sca"
},
{
"path": "leap/layers.py",
"chars": 11274,
"preview": "# -*- coding: utf-8 -*-\n\"\"\"\n Copyright 2018 Jacob M. Graving <jgraving@gmail.com>\n\n Licensed under the Apache Licens"
},
{
"path": "leap/models.py",
"chars": 12299,
"preview": "import keras\nfrom keras import backend as K\nfrom keras.models import Model\nfrom keras.layers import Input, Conv2D, Conv2"
},
{
"path": "leap/plot_joints_single.m",
"chars": 842,
"preview": "function h = plot_joints_single(pts, segments, markerSize, lineWidth)\n%PLOT_JOINTS_SINGLE Plot joints on a single frame."
},
{
"path": "leap/predict_box.m",
"chars": 2675,
"preview": "function preds = predict_box(box, modelPath, saveConfmaps)\n%PREDICT_BOX Evaluates model predictions on a stack of frames"
},
{
"path": "leap/predict_box.py",
"chars": 6978,
"preview": "import h5py\nimport numpy as np\nimport os\nfrom time import time\nimport keras\nimport keras.models\nfrom keras.layers import"
},
{
"path": "leap/pts2confmaps.m",
"chars": 1073,
"preview": "function confmaps = pts2confmaps(pts, sz, sigma, normalize)\n%PTS2CONFMAPS Generate confidence maps centered at specified"
},
{
"path": "leap/test_leap.m",
"chars": 854,
"preview": "function works = test_leap()\n%TEST_LEAP Checks whether LEAP is properly installed.\n% Usage:\n% test_leap\n% works = te"
},
{
"path": "leap/toolbox/aliases/alims.m",
"chars": 195,
"preview": "function varargout = alims(X)\n%ALIMS Alias for arange.\n% Usage:\n% R = alims(X)\n% [min_val, max_val] = alims(X)\n%\n% S"
},
{
"path": "leap/toolbox/aliases/ff.m",
"chars": 103,
"preview": "function varargout = ff(varargin)\n N = max(nargout,1);\n\tvarargout{1:N} = fullfile(varargin{:});\nend\n"
},
{
"path": "leap/toolbox/aliases/h5file.m",
"chars": 89,
"preview": "function varargout = h5file(varargin)\n\tvarargout{1:nargout} = hdf5prop(varargin{:});\nend\n"
},
{
"path": "leap/toolbox/aliases/imgsc.m",
"chars": 360,
"preview": "function h = imgsc(I, varargin)\n%IMGSC Alias for imagesc for images.\n% Usage:\n% imgsc(I)\n% imgsc(I, ...)\n%\n% See als"
},
{
"path": "leap/toolbox/aliases/repext.m",
"chars": 94,
"preview": "function varargout = repext(varargin)\n\tvarargout{1:max(nargout,1)} = extrep(varargin{:});\nend\n"
},
{
"path": "leap/toolbox/graphics/FEX-settingsdlg/settingsdlg.m",
"chars": 18089,
"preview": "function [settings, button] = settingsdlg(varargin)\n% SETTINGSDLG Default dialog to produce a settings-struc"
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/+mixin/Container.m",
"chars": 9443,
"preview": "classdef Container < handle\n %uix.mixin.Container Container mixin\n %\n % uix.mixin.Container is a mixin class "
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/+mixin/Flex.m",
"chars": 2024,
"preview": "classdef Flex < handle\n %uix.mixin.Flex Flex mixin\n %\n % uix.mixin.Flex is a mixin class used by flex contain"
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/+mixin/Panel.m",
"chars": 6663,
"preview": "classdef Panel < uix.mixin.Container\n %uix.mixin.Panel Panel mixin\n %\n % uix.mixin.Panel is a mixin class use"
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/Box.m",
"chars": 1291,
"preview": "classdef Box < uix.Container & uix.mixin.Container\n %uix.Box Box and grid base class\n %\n % uix.Box is a base "
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/ChildEvent.m",
"chars": 778,
"preview": "classdef( Hidden, Sealed ) ChildEvent < event.EventData\n %uix.ChildEvent Event data for child event\n %\n % e ="
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/ChildObserver.m",
"chars": 8831,
"preview": "classdef ( Hidden, Sealed ) ChildObserver < handle\n %uix.ChildObserver Child observer\n %\n % co = uix.ChildObs"
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/Container.m",
"chars": 334,
"preview": "classdef Container < matlab.ui.container.internal.UIContainer\n %uix.Container Container base class\n %\n % uix."
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/Divider.m",
"chars": 10925,
"preview": "classdef Divider < matlab.mixin.SetGet\n %uix.Divider Draggable divider\n %\n % d = uix.Divider() creates a divi"
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/FigureData.m",
"chars": 781,
"preview": "classdef ( Hidden, Sealed ) FigureData < event.EventData\n %uix.FigureData Event data for FigureChanged on uix.Figure"
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/FigureObserver.m",
"chars": 3091,
"preview": "classdef ( Hidden, Sealed ) FigureObserver < handle\n %uix.FigureObserver Figure observer\n %\n % A figure obser"
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/Grid.m",
"chars": 11948,
"preview": "classdef Grid < uix.Box\n %uix.Grid Grid\n %\n % b = uix.Grid(p1,v1,p2,v2,...) constructs a grid and sets parame"
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/GridFlex.m",
"chars": 19877,
"preview": "classdef GridFlex < uix.Grid & uix.mixin.Flex\n %uix.GridFlex Flexible grid\n %\n % b = uix.GridFlex(p1,v1,p2,v2"
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/HBox.m",
"chars": 6649,
"preview": "classdef HBox < uix.Box\n %uix.HBox Horizontal box\n %\n % b = uix.HBox(p1,v1,p2,v2,...) constructs a horizontal"
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/Node.m",
"chars": 2881,
"preview": "classdef ( Hidden ) Node < dynamicprops\n %uix.Node Node\n %\n % n = uix.Node(o) creates a node for the handle o"
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/Panel.m",
"chars": 2095,
"preview": "classdef Panel < matlab.ui.container.Panel & uix.mixin.Panel\n %uix.Panel Standard panel\n %\n % b = uix.Panel(p"
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/PointerManager.m",
"chars": 4743,
"preview": "classdef ( Hidden, Sealed ) PointerManager < handle\n %uix.PointerManager Pointer manager\n \n % Copyright 2016 "
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/SelectionData.m",
"chars": 954,
"preview": "classdef( Hidden, Sealed ) SelectionData < event.EventData\n %uix.SelectionData Event data for selection event\n %\n"
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/VBox.m",
"chars": 6681,
"preview": "classdef VBox < uix.Box\n %uix.VBox Vertical box\n %\n % b = uix.VBox(p1,v1,p2,v2,...) constructs a vertical box"
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/calcPixelSizes.m",
"chars": 1574,
"preview": "function pSizes = calcPixelSizes( pTotal, mSizes, pMinima, pPadding, pSpacing )\n%calcPixelSizes Calculate child sizes i"
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/setPosition.m",
"chars": 877,
"preview": "function setPosition( o, p, u )\n%setPosition Set position of graphics object\n%\n% setPosition(o,p,u) sets the position "
},
{
"path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/tracking.m",
"chars": 5760,
"preview": "function varargout = tracking( varargin )\n%tracking Track anonymized usage data\n%\n% tracking(p,v,id) tracks usage to t"
},
{
"path": "leap/toolbox/graphics/distributionPlot/colorCode2rgb.m",
"chars": 939,
"preview": "function rgbVec = colorCode2rgb(c)\n%COLORCODE2RGB converts a color code to an rgb vector\n%\n\n% SYNOPSIS rgbVec = colorCod"
},
{
"path": "leap/toolbox/graphics/draggable/draggable.m",
"chars": 21955,
"preview": "function draggable(h,varargin)\n% DRAGGABLE - Make it so that a graphics object can be dragged in a figure.\n% This func"
},
{
"path": "leap/toolbox/graphics/figclosekey.m",
"chars": 624,
"preview": "function figclosekey(h, key)\n%FIGCLOSEKEY Add a hotkey for closing the figure.\n% Usage:\n% figclosekey(h, key)\n% \n% Arg"
},
{
"path": "leap/toolbox/graphics/figsize.m",
"chars": 1679,
"preview": "function sz = figsize(h, width, height)\n%FIGSIZE Resizes the specified figure while keeping it on screen.\n% Usage:\n% f"
},
{
"path": "leap/toolbox/graphics/fontsize.m",
"chars": 285,
"preview": "function fontsize(size, h)\n%FONTSIZE Sets the fontsize across a graphics object.\n% Usage:\n% fontsize(size) % default: "
},
{
"path": "leap/toolbox/graphics/hline.m",
"chars": 807,
"preview": "function h = hline(y, varargin)\n%HLINE Easy plotting of a horizonal line.\n% Usage:\n% hline\n% hline(y)\n% hline(ax, "
},
{
"path": "leap/toolbox/graphics/isax.m",
"chars": 187,
"preview": "function TF = isax(h)\n%ISAX Check if input is an axes handle.\n% Usage:\n% TF = isax(h)\n% \n% Args:\n% h: \n% \n% See also"
},
{
"path": "leap/toolbox/graphics/isfig.m",
"chars": 320,
"preview": "function TF = isfig(h)\n%ISFIG Checks whether the handle(s) specified are existing figures.\n% Usage:\n% TF = isfig(h)\n%\n"
},
{
"path": "leap/toolbox/graphics/noticks.m",
"chars": 852,
"preview": "function noticks(ax, whichAxes)\n%NOTICKS Hides ticks in the axes.\n% Usage:\n% noticks\n% noticks('x') % hides only x-t"
},
{
"path": "leap/toolbox/graphics/pareto2.m",
"chars": 598,
"preview": "function [ax, b, p] = pareto2(X, leftYLabel, rightYLabel)\n%PARETO2 Prettier pareto function.\n% Usage:\n% pareto2(X)\n% "
},
{
"path": "leap/toolbox/graphics/plotExplainedVar.m",
"chars": 677,
"preview": "function [ax, b, p] = plotExplainedVar(explained)\n%PLOTEXPLAINEDVAR Pretty plot PCA explained variance.\n% Usage:\n% plo"
},
{
"path": "leap/toolbox/graphics/plotpts.m",
"chars": 482,
"preview": "function h = plotpts(pts, varargin)\n%PLOTPTS Convenience wrapper for plotting scatters. Same syntax as plot().\n% Usage:\n"
},
{
"path": "leap/toolbox/graphics/redblue.m",
"chars": 1044,
"preview": "function map = redblue(N, dark)\n%REDBLUE Red and blue colormap going from blue to white to red.\n% Usage:\n% map = redbl"
},
{
"path": "leap/toolbox/graphics/sc/gray.m",
"chars": 1016,
"preview": "%GRAY Black-white colormap\n%\n% Examples:\n% map = gray;\n% map = gray(len);\n% B = gray(A);\n% B = gray(A, lims);\n%"
},
{
"path": "leap/toolbox/graphics/sc/private/colormap_helper.m",
"chars": 1693,
"preview": "function map = colormap_helper(map, len, lims)\n%COLORMAP_HELPER Helper function for colormaps\n%\n% Examples:\n% map = c"
},
{
"path": "leap/toolbox/graphics/sc/private/rescale.m",
"chars": 1632,
"preview": "function [B, lims] = rescale(A, lims, out_lims)\n%RESCALE Linearly rescale values in an array\n%\n% Examples:\n% B = resc"
},
{
"path": "leap/toolbox/graphics/sc/real2rgb.m",
"chars": 4936,
"preview": "%REAL2RGB Converts a real-valued matrix into a truecolor image\n%\n% Examples:\n% B = real2rgb(A, cmap);\n% B = real2rg"
},
{
"path": "leap/toolbox/graphics/shortticks.m",
"chars": 226,
"preview": "function shortticks(ax)\n%SHORTTICKS Make the axis ticks short.\n% Usage:\n% shortticks\n% shortticks(ax)\n% \n% See also:"
},
{
"path": "leap/toolbox/hdf5/h5att2struct.m",
"chars": 806,
"preview": "function S = h5att2struct(filename, location)\n%H5ATT2STRUCT Reads a set of HDF5 attributes into a named structure.\n% Usa"
},
{
"path": "leap/toolbox/hdf5/h5getdatasets.m",
"chars": 1019,
"preview": "function datasets = h5getdatasets(filepath, grp, recurse)\n%H5GETDATASETS Returns a list of all datasets in an HDF5 file."
},
{
"path": "leap/toolbox/hdf5/h5readframes.m",
"chars": 1954,
"preview": "function [frames, numFrames] = h5readframes(filepath, dataset, idx)\n%H5READFRAMES Reads video frames from an HDF5 file.\n"
},
{
"path": "leap/toolbox/hdf5/h5readgroup.m",
"chars": 1248,
"preview": "function data = h5readgroup(filepath, group)\n%H5READGROUP Reads a group into a structure containing all the data in the "
},
{
"path": "leap/toolbox/hdf5/h5save.m",
"chars": 2428,
"preview": "function h5save(filepath, X, dset, varargin)\n%H5SAVE Create and save a variable to an HDF5 file.\n% Usage:\n% h5save(fil"
},
{
"path": "leap/toolbox/hdf5/h5savegroup.m",
"chars": 1067,
"preview": "function h5savegroup(filepath, S, grp, varargin)\n%H5SAVEGROUP Saves a struct as a HDF5 group.\n% Usage:\n% h5savegroup(f"
},
{
"path": "leap/toolbox/hdf5/h5size.m",
"chars": 561,
"preview": "function [sz, maxSize] = h5size(filepath, dataset, dim)\n%H5SIZE Returns the size of the specified dataset.\n% Usage:\n% "
},
{
"path": "leap/toolbox/hdf5/h5struct2att.m",
"chars": 726,
"preview": "function h5struct2att(filepath, location, S)\n%H5STRUCT2ATT Writes attributes to an HDF5 file from a scalar structure.\n% "
},
{
"path": "leap/toolbox/hdf5/hdf5prop/h5datacreate.m",
"chars": 7610,
"preview": "function datasetId = h5datacreate(h5file,varname,varargin)\n%H5datacreate Create HDF5 dataset.\n%\n% H5DATACREATE(HFILE,"
},
{
"path": "leap/toolbox/hdf5/hdf5prop/hdf5prop.m",
"chars": 15122,
"preview": "classdef hdf5prop < handle\n% HDF5PROP Class for transparent file data access.\n% Matlab class to create and access HDF5 d"
},
{
"path": "leap/toolbox/imageproc/ind2im.m",
"chars": 1307,
"preview": "function I = ind2im(ind, sz, vals, fillval)\n%IND2IM Create image from a set of linear indices.\n% Usage:\n% I = ind2im(i"
},
{
"path": "leap/toolbox/inputParsing/get_caller_name.m",
"chars": 1045,
"preview": "function caller = get_caller_name(varargin)\n%GET_CALLER_NAME Returns the name of the caller function.\n% Usage:\n% calle"
},
{
"path": "leap/toolbox/inputParsing/nameval2struct.m",
"chars": 709,
"preview": "function S = nameval2struct(C)\n%NAMEVAL2STRUCT Converts a cell array of name-value pairs to a struct.\n% Usage:\n% struc"
},
{
"path": "leap/toolbox/inputParsing/parse_params.m",
"chars": 1890,
"preview": "function [results, unmatched] = parse_params(args, defaults, varargin)\n%PARSE_PARAMS Parses a set of name-value pairs.\n%"
},
{
"path": "leap/toolbox/inputParsing/struct2nameval.m",
"chars": 392,
"preview": "function C = struct2nameval(S)\n%STRUCT2NAMEVAL Converts a structure to a cell array of name-value pairs.\n% Usage:\n% na"
},
{
"path": "leap/toolbox/io/GetFullPath.m",
"chars": 11855,
"preview": "function File = GetFullPath(File, Style)\n% GetFullPath - Get absolute canonical path of a file or folder\n% Absolute path"
},
{
"path": "leap/toolbox/io/dir_ext.m",
"chars": 866,
"preview": "function matches = dir_ext(path, extensions, return_paths)\n%DIR_EXT Returns files in a directory with the matching exten"
},
{
"path": "leap/toolbox/io/dir_paths.m",
"chars": 1318,
"preview": "function [paths, base_path] = dir_paths(path, type)\n%DIR_PATHS Returns the full paths of a directory listing.\n% Usage:\n%"
},
{
"path": "leap/toolbox/io/dir_regex.m",
"chars": 678,
"preview": "function matches = dir_regex(path, expression, return_paths)\n%DIR_REGEX Returns contents in the path matching a regular "
},
{
"path": "leap/toolbox/io/exists.m",
"chars": 1497,
"preview": "function TF = exists(path, dir_only)\n%EXISTS Returns true if the specified path exists in the filesystem.\n% Usage:\n% T"
},
{
"path": "leap/toolbox/io/ext2filter_spec.m",
"chars": 563,
"preview": "function filter_spec = ext2filter_spec(exts)\n%EXT2FILTER_SPEC Generates a filter specification from a list of file exten"
},
{
"path": "leap/toolbox/io/extrep.m",
"chars": 469,
"preview": "function new_path = extrep(filepath, new_ext)\n%EXTREP Replace the extension of a file path.\n% Usage:\n% newpath = extre"
},
{
"path": "leap/toolbox/io/funpath.m",
"chars": 402,
"preview": "function path = funpath(~)\n%FUNPATH Returns the path to the calling function.\n% Usage:\n% path = funpath()\n% path = f"
},
{
"path": "leap/toolbox/io/get_ext.m",
"chars": 637,
"preview": "function ext = get_ext(path, no_dot)\n%GET_EXT Returns the extension in a path. The path need not exist.\n% Usage:\n% ext"
},
{
"path": "leap/toolbox/io/get_filename.m",
"chars": 576,
"preview": "function filename = get_filename(path, no_ext)\n%GET_FILENAME Returns the filename in a path. The path need not exist.\n% "
},
{
"path": "leap/toolbox/io/get_filesize.m",
"chars": 387,
"preview": "function bytes = get_filesize(file_path)\n%GET_FILESIZE Returns the size of the specified file in bytes.\n% This is a wrap"
},
{
"path": "leap/toolbox/io/get_new_filename.m",
"chars": 701,
"preview": "function new_filename = get_new_filename(filename, noSpaces)\n%GET_NEW_FILENAME Returns a new filename based on the one s"
},
{
"path": "leap/toolbox/io/lastdir.m",
"chars": 1764,
"preview": "function dir_path = lastdir(new_path)\n%LASTDIR Remembers the last directory used for use in UI browsing dialogs.\n% Usage"
},
{
"path": "leap/toolbox/io/mkdirto.m",
"chars": 414,
"preview": "function TF = mkdirto(path)\n%MKDIRTO Quietly makes all directories to path that do not currently exist.\n% Usage:\n% mkd"
},
{
"path": "leap/toolbox/io/uibrowse.m",
"chars": 2503,
"preview": "function [path, filter_idx] = uibrowse(filter_spec, start_path, dialog_title, type)\n%UIBROWSE Displays a file or folder"
},
{
"path": "leap/toolbox/ml/ezpca.m",
"chars": 972,
"preview": "function pcs = ezpca(X, varargin)\n%EZPCA PCA -- quick and easy!\n% Usage:\n% pcs = ezpca(X)\n% \n% Args:\n% X: N x D data"
},
{
"path": "leap/toolbox/strings/bytes2str.m",
"chars": 659,
"preview": "function [str, x_bytes, unit] = bytes2str(bytes, precision)\n%BYTES2STR Returns the number of bytes in a more readable fo"
},
{
"path": "leap/toolbox/strings/instr.m",
"chars": 2975,
"preview": "function TF = instr(needle, haystack, flags)\n%INSTR Returns true if (any) needle is in (any) haystack.\n% Usage:\n% TF ="
},
{
"path": "leap/toolbox/strings/printf.m",
"chars": 1800,
"preview": "function formatted = printf(str, varargin)\n%PRINTF Prints formatted output.\n% Usage:\n% printf(str, ...)\n% formatted "
},
{
"path": "leap/toolbox/strings/secs2hms.m",
"chars": 293,
"preview": "function [h, m, s] = secs2hms(numSecs)\n%SECS2HMS Converts a number of seconds to hours, minutes and fractional seconds.\n"
},
{
"path": "leap/toolbox/strings/secsf.m",
"chars": 958,
"preview": "function str = secsf(format, numSecs)\n%SECSF Yet another seconds formatting function.\n% Usage:\n% str = secsf(format, n"
},
{
"path": "leap/toolbox/utilities/af.m",
"chars": 256,
"preview": "function out = af(func,varargin)\n%AF Convenience wrapper for arrayfun with non-uniform output.\n% Usage:\n% out = af(fun"
},
{
"path": "leap/toolbox/utilities/arange.m",
"chars": 296,
"preview": "function [min_val, max_val] = arange(X)\n%ARANGE Returns the range (min and max) of an entire array.\n% Usage:\n% R = ara"
},
{
"path": "leap/toolbox/utilities/areempty.m",
"chars": 247,
"preview": "function empties = areempty(cellarr)\n%AREEMPTY Returns a logical array of the size of the cell array indicating whether "
},
{
"path": "leap/toolbox/utilities/argmin.m",
"chars": 255,
"preview": "function idx = argmin(X, dim)\n%ARGMAX Returns the index at which the min is found.\n% Usage:\n% idx = argmin(X)\n% idx "
},
{
"path": "leap/toolbox/utilities/arr2cell.m",
"chars": 384,
"preview": "function C = arr2cell(X, dim)\n%ARR2CELL Splits an array into a cell across the specified dimension.\n% Usage:\n% C = arr"
},
{
"path": "leap/toolbox/utilities/cell1.m",
"chars": 569,
"preview": "function varargout = cell1(sz, dim)\n%CELL1 Creates a 1d empty cell array. Convenience for cell(sz, 1).\n% Usage:\n% C = "
},
{
"path": "leap/toolbox/utilities/cellcat.m",
"chars": 865,
"preview": "function [X,idx] = cellcat(C, dim)\n%CELLCAT Unpacks and concatenates a cell array. Shorthand for: cat(dim, C{:})\n% Usage"
},
{
"path": "leap/toolbox/utilities/cf.m",
"chars": 253,
"preview": "function out = cf(func,varargin)\n%CF Convenience wrapper for cellfun with non-uniform output.\n% Usage:\n% out = cf(func"
},
{
"path": "leap/toolbox/utilities/clip.m",
"chars": 415,
"preview": "function X2 = clip(X, bounds, varargin)\n%CLIP Clip values in an array to lower and upper bounds.\n% Usage:\n% X2 = clip("
},
{
"path": "leap/toolbox/utilities/functional_programming/wrap.m",
"chars": 874,
"preview": "function out = wrap(f, out_indices)\n\n% out = wrap(f, out_indices)\n% \n% If you've ever needed multiple outputs from a fun"
},
{
"path": "leap/toolbox/utilities/get_new_string.m",
"chars": 629,
"preview": "function new_str = get_new_string(str, strings, format)\n%GET_NEW_STRING Returns a new string that does not exist in a pr"
},
{
"path": "leap/toolbox/utilities/grp2cell.m",
"chars": 1206,
"preview": "function [C, grps, G] = grp2cell(X, G, dim)\n%GRP2CELL Splits an array into a cell using a grouping variable.\n% Usage:\n% "
},
{
"path": "leap/toolbox/utilities/horz.m",
"chars": 138,
"preview": "function V = horz(X)\n%HORZ Returns the array as a horizontal vector.\n% Usage:\n% V = horz(X) % size(V) = [1, numel(X)]\n"
},
{
"path": "leap/toolbox/utilities/iseven.m",
"chars": 182,
"preview": "function TF = iseven(X)\n%ISEVEN Returns true if the input is divisible by 2, false otherwise.\n% Usage:\n% TF = iseven(X"
},
{
"path": "leap/toolbox/utilities/loadvar.m",
"chars": 1109,
"preview": "function varargout = loadvar(mat_file, var_name, varargin)\n%LOADVAR Loads one or more variables from a MAT file.\n% Usage"
},
{
"path": "leap/toolbox/utilities/nunique.m",
"chars": 161,
"preview": "function N = nunique(X)\n%NUNIQUE Returns the number of unique elements in an array.\n% Usage:\n% N = nunique(X)\n%\n% See "
},
{
"path": "leap/toolbox/utilities/rownorm.m",
"chars": 336,
"preview": "function n = rownorm(X, p)\n%ROWNORM Returns the row-wise norm of a matrix X.\n% Usage:\n% n = rownorm(X)\n% n = rownorm"
},
{
"path": "leap/toolbox/utilities/stacks/imtile.m",
"chars": 936,
"preview": "function T = imtile(I, varargin)\n%IMTILE Tiles a stack or set of images.\n% Usage:\n% T = imtile(stack)\n% T = imtile(I"
},
{
"path": "leap/toolbox/utilities/stacks/stack2cell.m",
"chars": 244,
"preview": "function C = stack2cell(S)\n%STACK2CELL Returns a cell array with one slice of the stack in each cell.\n% C = stack2cell"
},
{
"path": "leap/toolbox/utilities/stacks/stack2vecs.m",
"chars": 335,
"preview": "function X = stack2vecs(S)\n%STACK2VECS Convert stack to observation-by-features matrix.\n% Usage:\n% X = stack2vecs(S)\n%"
},
{
"path": "leap/toolbox/utilities/stacks/vecs2stack.m",
"chars": 384,
"preview": "function S = vecs2stack(X,sz)\n%VECS2STACK Convert observation-by-features matrix to stack.\n% Usage:\n% X = stack2vecs(S"
},
{
"path": "leap/toolbox/utilities/swap.m",
"chars": 121,
"preview": "function [B, A] = swap(A, B)\n%SWAP Swap two variables without an intermediate copy.\n% Usage:\n% [B, A] = swap(A, B)\n\nen"
},
{
"path": "leap/toolbox/utilities/time/GetSystemTimePreciseAsFileTime.m",
"chars": 557,
"preview": "%GETSYSTEMTIMEPRECISEASFILETIME Returns system time with high precision.\n% Usage:\n% t = GetSystemTimePreciseAsFileTime"
},
{
"path": "leap/toolbox/utilities/time/stic.m",
"chars": 273,
"preview": "function timer_id = stic\n%STIC TIC equivalent using precise system time.\n% Usage:\n% stic\n% timer_id = tic\n%\n% See al"
},
{
"path": "leap/toolbox/utilities/time/stoc.m",
"chars": 709,
"preview": "function [dt, timer_id] = stoc(timer_id)\n%STIC TOC equivalent using precise system time.\n% Usage:\n% stoc\n% stoc(time"
},
{
"path": "leap/toolbox/utilities/time/stocf.m",
"chars": 1044,
"preview": "function [dt, timer_id] = stocf(timer_id, str, varargin)\n%STOCF Report elapsed time with print formatting.\n% Usage:\n% "
},
{
"path": "leap/toolbox/utilities/time/systime.m",
"chars": 488,
"preview": "function secs = systime\n%SYSTIME Returns precise system time in seconds.\n% Usage:\n% secs = systime\n%\n% Note: This is j"
},
{
"path": "leap/toolbox/utilities/varsize.m",
"chars": 833,
"preview": "function size = varsize(X, units)\n%VARSIZE Returns the size of a variable in bytes.\n% Usage:\n% size = varsize(X)\n% s"
},
{
"path": "leap/toolbox/utilities/varstruct.m",
"chars": 2089,
"preview": "function S = varstruct(var1, var2, varargin)\n%VARSTRUCT Creates a structure out of a set of variables. Fieldnames are in"
},
{
"path": "leap/toolbox/utilities/vert.m",
"chars": 141,
"preview": "function V = vert(X)\n%VERT Returns the matrix as a vertical vector.\n% Usage:\n% V = vert(X)\n%\n% See also: mat2vec, resh"
},
{
"path": "leap/toolbox/utilities/vplay.m",
"chars": 328,
"preview": "function vp = vplay(mov, varargin)\n%VPLAY Play a movie using the vplayer class.\n% Usage:\n% vplay(mov)\n% vplay(mov, ."
},
{
"path": "leap/toolbox/video/validate_stack.m",
"chars": 838,
"preview": "function [images, numFrames] = validate_stack(images, no_error)\n%VALIDATE_STACK Validates a stack of images and converts"
},
{
"path": "leap/toolbox/video/vplayer.m",
"chars": 25560,
"preview": "classdef vplayer < matlab.mixin.SetGet\n %VPLAYER Image video/stack player class.\n % Usage:\n % vplayer(S)\n "
},
{
"path": "leap/training.py",
"chars": 11480,
"preview": "import numpy as np\nimport h5py\nimport os\nfrom time import time\nfrom scipy.io import loadmat, savemat\nimport re\nimport sh"
},
{
"path": "leap/utils.py",
"chars": 2425,
"preview": "import os\nimport numpy as np\nimport re\nfrom time import time\nimport h5py\n\ndef versions(list_devices=False):\n \"\"\" Prin"
},
{
"path": "leap/viz.py",
"chars": 4553,
"preview": "import numpy as np\nimport matplotlib.pyplot as plt\nplt.switch_backend('agg')\n\n\ndef show_pred(net, X, Y, joint_idx=0, alp"
},
{
"path": "readme.md",
"chars": 6130,
"preview": "# SLEAP: Social LEAP Estimates Animal Poses\n\n%UNINSTALL_LEAP Removes LEAP code from the MATLAB path and Python environment.\n% Usage:\n% un"
}
]
// ... and 11 more files (download for full content)
About this extraction
This page contains the full source code of the talmo/leap GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 157 files (36.8 MB), approximately 121.1k tokens, and a symbol index with 45 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.