Full Code of talmo/leap for AI

master c39e07b647da cached
157 files
36.8 MB
121.1k tokens
45 symbols
1 requests
Download .txt
Showing preview only (479K chars total). Download the full file or copy to clipboard to get everything.
Repository: talmo/leap
Branch: master
Commit: c39e07b647da
Files: 157
Total size: 36.8 MB

Directory structure:
gitextract_ix4wtaz2/

├── .gitignore
├── LICENSE
├── analysis/
│   └── gait_analysis/
│       ├── Cluster_Velocity_Distributions.mat
│       ├── GaitVectors3.mat
│       ├── Gait_Densities.mat
│       ├── Gait_Speed_Distributions.mat
│       ├── Swing_Velocity_Over_Time.mat
│       ├── Swing_and_Stance_versus_Velocity.mat
│       ├── TetrapodExample.mat
│       ├── TripodExample.mat
│       ├── compute_gait_densities.m
│       ├── gait_analysis_computation.m
│       ├── gait_analysis_plotting.m
│       └── plot_gait_densities.m
├── data/
│   └── readme.md
├── examples/
│   ├── batch_process_video.ipynb
│   ├── hdf5tovid.m
│   └── vidtohdf5.m
├── install_leap.m
├── leap/
│   ├── __init__.py
│   ├── compute_errors.m
│   ├── confmaps2pts.m
│   ├── generate_training_set.m
│   ├── graph2paf.m
│   ├── guis/
│   │   ├── cluster_sample.mlapp
│   │   ├── create_skeleton.mlapp
│   │   └── label_joints.m
│   ├── hpc/
│   │   └── python_gpu.sh
│   ├── image_augmentation.py
│   ├── layers.py
│   ├── models.py
│   ├── plot_joints_single.m
│   ├── predict_box.m
│   ├── predict_box.py
│   ├── pts2confmaps.m
│   ├── test_leap.m
│   ├── toolbox/
│   │   ├── aliases/
│   │   │   ├── alims.m
│   │   │   ├── ff.m
│   │   │   ├── h5file.m
│   │   │   ├── imgsc.m
│   │   │   └── repext.m
│   │   ├── graphics/
│   │   │   ├── FEX-settingsdlg/
│   │   │   │   └── settingsdlg.m
│   │   │   ├── GUI Layout Toolbox/
│   │   │   │   └── layout/
│   │   │   │       └── +uix/
│   │   │   │           ├── +mixin/
│   │   │   │           │   ├── Container.m
│   │   │   │           │   ├── Flex.m
│   │   │   │           │   └── Panel.m
│   │   │   │           ├── Box.m
│   │   │   │           ├── ChildEvent.m
│   │   │   │           ├── ChildObserver.m
│   │   │   │           ├── Container.m
│   │   │   │           ├── Divider.m
│   │   │   │           ├── FigureData.m
│   │   │   │           ├── FigureObserver.m
│   │   │   │           ├── Grid.m
│   │   │   │           ├── GridFlex.m
│   │   │   │           ├── HBox.m
│   │   │   │           ├── Node.m
│   │   │   │           ├── Panel.m
│   │   │   │           ├── PointerManager.m
│   │   │   │           ├── SelectionData.m
│   │   │   │           ├── VBox.m
│   │   │   │           ├── calcPixelSizes.m
│   │   │   │           ├── setPosition.m
│   │   │   │           └── tracking.m
│   │   │   ├── distributionPlot/
│   │   │   │   └── colorCode2rgb.m
│   │   │   ├── draggable/
│   │   │   │   └── draggable.m
│   │   │   ├── figclosekey.m
│   │   │   ├── figsize.m
│   │   │   ├── fontsize.m
│   │   │   ├── hline.m
│   │   │   ├── isax.m
│   │   │   ├── isfig.m
│   │   │   ├── noticks.m
│   │   │   ├── pareto2.m
│   │   │   ├── plotExplainedVar.m
│   │   │   ├── plotpts.m
│   │   │   ├── redblue.m
│   │   │   ├── sc/
│   │   │   │   ├── gray.m
│   │   │   │   ├── private/
│   │   │   │   │   ├── colormap_helper.m
│   │   │   │   │   └── rescale.m
│   │   │   │   └── real2rgb.m
│   │   │   └── shortticks.m
│   │   ├── hdf5/
│   │   │   ├── h5att2struct.m
│   │   │   ├── h5getdatasets.m
│   │   │   ├── h5readframes.m
│   │   │   ├── h5readgroup.m
│   │   │   ├── h5save.m
│   │   │   ├── h5savegroup.m
│   │   │   ├── h5size.m
│   │   │   ├── h5struct2att.m
│   │   │   └── hdf5prop/
│   │   │       ├── h5datacreate.m
│   │   │       └── hdf5prop.m
│   │   ├── imageproc/
│   │   │   └── ind2im.m
│   │   ├── inputParsing/
│   │   │   ├── get_caller_name.m
│   │   │   ├── nameval2struct.m
│   │   │   ├── parse_params.m
│   │   │   └── struct2nameval.m
│   │   ├── io/
│   │   │   ├── GetFullPath.m
│   │   │   ├── dir_ext.m
│   │   │   ├── dir_paths.m
│   │   │   ├── dir_regex.m
│   │   │   ├── exists.m
│   │   │   ├── ext2filter_spec.m
│   │   │   ├── extrep.m
│   │   │   ├── funpath.m
│   │   │   ├── get_ext.m
│   │   │   ├── get_filename.m
│   │   │   ├── get_filesize.m
│   │   │   ├── get_new_filename.m
│   │   │   ├── lastdir.m
│   │   │   ├── mkdirto.m
│   │   │   └── uibrowse.m
│   │   ├── ml/
│   │   │   └── ezpca.m
│   │   ├── strings/
│   │   │   ├── bytes2str.m
│   │   │   ├── instr.m
│   │   │   ├── printf.m
│   │   │   ├── secs2hms.m
│   │   │   └── secsf.m
│   │   ├── utilities/
│   │   │   ├── af.m
│   │   │   ├── arange.m
│   │   │   ├── areempty.m
│   │   │   ├── argmin.m
│   │   │   ├── arr2cell.m
│   │   │   ├── cell1.m
│   │   │   ├── cellcat.m
│   │   │   ├── cf.m
│   │   │   ├── clip.m
│   │   │   ├── functional_programming/
│   │   │   │   └── wrap.m
│   │   │   ├── get_new_string.m
│   │   │   ├── grp2cell.m
│   │   │   ├── horz.m
│   │   │   ├── iseven.m
│   │   │   ├── loadvar.m
│   │   │   ├── nunique.m
│   │   │   ├── rownorm.m
│   │   │   ├── stacks/
│   │   │   │   ├── imtile.m
│   │   │   │   ├── stack2cell.m
│   │   │   │   ├── stack2vecs.m
│   │   │   │   └── vecs2stack.m
│   │   │   ├── swap.m
│   │   │   ├── time/
│   │   │   │   ├── GetSystemTimePreciseAsFileTime.m
│   │   │   │   ├── GetSystemTimePreciseAsFileTime.mexw64
│   │   │   │   ├── stic.m
│   │   │   │   ├── stoc.m
│   │   │   │   ├── stocf.m
│   │   │   │   └── systime.m
│   │   │   ├── varsize.m
│   │   │   ├── varstruct.m
│   │   │   ├── vert.m
│   │   │   └── vplay.m
│   │   └── video/
│   │       ├── validate_stack.m
│   │       └── vplayer.m
│   ├── training.py
│   ├── utils.py
│   └── viz.py
├── readme.md
├── setup.py
└── uninstall_leap.m

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# MATLAB
*.asv


# LEAP-specific
data/*
!data/readme.md
models/*
leap/toolbox/io/.lastdir


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: analysis/gait_analysis/GaitVectors3.mat
================================================
[File too large to display: 25.2 MB]

================================================
FILE: analysis/gait_analysis/Gait_Densities.mat
================================================
[File too large to display: 11.2 MB]

================================================
FILE: analysis/gait_analysis/compute_gait_densities.m
================================================
%% Pathing
embed_path = 'Z:\code\2018-05-05_joints_tsne_FlyAging_talmo-labels\results\FlyAging-DiegoCNN_v1.0_filters=64_rot=15_lrfactor=0.1_lrmindelta=1e-05_03.mat';
density_path = 'Z:\code\2018-05-05_joints_tsne_FlyAging_talmo-labels\viz\FlyAging-DiegoCNN_v1.0_filters=64_rot=15_lrfactor=0.1_lrmindelta=1e-05_03\density.mat';

density = load(density_path);
embed = load(embed_path);
Y = embed.Y;
Ld = density.Ld;
gait_path = 'C:\code\murthylab\JointTracker\LabelPostProcessing\GaitVectors3.mat';
gait = load(gait_path);
%% Compute density and segment
sigma = 20/30;
numGridPoints = 500;
gridRange = [-20 20];

% Setup grid
gv = linspace(gridRange(1), gridRange(2), numGridPoints);
xv = gv; yv = gv;

D_tripod = getDensity(Y(gait.hmm.most_likely_seq == 3 & gait.moving_forward',:),sigma, numGridPoints, gridRange);
D_tetrapod = getDensity(Y(gait.hmm.most_likely_seq == 4 & gait.moving_forward',:),sigma, numGridPoints, gridRange);
D_NC = getDensity(Y(gait.hmm.most_likely_seq == 5 & gait.moving_forward',:),sigma, numGridPoints, gridRange);

gait_density = zeros([size(D_tripod),3]);
gait_density(:,:,1) = D_tripod;
gait_density(:,:,2) = D_tetrapod;
gait_density(:,:,3) = D_NC;

% Saving
save_path = 'Gait_Densities';
save(save_path,'gait_density','D_tripod','D_tetrapod','D_NC','Ld');


================================================
FILE: analysis/gait_analysis/gait_analysis_computation.m
================================================
clear all;
%% Pathing
% addpath(genpath('deps'));
joints_dir = 'Z:\data\JointTracker\2018-02_FlyAging_boxes\expts\preds\talmo-labels\FlyAging-DiegoCNN_v1.0_filters=64_rot=15_lrfactor=0.1_lrmindelta=1e-05_03';
data_dir = 'Z:\data\Fly_Aging\male_data';

% Get the paths from the directories
data_fns = dir([data_dir,'/*_*']);
exptnames = {data_fns(:).name};

joint_fns = dir(strcat(joints_dir,'\*.h5'));
joint_exptnames = {joint_fns(:).name};
joint_exptnames = cf(@(x) x(1:end-3),joint_exptnames);

joint_expt_2_data = zeros(size(joint_exptnames));
for i = 1:numel(joint_exptnames)
    for j = 1:numel(exptnames)
        if strcmp(exptnames{j},joint_exptnames{i})
           joint_expt_2_data(i) = j;
        end
    end
end

joint_paths = cell(size(joint_fns));
for i = 1:numel(joint_fns)
    joint_paths{i} = fullfile(joint_fns(i).folder,joint_fns(i).name);
end

data_paths = cell(size(data_fns));
for i = 1:numel(data_fns)
    data_paths{i} = [fullfile(data_fns(i).folder,data_fns(i).name),'/Positions.dat'];
end

%% Get all of the positions for all of the videos. 
pos = cell(size(joint_paths));
% Loads the joint positions and adds the thorax back in to the fifth
% feature as all zeros.
parfor i = 1:numel(joint_paths)
    joints = h5read(joint_paths{i},'/positions_pred');
    joints = joints - joints(5,:,:);
    joints = reshape(joints,[],size(joints,3));
    pos{i} = joints;
end
pos = cat(2,pos{:});

%% Get the ids of all of the walking bouts according to the speed of centroids
moving_forward = cell(size(data_paths));
forward_velocity = cell(size(data_paths));
speed = cell(size(data_paths));
smoothing_window = 5;

% Velocity thresholds 
reconversion_constant = (1/(24.40/1088))./35; % This corrects a previous measurement error
conversion_to_mm = 31.0857/1088;
forward_motion_thresh = .02; % 2 mm/s

% For each video, get all velocity and speed stats;
for i = 1:numel(data_paths)
    % Load centers from ellipse data and smooth
    frames = h5read(['Z:\data\JointTracker\2018-02_FlyAging_boxes\expts\' exptnames{i} '.h5'],'/framesIdx');
%     ell = h5read(data_paths{i},'/ell');
    [X,Y,numLines] = positionReader(data_paths{i});
    X = X.*reconversion_constant;
    Y = Y.*reconversion_constant;% reconversion constants.
    ctr = [X(frames),Y(frames)];
    ctr = smoothdata(ctr,1,'movmean',smoothing_window);
    ell = h5read(['Z:\data\JointTracker\2018-02_FlyAging_boxes\expts\' exptnames{i} '.h5'],'/ell');
    
    % Get the velocity, direction of motion, and orientation of the fly
    vel_ctr = diffpad(ctr);
%     vel_ctr = smoothdata(vel_ctr,1,'movmean',smoothing_window*10);

    direction_of_motion = smoothdata(mod(unwrap(atan2(vel_ctr(:,2),vel_ctr(:,1))),2*pi),'movmean',smoothing_window);
    orientation = mod(unwrap((ell(frames,5)*2*pi/360)),2*pi);
    difference_dir = abs(direction_of_motion-orientation);
 
    % Get the component of the velocity in the forward direction
    speed_ctr = sqrt(sum(vel_ctr.^2,2));
    speed{i} = speed_ctr;
    forward_velocity{i} = cos(difference_dir).*speed{i};
    moving_forward{i} = forward_velocity{i} > forward_motion_thresh; 
end
lengths = cellfun(@(x) numel(x),moving_forward);
speed = cat(1,speed{:});
moving_forward = cat(1,moving_forward{:});
forward_velocity = cat(1,forward_velocity{:});
fv = forward_velocity;

%% Get rasters of when legs are moving in the forward direction
% This is meant to replicate "Quantification of 
% gait parameters in freely walking wild type and sensory deprived 
% Drosophila melanogaster" Figure 4
Fs = 100;

% We want to look at only the leg tips 
tips = [22 26 30 10 14 18]; % (The order matches the paper)
pos_tips = reshape(pos,[],2,size(pos,2));
pos_tips = pos_tips(tips,:,:);
dim = 1; % Only look in the x direction 
traj = squeeze(pos_tips(:,dim,:));

% Get the velocity relative to the center of the fly (egocentric vel) 
vel = diffpad(traj,2);
vel = smoothdata(vel,2,'gauss',smoothing_window);

% Define stance to be when the legs move in the negative x direction. 
stance = vel<0;

%% Get the contiguous bouts that will be used for HMM fitting
seq = sum(stance)+1;
TRGuess = rand(3);
EMITGuess = rand(3,7);
mf = find(moving_forward);

% Get the frames of contiguous bouts
endFrames = cumsum(lengths);
startFrames = [1 endFrames(1:end-1)'+1];
samples = cell1(numel(lengths));
samples_per_video = 3000;
duration_thresh = 50;
for i = 1:numel(lengths)
    bw = bwconncomp(moving_forward(startFrames(i):endFrames(i)));
    duration = cellfun(@(x) numel(x),bw.PixelIdxList);
    bw.PixelIdxList(duration < duration_thresh) = [];
    ids = cat(1,bw.PixelIdxList{:});
    ids = ids(1:(min(samples_per_video,numel(ids))));
    samples{i} = startFrames(i) + ids - 1;
end
num_samples = cellfun(@(x) numel(x),samples);
samples = cat(1,samples{num_samples>1});

%% Train HMM and get most likely sequence
tic;[ESTTR,ESTEMIT] = hmmtrain(seq(samples),TRGuess,EMITGuess);toc
mls = hmmviterbi(seq,ESTTR,ESTEMIT);

%% Here you should look at the emission probabilities
figure;
imagesc(ESTEMIT);

%% Reorder the labels to 3 = tripod, 4 = tetrapod, 5 = non-canonical
% Note: this can be automated, but since initialization is random
% and it doesn't take that long to train, I prefer checking manually. 
mls(mls == 3) = 4;
mls(mls == 2) = 5;
mls(mls == 1) = 3;

%% Save results into a structure
hmm.TRGUESS = TRGuess;
hmm.EMITGUESS = EMITGuess;
hmm.TrainingSamples = uint8(seq(samples));
hmm.ESTTR = ESTTR;
hmm.ESTEMIT = ESTEMIT;
hmm.most_likely_seq = uint8(mls);

% Saving
% save_path = 'GaitVector3';
% save(save_path,'hmm');

%% Look at particular section
% Limit the window to a particular region
% Tripod = 17230751;
% Tetrapod = 4781582;
% Tetrapod = 8472800
start = 8472800;
ids = start:start+100;

% Save example gait vectors
example_vel = vel(:,ids);
example_stance = stance(:,ids);
example_fv = fv(ids);

% Saving
% save_path = 'TripodExample';
% save_path = 'TetrapodExample';
% save(save_path,'example_vel','example_stance','example_fv','Fs');

%% Calculate the distribution of speeds for each hidden state and save
mls = hmm.most_likely_seq;
speed_lim = [2 35]; % mm/s
tri_ids = moving_forward & (mls == 3)' & (fv.*Fs < speed_lim(2)) & (fv.*Fs > speed_lim(1));
tetra_ids = moving_forward & (mls == 4)' & (fv.*Fs < speed_lim(2)) & (fv.*Fs > speed_lim(1));
NC_ids = moving_forward & (mls == 5)' & (fv.*Fs < speed_lim(2)) & (fv.*Fs > speed_lim(1));


[N1,edges1] = histcounts(fv(tri_ids).*Fs,166);
[N2,edges2] = histcounts(fv(tetra_ids).*Fs,166);
[N3,edges3] = histcounts(fv(NC_ids).*Fs,166);

% Saving
% save_path = 'Gait_Speed_Distributions';
% save(save_path,'N1','N2','N3','edges1','edges2','edges3','speed_lim')

%% Look at the Velocity statistics per Tsne locomotor state
density_path = 'Z:\code\2018-05-05_joints_tsne_FlyAging_talmo-labels\viz\FlyAging-DiegoCNN_v1.0_filters=64_rot=15_lrfactor=0.1_lrmindelta=1e-05_03\density.mat';
density = load(density_path);
embedding_ordered_locomotor_states = [7 11 13 10 8 9];
num_states = numel(embedding_ordered_locomotor_states);

%% Calculate the velocity distributions.
speed_ids = fv > 0 & fv < .4;
h = cell1(num_states);
N = cell1(num_states);
edges = cell1(num_states);
for i = 1:num_states
    state_ids = density.YL == embedding_ordered_locomotor_states(i);
    [N{i},edges{i}] = histcounts(fv(state_ids & speed_ids)*100,'Normalization','pdf');
end

% Saving
% save_path = 'Cluster_Velocity_Distributions';
% save(save_path,'N','edges','num_states');

%% Calculate the velocity of legs during swing as you bin velocities differently
win = -1:5;
speed_levels = [2 5:5:45];
leg_vel_at_speed = zeros(numel(speed_levels)-1,numel(win));
leg_vel_std_at_speed = zeros(numel(speed_levels)-1,numel(win));

% For each body speed level get the velocity of the swings over the window
for s = 1:numel(speed_levels)-1
    swings = cell([size(vel,1) 1]);
    inSpeed = (speed'*Fs >= speed_levels(s) & speed'*Fs < speed_levels(s+1));
    % Get the swings in which the fly was at the speed level and moving
    % forward.
    parfor i = 1:size(vel,1)
        % First get the swing starts
        bw = bwconncomp(~stance(i,:));
        ids = cell2mat(cf(@(x) x(1),bw.PixelIdxList));
        swing_start = false(size(stance(i,:)));
        swing_start(ids) = true;
        % Then get the swing starts where the fly is moving forward within
        % the speed threshold. 
        bw = bwconncomp(swing_start  & inSpeed);
        ids = cell2mat(cf(@(x) x(1),bw.PixelIdxList));
        ids = ids' + win;
        vel_i = vel(i,:);
        swings{i} = indpad(vel_i,ids);
    end
    swings = cat(1,swings{:});
    leg_vel_at_speed(s,:) = nanmean(swings,1);
    leg_vel_std_at_speed(s,:) = nanstd(swings,1);
end

% Saving
% save_path = 'Swing_Velocity_Over_Time';
% save(save_path,'leg_vel_at_speed','leg_vel_std_at_speed','speed_levels','win');

%% Stance / Swing Duration
swing_durations = cell([1,size(stance,1)]);
stance_durations = cell([1,size(stance,1)]);
swing_body_velocities = cell([1,size(stance,1)]);
stance_body_velocities = cell([1,size(stance,1)]);
period = cell([1,size(stance,1)]);
period_velocities = cell([1,size(stance,1)]);

% Get the swing, stance, and period duration and velocities
parfor i = 1:size(stance,1)
    % Swing
    bw = bwconncomp(~stance(i,:));
    swing_durations{i} = cellfun(@(x) numel(x),bw.PixelIdxList);
    swing_body_velocities{i} = cellfun(@(x) mean(fv(x)),bw.PixelIdxList);
    
    % Stance
    bw = bwconncomp(stance(i,:));
    stance_durations{i} = cellfun(@(x) numel(x),bw.PixelIdxList);
    stance_body_velocities{i} = cellfun(@(x) mean(fv(x)),bw.PixelIdxList);
    
    % Period
%     bw = bwconncomp(stance(i,:));
    period{i} = cellfun(@(x,y) numel(x(1):y(end)),{bw.PixelIdxList{1:end-1}},{bw.PixelIdxList{2:end}});
    period_velocities{i} = cellfun(@(x,y) mean(fv(x(1):y(end))),{bw.PixelIdxList{1:end-1}},{bw.PixelIdxList{2:end}});
end

swing_durations = cat(2,swing_durations{:});
stance_durations = cat(2,stance_durations{:});
swing_body_velocities = cat(2,swing_body_velocities{:});
stance_body_velocities = cat(2,stance_body_velocities{:});
period = cat(2,period{:});
period_velocities = cat(2,period_velocities{:});

%% Plot as a line plot
stance_vel_thresh = 7.2; % This number is taken from Mendes et al

ids = stance_body_velocities*Fs > stance_vel_thresh;
xranges = [stance_vel_thresh  50];
yranges = [0 prctile(stance_durations(ids),99)];
stance_edges = xranges(1):1:xranges(2);

% Stance Duration vs Velocity
X1 = stance_body_velocities(ids)*Fs;
Y1 = stance_durations(ids);
[X1ids] = discretize(X1,stance_edges);
[stance_dur_mu, stance_dur_std] = grpstats(Y1, categorical(X1ids),{'mean','std'});

swing_vel_thresh = 7.2; % This number is taken from Mendes et al
ids = swing_body_velocities*Fs > swing_vel_thresh;
xranges = [swing_vel_thresh  50];
swing_edges = xranges(1):1:xranges(2);

% Stance Duration vs Velocity
X2 = swing_body_velocities(ids)*Fs;
Y2 = swing_durations(ids);
[X2ids] = discretize(X2,swing_edges);
[swing_dur_mu, swing_dur_std] = grpstats(Y2, categorical(X2ids),{'mean','std'});

% Saving
% save_path = 'Swing_and_Stance_versus_Velocity';
% save(save_path,'stance_dur_mu','swing_dur_mu','stance_dur_std','swing_dur_std','stance_edges','swing_edges')


================================================
FILE: analysis/gait_analysis/gait_analysis_plotting.m
================================================
clear all;
%% Look at particular section
% Pick the gait example to observe
% load('TetrapodExample');
load('TripodExample');

% Plot the velocity of the leg tips
figure('pos',[153, 427, 560, 420]); hold on; axis tight; set(gcf,'color','w'); fontsize(16)
imagesc(example_vel);
xlabel('Time (seconds)')
xticklabels(xticks/Fs);
ylabel('Leg Tip')
yticks([1:6])
yticklabels({'RF','RM','RH','LF','LM','LH'})
ax1 = gca;
caxis([-10 10]);

% Plot the rasters
figure('pos',[850, 634, 848, 334]); hold on; axis tight; set(gcf,'color','w'); fontsize(16)
imagesc(example_stance); colormap('gray');
xlabel('Time (seconds)')
xticklabels(xticks/Fs);
ylabel('Leg Tip')
yticks([1:6])
yticklabels({'LF','LM','LH','RF','RM','RH'})
ax2 = gca;

% plot the forward velocity of the fly
figure('pos',[850, 359, 854, 186]); hold on; axis tight; set(gcf,'color','w'); fontsize(16)
plot(example_fv.*Fs)
xlabel('Time (seconds)')
xticklabels(xticks/Fs);
ylabel('Forward Velocity (mm/s)')
ax3 = gca;
ylim([0 40]);

% plot the raster of tripod, tetrapod, or non-canonical
figure('pos',[849, 114, 931, 150]); hold on; axis tight; set(gcf,'color','w'); fontsize(16)
example_gait = sum(example_stance,1);
example_gait(~(example_gait == 3 | example_gait == 4)) = 5; 
imagesc(example_gait);colormap('jet');h = colorbar; 
xlabel('Time (seconds)')
xticklabels(xticks/Fs);
yticks([])
ylabel(h,'Number of legs in stance')
ax4 = gca;
linkaxes([ax1,ax2,ax3,ax4],'x')

%% Plot the emission probabilities for each hidden states
load('GaitVectors3.mat');
emissions = hmm.ESTEMIT;
temp = emissions(2,:);
emissions(2,:) = emissions(3,:);
emissions(3,:) = temp;
figure; hold on;
imagesc(emissions);
for i = 1:size(emissions,1)
    for j = 1:size(emissions,2)
        caption = sprintf('%.2f',emissions(i,j));
        text(j,i,caption,'Fontsize',10,'FontWeight','bold','HorizontalAlignment','center','Color',[0 0 0]);
    end
end
axis ij;
axis tight
xlabel('Number of Legs in Stance');
yticks([1 2 3]);
yticklabels({'Tripod','Tetrapod','Non-canonical'})
xticklabels({'0','1','2','3','4','5','6'})
fontsize(16)

figure; hold on;
plot(emissions','LineWidth',3);
xlabel('Number of Legs in Stance');
xticklabels({'0','1','2','3','4','5','6'})
legend({'Tripod','Tetrapod','Non-canonical'})
ylabel('Emission Probability')

%% Plot the distribution of speeds
load('Gait_Speed_Distributions');
figure; hold on;
plot(edges1(1:end-1),N1,'LineWidth',3)
plot(edges2(1:end-1),N2,'LineWidth',3)
plot(edges3(1:end-1),N3,'LineWidth',3)
xlabel('Forward Velocity (ms)')
ylabel('Count')
xlim(speed_lim)
legend({'Tripod','Tetrapod','Non-canonical'})
fontsize(16)

%% Plot the velocity Distributions
load('Cluster_Velocity_Distributions');
figure; hold on;
cmap = spring(num_states);
for i = 1:num_states
    plot(edges{i}(1:end-1),N{i},'Color',cmap(num_states + 1 -i,:),'LineWidth',3);
end
grid on;
xlabel('Forward Velocity')
ylabel('Probability')
axis tight;
ylim([0 .2])

%% Plotting the mean and std with bounded lines
load('Swing_Velocity_Over_Time');
cmap = parula(numel(speed_levels));
p_lines = cell1(numel(speed_levels)-1);
figure('pos',[568, 186, 1036, 798]); figclosekey; set(gcf,'color','w'); hold on;
for i = 1:size(leg_vel_at_speed,1)
    yci = zeros(2,size(leg_vel_at_speed,2));
    yci(1,:) = leg_vel_std_at_speed(i,:);
    yci(2,:) = leg_vel_std_at_speed(i,:);
    [p_lines{i},~] = boundedline(win*10,leg_vel_at_speed(i,:),yci','alpha','cmap',cmap(i,:));
    p_lines{i}.LineWidth = 3;
end

% Legend
leg = cell([1 numel(speed_levels)-1]);
for i = 1:numel(speed_levels)-1
    leg{i} = sprintf('%d - %d mm/s',speed_levels(i),speed_levels(i+1));
end
l = legend([p_lines{:}],leg);
l.Position = [0.7503 0.6488 0.1573 0.3239];
fontsize(16)
xlabel('Time from swing onset (ms)')
ylabel('Swing velocity (mm/s)')
% export_fig('figs/Swing_Velocity_vs_Time_Confidences.png','-r300')

%% Plot the Swing_and_Stance_versus_Velocity
load('Swing_and_Stance_versus_Velocity')
figure; figclosekey, hold on;

yci = zeros(2,numel(stance_dur_std));
yci(1,:) = stance_dur_std;
yci(2,:) = stance_dur_std;
[bl1,~] = boundedline(stance_edges(2:end),stance_dur_mu',yci','alpha');

yci = zeros(2,numel(swing_dur_std));
yci(1,:) = swing_dur_std;
yci(2,:) = swing_dur_std;
[bl2,~] = boundedline(swing_edges(2:end),swing_dur_mu',yci','alpha','r');

yticklabels(round(yticks*10))
ylabel('Durations (ms)');
xlabel('Average Body Speed (mm/s)');
axis tight
legend([bl1,bl2],{'Stance','Swing'})
fontsize(16);

================================================
FILE: analysis/gait_analysis/plot_gait_densities.m
================================================
clear all;
%% Plot all three densities overlayed ontop of one another
load('Gait_Densities');
figure; hold on;
h = imagesc(gait_density./(max(max(max(gait_density)))));
axis equal; axis xy; axis off

%% Plot the distributions for each mode in the same scale

Locomotor_states = [7 11 13 9 10 8];
[cropped,~,crop_mask] = bwcrop(ismember(Ld,Locomotor_states));

figure; hold on; axis xy; axis equal; axis off;
cropped_density = reshape(D_tripod(crop_mask),size(cropped,1),[]);
h = imagesc(cropped_density);
c1 = colorbar;
peak = max(c1.Limits);
colormap('viridis')

figure; hold on; axis xy; axis equal; axis off;
cropped_density = reshape(D_tetrapod(crop_mask),size(cropped,1),[]);
h = imagesc(cropped_density);
% h.AlphaData = 1 .* ~reshape(density.Lbnds(crop_mask),size(cropped,1),[]);
c2 = colorbar;
colormap('viridis')
caxis([0 peak]);

figure; hold on; axis xy; axis equal; axis off;
cropped_density = reshape(D_NC(crop_mask),size(cropped,1),[]);
h = imagesc(cropped_density);
% h.AlphaData = 1 .* ~reshape(density.Lbnds(crop_mask),size(cropped,1),[]);
colormap('viridis')
c3 = colorbar;
caxis([0 peak]);


================================================
FILE: data/readme.md
================================================
The full fly dataset and all trained networks used in our paper can be downloaded from: [`http://arks.princeton.edu/ark:/88435/dsp01pz50gz79z`](http://arks.princeton.edu/ark:/88435/dsp01pz50gz79z)

Additional datasets:

## BermanFlies

_We recommend using this for testing._

- [Cluster sampled images](https://1drv.ms/u/s!AnmpIqqfwz3zgbUg2M8Sa_0NcLhrMg) (**68.7 MiB**)
- [Labels for n = 1500 images](https://1drv.ms/u/s!AnmpIqqfwz3zgbUeIuYYDj_A1pCr9Q) (**291 KiB**)
- [Trained network model](https://1drv.ms/u/s!AnmpIqqfwz3zgcwUvbqPUn7mXIMLLg) (**53.0 MiB**)
- [Full dataset](http://arks.princeton.edu/ark:/88435/dsp01pz50gz79z) (**~168 GiB**)

## CatNect

This is the dataset used for the tutorial. It contains a clip of a cat chasing a laser pointer recorded with a Kinect.

- [Full clip](https://1drv.ms/u/s!AnmpIqqfwz3zgcwS_3gAJFU0sANBcA) (**105 MiB**)
- [Cluster sampled images](https://1drv.ms/u/s!AnmpIqqfwz3zgcwR_9mmNJz8ALW3Hw) (**42.2 MiB**)
- [Labels for n = 40 images](https://1drv.ms/u/s!AnmpIqqfwz3zgcwQlKSXXDy9KvIPVg) (**82.3 KiB**)
- [Trained network model](https://1drv.ms/u/s!AnmpIqqfwz3zgc0H-v6qbnc3vak2Lw) (**64.7 MiB**)



================================================
FILE: examples/batch_process_video.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example: Predict body part positions from an MP4 file\n",
    "This notebook presents an example pipeline for applying a trained LEAP network to ~360k frames read from an MP4 video.\n",
    "\n",
    "You can download the data to reproduce the benchmarking results below.\n",
    "\n",
    "**Input data:** [072212_163153.mp4](https://1drv.ms/v/s!AnmpIqqfwz3zgcgekCxNp-MN76p1UQ) (254 MiB)\n",
    "\n",
    "**Output data:** [072212_163153.preds.h5](https://1drv.ms/u/s!AnmpIqqfwz3zgcgdDhQrKRsBaxvCXQ) (46.9 MiB)\n",
    "\n",
    "The trained network can be [downloaded here](https://1drv.ms/u/s!AnmpIqqfwz3zgdpIOzsqojhEmr0J0w) or from the links in the [repository data folder](https://github.com/talmo/leap/tree/master/data)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Anaconda3\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n",
      "Using TensorFlow backend.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Platform: Windows-10-10.0.16299-SP0\n",
      "h5py:\n",
      "Summary of the h5py configuration\n",
      "---------------------------------\n",
      "\n",
      "h5py    2.7.1\n",
      "HDF5    1.10.1\n",
      "Python  3.6.4 |Anaconda, Inc.| (default, Jan 16 2018, 10:22:32) [MSC v.1900 64 bit (AMD64)]\n",
      "sys.platform    win32\n",
      "sys.maxsize     9223372036854775807\n",
      "numpy   1.14.1\n",
      "\n",
      "Keras: 2.2.0\n",
      "Tensorflow: 1.5.0\n",
      "Devices:\n",
      "[name: \"/device:CPU:0\"\n",
      "device_type: \"CPU\"\n",
      "memory_limit: 268435456\n",
      "locality {\n",
      "}\n",
      "incarnation: 12947954769568633288\n",
      ", name: \"/device:GPU:0\"\n",
      "device_type: \"GPU\"\n",
      "memory_limit: 9143884186\n",
      "locality {\n",
      "  bus_id: 1\n",
      "}\n",
      "incarnation: 16757352010506659625\n",
      "physical_device_desc: \"device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:01:00.0, compute capability: 6.1\"\n",
      ", name: \"/device:GPU:1\"\n",
      "device_type: \"GPU\"\n",
      "memory_limit: 9143884186\n",
      "locality {\n",
      "  bus_id: 1\n",
      "}\n",
      "incarnation: 16188733693954377295\n",
      "physical_device_desc: \"device: 1, name: GeForce GTX 1080 Ti, pci bus id: 0000:02:00.0, compute capability: 6.1\"\n",
      "]\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import cv2\n",
    "import h5py\n",
    "from time import time\n",
    "\n",
    "import keras\n",
    "import keras.models\n",
    "from leap.predict_box import convert_to_peak_outputs\n",
    "from leap.utils import versions\n",
    "\n",
    "versions(list_devices=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Media file path\n",
    "video_path = \"D:/tmp/072212_163153.mp4\"\n",
    "\n",
    "# Trained network path\n",
    "model_path = \"D:/OneDrive/code/leap/data/BermanFlies/models/180615_025354-n=1500/final_model.h5\"\n",
    "\n",
    "# Predictions output path\n",
    "save_path = \"D:/tmp/072212_163153.preds.h5\"\n",
    "\n",
    "# Number of frames to read before predicting (higher = faster, but limited by RAM)\n",
    "chunk_size = 10000\n",
    "\n",
    "# Number of frames to evaluate at once on the GPU (higher = faster, but limited by GPU memory)\n",
    "batch_size = 64"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: D:/OneDrive/code/leap/data/BermanFlies/models/180615_025354-n=1500/final_model.h5\n",
      "    Input: (None, 192, 192, 1)\n",
      "    Output: (None, 3, 32)\n",
      "Predicted: 10000/361000 frames | Elapsed: 0.4 min / 391.1 FPS / ETA: 15.0 min\n",
      "Predicted: 20000/361000 frames | Elapsed: 0.8 min / 435.0 FPS / ETA: 13.1 min\n",
      "Predicted: 30000/361000 frames | Elapsed: 1.1 min / 451.0 FPS / ETA: 12.2 min\n",
      "Predicted: 40000/361000 frames | Elapsed: 1.5 min / 459.0 FPS / ETA: 11.7 min\n",
      "Predicted: 50000/361000 frames | Elapsed: 1.8 min / 464.7 FPS / ETA: 11.2 min\n",
      "Predicted: 60000/361000 frames | Elapsed: 2.1 min / 468.7 FPS / ETA: 10.7 min\n",
      "Predicted: 70000/361000 frames | Elapsed: 2.5 min / 471.4 FPS / ETA: 10.3 min\n",
      "Predicted: 80000/361000 frames | Elapsed: 2.8 min / 473.5 FPS / ETA: 9.9 min\n",
      "Predicted: 90000/361000 frames | Elapsed: 3.2 min / 474.8 FPS / ETA: 9.5 min\n",
      "Predicted: 100000/361000 frames | Elapsed: 3.5 min / 476.0 FPS / ETA: 9.1 min\n",
      "Predicted: 110000/361000 frames | Elapsed: 3.8 min / 477.4 FPS / ETA: 8.8 min\n",
      "Predicted: 120000/361000 frames | Elapsed: 4.2 min / 478.4 FPS / ETA: 8.4 min\n",
      "Predicted: 130000/361000 frames | Elapsed: 4.5 min / 479.5 FPS / ETA: 8.0 min\n",
      "Predicted: 140000/361000 frames | Elapsed: 4.8 min / 481.4 FPS / ETA: 7.7 min\n",
      "Predicted: 150000/361000 frames | Elapsed: 5.2 min / 483.0 FPS / ETA: 7.3 min\n",
      "Predicted: 160000/361000 frames | Elapsed: 5.5 min / 483.7 FPS / ETA: 6.9 min\n",
      "Predicted: 170000/361000 frames | Elapsed: 5.9 min / 484.0 FPS / ETA: 6.6 min\n",
      "Predicted: 180000/361000 frames | Elapsed: 6.2 min / 484.1 FPS / ETA: 6.2 min\n",
      "Predicted: 190000/361000 frames | Elapsed: 6.5 min / 484.3 FPS / ETA: 5.9 min\n",
      "Predicted: 200000/361000 frames | Elapsed: 6.9 min / 484.6 FPS / ETA: 5.5 min\n",
      "Predicted: 210000/361000 frames | Elapsed: 7.2 min / 484.7 FPS / ETA: 5.2 min\n",
      "Predicted: 220000/361000 frames | Elapsed: 7.6 min / 484.6 FPS / ETA: 4.8 min\n",
      "Predicted: 230000/361000 frames | Elapsed: 7.9 min / 484.7 FPS / ETA: 4.5 min\n",
      "Predicted: 240000/361000 frames | Elapsed: 8.3 min / 484.7 FPS / ETA: 4.2 min\n",
      "Predicted: 250000/361000 frames | Elapsed: 8.6 min / 484.9 FPS / ETA: 3.8 min\n",
      "Predicted: 260000/361000 frames | Elapsed: 8.9 min / 485.0 FPS / ETA: 3.5 min\n",
      "Predicted: 270000/361000 frames | Elapsed: 9.3 min / 485.2 FPS / ETA: 3.1 min\n",
      "Predicted: 280000/361000 frames | Elapsed: 9.6 min / 485.3 FPS / ETA: 2.8 min\n",
      "Predicted: 290000/361000 frames | Elapsed: 10.0 min / 485.1 FPS / ETA: 2.4 min\n",
      "Predicted: 300000/361000 frames | Elapsed: 10.3 min / 485.1 FPS / ETA: 2.1 min\n",
      "Predicted: 310000/361000 frames | Elapsed: 10.7 min / 485.1 FPS / ETA: 1.8 min\n",
      "Predicted: 320000/361000 frames | Elapsed: 11.0 min / 485.3 FPS / ETA: 1.4 min\n",
      "Predicted: 330000/361000 frames | Elapsed: 11.3 min / 485.1 FPS / ETA: 1.1 min\n",
      "Predicted: 340000/361000 frames | Elapsed: 11.7 min / 485.2 FPS / ETA: 0.7 min\n",
      "Predicted: 350000/361000 frames | Elapsed: 12.0 min / 485.4 FPS / ETA: 0.4 min\n",
      "Predicted: 360000/361000 frames | Elapsed: 12.4 min / 485.6 FPS / ETA: 0.0 min\n",
      "Predicted: 361000/361000 frames | Elapsed: 12.4 min / 485.4 FPS / ETA: 0.0 min\n",
      "Finished predicting 361000 frames.\n",
      "    Prediction | Runtime: 11.55 min / 520.921 FPS\n",
      "    Reading    | Runtime: 0.78 min / 7726.516 FPS\n",
      "Saved: D:/tmp/072212_163153.preds.h5\n",
      "Total runtime: 12.4 mins\n",
      "Total performance: 483.585 FPS\n"
     ]
    }
   ],
   "source": [
    "t0_all = time()\n",
    "\n",
    "# Load model and convert to peak-coordinate output\n",
    "model = convert_to_peak_outputs(keras.models.load_model(model_path))\n",
    "print(\"Model:\", model_path)\n",
    "print(\"    Input:\", str(model.input_shape))\n",
    "print(\"    Output:\", str(model.output_shape))\n",
    "\n",
    "# model = keras.utils.multi_gpu_model(model, gpus=2)\n",
    "\n",
    "# Open video for reading\n",
    "reader = cv2.VideoCapture(video_path)\n",
    "num_samples = int(reader.get(cv2.CAP_PROP_FRAME_COUNT))\n",
    "\n",
    "# Initialize\n",
    "positions_pred = []\n",
    "conf_pred = []\n",
    "buffer = []\n",
    "samples_predicted = 0\n",
    "reading_runtime = 0\n",
    "prediction_runtime = 0\n",
    "done = False\n",
    "\n",
    "# Process video chunk-by-chunk\n",
    "while not done:\n",
    "    t0_reading = time()\n",
    "    # Read and finish if no frame was retrieved\n",
    "    returned_frame, I = reader.read()\n",
    "    done = not returned_frame\n",
    "    reading_runtime += time() - t0_reading\n",
    "    \n",
    "    # Add current frame to buffer\n",
    "    if not done:\n",
    "        buffer.append(I[...,0])\n",
    "    \n",
    "    # Do we have anything to predict?\n",
    "    if len(buffer) >= chunk_size or (done and len(buffer) > 0):\n",
    "        t0_prediction = time()\n",
    "        \n",
    "        # Predict on buffer\n",
    "        Y = model.predict(np.stack(buffer, axis=0)[...,None], batch_size=batch_size)\n",
    "        \n",
    "        # Save\n",
    "        positions_pred.append(Y[:,:2,:].astype(\"int32\"))\n",
    "        conf_pred.append(Y[:,2,:].squeeze())\n",
    "        \n",
    "        # Empty out buffer container\n",
    "        buffer = []\n",
    "        \n",
    "        # Performance stats\n",
    "        samples_predicted += len(Y)\n",
    "        prediction_runtime += time() - t0_prediction\n",
    "        elapsed = time() - t0_all\n",
    "        fps = samples_predicted / elapsed\n",
    "        print(\"Predicted: %d/%d frames | Elapsed: %.1f min / %.1f FPS / ETA: %.1f min\" %\n",
    "              (samples_predicted, num_samples, elapsed / 60, fps, (num_samples - samples_predicted) / fps / 60))\n",
    "        \n",
    "# Close video reader\n",
    "reader.release()\n",
    "\n",
    "# Merge arrays\n",
    "positions_pred = np.concatenate(positions_pred, axis=0)\n",
    "conf_pred = np.concatenate(conf_pred, axis=0)\n",
    "\n",
    "# Report performance stats\n",
    "print(\"Finished predicting %d frames.\" % samples_predicted)\n",
    "print(\"    Prediction | Runtime: %.2f min / %.3f FPS\" % (prediction_runtime / 60, samples_predicted / prediction_runtime))\n",
    "print(\"    Reading    | Runtime: %.2f min / %.3f FPS\" % (reading_runtime / 60, samples_predicted / reading_runtime))\n",
    "\n",
    "# Save\n",
    "if os.path.exists(save_path):\n",
    "    os.remove(save_path)\n",
    "with h5py.File(save_path, \"w\") as f:\n",
    "        f.attrs[\"num_samples\"] = num_samples\n",
    "        f.attrs[\"video_path\"] = video_path\n",
    "        f.attrs[\"model_path\"] = model_path\n",
    "\n",
    "        ds_pos = f.create_dataset(\"positions_pred\", data=positions_pred, compression=\"gzip\", compression_opts=1)\n",
    "        ds_pos.attrs[\"description\"] = \"coordinate of peak at each sample\"\n",
    "        ds_pos.attrs[\"dims\"] = \"(sample, [x, y], joint) === (sample, [column, row], joint)\"\n",
    "\n",
    "        ds_conf = f.create_dataset(\"conf_pred\", data=conf_pred, compression=\"gzip\", compression_opts=1)\n",
    "        ds_conf.attrs[\"description\"] = \"confidence map value in [0, 1.0] at peak\"\n",
    "        ds_conf.attrs[\"dims\"] = \"(sample, joint)\"\n",
    "\n",
    "        total_runtime = time() - t0_all\n",
    "        f.attrs[\"reading_runtime_secs\"] = reading_runtime\n",
    "        f.attrs[\"prediction_runtime_secs\"] = prediction_runtime\n",
    "        f.attrs[\"total_runtime_secs\"] = total_runtime\n",
    "        \n",
    "    \n",
    "print(\"Saved:\", save_path)\n",
    "\n",
    "print(\"Total runtime: %.1f mins\" % (total_runtime / 60))\n",
    "print(\"Total performance: %.3f FPS\" % (samples_predicted / total_runtime))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": "block",
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}


================================================
FILE: examples/hdf5tovid.m
================================================
% Clean start!
clear all, clc

%% Parameters
% Path to input file
dataPath = '../data/examples/072212_163153.clip.h5';

% Dataset name
dset = '/box';

% Path to output file
savePath = 'C:\tmp\072212_163153.clip.mp4';

% Frames to convert at a time (lower this if your memory is limited)
chunkSize = 1000;

% Framerate for playback of the video file
fps = 25;

%% Initialize
% Get dataset info
info = h5info(dataPath, dset);
shape = info.Dataspace.Size;
numFrames = shape(end);

% Check if file already exists
if exist(savePath,'file') > 0
    warning(['Overwriting existing video file: ' savePath])
    delete(savePath)
end

% Open video for writing
writer = VideoWriter(savePath,'MPEG-4'); % use this for MP4s
% writer = VideoWriter(savePath,'Motion JPEG AVI'); % use this for AVIs

% Set compression quality (higher = bigger file, better quality)
writer.Quality = 100;

% Set playback speed in frames/second
writer.FrameRate = fps;

% Open file for writing
writer.open();

%% Save
framesWritten = 0;
done = false;
t0 = tic;
while ~done
    % Check how many frames to read
    chunkFrames = min(chunkSize, numFrames - framesWritten);
    
    % Read chunk
    chunk = h5read(dataPath,dset,[1 1 1 framesWritten+1], [inf inf inf chunkFrames]);
    
    % Check for datatype/range concordance (floats must be in [0,1])
    if isfloat(chunk) && max(chunk(:)) > 1
        chunk = chunk / 255;
    end
    
    % Write frames
    writer.writeVideo(chunk);

    % Increment frames written counter
    framesWritten = framesWritten + size(chunk,4);
    
    % Check if we're done
    done = framesWritten >= numFrames;
end

elapsed = toc(t0);
fprintf('Finished writing %d frames in %.2f mins:\n\t%s\n', framesWritten, elapsed/60, savePath)

% Close file
writer.close();


================================================
FILE: examples/vidtohdf5.m
================================================
% Clean start!
clear all, clc

%% Parameters
% Path to input file
videoPath = '..\..\leap\data\examples\072212_163153.clip.mp4';

% Path to output file
% savePath = '..\..\leap\data\examples\072212_163153.clip.h5';
savePath = 'C:\tmp\072212_163153.clip.h5';

% Frames to convert at a time (lower this if your memory is limited)
chunkSize = 1000;

% Convert frames to single channel grayscale images (instead of 3 channel RGB)
grayscale = true;

%% Initialize
% Open video for reading
vr = VideoReader(videoPath);

% Check size from first frame
I0 = vr.readFrame();
if grayscale; I0 = rgb2gray(I0); end
frameSize = size(I0);
if numel(frameSize) == 2; frameSize = [frameSize 1]; end

% Reset VideoReader
delete(vr);
vr = VideoReader(videoPath);

% Check if file already exists
if exist(savePath,'file') > 0
    warning(['Overwriting existing HDF5 file: ' savePath])
    delete(savePath)
end

% Create HDF5 file with infinite number of frames and GZIP compression
h5create(savePath,'/box',[frameSize inf],'ChunkSize',[frameSize 1],'Deflate',1,'Datatype','uint8')

%% Save
buffer = cell(chunkSize,1);
done = false;
framesRead = 0;
framesWritten = 0;
t0 = tic;
while ~done
    % Read next frame
    I = vr.readFrame();
    if grayscale; I = rgb2gray(I); end
    
    % Check if there are any frames left
    done = ~vr.hasFrame();
    
    % Increment frames read counter and add to the write buffer
    framesRead = framesRead + 1;
    buffer{mod(framesRead-1, chunkSize)+1} = I;
    
    % Have we filled the buffer or are there no frames left?
    if mod(framesRead, chunkSize) == 0 || done
        % Concatenate the buffer into an array
        chunk = cat(4, buffer{:});
        
        % Extend the dataset and save to disk
        h5write(savePath, '/box', chunk, [1 1 1 framesWritten+1], size(chunk))
        
        % Increment frames written counter
        framesWritten = framesWritten + size(chunk,4);
    end
end
elapsed = toc(t0);
fprintf('Finished writing %d frames in %.2f mins.\n', framesWritten, elapsed/60)

h5disp(savePath)


================================================
FILE: install_leap.m
================================================
function install_leap()
%INSTALL_LEAP Installs the Python package and adds MATLAB scripts to path.
% Usage:
%   install_leap
% 
% See also: uninstall_leap, test_leap

% Find base repository path (where this file is contained)
basePath = fileparts(which('install_leap'));

% Add to MATLAB path
addpath(genpath(fullfile(basePath,'leap')));

% Check if Python package is importable
canImportPython = test_leap();
if ~canImportPython
    [status,msg] = system(['pip install -e "' basePath '"']);
    disp(msg)
end

% Check again
test_leap;
end


================================================
FILE: leap/__init__.py
================================================
from . import image_augmentation
from . import layers
from . import models
from . import predict_box
from . import training
from . import utils
from . import viz

================================================
FILE: leap/compute_errors.m
================================================
function err = compute_errors(pos_pred, pos_gt)
%COMPUTE_ERRORS Computes error rates given predicted and ground truth positions.
% Usage:
%   err = compute_errors(pos_pred, pos_gt)
% 
% Args:
%   pos_pred: predicted positions    (J x 2 x N)
%   pos_gt: ground truth predictions (J x 2 x N)
%
% Returns:
%   err: struct with error metrics
% 
% See also: 

if isstruct(pos_pred)
    pos_pred =  pos_pred.positions_pred;
end
if isstruct(pos_gt)
    pos_gt = pos_gt.positions_pred;
end

% Find the difference between predicted and ground truth
delta = pos_pred - pos_gt;

% Find Euclidean distance between each pair of points
euclidean = squeeze(sqrt(sum(delta .^ 2, 2)))';

% Compute metrics overall
mae_all = mean(abs(delta(:)));
mse_all = mean(delta(:) .^ 2);
rmse_all = sqrt(mse_all);

% Compute metrics per joint
delta_rows = reshape(permute(delta,[2 3 1]),[],size(delta,1));
mae = mean(abs(delta_rows));
mse = mean(delta_rows .^ 2);
rmse = sqrt(mse);

% Return everything
err = varstruct(delta, euclidean, ...
    mae_all, mse_all, rmse_all, ...
    delta_rows, mae, mse, rmse);

end


================================================
FILE: leap/confmaps2pts.m
================================================
function [pts, confvals] = confmaps2pts(C)
%CONFMAPS2PTS Convert a set of confidence maps into a set of points.
% Usage:
%   [pts, confvals] = confmaps2pts(C)
%
% See also: pts2confmaps

numChannels = size(C,3);
pts = zeros(numChannels,2,'single');
confvals = zeros(numChannels,1,'like',C);
for i = 1:numChannels
    [confvals(i), ind] = max(vert(C(:,:,i)));
    [r,c] = ind2sub(size(C(:,:,i)),ind);
    pts(i,:) = [c r];
end

end



================================================
FILE: leap/generate_training_set.m
================================================
function savePath = generate_training_set(boxPath, varargin)
%GENERATE_TRAINING_SET Creates a dataset for training.
% Usage: generate_training_set(boxPath, ...)

t0_all = stic;
%% Setup
defaults = struct();
defaults.savePath = [];
defaults.scale = 1;
defaults.mirroring = true; % flip images and adjust confidence maps to augment dataset
defaults.horizontalOrientation = true; % animal is facing right/left if true (for mirroring)
defaults.sigma = 5; % kernel size for confidence maps
defaults.normalizeConfmaps = true; % scale maps to [0,1] range
defaults.postShuffle = true; % shuffle data before saving (useful for reproducible dataset order)
defaults.testFraction = 0; % separate these data from training and validation sets
defaults.compress = false; % use GZIP compression to save the outputs

params = parse_params(varargin,defaults);

% Paths
labelsPath = repext(boxPath,'labels.mat');

% Output
savePath = params.savePath;
if isempty(savePath)
    savePath = ff(fileparts(boxPath), 'training', [get_filename(boxPath,true) '.h5']);
    savePath = get_new_filename(savePath,true);
end
mkdirto(savePath)

%% Labels
labels = load(labelsPath);

% Check for complete frames
labeledIdx = find(squeeze(all(all(~isnan(labels.positions),2),1)));
numFrames = numel(labeledIdx);
printf('Found %d/%d labeled frames.', numFrames, size(labels.positions,3))

% Pull out label data
joints = labels.positions(:,:,labeledIdx);
joints = joints * params.scale;
numJoints = size(joints,1);

% Pull out other info
jointNames = labels.skeleton.nodes;
skeleton = struct();
skeleton.edges = labels.skeleton.edges;
skeleton.pos = labels.skeleton.pos;

%% Load images
stic;
box = h5readframes(boxPath,'/box',labeledIdx);
if params.scale ~= 1; box = imresize(box,params.scale); end
boxSize = size(box(:,:,:,1));
stocf('Loaded %d images', size(box,4))

% Load metadata
try exptID = h5read(boxPath, '/exptID'); exptID = exptID(labeledIdx); catch; end
try framesIdx = h5read(boxPath, '/framesIdx'); framesIdx = framesIdx(labeledIdx); catch; end
try idxs = h5read(boxPath, '/idxs'); idxs = idxs(labeledIdx); catch; end

try L = h5read(boxPath, '/L'); L = L(labeledIdx); catch; end
try box_no_seg = imresize(h5readframes(boxPath,'/box_no_seg',labeledIdx),params.scale); catch; end
try box_raw = imresize(h5readframes(boxPath,'/box_raw',labeledIdx),params.scale); catch; end
attrs = h5att2struct(boxPath);

%% Generate confidence maps
stic;
confmaps = NaN([boxSize(1:2), numJoints, numFrames],'single');
parfor i = 1:numFrames
    pts = joints(:,:,i);
    confmaps(:,:,:,i) = pts2confmaps(pts,boxSize(1:2),params.sigma,params.normalizeConfmaps);
end
stocf('Generated confidence maps') % 15 sec for 192x192x32x500
varsize(confmaps)

%% Augment by mirroring
if params.mirroring
    % Flip images
    if params.horizontalOrientation
        box_flip = flipud(box);
        try box_no_seg_flip = flipud(box_no_seg); catch; end
        try box_raw_flip = flipud(box_raw); catch; end
        confmaps_flip = flipud(confmaps);
        joints_flip = joints; joints_flip(:,2,:) = size(box,1) - joints_flip(:,2,:);
    else
        box_flip = fliplr(box);
        try box_no_seg_flip = fliplr(box_no_seg); catch; end
        try box_raw_flip = fliplr(box_raw); catch; end
        confmaps_flip = fliplr(confmaps);
        joints_flip = joints; joints_flip(:,1,:) = size(box,2) - joints_flip(:,1,:);
    end

    % Check for *L/*R naming pattern (e.g., {{'wingL','wingR'}, {'legR1','legL1'}})
    swap_names = {};
    baseNames = regexp(jointNames,'(.*)L([0-9]*)$','tokens');
    isSymmetric = ~cellfun(@isempty,baseNames);
    for i = horz(find(isSymmetric))
        nameR = [baseNames{i}{1}{1} 'R' baseNames{i}{1}{2}];
        if ismember(nameR,jointNames)
            swap_names{end+1} = {jointNames{i}, nameR};
        end
    end

    % Swap channels accordingly
    printf('Symmetric channels:')
    for i = 1:numel(swap_names)
        [~,swap_idx] = ismember(swap_names{i}, jointNames);
        if any(swap_idx == 0); continue; end
        printf('    %s (%d) <-> %s (%d)', jointNames{swap_idx(1)}, swap_idx(1), ...
            jointNames{swap_idx(2)}, swap_idx(2))

        joints_flip(swap_idx,:,:) = joints_flip(fliplr(horz(swap_idx)),:,:);
        confmaps_flip(:,:,swap_idx,:) = confmaps_flip(:,:,fliplr(horz(swap_idx)),:);
    end

    % Merge
    [box,flipped] = cellcat({box,box_flip},4);
    joints = cat(3, joints, joints_flip);
    try box_raw = cat(4,box_raw,box_raw_flip); catch; end
    try box_no_seg = cat(4,box_no_seg,box_no_seg_flip); catch; end
    confmaps = cat(4, confmaps, confmaps_flip);

    labeledIdx = [labeledIdx(:); labeledIdx(:)];
    try exptID = [exptID(:); exptID(:)]; catch; end
    try framesIdx = [framesIdx(:); framesIdx(:)]; catch; end
    try idxs = [idxs(:); idxs(:)]; catch; end
    
    % Update frame count
    numFrames = size(box,4);
end

%% Post-shuffle
shuffleIdx = vert(1:numFrames);
if params.postShuffle
    shuffleIdx = randperm(numFrames);
    box = box(:,:,:,shuffleIdx);
    labeledIdx = labeledIdx(shuffleIdx);
    try box_no_seg = box_no_seg(:,:,:,shuffleIdx); catch; end
    try box_raw = box_raw(:,:,:,shuffleIdx); catch; end
    try exptID = exptID(shuffleIdx); catch; end
    try framesIdx = framesIdx(shuffleIdx); catch; end
    joints = joints(:,:,shuffleIdx);
    confmaps = confmaps(:,:,:,shuffleIdx);
end

%% Separate testing set
numTestFrames = round(numel(shuffleIdx) * params.testFraction);
if numTestFrames > 0
    testIdx = randperm(numel(shuffleIdx),numTestFrames);
    trainIdx = setdiff(shuffleIdx, testIdx);

    % Test set
    testing = struct();
    testing.shuffleIdx = shuffleIdx(testIdx);
    testing.box = box(:,:,:,testIdx);
    testing.labeledIdx = labeledIdx(testIdx);
    try testing.box_no_seg = box_no_seg(:,:,:,testIdx); catch; end
    try testing.box_raw = box_raw(:,:,:,testIdx); catch; end
    try testing.exptID = exptID(testIdx); catch; end
    try testing.framesIdx = framesIdx(testIdx); catch; end
    testing.joints = joints(:,:,testIdx);
    testing.confmaps = confmaps(:,:,:,testIdx);
    testing.testIdx = testIdx;

    % Training set
    shuffleIdx = shuffleIdx(trainIdx);
    box = box(:,:,:,trainIdx);
    labeledIdx = labeledIdx(trainIdx);
    try box_no_seg = box_no_seg(:,:,:,trainIdx); catch; end
    try box_raw = box_raw(:,:,:,trainIdx); catch; end
    try exptID = exptID(trainIdx); catch; end
    try framesIdx = framesIdx(trainIdx); catch; end
    joints = joints(:,:,trainIdx);
    confmaps = confmaps(:,:,:,trainIdx);
end

%% Save
% Augment metadata
attrs.createdOn = datestr(now);
attrs.boxPath = boxPath;
attrs.labelsPath = labelsPath;
attrs.scale = params.scale;
attrs.postShuffle = uint8(params.postShuffle);
attrs.horizontalOrientation = uint8(params.horizontalOrientation);

% Write
stic;
if exists(savePath); delete(savePath); end

% Training data
h5save(savePath,box,[],'compress',params.compress)
h5save(savePath,labeledIdx)
h5save(savePath,shuffleIdx)
try h5save(savePath,box_no_seg,[],'compress',params.compress); catch; end
try h5save(savePath,box_raw,[],'compress',params.compress); catch; end
try h5save(savePath,exptID); catch; end
try h5save(savePath,framesIdx); catch; end
h5save(savePath,joints,[],'compress',params.compress)
h5save(savePath,confmaps,[],'compress',params.compress)

% Testing data
if numTestFrames > 0
    h5save(savePath,trainIdx)
    h5savegroup(savePath,testing,'/testing','compress',params.compress)
end

% Metadata
h5writeatt(savePath,'/confmaps','sigma',params.sigma)
h5writeatt(savePath,'/confmaps','normalize',uint8(params.normalizeConfmaps))
h5struct2att(savePath,'/',attrs)
h5savegroup(savePath,skeleton,'/skeleton')
h5writeatt(savePath,'/skeleton','jointNames',strjoin(jointNames,'\n'))

stocf('Saved:\n%s', savePath)
get_filesize(savePath)


stocf(t0_all, 'Finished generating training set.');
end

================================================
FILE: leap/graph2paf.m
================================================
function paf = graph2paf(nodes, edges, sz, channelsOnly, sigma)
%GRAPH2PAF Converts a set of edges into part affinity fields.
% Usage:
%   graph2paf(nodes, edges, sz, sigma)
% 
% Args:
%   nodes: set of points (N x 2)
%   edges: indices into nodes defining directed edges (E x 2)
%   sz: grid/image size (1 x 2)
%   channelsOnly: stack all PAFs along channels (dim 3) instead of dim 4 (default: true)
%   sigma: maximum distance from edge to keep (default: 5)
% 
% Returns:
%   paf: part affinity fields (sz(1) x sz(2) x 2E) or (sz(1) x sz(2) x 2 x E)
% 
% See also: pts2confmaps

if nargin < 4 || isempty(channelsOnly); channelsOnly = true; end
if nargin < 5 || isempty(sigma); sigma = 5; end

% Create image coordinate grid
[XX,YY] = meshgrid(1:sz(2), 1:sz(1));

% Create PAFs for each edge
E = size(edges,1);
paf = cell(E,1);
for i = 1:E
    % Pull out edge points
    src = nodes(edges(i,2),:);
    dst = nodes(edges(i,1),:);
    
    % Edge length
    L = norm(dst - src, 2);

    % Unit vectors
    V = (dst - src) ./ L; % pointing along edge
    Vp = [-V(:,2), V(:,1)]; % perpendicular

    % Signed distance along edge
    D1 = sum(V .* ([XX(:) YY(:)] - src),2);

    % Absolute distance orthogonal to edge
    D2 = abs(sum(Vp .* ([XX(:) YY(:)] - src),2));

    % Vector field mask
    paf_mask = reshape(D1 >= 0 & D1 <= L & D2 <= sigma, sz);

    % Create vector field along channels (X and Y)
    paf{i} = paf_mask .* permute(V, [1 3 2]);
end

% Merge all edge PAFs
if channelsOnly
    paf = cat(3, paf{:});
else
    paf = cat(4, paf{:});
end

end


================================================
FILE: leap/guis/label_joints.m
================================================
function label_joints(boxPath, skeletonPath)
%LABEL_JOINTS GUI to click on images to yield a graph.
% Usage:
%   label_joints(boxPath)
%   label_joints(boxPath, skeletonPath)
%
% See also: make_template_skeleton



%% Startup
% addpath(genpath('deps'))

% Ask for path to data file
if nargin < 1 || isempty(boxPath); boxPath = uibrowse('*.h5',[],'Select box HDF5 file'); end

% Params
if nargin < 2; skeletonPath = []; end
recreate_labels = nargin > 1; % force recreate labels file

% Settings (saved in *.labels.mat file)
global config;
config = struct();
config.dsetName = '/box';
config.nodeSize = 10; % size of draggable markers
config.defaultNodeColor = [1 0 0]; % default color of movable nodes
config.initializedNodeColor = [1 1 0]; % color of initialized nodes
config.labeledNodeColor = [0 1 0]; % color of movable nodes with user input
config.initialFrame = 1; % first frame displayed
config.shuffleFrames = false; % shuffle frame order
config.autoSave = true; % save before going to a new frame
config.clickNearest = false; % true = click moves nearest node; false = selected node
config.draggable = true; % false = cannot drag joint markers
config.altArrowsToMoveNodes = true; % false = arrow keys move nodes, alt+arrows changes frames
config.zoomBoxFrames = [-250, 250]; % number of frames in the status zoomed in box (pre, post)
config.imgFigPos = [835 341 709 709]; % main labeling figure window
config.ctrlFigPos = [1545 342 374 708]; % control/reference window
config.statusFigPos = [836 33 1081 277]; % status bars and settings window

%%
% Initialize labeling session
box = [];
numNodes = [];
numFrames = [];
numLabeled = [];
global labels;

% Loads or creates *.labels.mat and populates config
initializeLabels();

% Pre-shuffle frames for shuffle mode
shuffleIdx = randperm(numFrames);

% Set status colormap colors
statusCmap = {
    config.defaultNodeColor
    config.initializedNodeColor
    config.labeledNodeColor
    };
for k = 1:numel(statusCmap)
    if ischar(statusCmap{k}); statusCmap{k} = colorCode2rgb(statusCmap{k}); end
end
statusCmap = cellcat(statusCmap,1);

% Zoom box convenience (compute window)
if isscalar(config.zoomBoxFrames); config.zoomBoxFrames = round([-0.5 0.5] .* config.zoomBoxFrames); end
zoomBoxWindow = config.zoomBoxFrames(1):config.zoomBoxFrames(2);

    function initializeLabels()
        labels = struct();

        % Metadata
        labels.boxPath = boxPath;
        labels.savePath = repext(boxPath, '.labels.mat');

        % Ask for path to skeleton file
        if isequal(skeletonPath,true) || ~exists(labels.savePath)
            skeletonPath = uibrowse('*.mat',[],'Select skeleton MAT file');
        end

        % Open box file
        box = h5file(boxPath, config.dsetName);
        numFrames = size(box,4);

        stic;
        if ~exists(labels.savePath) || recreate_labels
            % Load template skeleton
            labels.skeletonPath = skeletonPath;
            labels.skeleton = load(skeletonPath);

            % Initialize custom defaults container
            labels.initialization = NaN(numel(labels.skeleton.nodes), 2, numFrames, 'single');

            % Try using initialization built into the HDF5 file
            try
                labels.initialization = h5read(boxPath, '/initialization');
                labels.initialization_metadata = h5att2struct(boxPath, '/initialization');

                printf('Using pre-initialized joint predictions.')
            catch
            end

            % Initialize user labels
            labels.positions = NaN(numel(labels.skeleton.nodes), 2, numFrames, 'single');

            % Settings
            labels.config = config;

            % Timestamps
            labels.createdOn = datestr(now);
            labels.lastModified = datestr(now);

            % Initialize history
            labels.session = 1;
            addToHistory("Created labels file.");

            % Create labels file
            save(labels.savePath, '-struct', 'labels', '-v7.3')
            stocf('Created labels file: %s', labels.savePath)
        else
            % Load
            labels = load(labels.savePath);

            % Update paths
            labels.boxPath = boxPath;
            labels.savePath = repext(boxPath, '.labels.mat');

            % Update config
            if isfield(labels,'config')
                config = parse_params(labels.config,config);
            else
                labels.config = config;
            end

            if ~isfield(labels,'session')
                labels.session = 1;
            else
                labels.session = labels.session + 1;
            end

            stocf('Loaded existing labels file: %s', labels.savePath)
        end
        addToHistory('Started session.')

        % Convenience
        numNodes = numel(labels.skeleton.nodes);
    end

    function addToHistory(message)
    % Utility for adding a timestamped message to the history log

        session = labels.session;
        timestamp = datetime();
        message = string(message);
        historyItem = table(session, timestamp, message);
        disp(historyItem)

        if ~isfield(labels,'history') || isempty(labels.history)
            labels.history = historyItem;
        else
            labels.history = [labels.history; historyItem];
        end
    end

%% GUI
% Build GUI
global ui;
initializeGUI();
    function initializeGUI()
        ui = struct();

        % %%%% Controls figure %%%%
        ui.ctrl = struct();
        ui.ctrl.fig = figure('NumberTitle','off','MenuBar','none', ...
            'Name','LEAP Label GUI', 'WindowKeyPressFcn', @keyPress, 'DeleteFcn', @quit, ...
            'Position', config.ctrlFigPos);
        ui.ctrl.hbox = uix.HBox('Parent', ui.ctrl.fig);

        % Joints panel
        ui.ctrl.jointsPanel = uix.Panel('Parent',ui.ctrl.hbox, 'Title', 'Joints', 'Padding',5);
        ui.ctrl.jointsList = uicontrol(ui.ctrl.jointsPanel, 'Style', 'listbox', 'String', labels.skeleton.joints.name, ...
            'Callback',@(h,~,~)selectNode(h.Value));

        % Reference image
        ui.ctrl.refPanel = uix.Panel('Parent',ui.ctrl.hbox, 'Title', 'Reference', 'Padding',5);
        ui.ctrl.refAx = axes(uicontainer('Parent',ui.ctrl.refPanel));
        ui.ctrl.refImg = imagesc(labels.skeleton.refI);

        % Style
        ui.ctrl.refAx.Units = 'normalized';
        ui.ctrl.refAx.Position = [0 0 1 1];
        axis(ui.ctrl.refAx,'equal','tight','ij')
        colormap(ui.ctrl.refAx,'gray')
        noticks(ui.ctrl.refAx)
        hold(ui.ctrl.refAx,'on')

        % Plot reference skeleton
        for i = 1:size(labels.skeleton.segments,1)
            % Find default position of each nodes in the segment
            pos = labels.skeleton.pos(labels.skeleton.segments.joints_idx{i},:);

            % Plot
            plot(ui.ctrl.refAx, pos(:,1), pos(:,2), '.-', 'Color',labels.skeleton.segments.color{i}, 'LineWidth', 1);
        end

        % Draw each joint node
        ui.ctrl.refNodes = gobjects(height(labels.skeleton.joints),1);
        for i = 1:numel(ui.ctrl.refNodes)
            pos = labels.skeleton.joints.pos(i,:);
            ui.ctrl.refNodes(i) = plot(ui.ctrl.refAx, pos(1),pos(2),'o', 'Color','r');
        end

        % Set box widths
        ui.ctrl.hbox.Widths = [-1 -3];
        %%%%

        % %%%% Image figure %%%%
        ui.img = struct();
        ui.img.fig = figure('NumberTitle','off','MenuBar','none','ToolBar','none', ...
            'Name',sprintf('Frame %d/%d', config.initialFrame, numFrames), 'WindowKeyPressFcn', @keyPress, 'DeleteFcn', @quit, ...
            'Position', config.imgFigPos);
        ui.img.ax = axes(ui.img.fig);
        ui.img.img = imagesc(ui.img.ax, box(:,:,:,1));
        ui.img.img.ButtonDownFcn = @(~,~) clickImage();

        % Full figure image axes
        ui.img.ax.Units = 'normalized';
        ui.img.ax.Position = [0 0 1 1];

        % Style
        axis(ui.img.ax,'equal','tight','ij')
        colormap(ui.img.ax,'gray')
        noticks(ui.img.ax)
        hold(ui.img.ax,'on')

        % Initialize skeleton drawing container
        ui.skel = struct();
        ui.skel.segs = [];
        ui.skel.nodes = [];
        %%%%


        % %%%% Status figure %%%%
        % Initialize status container
        ui.status = struct();
        ui.status.selectedNode = [];
        ui.status.movedNodes = false(numNodes,1);
        ui.status.currentFrame = config.initialFrame;
        ui.status.unsavedChanges = false(numFrames,1);
        ui.status.initialPos = [];

        % Get full status indicators for all frames
        status = getStatus();
        numInitialized = sum(all(status == 1,1));
        numLabeled = sum(all(status == 2,1));

        % Create figure window
        ui.status.fig = figure('NumberTitle','off','MenuBar','none','ToolBar','none', ...
            'Name',sprintf('Status: %d/%d (%.2f%%) labeled', numLabeled, numFrames, numLabeled/numFrames*100), ...
            'WindowKeyPressFcn', @keyPress, 'DeleteFcn', @quit, ...
            'Position', config.statusFigPos);
        ui.status.hbox = uix.HBox('Parent', ui.status.fig, 'Padding',3);

        % Status panel (left)
        ui.status.statusPanel = uix.Panel('Parent',ui.status.hbox, 'Title','Status', 'Padding',5);
        ui.status.statusBoxes = uix.VBox('Parent', ui.status.statusPanel);

        % Status text
        ui.status.stats = uix.VBox('Parent',ui.status.statusBoxes);

        ui.status.framesInitialized = uicontrol(ui.status.stats,'Style','text','HorizontalAlignment','left',...
            'String',sprintf('Initialized: %d/%d (%.3f%%)', numInitialized, numFrames, numInitialized/numFrames*100));
        ui.status.framesLabeled = uicontrol(ui.status.stats,'Style','text','HorizontalAlignment','left',...
            'String',sprintf('Labeled: %d/%d (%.3f%%)', numLabeled, numFrames, numLabeled/numFrames*100));
        ui.status.stats.Heights = ones(1, numel(ui.status.stats.Children)) * 15;

        % Status bars
        ui.status.fullAx = axes(uicontainer('Parent',ui.status.statusBoxes));
        ui.status.fullImg = imagesc(ui.status.fullAx, 1:numFrames, 1:numNodes, status, 'ButtonDownFcn', @clickStatusbar);
        axis(ui.status.fullAx,'tight','ij')
        hold(ui.status.fullAx,'on');
        zoomBoxIdx = zoomBoxWindow + ui.status.currentFrame;
        zoomBoxPts = [
            zoomBoxIdx(1) 0
            zoomBoxIdx(end) 0
            zoomBoxIdx(end) numNodes
            zoomBoxIdx(1) numNodes
            zoomBoxIdx(1) 0
            ];
        ui.status.fullZoomBox = patch(ui.status.fullAx, zoomBoxPts(:,1),zoomBoxPts(:,2),'w','PickableParts','none');
        ui.status.fullZoomBox.FaceAlpha = 0.25;
        ui.status.fullZoomBox.EdgeColor = 'w';
        colormap(ui.status.fullAx, statusCmap)
        caxis(ui.status.fullAx,[0 2])
        ui.status.fullAx.XLim = [-0.5 0.5] + [1 numFrames];
        ui.status.fullAx.YLim = [-0.5 0.5] + [1 numNodes];
%         ui.status.fullAx.YTick = 1:numNodes;
%         ui.status.fullAx.YTickLabel = labels.skeleton.nodes;
%         ui.status.fullAx.YAxis.TickLabelInterpreter = 'none';

        % Status bars (zoomed)
        ui.status.zoomAx = axes(uicontainer('Parent',ui.status.statusBoxes));
        ui.status.zoomImg = imagesc(ui.status.zoomAx, zoomBoxIdx, 1:numNodes, zeros(numNodes,numel(zoomBoxIdx)), 'ButtonDownFcn', @clickStatusbar);
        axis(ui.status.zoomAx,'tight','ij')
        colormap(ui.status.zoomAx, statusCmap)
        caxis(ui.status.zoomAx,[0 2])
        ui.status.zoomAx.YLim = [-0.5 0.5] + [1 numNodes];
%         ui.status.zoomAx.YTick = 1:numNodes;
%         ui.status.zoomAx.YTickLabel = labels.skeleton.nodes;
%         ui.status.zoomAx.YAxis.TickLabelInterpreter = 'none';

        % Set UI heights
        ui.status.statusBoxes.Heights = [sum(ui.status.stats.Heights)+5 -1 -1];

        % Settings panel (right)
        ui.status.configPanel = uix.Panel('Parent',ui.status.hbox, 'Title','Settings','Padding',5);
        ui.status.configButtons = uix.VBox('Parent',ui.status.configPanel);

        % Auto-save
        uicontrol(ui.status.configButtons,'Style','checkbox','Value',config.autoSave, ...
            'Callback',@(h,~)setConfig('autoSave',h.Value), ...
            'String','Autosave labels','TooltipString','Automatically saves changes to disk when changing frames or exiting.');

        % Shuffle frame order
        uicontrol(ui.status.configButtons,'Style','checkbox','Value',config.shuffleFrames, ...
            'Callback',@(h,~)setConfig('shuffleFrames',h.Value), ...
            'String','Shuffle frame order','TooltipString','Shuffled order is fixed within this session. Uncheck to use file ordering.');

        % Click nearest
        uicontrol(ui.status.configButtons,'Style','checkbox','Value',config.clickNearest, ...
            'Callback',@(h,~)setConfig('clickNearest',h.Value), ...
            'String','Click to move nearest joint','TooltipString','If unchecked, clicking on the image moves the currently selected joint.');

        % Draggable markers
        uicontrol(ui.status.configButtons,'Style','checkbox','Value',config.draggable, ...
            'Callback', @(h,~)toggleDraggableMarkers(h.Value), ...
            'String','Draggable markers','TooltipString','If unchecked, joint markers can only be moved by clicking or keyboard.');

        % Alt + arrows to move nodes
        uicontrol(ui.status.configButtons,'Style','checkbox','Value',config.altArrowsToMoveNodes, ...
            'Callback', @(h,~)setConfig('altArrowsToMoveNodes',h.Value), ...
            'String','Alt + arrow keys move markers','TooltipString','If unchecked, move markers with Alt + arrow keys, change frames with arrow keys.');

        % Export confidence maps
        uicontrol(ui.status.configButtons,'Style','pushbutton', ...
            'Callback', @(h,~)generateTrainingSet(), ...
            'String','Generate training set','TooltipString','Creates a test set with confidence maps for training a network.');

        % Fast training
        uicontrol(ui.status.configButtons,'Style','pushbutton', ...
            'Callback', @(h,~)fastTrain(), ...
            'String','Fast train network','TooltipString','Trains a network for initialization using fast presets.');

        % Initialization from predictions
        uicontrol(ui.status.configButtons,'Style','pushbutton', ...
            'Callback', @(h,~)predictInitializations(), ...
            'String','Initialize with trained model','TooltipString','Generates predictions for all frames and uses it as initialization.');


        % Set UI sizes
        ui.status.configButtons.Heights = ones(1, numel(ui.status.configButtons.Children)) * 25;
        ui.status.hbox.Widths = [-1 175];

        % Give focus back to main image window
        figure(ui.img.fig);
    end
    function toggleDraggableMarkers(TF)
    % Sets whether joint markers are draggable using the mouse

        if TF
            set(ui.skel.nodes,'PickableParts','visible');
            draggable(ui.skel.nodes, @nodesMoved, 'endFcn', @nodesMoveEnd);
        else
            draggable(ui.skel.nodes, 'off');
            set(ui.skel.nodes,'PickableParts','none');
        end
        setConfig('draggable',TF);
    end
    function setConfig(configField, value)
    % Helper to set config fields to specified value
        config.(configField) = value;
    end

    function quit(h,~)
    % Quit callback to close all windows simultaneously
        % Log to history
        addToHistory("Finished session.")

        % Save
        if config.autoSave && isequal(h, ui.img.fig)
            saveLabels();
        end

        % Delete figs
        delete(ui.img.fig)
        delete(ui.ctrl.fig)
        delete(ui.status.fig)
    end

    function keyPress(~,evt)
    % Hotkeys
        % exclusive modifier flags:
        noModifier = isempty(evt.Modifier);
        shiftOnly = isequal(evt.Modifier, {'shift'});
        ctrlOnly = isequal(evt.Modifier, {'control'});
        altOnly = isequal(evt.Modifier, {'alt'});

        % non-exclusive:
        altPressed = ismember({'alt'}, evt.Modifier);
        ctrlPressed = ismember({'control'}, evt.Modifier);
        shiftPressed = ismember({'shift'}, evt.Modifier);

        switch evt.Key
            case 'q'
                delete(ui.img.fig)
            case 's'
                saveLabels()
            case 'r'
                if noModifier % current node
                    resetNodes(ui.status.selectedNode);
                elseif shiftOnly % all nodes
                    resetNodes();
                end
            case 'd'
                if noModifier % current node
                    setNodesToDefault(ui.status.selectedNode);
                elseif shiftOnly % all nodes
                    setNodesToDefault();
                end
            case 'tab'
                if noModifier
                    selectNode(mod(ui.status.selectedNode-1+1, numNodes) + 1);
                elseif shiftOnly
                    selectNode(mod(ui.status.selectedNode-1-1, numNodes) + 1);
                end
            case 'downarrow'
                dXY = [0 1];
                if (config.altArrowsToMoveNodes && altPressed) || ~config.altArrowsToMoveNodes
                    if noModifier
                        nudgeNode(dXY)
                    elseif shiftOnly
                        nudgeNode(dXY * 5)
                    elseif ctrlOnly
                        nudgeSegment(dXY)
                    end
                end
            case 'uparrow'
                dXY = [0 -1];
                if (config.altArrowsToMoveNodes && altPressed) || ~config.altArrowsToMoveNodes
                    if noModifier
                        nudgeNode(dXY)
                    elseif shiftOnly
                        nudgeNode(dXY * 5)
                    elseif ctrlOnly
                        nudgeSegment(dXY)
                    end
                end
            case 'leftarrow'
                if (config.altArrowsToMoveNodes && altPressed) || (~config.altArrowsToMoveNodes && ~altPressed)
                    dXY = [-1 0] - (shiftPressed * 4);
                    if ctrlPressed; nudgeSegment(dXY);
                    else; nudgeNode(dXY); end
                else
                    dt = -1 - (shiftPressed * 4);
                    if config.shuffleFrames
                        idx = find(shuffleIdx == ui.status.currentFrame);
                        goToFrame(shuffleIdx(mod(idx-1+dt, numFrames) + 1))
                    else
                        goToFrame(mod(ui.status.currentFrame-1+dt, numFrames) + 1)
                    end
                end
            case 'rightarrow'
                if (config.altArrowsToMoveNodes && altPressed) || (~config.altArrowsToMoveNodes && ~altPressed)
                    dXY = [1 0] + (shiftPressed * 4);
                    if ctrlPressed; nudgeSegment(dXY);
                    else; nudgeNode(dXY); end
                else
                    dt = 1 + (shiftPressed * 4);
                    if config.shuffleFrames
                        idx = find(shuffleIdx == ui.status.currentFrame);
                        goToFrame(shuffleIdx(mod(idx-1+dt, numFrames) + 1))
                    else
                        goToFrame(mod(ui.status.currentFrame-1+dt, numFrames) + 1)
                    end
                end
            case 'space'
                % Get labeling status for all frames
                labeled = getStatus() == 2;

                % Consider current joint only if shift is pressed
                if shiftPressed; labeled = labeled(ui.status.selectedNode,:);
                else; labeled = all(labeled,1); end

                % Find unlabeled frames excluding current frame
                unlabeledIdxs = setdiff(find(labeled), ui.status.currentFrame);

                if ~isempty(unlabeledIdxs)
                    if ctrlPressed
                        % Go to random unlabeled frame
                        goToFrame(datasample(unlabeledIdxs,1));
                    else
                        % Go to first unlabeled frame
                        goToFrame(unlabeledIdxs(1))
                    end

                end
            case 'g'
                % go to frame dialog
%                 if ctrlOnly
                % TODO:
                %   - custom dialog box that starts focused on the textbox
                %     and returns after pressing Enter/Esc
                answer = inputdlg('Skip to frame index:','Skip to frame',1,{num2str(ui.status.currentFrame)});
                try
                    idx = round(str2double(answer));
                    if idx >= 1 && idx <= numFrames
                        goToFrame(idx);
                    end
                catch
                end
%                 end
            case 'f'
                markAllCorrect();
            otherwise
%                 evt
        end
    end

    function clickImage()
    % Callback to image clicks (but not on nodes)
        % Pull out clicked point coordinate
        pt = ui.img.ax.CurrentPoint(1,1:2);

        % Get current node positions
        pos = getNodePositions();

        if config.clickNearest
            % Find nearest node location
            i = argmin(rownorm(pos - pt));
        else
            % Use current selection
            i = ui.status.selectedNode;
        end

        % Update node position
        pos(i,:) = pt;
        updateSkeleton(pos);

    end

    function clickStatusbar(h,evt)
    % Callback for seeking via mouse-click on the status bars
        if evt.Button == 1
            idx = clip(round(evt.IntersectionPoint(1)),[1 numFrames]);
            goToFrame(idx);
        end
    end

    function status = getStatus(idx)
    % Utility function that checks labels for completeness status
    % Returns [numJoints x numel(idx)] matrix with values:
    %   0: default
    %   1: initialized
    %   2: labeled

        % Get status for all frames by default
        if nargin < 1; idx = 1:numFrames; end

        % Initialize as default (0)
        status = zeros(numNodes, numel(idx));

        % Check for initialization
        isInitialized = squeeze(all(~isnan(labels.initialization(:,:,idx)),2));
        status(isInitialized) = 1;

        % Check for user labels
        isLabeled = squeeze(all(~isnan(labels.positions(:,:,idx)),2));
        status(isLabeled) = 2;
    end

%% Training and dataset generation
    function predictInitializations(modelPath)
    % Generates predictions for the entire dataset and uses those for
    % initialization of unlabeled frames.

        if nargin < 1 || isempty(modelPath)
            modelPath = uibrowse([],[],'Select model folder...', 'dir');
            if isempty(modelPath) || ~exists(modelPath); return; end
        end

        % TODO: better system for choosing final vs best validation model
%         if exists(ff(modelPath, 'final_model.h5'))
%             numValidationSamples = numel(loadvar(ff(modelPath,'training_info.mat'),'val_idx'));
% %             numWeights = numel(dir_files(ff(modelPath,'weights')));
%             if numValidationSamples < 500
%                 modelPath = ff(modelPath,'final_model.h5');
%             end
%         end
        numValidationSamples = numel(loadvar(ff(modelPath,'training_info.mat'),'val_idx'));
        if exists(ff(modelPath, 'best_model.h5')) && numValidationSamples > 500
            modelPath = ff(modelPath, 'best_model.h5');
        else
            modelPath = ff(modelPath, 'final_model.h5');
        end

        % Predict
        preds = predict_box(boxPath, modelPath, false);

        % Save
        labels.initialization = preds.positions_pred;
        saveLabels();

        % Update status
        isInitialized = squeeze(all(~isnan(labels.initialization),2));
        numInitialized = sum(all(isInitialized,1));
        ui.status.framesInitialized.String = sprintf('Initialized: %d/%d (%.2f%%)', numInitialized, numFrames, numInitialized/numFrames*100);

        % Update status bars
        status = getStatus();
        ui.status.fullImg.CData = status;
        zoom_idx = ui.status.zoomImg.XData > 0 & ui.status.zoomImg.XData <= size(status,2);
        ui.status.zoomImg.CData(:,zoom_idx) = status(:,ui.status.zoomImg.XData(zoom_idx));

        % Log event
        addToHistory(['Initialized with model: ' modelPath])

        % Calculate error rate on labels
        labeled = all(getStatus() == 2,1);
        pos_gt = labels.positions(:,:,labeled);
        pos_pred = labels.initialization(:,:,labeled);
        pred_metrics = compute_errors(pos_pred,pos_gt);

        % Display errors
        printf('Error: mean = %.2f, s.d. = %.2f', mean(pred_metrics.euclidean(:)), std(pred_metrics.euclidean(:)))
        prcs = [50 75 90];
        prc_errs = prctile(pred_metrics.euclidean(:), prcs);
        for i = 1:numel(prcs)
            printf('       %d%% = %.3f', prcs(i), prc_errs(i))
        end

        % Replot
        goToFrame(ui.status.currentFrame);
    end
    function generateTrainingSet()
        
        % Default save path
        defaultSavePath = ff(fileparts(boxPath), 'training', [get_filename(boxPath,true) '.h5']);
        defaultSavePath = get_new_filename(defaultSavePath,true);
        
        % Create dialog with parameters
        [params, buttonPressed] = settingsdlg(...
            'WindowWidth', 400,...
            'title','Generate a training set', ...
            'Description','Export a dataset for training based on the current labels.',...
            'separator','General options',...
            {'Save path';'savePath'}, defaultSavePath, ...
            {'Scale - for resizing images';'scale'}, 1, ...
            {'Sigma - kernel size for confidence maps';'sigma'}, 5, ...
            {'Test set fraction - held out frames';'testFraction'},0.1,...
            {'Shuffle - randomize saved dataset order';'postShuffle'}, true, ...
            {'Compress - reduce file size, but slower to load';'compress'}, true, ...
            'separator','Data mirroring',...
            {'Mirror images - augment by flipping along the body axis';'mirroring'}, [true, false], ...
            {'Animal orientation';'animalOrientation'},  {'left/right','top/bottom'} ...
            );
        
        % Cancel if OK was not pressed (cancel or window closed)
        if ~strcmpi(buttonPressed,'ok'); return; end
        
        % Convert listbox input to boolean for orientation
        params.horizontalOrientation = strcmpi(params.animalOrientation,'left/right');
        
        % Check for existing save path
        if exists(params.savePath)
            answer = questdlg('Save path already exists, overwrite existing file?', 'Overwrite file', 'Overwrite', 'Cancel', 'Overwrite');
            if ~strcmpi(answer, 'Overwrite'); return; end
        end
        
        % Run!
        generate_training_set(boxPath,params);
        
        % Log action
        addToHistory('Generated training set.')
    end
    function fastTrain()
        % Generate a training set for fast training from current labels
        
        % Build default output path
        runName = sprintf('%s-n=%d', datestr(now,'yymmdd_HHMMSS'), numLabeled);
        defaultModelsFolder = ff(fileparts(boxPath), 'models');
        
        % Create dialog with parameters
        [params, buttonPressed] = settingsdlg(...
            'WindowWidth', 500,...
            'title','Fast training', ...
            'Description','Quickly train a model using current labels and predict on remaining frames as initialization.',...
            'separator','Dataset',...
            {'Scale - for resizing images';'scale'}, 1, ...
            {'Sigma - kernel size for confidence maps';'sigma'}, 5, ...
            'separator','Data mirroring',...
            {'Mirror images - augment by flipping along the body axis';'mirroring'}, [true, false], ...
            {'Animal orientation';'animalOrientation'}, {'left/right','top/bottom'}, ...
            'separator','Model',...
            {'Network architecture';'netName'},{'leap_cnn','hourglass','stacked_hourglass'},...
            {'Filters - base number of filters for model';'filters'},32,...
            {'Upsampling layers - use bilinear upsampling instead of transposed conv';'upsamplingLayers'},true,...
            'separator','Training',...
            {'Model path - folder to save run data to';'modelsFolder'},defaultModelsFolder,...
            {'Rotate angle - augment data via random rotations';'rotateAngle'},5,...
            {'Validation set fraction - frames used for validation';'valSize'},0.1,...
            {'Epochs - number of rounds of training';'epochs'},15,...
            {'Batch size - number of samples per batch';'batchSize'},50,...
            {'Batches per epoch - number of batches of samples per round';'batchesPerEpoch'},50,...
            {'Validation batches per epoch - number of batches to use for validation';'valBatchesPerEpoch'},10,...
            {'Save every epoch - save weights from every epoch instead of just best+final';'saveEveryEpoch'},false,...
            'separator','Training (advanced)',...
            {'Reduce LR factor - drop learning rate when loss plateaus';'reduceLRFactor'},0.1,...
            {'Reduce LR patience - wait after loss plateaus before reducing LR';'reduceLRPatience'},2,...
            {'Reduce LR cooldown - wait after reducing LR before detecting plateau';'reduceLRCooldown'},0,...
            {'Reduce LR min delta - minimum change in loss to not plateau';'reduceLRMinDelta'},1e-5,...
            {'Reduce LR min LR - minimum LR to not drop below';'reduceLRMinLR'},1e-10,...
            {'AMSGrad - optimizer variant for more emphasis on rare data';'amsgrad'},true ...
            );
        
        % Cancel if OK was not pressed (cancel or window closed)
        if ~strcmpi(buttonPressed,'ok'); return; end
        
        % Convert listbox input to boolean for orientation
        params.horizontalOrientation = strcmpi(params.animalOrientation,'left/right');
        
        % Generate temporary training set file
        dataPath = [tempname '.h5'];
        dataPath = generate_training_set(boxPath,'savePath',dataPath,...
            'scale',params.scale,...
            'mirroring',params.mirroring,...
            'horizontalOrientation',params.horizontalOrientation,...
            'sigma',params.sigma, ...
            'normalizeConfmaps',true,...
            'postShuffle',true, ...
            'testFraction',0);
        
        % Log action
        addToHistory(sprintf('Fast training (n = %d)', numLabeled))
        
        % Create CLI command for training
        basePath = fileparts(funpath(true));
        cmd = {
            'python'
            ['"' ff(basePath, 'training.py') '"']
            ['"' dataPath '"']
            ['--base-output-path="' params.modelsFolder '"']
            ['--run-name="' runName '"']
            ['--net-name="' params.netName '"']
            sprintf('--filters=%d',params.filters)
            sprintf('--rotate-angle=%d', params.rotateAngle)
            sprintf('--val-size=%.5f', params.valSize)
            sprintf('--epochs=%d', params.epochs)
            sprintf('--batch-size=%d', params.batchSize)
            sprintf('--batches-per-epoch=%d', params.batchesPerEpoch)
            sprintf('--val-batches-per-epoch=%d', params.valBatchesPerEpoch)
            sprintf('--reduce-lr-factor=%.10f', params.reduceLRFactor)
            sprintf('--reduce-lr-patience=%d', params.reduceLRPatience)
            sprintf('--reduce-lr-cooldown=%d', params.reduceLRCooldown)
            sprintf('--reduce-lr-min-delta=%.10f', params.reduceLRMinDelta)
            sprintf('--reduce-lr-min-lr=%.10f', params.reduceLRMinLR)
            };
        
        if params.upsamplingLayers; cmd{end+1} = '--upsampling-layers'; end
        if params.saveEveryEpoch; cmd{end+1} = '--save-every-epoch'; end
        if params.amsgrad; cmd{end+1} = '--amsgrad'; end
        
        cmd = strjoin(cmd);
        disp(cmd)

        % Train!
        try
            exit_code = system(cmd);
%             [exit_code,cmd_output] = system(cmd);
        catch ME
            delete(dataPath)
            rethrow(ME)
        end
        delete(dataPath)

        % TODO: parse this out from python output?
        modelPath = ff(params.modelsFolder, runName);

        % Run trained model on data to initialize labels
        if exists(ff(modelPath, 'final_model.h5'))
            predictInitializations(modelPath)
        end
    end

%% Ploting
initializeSkeleton();

    function initializeSkeleton()
    % Creates graphics objects representing the interactive skeleton

        % Draw each line segment
        if ~isempty(ui.skel.segs); delete(ui.skel.segs); end
        ui.skel.segs = gobjects(size(labels.skeleton.segments,1),1);
        for i = 1:numel(ui.skel.segs)
            % Find default position of each nodes in the segment
            pos = labels.skeleton.pos(labels.skeleton.segments.joints_idx{i},:);

            % Plot
            ui.skel.segs(i) = plot(ui.img.ax, pos(:,1), pos(:,2), '.-', ...
                'Color',labels.skeleton.segments.color{i});

            % Add metadata
            ui.skel.segs(i).UserData.seg_idx = i;
            ui.skel.segs(i).UserData.seg_joints_idx = labels.skeleton.segments.joints_idx{i};
        end

        % Clicks on the skeleton edges should pass through to the image
        set(ui.skel.segs, 'PickableParts', 'none');

        % Draw each joint node
        if ~isempty(ui.skel.nodes); delete(ui.skel.nodes); end
%         status = getStatus(ui.status.currentFrame); statusCmap(status(i)+1,:)
        ui.skel.nodes = gobjects(height(labels.skeleton.joints),1);
        for i = 1:numel(ui.skel.nodes)
            ui.skel.nodes(i) = plot(ui.img.ax,labels.skeleton.joints.pos(i,1),labels.skeleton.joints.pos(i,2),'o',...
                'Color','w', 'LineWidth', 1, 'PickableParts','none');
            ui.skel.nodes(i).UserData.node_idx = i;
        end

        % Make movable and add callbacks
        if config.draggable
            set(ui.skel.nodes, 'PickableParts','visible');
            draggable(ui.skel.nodes, @nodesMoved, 'endFcn', @nodesMoveEnd);
        end
    end
    function pos = getNodePositions()
    % Utility function that returns node positions from the corresponding graphics objects
        pos = NaN(numel(ui.skel.nodes),2);
        for i = 1:numel(ui.skel.nodes)
            pos(i,:) = [ui.skel.nodes(i).XData ui.skel.nodes(i).YData];
        end
    end
    function updateSkeleton(pos)
    % Updates pre-initialiazed skeleton graphics objects

        if nargin < 1
            % Get current node positions from graphics
            pos = getNodePositions();
        else
            % Update node positions
            for i = 1:size(pos,1)
                % Check for modification to graphics positions
                old_pos = [ui.skel.nodes(i).XData ui.skel.nodes(i).YData];
                if ~isequal(pos(i,:), old_pos)
                    % Update graphics
                    ui.skel.nodes(i).XData = pos(i,1);
                    ui.skel.nodes(i).YData = pos(i,2);
                end
            end
        end

        % Check for changes
        for i = 1:numNodes
            if ~isequal(pos(i,:), ui.status.initialPos(i,:))
                % Mark node as moved
                ui.status.movedNodes(i) = true;

                % Denote unsaved changes
                ui.status.unsavedChanges(ui.status.currentFrame) = true;
            end
        end

        % Set defaults
        set(ui.skel.nodes, 'Marker', 'o'); % Default marker (no changes)
        set(ui.skel.nodes, 'MarkerSize', config.nodeSize); % Default size (unselected)
        set(ui.ctrl.refNodes, 'MarkerSize', config.nodeSize); % Default size (unselected)

        % Update node colors based on status
        status = getStatus(ui.status.currentFrame);
        for i = 1:numNodes
            % Set status color
            ui.skel.nodes(i).Color = statusCmap(status(i)+1,:);

            % Uncommitted changes
            if ui.status.movedNodes(i); ui.skel.nodes(i).Marker = 's'; end

            % Selected node
            if ui.status.selectedNode == i
                ui.skel.nodes(i).MarkerSize = 9;
                ui.ctrl.refNodes(i).MarkerSize = 9;
            end
        end

        % Update edges
        for i = 1:numel(ui.skel.segs)
            ui.skel.segs(i).XData(:) = pos(ui.skel.segs(i).UserData.seg_joints_idx,1);
            ui.skel.segs(i).YData(:) = pos(ui.skel.segs(i).UserData.seg_joints_idx,2);
        end

        drawnow;
    end

    function nodesMoved(h)
    % Called while node is being moved to update skeleton

        % Get node index
        node_idx = h.UserData.node_idx;

        % Set selected node
        if ui.status.selectedNode ~= node_idx
            selectNode(node_idx)
        end

        % Update
        updateSkeleton()

    end

    function nodesMoveEnd(h)
    % Called when the node is released after moving
        % Get node index
        node_idx = h.UserData.node_idx;

        % Set selected node
        if ui.status.selectedNode ~= node_idx
            selectNode(node_idx)
        end

        % Update
        updateSkeleton()

    end

    function selectNode(i)
    % Utility function that sets the selected node across the entire GUI

        % Check for changes
        previousSelection = ui.status.selectedNode;

        if ~isequal(previousSelection, i)
            % Set selected node
            ui.status.selectedNode = i;

            % Update listbox
            ui.ctrl.jointsList.Value = i;

            % Update graphics
            updateSkeleton();
        end
    end

    function nudgeNode(dXY, i)
    % Utility function for moving a node by a delta amount
        if nargin < 2; i = ui.status.selectedNode; end

        % Get and update node position
        pos = getNodePositions();
        pos(i,:) = pos(i,:) + dXY;

        % Update
        updateSkeleton(pos);
    end

    function nudgeSegment(dXY, i)
    % Utility function for moving all segments with a node by a delta amount
        if nargin < 2; i = ui.status.selectedNode; end

        % Find each segment with the current node and pull out all nodes
        seg_nodes = {};
        for j = 1:height(labels.skeleton.segments)
            idx = labels.skeleton.segments.joints_idx{j};
            if any(idx == i)
                seg_nodes{end+1} = idx;
            end
        end

        % Get the union of the set to make sure we don't double move any nodes
        seg_nodes = unique(cellcat(seg_nodes));

        % Get current positions
        pos = getNodePositions();

        % Move all nodes
        for j = 1:numel(seg_nodes)
            pos(j,:) = pos(j,:) + dXY;
        end

        % Update edges
         updateSkeleton(pos);
    end

    function setNodesToDefault(node_idx)
    % Utility function to reset nodes to default position from the skeleton template
        if nargin < 1; node_idx = 1:numNodes; end

        % Get current positions
        pos = getNodePositions();

        % Get default positions
        default_pos = labels.skeleton.joints.pos;

        % Update with defaults
%         pos(node_idx,:) = default_pos(node_idx,:);
        labels.positions(node_idx,:,ui.status.currentFrame) = NaN;
        ui.status.movedNodes(node_idx) = false;

        % Update
        updateSkeleton();
    end

    function pos = getInitialPos(idx)
    % Utility to compute the initial node positions for a single frame
        if nargin < 1; idx = ui.status.currentFrame; end

        % Start off with defaults
        pos = labels.skeleton.joints.pos;

        % Update with initialized positions
        init_pos = labels.initialization(:,:,idx);
        init_nodes = find(all(~isnan(init_pos),2));
        pos(init_nodes,:) = init_pos(init_nodes,:);

        % Update with user-labeled positions
        label_pos = labels.positions(:,:,idx);
        label_nodes = find(all(~isnan(label_pos),2));
        pos(label_nodes,:) = label_pos(label_nodes,:);
    end

    function resetNodes(node_idx)
    % Utility function to reset nodes to their initial positions when the frame was drawn
        if nargin < 1; node_idx = 1:numNodes; end

        % Start off with what we have now
        pos = getNodePositions();

        % Get initial postions
        init_pos = getInitialPos(ui.status.currentFrame);

        % Set positions for specified nodes
        pos(node_idx,:) = init_pos(node_idx);

        % Update
        updateSkeleton(pos);
    end

%% Frame update and saving
    function markAllCorrect()
    % Helper for setting all nodes in the current frame as correct
        ui.status.movedNodes(:) = true;
        commitChanges();
    end
    function commitChanges()
    % Utility function for committing changes to node positions in the
    % current frame to the labels structure (but does not save to disk)

        % Get current positions
        pos = getNodePositions();

        % Check current status
        status = getStatus(ui.status.currentFrame);
        isLabeled = all(status == 2);

        for i = horz(find(ui.status.movedNodes))
            % Commit to labels
            labels.positions(i,:,ui.status.currentFrame) = pos(i,:);

            % Reset moved state
            ui.status.movedNodes(i) = false;

            % Mark unsaved changes
            ui.status.unsavedChanges(ui.status.currentFrame) = true;
        end

        % Update status
        status = getStatus(ui.status.currentFrame);

        % Update skeleton display
        %updateSkeleton();

        % Update stats if changed
        if ~all(isLabeled) && all(status == 2)
            addToHistory(sprintf('Labeled frame %d', ui.status.currentFrame));
            numLabeled = numLabeled + 1;
            ui.status.framesLabeled.String = sprintf('Labeled: %d/%d (%.2f%%)', numLabeled, numFrames, numLabeled/numFrames*100);
        end

        % Update full status data
        ui.status.fullImg.CData(:,ui.status.currentFrame) = status;

        % Update zoomed status bar data
        zoomBoxIdx = ui.status.zoomImg.XData;
        if any(zoomBoxIdx == ui.status.currentFrame)
            ui.status.zoomImg.CData(:,zoomBoxIdx == ui.status.currentFrame) = status;
        end

        % Update status fig title
        savedStatus = '';
        if any(ui.status.unsavedChanges); savedStatus = ' [unsaved]'; end
        ui.status.fig.Name = sprintf('Status: %d/%d (%.2f%%) labeled%s', numLabeled, numFrames, numLabeled/numFrames*100, savedStatus);
    end
    function goToFrame(idx)
    % Utility function for seeking to another frame

        % Commit changes to labels
        commitChanges();

        % Autosave before anything
        if config.autoSave && ~isempty(ui.status.currentFrame) && ui.status.unsavedChanges(ui.status.currentFrame)
            saveLabels();
        end

        % Update image
        ui.img.img.CData = box(:,:,:,idx);

        % Update status
        ui.status.currentFrame = idx;
        ui.img.fig.Name = sprintf('%d/%d', ui.status.currentFrame, numFrames);

        % Get initial positions
        ui.status.initialPos = getInitialPos(idx);

        % Update with initial positions
        updateSkeleton(ui.status.initialPos);

        % Update status zoom box position
        zoomBoxIdx = zoomBoxWindow + ui.status.currentFrame;
        zoomBoxPts = [
            zoomBoxIdx(1) 0
            zoomBoxIdx(end) 0
            zoomBoxIdx(end) numNodes
            zoomBoxIdx(1) numNodes
            zoomBoxIdx(1) 0
            ];
        ui.status.fullZoomBox.XData = zoomBoxPts(:,1);
        ui.status.fullZoomBox.YData = zoomBoxPts(:,2);

        % Update zoomed status bar data
        ui.status.zoomImg.XData = zoomBoxIdx;
        ui.status.zoomImg.CData(:) = 0; % reset
        isValidIdx = zoomBoxIdx > 0 & zoomBoxIdx <= numFrames;
        ui.status.zoomImg.CData(:,isValidIdx) = ui.status.fullImg.CData(:,zoomBoxIdx(isValidIdx));
        ui.status.zoomAx.XLim = zoomBoxIdx([1 end]) + [-0.5 0.5];
    end

    function saveLabels()
    % Saves everything in the labels structure to disk

        stic;
        % Commit unsaved changes to labels
        commitChanges();

        % Update if there were any changes
        if any(ui.status.unsavedChanges)

            % Update last modified timestamp
            labels.lastModified = datestr(now);
        end

        % Save current frame so we pick up where we left off
        config.initialFrame = ui.status.currentFrame;

        % Save figure positions
        config.imgFigPos = ui.img.fig.Position;
        config.ctrlFigPos = ui.ctrl.fig.Position;
        config.statusFigPos = ui.status.fig.Position;


        % Save config
        labels.config = config;

        % Save to labels file
        save(labels.savePath, '-struct', 'labels')

        % Clear modified flags
        ui.status.unsavedChanges(:) = false;
        commitChanges();

        stocf('Saved labels: %s', labels.savePath)
    end


%% Start!
goToFrame(config.initialFrame);
selectNode(1);
updateSkeleton();

end


================================================
FILE: leap/hpc/python_gpu.sh
================================================
#!/bin/bash
#SBATCH --time=4:00:00
#SBATCH --mem=128000
#SBATCH -N 1
#SBATCH --cpus-per-task=4
#SBATCH --ntasks-per-node=1
#SBATCH --ntasks-per-socket=1
#SBATCH --gres=gpu:1

echo "args: ${@:1}"

python ${@:1}


================================================
FILE: leap/image_augmentation.py
================================================
import cv2
import numpy as np
import keras
from keras.utils import Sequence

def transform_imgs(X, theta=(-180,180), scale=1.0):
    """ Transforms sets of images with the same random transformation across each channel. """
    
    # Sample random rotation and scale if range specified
    if not np.isscalar(theta):
        theta = np.ptp(theta) * np.random.rand() + np.min(theta)
    if not np.isscalar(scale):
        scale = np.ptp(scale) * np.random.rand() + np.min(scale)
    
    # Standardize X input to a list
    single_img = type(X) == np.ndarray
    if single_img:
        X = [X,]
    
    # Find image parameters
    img_size = X[0].shape[:2]
    ctr = (img_size[0] / 2, img_size[0] / 2)
    
    # Compute affine transformation matrix
    T = cv2.getRotationMatrix2D(ctr, theta, scale)
    
    # Make sure we don't overwrite the inputs
    X = [x.copy() for x in X]
    
    # Apply to each image
    for i in range(len(X)):
        if X[i].ndim == 2:
            # Single channel image
            X[i] = cv2.warpAffine(X[i], T, img_size[::-1])
        else:
            # Multi-channel image
            for c in range(X[i].shape[-1]):
                X[i][...,c] = cv2.warpAffine(X[i][...,c], T, img_size[::-1])
    
    # Pull the single image back out of the list
    if single_img:
        X = X[0]
    
    return X


class PairedImageAugmenter(Sequence):
    def __init__(self, X, Y, batch_size=32, shuffle=False, theta=(-180,180), scale=1.0):
        self.X = X
        self.Y = Y
        self.batch_size = batch_size
        self.theta = theta
        self.scale = scale
        
        self.num_samples = len(X)
        all_idx = np.arange(self.num_samples)
        if shuffle:
            np.random.shuffle(all_idx)
        
        self.batches = np.array_split(all_idx, np.ceil(self.num_samples / self.batch_size))
        
    def __len__(self):
        return len(self.batches)
    
    def __getitem__(self, batch_idx):
        idx = self.batches[batch_idx]
        X = self.X[idx]
        Y = self.Y[idx]
        
        for i in range(len(X)):
            X[i], Y[i] = transform_imgs((X[i],Y[i]), theta=self.theta, scale=self.scale)
        return X, Y

    
class MultiInputOutputPairedImageAugmenter(PairedImageAugmenter):
    def __init__(self, input_names, output_names, *args, **kwargs):
        if type(input_names) != list:
            input_names = [input_names,]
        if type(output_names) != list:
            output_names = [output_names,]
        self.input_names = input_names
        self.output_names = output_names
        super().__init__(*args, **kwargs)
        
    def __getitem__(self, batch_idx):
        X,Y = super().__getitem__(batch_idx)
        return ({k: X for k in self.input_names}, {k: Y for k in self.output_names})
    

================================================
FILE: leap/layers.py
================================================
# -*- coding: utf-8 -*-
"""
   Copyright 2018 Jacob M. Graving <jgraving@gmail.com>

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

   Modified from code written by François Chollet 
   All contributions by François Chollet:
   Copyright (c) 2015 - 2018, François Chollet.
   All rights reserved.
"""

import numpy as np

import keras
import keras.backend as K
from keras.legacy import interfaces
from keras.engine import Layer
from keras.engine import InputSpec

from keras.utils import conv_utils
from keras.backend import int_shape, permute_dimensions

from keras.backend import tf

from keras.layers import Conv2D, Add

from packaging.version import parse as parse_version

__all__ = ['UpSampling2D', 'Maxima2D']


def resize_images(x, height_factor, width_factor, interpolation, data_format):
    """Resizes the images contained in a 4D tensor.
    # Arguments
        x: Tensor or variable to resize.
        height_factor: Positive integer.
        width_factor: Positive integer.
        interpolation: string, "nearest", "bilinear" or "bicubic"
        data_format: string, `"channels_last"` or `"channels_first"`.
    # Returns
        A tensor.
    # Raises
        ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`.
    """
    if interpolation == 'nearest':
        tf_resize = tf.image.resize_nearest_neighbor
    elif interpolation == 'bilinear':
        tf_resize = tf.image.resize_bilinear
    elif interpolation == 'bicubic':
        tf_resize = tf.image.resize_bicubic
    else:
        raise ValueError('Invalid interpolation method:', interpolation)
    if data_format == 'channels_first':
        original_shape = int_shape(x)
        new_shape = tf.shape(x)[2:]
        new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32'))
        x = permute_dimensions(x, [0, 2, 3, 1])
        x = tf_resize(x, new_shape, align_corners=True)
        x = permute_dimensions(x, [0, 3, 1, 2])
        x.set_shape((None, None, original_shape[2] * height_factor if original_shape[2] is not None else None,
                     original_shape[3] * width_factor if original_shape[3] is not None else None))
        return x
    elif data_format == 'channels_last':
        original_shape = int_shape(x)
        new_shape = tf.shape(x)[1:3]
        new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32'))
        x = tf_resize(x, new_shape, align_corners=True)
        x.set_shape((None, original_shape[1] * height_factor if original_shape[1] is not None else None,
                     original_shape[2] * width_factor if original_shape[2] is not None else None, None))
        return x
    else:
        raise ValueError('Invalid data_format:', data_format)


class UpSampling2D(Layer):
    """Upsampling layer for 2D inputs.
    Repeats the rows and columns of the data
    by size[0] and size[1] respectively with bilinear interpolation.
    # Arguments
        size: int, or tuple of 2 integers.
            The upsampling factors for rows and columns.
        data_format: A string,
            one of `channels_last` (default) or `channels_first`.
            The ordering of the dimensions in the inputs.
            `channels_last` corresponds to inputs with shape
            `(batch, height, width, channels)` while `channels_first`
            corresponds to inputs with shape
            `(batch, channels, height, width)`.
            It defaults to the `image_data_format` value found in your
            Keras config file at `~/.keras/keras.json`.
            If you never set it, then it will be "channels_last".
        interpolation: A string,
            one of 'nearest' (default), 'bilinear', or 'bicubic'
    # Input shape
        4D tensor with shape:
        - If `data_format` is `"channels_last"`:
            `(batch, rows, cols, channels)`
        - If `data_format` is `"channels_first"`:
            `(batch, channels, rows, cols)`
    # Output shape
        4D tensor with shape:
        - If `data_format` is `"channels_last"`:
            `(batch, upsampled_rows, upsampled_cols, channels)`
        - If `data_format` is `"channels_first"`:
            `(batch, channels, upsampled_rows, upsampled_cols)`
    """

    @interfaces.legacy_upsampling2d_support
    def __init__(self, size=(2, 2), data_format=None, interpolation='nearest', **kwargs):
        super(UpSampling2D, self).__init__(**kwargs)
        # Update to K.normalize_data_format after keras 2.2.0
        if parse_version(keras.__version__) > parse_version("2.2.0"):
            self.data_format = K.normalize_data_format(data_format)
        else:
            self.data_format = conv_utils.normalize_data_format(data_format)

        self.interpolation = interpolation
        self.size = conv_utils.normalize_tuple(size, 2, 'size')
        self.input_spec = InputSpec(ndim=4)

    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            height = self.size[0] * input_shape[2] if input_shape[2] is not None else None
            width = self.size[1] * input_shape[3] if input_shape[3] is not None else None
            return (input_shape[0],
                    input_shape[1],
                    height,
                    width)
        elif self.data_format == 'channels_last':
            height = self.size[0] * input_shape[1] if input_shape[1] is not None else None
            width = self.size[1] * input_shape[2] if input_shape[2] is not None else None
            return (input_shape[0],
                    height,
                    width,
                    input_shape[3])

    def call(self, inputs):
        return resize_images(inputs, self.size[0], self.size[1],
                             self.interpolation, self.data_format)

    def get_config(self):
        config = {'size': self.size,
                  'data_format': self.data_format}
        base_config = super(UpSampling2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


def _find_maxima(x):

    x = K.cast(x, K.floatx())

    col_max = K.max(x, axis=1)
    row_max = K.max(x, axis=2)

    maxima = K.max(col_max, 1)
    maxima = K.expand_dims(maxima, -2)

    cols = K.cast(K.argmax(col_max, -2), K.floatx())
    rows = K.cast(K.argmax(row_max, -2), K.floatx())
    cols = K.expand_dims(cols, -2)
    rows = K.expand_dims(rows, -2)

    # maxima = K.concatenate([rows, cols, maxima], -2) # y, x, val
    maxima = K.concatenate([cols, rows, maxima], -2) # x, y, val

    return maxima


def find_maxima(x, data_format):
    """Finds the 2D maxima contained in a 4D tensor.
    # Arguments
        x: Tensor or variable.
        data_format: string, `"channels_last"` or `"channels_first"`.
    # Returns
        A tensor.
    # Raises
        ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`.
    """
    if data_format == 'channels_first':
        x = permute_dimensions(x, [0, 2, 3, 1])
        x = _find_maxima(x)
        x = permute_dimensions(x, [0, 2, 1])
        return x
    elif data_format == 'channels_last':
        x = _find_maxima(x)
        return x
    else:
        raise ValueError('Invalid data_format:', data_format)


class Maxima2D(Layer):
    """Maxima layer for 2D inputs.
    Finds the maxima and 2D indices
    for the channels in the input.
    The output is ordered as [row, col, maximum].
    # Arguments
        data_format: A string,
            one of `channels_last` (default) or `channels_first`.
            The ordering of the dimensions in the inputs.
            `channels_last` corresponds to inputs with shape
            `(batch, height, width, channels)` while `channels_first`
            corresponds to inputs with shape
            `(batch, channels, height, width)`.
            It defaults to the `image_data_format` value found in your
            Keras config file at `~/.keras/keras.json`.
            If you never set it, then it will be "channels_last".
    # Input shape
        4D tensor with shape:
        - If `data_format` is `"channels_last"`:
            `(batch, rows, cols, channels)`
        - If `data_format` is `"channels_first"`:
            `(batch, channels, rows, cols)`
    # Output shape
        3D tensor with shape:
        - If `data_format` is `"channels_last"`:
            `(batch, 3, channels)`
        - If `data_format` is `"channels_first"`:
            `(batch, channels, 3)`
    """

    def __init__(self, data_format=None, **kwargs):
        super(Maxima2D, self).__init__(**kwargs)
        # Update to K.normalize_data_format after keras 2.2.0
        if parse_version(keras.__version__) > parse_version("2.2.0"):
            self.data_format = K.normalize_data_format(data_format)
        else:
            self.data_format = conv_utils.normalize_data_format(data_format)
        self.input_spec = InputSpec(ndim=4)

    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            return (input_shape[0],
                    input_shape[1],
                    3)
        elif self.data_format == 'channels_last':
            return (input_shape[0],
                    3,
                    input_shape[3])

    def call(self, inputs):
        return find_maxima(inputs, self.data_format)

    def get_config(self):
        config = {'data_format': self.data_format}
        base_config = super(Maxima2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


def residual_bottleneck_module(x_in, output_filters=32, bottleneck_factor=2, prefix="res", activation="relu", initializer="glorot_normal"):
    # Get input shape and channels
    in_shape = K.int_shape(x_in)
    input_filters = in_shape[3]
    
    # Bottleneck filters are proportional to the output filters
    bottleneck_filters = output_filters // bottleneck_factor
    
    # Bottleneck block
    x = Conv2D(filters=bottleneck_filters, kernel_size=1, padding="same", activation=activation, kernel_initializer=initializer, name=prefix + "_Conv1")(x_in)
    x = Conv2D(filters=bottleneck_filters, kernel_size=3, padding="same", activation=activation, kernel_initializer=initializer, name=prefix + "_Conv2")(x)
    x = Conv2D(filters=output_filters, kernel_size=1, padding="same", activation=activation, kernel_initializer=initializer, name=prefix + "_Conv3")(x)
    
    # 1x1 conv if input channels are different from output channels
    if output_filters != input_filters:
        x_in = Conv2D(filters=output_filters, kernel_size=1, padding="same", activation=activation, kernel_initializer=initializer, name=prefix + "_ConvSkip")(x_in)
    
    # Residual connection
    x = Add(name=prefix + "_AddRes")([x_in, x])
    
    return x

================================================
FILE: leap/models.py
================================================
import keras
from keras import backend as K
from keras.models import Model
from keras.layers import Input, Conv2D, Conv2DTranspose, Add, MaxPooling2D
from keras.optimizers import Adam

from leap.layers import residual_bottleneck_module, UpSampling2D

def leap_cnn(img_size, output_channels, filters=64, upsampling_layers=False, amsgrad=False, summary=False):
    """
    Creates and compiles network model.

    :param img_size: shape of a single image, optionally including channels
    :param output_channels: number of output channels (joints being predicted)
    :param filters: number of baseline filters to use (more filters will be used in intermediate layers)
    :param summary: prints network summary after compiling
    """
    if len(img_size) == 2:
        img_size = img_size + (1,)

    x_in = Input(img_size)

    x1 = Conv2D(filters, kernel_size=3, padding="same", activation="relu")(x_in)
    x1 = Conv2D(filters, kernel_size=3, padding="same", activation="relu")(x1)
    x1 = Conv2D(filters, kernel_size=3, padding="same", activation="relu")(x1)
    x1_pool = MaxPooling2D(pool_size=2, strides=2, padding="same")(x1)

    x2 = Conv2D(filters*2, kernel_size=3, padding="same", activation="relu")(x1_pool)
    x2 = Conv2D(filters*2, kernel_size=3, padding="same", activation="relu")(x2)
    x2 = Conv2D(filters*2, kernel_size=3, padding="same", activation="relu")(x2)
    x2_pool = MaxPooling2D(pool_size=2, strides=2, padding="same")(x2)

    x3 = Conv2D(filters*4, kernel_size=3, padding="same", activation="relu")(x2_pool)
    x3 = Conv2D(filters*4, kernel_size=3, padding="same", activation="relu")(x3)
    x3 = Conv2D(filters*4, kernel_size=3, padding="same", activation="relu")(x3)

    if upsampling_layers:
        x4 = UpSampling2D(interpolation="bilinear")(x3)
    else:
        x4 = Conv2DTranspose(filters*2, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal")(x3)
    x4 = Conv2D(filters*2, kernel_size=3, padding="same", activation="relu")(x4)
    x4 = Conv2D(filters*2, kernel_size=3, padding="same", activation="relu")(x4)

    if upsampling_layers:
        x_out = UpSampling2D(interpolation="bilinear")(x4)
        x_out = Conv2D(output_channels, kernel_size=3, padding="same", activation="linear")(x_out)
    else:
        x_out = Conv2DTranspose(output_channels, kernel_size=3, strides=2, padding="same", activation="linear", kernel_initializer="glorot_normal")(x4)

    # Compile
    net = Model(inputs=x_in, outputs=x_out, name="LeapCNN")
    net.compile(optimizer=Adam(amsgrad=amsgrad), loss="mean_squared_error")

    if summary:
        net.summary()

    return net

def hourglass(img_size, output_channels, filters=64, upsampling_layers=False, amsgrad=False, summary=False):
    """
    Creates and compiles network model.

    :param img_size: shape of a single image, optionally including channels
    :param output_channels: number of output channels (joints being predicted)
    :param filters: number of baseline filters to use (more filters will be used in intermediate layers)
    :param summary: prints network summary after compiling
    """

    if len(img_size) == 2:
        img_size = img_size + (1,)

    x_in = Input(img_size, name="x_in")

    x1_pre = residual_bottleneck_module(x_in, prefix="x1", output_filters=filters)
    x1 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x1_pool")(x1_pre)

    x2_pre = residual_bottleneck_module(x1, prefix="x2", output_filters=filters)
    x2 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x2_pool")(x2_pre)

    x3_pre = residual_bottleneck_module(x2, prefix="x3", output_filters=filters)
    x3 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x3_pool")(x3_pre)

    x4_pre = residual_bottleneck_module(x3, prefix="x4", output_filters=filters)
    x4 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x4_pool")(x4_pre)


    x5 = residual_bottleneck_module(x4, prefix="x5", output_filters=filters)


    if upsampling_layers:
        x6_pre = UpSampling2D(interpolation="bilinear", name="x6_Upsample")(x5)
    else:
        x6_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x6_ConvT")(x5)
    x6_add = Add(name="x6_Add")([x4_pre, x6_pre])
    x6 = residual_bottleneck_module(x6_add, prefix="x6", output_filters=filters)

    if upsampling_layers:
        x7_pre = UpSampling2D(interpolation="bilinear", name="x7_Upsample")(x6)
    else:
        x7_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x7_ConvT")(x6)
    x7_add = Add(name="x7_Add")([x3_pre, x7_pre])
    x7 = residual_bottleneck_module(x7_add, prefix="x7", output_filters=filters)

    if upsampling_layers:
        x8_pre = UpSampling2D(interpolation="bilinear", name="x8_Upsample")(x7)
    else:
        x8_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x8_ConvT")(x7)
    x8_add = Add(name="x8_Add")([x2_pre, x8_pre])
    x8 = residual_bottleneck_module(x8_add, prefix="x8", output_filters=filters)

    if upsampling_layers:
        x9_pre = UpSampling2D(interpolation="bilinear", name="x9_Upsample")(x8)
    else:
        x9_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x9_ConvT")(x8)
    x9_add = Add(name="x9_Add")([x1_pre, x9_pre])
    x9 = residual_bottleneck_module(x9_add, prefix="x9", output_filters=filters)

    x_out = Conv2D(filters=output_channels, kernel_size=3, strides=1, padding="same", activation="linear", name="x_out")(x9)

    # Compile
    model = Model(inputs=x_in, outputs=x_out, name="hourglass")
    model.compile(optimizer=Adam(amsgrad=amsgrad), loss="mean_squared_error")

    if summary:
        model.summary()

    return model



def stacked_hourglass(img_size, output_channels, filters=64, upsampling_layers=False, amsgrad=False, summary=False):
    """
    Creates and compiles network model.

    :param img_size: shape of a single image, optionally including channels
    :param output_channels: number of output channels (joints being predicted)
    :param filters: number of baseline filters to use (more filters will be used in intermediate layers)
    :param summary: prints network summary after compiling
    """

    if len(img_size) == 2:
        img_size = img_size + (1,)

    x_in = Input(img_size, name="x_in")

    x1_1_pre = residual_bottleneck_module(x_in, prefix="x1_1", output_filters=filters)
    x1_1 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x1_1_pool")(x1_1_pre)

    x1_2_pre = residual_bottleneck_module(x1_1, prefix="x1_2", output_filters=filters)
    x1_2 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x1_2_pool")(x1_2_pre)

    x1_3_pre = residual_bottleneck_module(x1_2, prefix="x1_3", output_filters=filters)
    x1_3 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x1_3_pool")(x1_3_pre)

    x1_4_pre = residual_bottleneck_module(x1_3, prefix="x1_4", output_filters=filters)
    x1_4 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x1_4_pool")(x1_4_pre)


    x1_5 = residual_bottleneck_module(x1_4, prefix="x1_5", output_filters=filters)


    if upsampling_layers:
        x1_6_pre = UpSampling2D(interpolation="bilinear", name="x1_6_Upsample")(x1_5)
    else:
        x1_6_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x1_6_ConvT")(x1_5)
    x1_6_add = Add(name="x1_6_Add")([x1_4_pre, x1_6_pre])
    x1_6 = residual_bottleneck_module(x1_6_add, prefix="x1_6", output_filters=filters)

    if upsampling_layers:
        x1_7_pre = UpSampling2D(interpolation="bilinear", name="x1_7_Upsample")(x1_6)
    else:
        x1_7_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x1_7_ConvT")(x1_6)
    x1_7_add = Add(name="x1_7_Add")([x1_3_pre, x1_7_pre])
    x1_7 = residual_bottleneck_module(x1_7_add, prefix="x1_7", output_filters=filters)

    if upsampling_layers:
        x1_8_pre = UpSampling2D(interpolation="bilinear", name="x1_8_Upsample")(x1_7)
    else:
        x1_8_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x1_8_ConvT")(x1_7)
    x1_8_add = Add(name="x1_8_Add")([x1_2_pre, x1_8_pre])
    x1_8 = residual_bottleneck_module(x1_8_add, prefix="x1_8", output_filters=filters)

    if upsampling_layers:
        x1_9_pre = UpSampling2D(interpolation="bilinear", name="x1_9_Upsample")(x1_8)
    else:
        x1_9_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x1_9_ConvT")(x1_8)
    x1_9_add = Add(name="x1_9_Add")([x1_1_pre, x1_9_pre])
    x1_9 = residual_bottleneck_module(x1_9_add, prefix="x1_9", output_filters=filters)

    #############

    x2_1_pre = residual_bottleneck_module(x1_9, prefix="x2_1", output_filters=filters)
    x2_1 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x2_1_pool")(x2_1_pre)

    x2_2_pre = residual_bottleneck_module(x2_1, prefix="x2_2", output_filters=filters)
    x2_2 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x2_2_pool")(x2_2_pre)

    x2_3_pre = residual_bottleneck_module(x2_2, prefix="x2_3", output_filters=filters)
    x2_3 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x2_3_pool")(x2_3_pre)

    x2_4_pre = residual_bottleneck_module(x2_3, prefix="x2_4", output_filters=filters)
    x2_4 = MaxPooling2D(pool_size=2, strides=2, padding="same", name="x2_4_pool")(x2_4_pre)


    x2_5 = residual_bottleneck_module(x2_4, prefix="x2_5", output_filters=filters)


    if upsampling_layers:
        x2_6_pre = UpSampling2D(interpolation="bilinear", name="x2_6_Upsample")(x2_5)
    else:
        x2_6_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x2_6_ConvT")(x2_5)
    x2_6_add = Add(name="x2_6_Add")([x2_4_pre, x2_6_pre])
    x2_6 = residual_bottleneck_module(x2_6_add, prefix="x2_6", output_filters=filters)

    if upsampling_layers:
        x2_7_pre = UpSampling2D(interpolation="bilinear", name="x2_7_Upsample")(x2_6)
    else:
        x2_7_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x2_7_ConvT")(x2_6)
    x2_7_add = Add(name="x2_7_Add")([x2_3_pre, x2_7_pre])
    x2_7 = residual_bottleneck_module(x2_7_add, prefix="x2_7", output_filters=filters)

    if upsampling_layers:
        x2_8_pre = UpSampling2D(interpolation="bilinear", name="x2_8_Upsample")(x2_7)
    else:
        x2_8_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x2_8_ConvT")(x2_7)
    x2_8_add = Add(name="x2_8_Add")([x2_2_pre, x2_8_pre])
    x2_8 = residual_bottleneck_module(x2_8_add, prefix="x2_8", output_filters=filters)

    if upsampling_layers:
        x2_9_pre = UpSampling2D(interpolation="bilinear", name="x2_9_Upsample")(x2_8)
    else:
        x2_9_pre = Conv2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal", name="x2_9_ConvT")(x2_8)
    x2_9_add = Add(name="x2_9_Add")([x2_1_pre, x2_9_pre])
    x2_9 = residual_bottleneck_module(x2_9_add, prefix="x2_9", output_filters=filters)

    #############

    x_out1 = residual_bottleneck_module(x1_9, output_filters=output_channels, bottleneck_factor=1, prefix="x_out1", activation="linear")
    x_out2 = residual_bottleneck_module(x2_9, output_filters=output_channels, bottleneck_factor=1, prefix="x_out2", activation="linear")

    # Compile
    model = Model(inputs=x_in, outputs=[x_out1, x_out2], name="StackedHourglass")
    model.compile(optimizer=Adam(amsgrad=amsgrad), loss="mean_squared_error")

    if summary:
        model.summary()

    return model


================================================
FILE: leap/plot_joints_single.m
================================================
function h = plot_joints_single(pts, segments, markerSize, lineWidth)
%PLOT_JOINTS_SINGLE Plot joints on a single frame.
% Usage:
%   plot_joints_single(pts, segments)
%   plot_joints_single(pts, segments, markerSize, lineWidth)
% 
% Args:
%   pts: 
%   segments (or skeleton): table with line segments or struct containing it
%   markerSize: default: 20
%   lineWidth: default: 2
% 
% See also: 

if isfield(segments,'segments'); segments = segments.segments; end
if nargin < 3 || isempty(markerSize); markerSize = 20; end
if nargin < 4 || isempty(lineWidth); lineWidth = 2; end

h = gobjects(numel(segments.joints_idx),1);
for i = 1:numel(segments.joints_idx)
    h(i) = plotpts(pts(segments.joints_idx{i},:),'.-', ...
        'Color',segments.color{i},'MarkerSize',markerSize,'LineWidth',lineWidth);
end

if nargout < 1; clear h; end

end


================================================
FILE: leap/predict_box.m
================================================
function preds = predict_box(box, modelPath, saveConfmaps)
%PREDICT_BOX Evaluates model predictions on a stack of frames. Wrapper for predict_box.py.
% Usage:
%   preds = predict_box(box, modelPath)
%   preds = predict_box(box, modelPath, saveConfmaps)
%
% Args:
%   box: 4-D array or path to HDF5 file with '/box' dataset
%   modelPath: path to model weights
%   saveConfmaps: if true, returns full confidence maps in addition to
%                 peaks (default: false). Very slow and memory intensive!
%   
% Returns:
%   preds: struct containing results from model prediction
%        .positions_pred: 3-D array of (parts x [X Y] x frames) indicating
%                         peak positions for each confidence map in image coordinates
%        .conf_pred: 2-D array of (parts x frames) with the confidence map
%                    value at the peak pixel for detecting bad predictions
%        .confmaps: 4-D array of confidence maps returned if saveConfmaps
%                   is set true

if nargin < 3 || isempty(saveConfmaps); saveConfmaps = false; end

% Process args
delete_box = false;
is_singleton = false;
if ischar(box)
    boxPath = box;
else
    boxPath = [tempname '.h5'];
    
    if numel(size(box)) < 4
        box = repmat(box,[1 1 1 2]);
        is_singleton = true;
    end
    
    h5save(boxPath, box)
    delete_box = true;
end

% Generate temporary output filename
outPath = [tempname '.h5'];

% Build command line args
cmd = {
    'python'
    ['"' ff(funpath(true), 'predict_box.py') '"']
    ['"' boxPath '"']
    ['"' modelPath '"']
    ['"' outPath '"']
    };
if saveConfmaps
    cmd{end+1} = '--save-confmaps';
end
disp(strjoin(cmd))

% Predict
try
    exit_code = system(strjoin(cmd));
catch ME
    if delete_box && exists(boxPath); delete(boxPath); end
    rethrow(ME)
end
if delete_box && exists(boxPath); delete(boxPath); end


% Read data back in
try
    preds = h5readgroup(outPath);
catch ME
    if exists(outPath); delete(outPath); end
    rethrow(ME)
end
if exists(outPath); delete(outPath); end

% Adjust for 0-based indexing
preds.positions_pred = single(preds.positions_pred) + 1;

% Rescale confidence maps to correct range prior to quantization
if saveConfmaps && isfield(preds.Attributes, 'confmaps')
    c_min = preds.Attributes.confmaps.range_min;
    c_max = preds.Attributes.confmaps.range_max;
    preds.confmaps = rescale(single(preds.confmaps) / 255, c_min, c_max);
end

% Adjust for singleton input
if is_singleton
    preds.positions_pred = preds.positions_pred(:,:,1);
    preds.conf_pred = preds.conf_pred(:,1);
    if isfield(preds,'confmaps')
        preds.confmaps = preds.confmaps(:,:,:,1);
    end
end

end


================================================
FILE: leap/predict_box.py
================================================
import h5py
import numpy as np
import os
from time import time
import keras
import keras.models
from keras.layers import Lambda
import tensorflow as tf
import re
from clize import run

from leap.utils import find_weights, find_best_weights, preprocess
from leap.layers import Maxima2D

def tf_find_peaks(x):
    """ Finds the maximum value in each channel and returns the location and value.
    Args:
        x: rank-4 tensor (samples, height, width, channels)

    Returns:
        peaks: rank-3 tensor (samples, [x, y, val], channels)
    """

    # Store input shape
    in_shape = tf.shape(x)

    # Flatten height/width dims
    flattened = tf.reshape(x, [in_shape[0], -1, in_shape[-1]])

    # Find peaks in linear indices
    idx = tf.argmax(flattened, axis=1)

    # Convert linear indices to subscripts
    rows = tf.floor_div(tf.cast(idx,tf.int32), in_shape[1])
    cols = tf.floormod(tf.cast(idx,tf.int32), in_shape[1])

    # Dumb way to get actual values without indexing
    vals = tf.reduce_max(flattened, axis=1)

    # Return N x 3 x C tensor
    return tf.stack([
        tf.cast(cols, tf.float32),
        tf.cast(rows, tf.float32),
        vals
    ], axis=1)


def convert_to_peak_outputs(model, include_confmaps=False):
    """ Creates a new Keras model with a wrapper to yield channel peaks from rank-4 tensors. """
    if type(model.output) == list:
        confmaps = model.output[-1]
    else:
        confmaps = model.output

    if include_confmaps:
        return keras.Model(model.input, [Lambda(tf_find_peaks)(confmaps), confmaps])
    else:
        # return keras.Model(model.input, Lambda(tf_find_peaks)(confmaps))
        return keras.Model(model.input, Maxima2D()(confmaps))


def predict_box(box_path, model_path, out_path, *, box_dset="/box", epoch=None, verbose=True, overwrite=False, save_confmaps=False, batch_size=32):
    """
    Predict and save peak coordinates for a box.

    :param box_path: path to HDF5 file with box dataset
    :param model_path: path to Keras weights file or run folder with weights subfolder
    :param out_path: path to HDF5 file to save results to
    :param box_dset: name of HDF5 dataset containing box images
    :param epoch: epoch to use if run folder provided instead of Keras weights file
    :param verbose: if True, prints some info and statistics during procesing
    :param overwrite: if True and out_path exists, file will be overwritten
    :param save_confmaps: if True, saves the full confidence maps as additional datasets in the output file (very slow)
    :param batch_size: number of samples to evaluate at once per batch (see keras.Model API)
    """

    if verbose:
        print("model_path:", model_path)

    # Find model weights
    model_name = None
    weights_path = model_path
    if os.path.isdir(model_path):
        model_name = os.path.basename(model_path)

        weights_paths, epochs, val_losses = find_weights(model_path)

        if epoch == None and len(val_losses) > 0:
            weights_path = weights_paths[np.argmin(val_losses)]
        elif epoch == "final" or (epoch == None and len(val_losses) == 0):
            weights_path = os.path.join(model_path, "final_model.h5")
        else:
            weights_path = weights_paths[epoch]

    # Input data
    box = h5py.File(box_path,"r")[box_dset]
    num_samples = box.shape[0]
    if verbose:
        print("Input:", box_path)
        print("box.shape:", box.shape)

    # Create output path
    if out_path[-3:] != ".h5":
        if model_name == None:
            out_path = os.path.join(out_path, os.path.basename(box_path))
        else:
            out_path = os.path.join(out_path, model_name, os.path.basename(box_path))
        os.makedirs(os.path.dirname(out_path), exist_ok=True)

    model_name = os.path.basename(model_path)

    if verbose:
        print("Output:", out_path)

    t0_all = time()
    if os.path.exists(out_path):
        if overwrite:
            os.remove(out_path)
            print("Deleted existing output.")
        else:
            print("Error: Output path already exists.")
            return

    # Load and prepare model
    model = keras.models.load_model(weights_path)
    model_peaks = convert_to_peak_outputs(model, include_confmaps=save_confmaps)
    if verbose:
        print("weights_path:", weights_path)
        print("Loaded model: %d layers, %d params" % (len(model.layers), model.count_params()))

    # Load data and preprocess (normalize)
    t0 = time()
    X = preprocess(box[:])
    if verbose:
        print("Loaded [%.1fs]" % (time() - t0))

    # Evaluate
    t0 = time()
    if save_confmaps:
        Ypk, confmaps = model_peaks.predict(X, batch_size=batch_size)

        # Quantize
        confmaps_min = confmaps.min()
        confmaps_max = confmaps.max()
        confmaps = (confmaps - confmaps_min) / (confmaps_max - confmaps_min)
        confmaps = (confmaps * 255).astype('uint8')

        # Reshape
        confmaps = np.transpose(confmaps, (0, 3, 2, 1))
    else:
        Ypk = model_peaks.predict(X, batch_size=batch_size)
    prediction_runtime = time() - t0
    if verbose:
        print("Predicted [%.1fs]" % prediction_runtime)
        print("Prediction performance: %.3f FPS" % (num_samples / prediction_runtime))

    # Save
    t0 = time()
    with h5py.File(out_path, "w") as f:
        f.attrs["num_samples"] = num_samples
        f.attrs["img_size"] = X.shape[1:]
        f.attrs["box_path"] = box_path
        f.attrs["box_dset"] = box_dset
        f.attrs["model_path"] = model_path
        f.attrs["weights_path"] = weights_path
        f.attrs["model_name"] = model_name

        ds_pos = f.create_dataset("positions_pred", data=Ypk[:,:2,:].astype("int32"), compression="gzip", compression_opts=1)
        ds_pos.attrs["description"] = "coordinate of peak at each sample"
        ds_pos.attrs["dims"] = "(sample, [x, y], joint) === (sample, [column, row], joint)"

        ds_conf = f.create_dataset("conf_pred", data=Ypk[:,2,:].squeeze(), compression="gzip", compression_opts=1)
        ds_conf.attrs["description"] = "confidence map value in [0, 1.0] at peak"
        ds_conf.attrs["dims"] = "(sample, joint)"

        if save_confmaps:
            ds_confmaps = f.create_dataset("confmaps", data=confmaps, compression="gzip", compression_opts=1)
            ds_confmaps.attrs["description"] = "confidence maps"
            ds_confmaps.attrs["dims"] = "(sample, channel, width, height)"
            ds_confmaps.attrs["range_min"] = confmaps_min
            ds_confmaps.attrs["range_max"] = confmaps_max

        total_runtime = time() - t0_all
        f.attrs["total_runtime_secs"] = total_runtime
        f.attrs["prediction_runtime_secs"] = prediction_runtime

    if verbose:
        print("Saved [%.1fs]" % (time() - t0))

        print("Total runtime: %.1f mins" % (total_runtime / 60))
        print("Total performance: %.3f FPS" % (num_samples / total_runtime))


if __name__ == "__main__":
    run(predict_box)


================================================
FILE: leap/pts2confmaps.m
================================================
function confmaps = pts2confmaps(pts, sz, sigma, normalize)
%PTS2CONFMAPS Generate confidence maps centered at specified points.
% Usage:
%   confmaps = pts2confmaps(pts, sz, sigma)
%
% Args:
%   pts: N x 2 or cell array of {N1 x 2, N2 x 2, ...}, where each cell will
%       correspond to a single channel to create multipoint confidence maps
%   sz: [rows cols]
%   sigma: filter size (default: 5)
%   normalize: outputs maps in [0, 1] rather than PDF (default: true)
%
% See also: label_joints

if ~iscell(pts); pts = arr2cell(pts,1); end
if nargin < 3 || isempty(sigma); sigma = 5; end
if nargin < 4 || isempty(normalize); normalize = true; end

confmaps = NaN(sz(1), sz(2), numel(pts));
xv = 1:sz(2); yv = 1:sz(1);
[XX,YY] = meshgrid(xv,yv);

for i = 1:numel(pts)
    x = permute(pts{i}(:,1),[2 3 1]);
    y = permute(pts{i}(:,2),[2 3 1]);
%     confmaps(:,:,i) = sum(exp(-((YY-y).^2 + (XX-x).^2)./(2*sigma^2)),3);
    confmaps(:,:,i) = max(exp(-((YY-y).^2 + (XX-x).^2)./(2*sigma^2)),[],3);
end

if ~normalize
    confmaps = confmaps ./ (sigma * sqrt(2*pi));
end

end


================================================
FILE: leap/test_leap.m
================================================
function works = test_leap()
%TEST_LEAP Checks whether LEAP is properly installed.
% Usage:
%   test_leap
%   works = test_leap % returns true/false
% 
% See also: install_leap

works = true;

% Check if we can import the LEAP package from anywhere
cdCmd = ['cd "' matlabroot '"'];
if ispc(); cdCmd = ['cd /D "' matlabroot '"']; end
[status,msg] = system([cdCmd ' && python -c "import leap"']);

if status ~= 0
    works = false;
end

if nargout == 0
    if works
        printf('Test LEAP successful!')
    else
        disp('Unable to import LEAP python package. Make sure LEAP and its dependencies are installed.')
        disp('Go to the base LEAP directory containing setup.py and run from MATLAB:')
        disp('    !pip install -e .')
        disp('Or try the MATLAB installer:')
        disp('    install_leap')
    end
    clear works
end

end


================================================
FILE: leap/toolbox/aliases/alims.m
================================================
function varargout = alims(X)
%ALIMS Alias for arange.
% Usage:
%   R = alims(X)
%   [min_val, max_val] = alims(X)
%
% See also: arange

varargout = wrap(@() arange(X), 1:max(1, nargout));

end



================================================
FILE: leap/toolbox/aliases/ff.m
================================================
function varargout = ff(varargin)
    N = max(nargout,1);
	varargout{1:N} = fullfile(varargin{:});
end


================================================
FILE: leap/toolbox/aliases/h5file.m
================================================
function varargout = h5file(varargin)
	varargout{1:nargout} = hdf5prop(varargin{:});
end


================================================
FILE: leap/toolbox/aliases/imgsc.m
================================================
function h = imgsc(I, varargin)
%IMGSC Alias for imagesc for images.
% Usage:
%   imgsc(I)
%   imgsc(I, ...)
%
% See also: imagesc, sc

% figure('KeyPressFcn',@KeyPressFcn_cb)
figclosekey
h = imagesc(I, varargin{:});
H = size(I,1); W = size(I,2);
if  max([H W]) / min([H W]) < 2 || all([H W] < 25)
    axis image
end
colorbar

if nargout < 1; clear h; end

end

================================================
FILE: leap/toolbox/aliases/repext.m
================================================
function varargout = repext(varargin)
	varargout{1:max(nargout,1)} = extrep(varargin{:});
end


================================================
FILE: leap/toolbox/graphics/FEX-settingsdlg/settingsdlg.m
================================================
function [settings, button] = settingsdlg(varargin)
% SETTINGSDLG             Default dialog to produce a settings-structure
%
% settings = SETTINGSDLG('fieldname', default_value, ...) creates a modal
% dialog box that returns a structure formed according to user input. The
% input should be given in the form of 'fieldname', default_value - pairs,
% where 'fieldname' is the fieldname in the structure [settings], and
% default_value the initial value displayed in the dialog box.
%
% SETTINGSDLG uses UIWAIT to suspend execution until the user responds.
%
% settings = SETTINGSDLG(settings) uses the structure [settings] to form
% the various input fields. This is the most basic (and limited) usage of
% SETTINGSDLG.
%
% [settings, button] = SETTINGSDLG(settings) returns which button was
% pressed, in addition to the (modified) structure [settings]. Either 'ok',
% 'cancel' or [] are possible values. The empty output means that the
% dialog was closed before either Cancel or OK were pressed.
%
% SETTINGSDLG('title', 'window_title') uses 'window_title' as the dialog's
% title. The default is 'Adjust settings'.
%
% SETTINGSDLG('description', 'brief_description',...) starts the dialog box
% with 'brief_description', followed by the input fields.
%
% SETTINGSDLG('windowposition', P, ...) positions the dialog box according to
% the string or vector [P]; see movegui() for valid values.
%
% SETTINGSDLG( {'display_string', 'fieldname'}, default_value,...) uses the
% 'display_string' in the dialog box, while assigning the corresponding
% user-input to fieldname 'fieldname'.
%
% SETTINGSDLG(..., 'checkbox_string', true, ...) displays a checkbox in
% stead of the default edit box, and SETTINGSDLG('fieldname', {'string1',
% 'string2'},... ) displays a popup box with the strings given in
% the second cell-array.
%
% Additionally, you can put [..., 'separator', 'seperator_string',...]
% anywhere in the argument list, which will divide all the arguments into
% sections, with section headings 'seperator_string'.
%
% You can also modify the display behavior in the case of checkboxes. When
% defining checkboxes with a 2-element logical array, the second boolean
% determines whether all fields below that checkbox are initially disabled
% (true) or not (false).
%
% Example:
%
% [settings, button] = settingsdlg(...
%     'Description', 'This dialog will set the parameters used by FMINCON()',...
%     'title'      , 'FMINCON() options',...
%     'separator'  , 'Unconstrained/General',...
%     {'This is a checkbox'; 'Check'}, [true true],...
%     {'Tolerance X';'TolX'}, 1e-6,...
%     {'Tolerance on Function';'TolFun'}, 1e-6,...
%     'Algorithm'  , {'active-set','interior-point'},...
%     'separator'  , 'Constrained',...
%     {'Tolerance on Constraints';'TolCon'}, 1e-6)
%
% See also inputdlg, dialog, errordlg, helpdlg, listdlg, msgbox, questdlg, textwrap,
% uiwait, warndlg.


% Please report bugs and inquiries to:
%
% Name   : Rody P.S. Oldenhuis
% E-mail : oldenhuis@gmail.com
% Licence: 2-clause BSD (See Licence.txt)


% If you find this work useful, please consider a donation:
% https://www.paypal.me/RodyO/3.5

    %% Initialize

    % errortraps
    narg = nargin;
    if verLessThan('MATLAB', '8.6')
        error(nargchk(1, inf, narg, 'struct')); %#ok<NCHKN>
    else
        narginchk(1, inf);
    end

    % parse input (+errortrap)
    have_settings = 0;
    if isstruct(varargin{1})
        settings = varargin{1}; have_settings = 1; end
    if (narg == 1)
        if isstruct(varargin{1})
            parameters = fieldnames(settings);
            values = cellfun(@(x)settings.(x), parameters, 'UniformOutput', false);
        else
            error('settingsdlg:incorrect_input',...
                'When passing a single argument, that argument must be a structure.')
        end
    else
        parameters = varargin(1+have_settings : 2 : end);
        values     = varargin(2+have_settings : 2 : end);
    end

    % Initialize data
    button = [];
    fields = cell(numel(parameters),1);
    tags   = fields;

    % Fill [settings] with default values & collect data
    for ii = 1:numel(parameters)

        % Extract fields & tags
        if iscell(parameters{ii})
            tags{ii}   = parameters{ii}{1};
            fields{ii} = parameters{ii}{2};
        else
            % More errortraps
            if ~ischar(parameters{ii})
                error('settingsdlg:nonstring_parameter',...
                'Arguments should be given as [''parameter'', value,...] pairs.')
            end
            tags{ii}   = parameters{ii};
            fields{ii} = parameters{ii};
        end

        % More errortraps
        if ~ischar(fields{ii})
            error('settingsdlg:fieldname_not_char',...
                'Fieldname should be a string.')
        end
        if ~ischar(tags{ii})
            error('settingsdlg:tag_not_char',...
                'Display name should be a string.')
        end

        % NOTE: 'Separator' is now in 'fields' even though
        % it will not be used as a fieldname

        % Make sure all fieldnames are properly formatted
        % (alternating capitals, no whitespace)
        if ~strcmpi(fields{ii}, {'Separator';'Title';'Description'})
            whitespace = isspace(fields{ii});
            capitalize = circshift(whitespace,[0,1]);
            fields{ii}(capitalize) = upper(fields{ii}(capitalize));
            fields{ii} = fields{ii}(~whitespace);
            % insert associated value in output
            if iscell(values{ii})
                settings.(fields{ii}) = values{ii}{1};
            elseif (length(values{ii}) > 1)
                settings.(fields{ii}) = values{ii}(1);
            else
                settings.(fields{ii}) = values{ii};
            end
        end
    end

    % Avoid (some) confusion
    clear parameters

    % Use default colorscheme from the OS
    bgcolor = get(0, 'defaultUicontrolBackgroundColor');
    % Default fontsize
    fontsize = get(0, 'defaultuicontrolfontsize');
    % Edit-bgcolor is platform-dependent.
    % MS/Windows: white.
    % UNIX: same as figure bgcolor
%     if ispc, edit_bgcolor = 'White';
%     else     edit_bgcolor = bgcolor;
%     end

% TODO: not really applicable since defaultUicontrolBackgroundColor
% doesn't really seem to work on Unix...
edit_bgcolor = 'White';

    % Get basic window properties
    title         = getValue('Adjust settings', 'Title');
    description   = getValue( [], 'Description');
    total_width   = getValue(325, 'WindowWidth');
    control_width = getValue(100, 'ControlWidth');

    % Window positioning:
    % Put the window in the center of the screen by default.
    % This will usually work fine, except on some  multi-monitor setups.
    scz  = get(0, 'ScreenSize');
    scxy = round(scz(3:4)/2-control_width/2);
    scx  = min(scz(3),max(1,scxy(1)));
    scy  = min(scz(4),max(1,scxy(2)));

    % String to pass on to movegui
    window_position = getValue('center', 'WindowPosition');


    % Calculate best height for all uicontrol()
    control_height = max(18, (fontsize+6));

    % Calculate figure height (will be adjusted later according to description)
    total_height = numel(fields)*1.25*control_height + ... % to fit all controls
                     1.5*control_height + 20; % to fit "OK" and "Cancel" buttons

    % Total number of separators
    num_separators = nnz(strcmpi(fields,'Separator'));

    % Draw figure in background
    fighandle = figure(...
         'integerhandle'   , 'off',...         % use non-integers for the handle (prevents accidental plots from going to the dialog)
         'Handlevisibility', 'off',...         % only visible from within this function
         'position'        , [scx, scy, total_width, total_height],...% figure position
         'visible'         , 'off',...         % hide the dialog while it is being constructed
         'backingstore'    , 'off',...         % DON'T save a copy in the background
         'resize'          , 'off', ...        % but just keep it resizable
         'renderer'        , 'zbuffer', ...    % best choice for speed vs. compatibility
         'WindowStyle'     ,'modal',...        % window is modal
         'units'           , 'pixels',...      % better for drawing
         'DockControls'    , 'off',...         % force it to be non-dockable
         'name'            , title,...         % dialog title
         'menubar'         ,'none', ...        % no menubar of course
         'toolbar'         ,'none', ...        % no toolbar
         'NumberTitle'     , 'off',...         % "Figure 1.4728...:" just looks corny
         'color'           , bgcolor);         % use default colorscheme

    %% Draw all required uicontrols(), and unhide window

    % Define X-offsets (different when separators are used)
    separator_offset_X = 2;
    if num_separators > 0
        text_offset_X = 20;
        text_width = (total_width-control_width-text_offset_X);
    else
        text_offset_X = separator_offset_X;
        text_width = (total_width-control_width);
    end

    % Handle description
    description_offset = 0;
    if ~isempty(description)

        % create textfield (negligible height initially)
        description_panel = uicontrol(...
            'parent'  , fighandle,...
            'style'   , 'text',...
            'Horizontalalignment', 'left',...
            'position', [separator_offset_X,...
                         total_height,total_width,1]);

        % wrap the description
        description = textwrap(description_panel, {description});

        % adjust the height of the figure
        textheight = size(description,1)*(fontsize+6);
        description_offset = textheight + 20;
        total_height = total_height + description_offset;
        set(fighandle,...
            'position', [scx, scy, total_width, total_height])

        % adjust the position of the textfield and insert the description
        set(description_panel, ...
            'string'  , description,...
            'position', [separator_offset_X, total_height-textheight, ...
                         total_width, textheight]);
    end

    % Define Y-offsets (different when descriptions are used)
    control_offset_Y = total_height-control_height-description_offset;

    % initialize loop
    controls = zeros(numel(tags)-num_separators,1);
    ii = 1;             sep_ind = 1;
    enable = 'on';      separators = zeros(num_separators,1);

    % loop through the controls
    if numel(tags) > 0
        while true

            % Should we draw a separator?
            if strcmpi(tags{ii}, 'Separator')

                % Print separator
                uicontrol(...
                    'style'   , 'text',...
                    'parent'  , fighandle,...
                    'string'  , values{ii},...
                    'horizontalalignment', 'left',...
                    'fontweight', 'bold',...
                    'position', [separator_offset_X,control_offset_Y-4, ...
                    total_width, control_height]);

                % remove separator, but save its position
                fields(ii) = [];
                tags(ii)   = [];  separators(sep_ind) = ii;
                values(ii) = [];  sep_ind = sep_ind + 1;

                % reset enable (when neccessary)
                if strcmpi(enable, 'off')
                    enable = 'on'; end

                % NOTE: DON'T increase loop index

            % ... or a setting?
            else

                % logicals: use checkbox
                if islogical(values{ii})

                    % First draw control
                    controls(ii) = uicontrol(...
                        'style'   , 'checkbox',...
                        'parent'  , fighandle,...
                        'enable'  , enable,...
                        'string'  , tags{ii},...
                        'value'   , values{ii}(1),...
                        'position', [text_offset_X,control_offset_Y-4, ...
                        total_width, control_height]);

                    % Should everything below here be OFF?
                    if (length(values{ii})>1)
                        % turn next controls off when asked for
                        if values{ii}(2)
                            enable = 'off'; end
                        % Turn on callback function
                        set(controls(ii),...
                            'Callback', @(varargin) EnableDisable(ii,varargin{:}));
                    end

                % doubles      : use edit box
                % cells        : use popup
                % cell-of-cells: use table
                else
                    % First print parameter
                    uicontrol(...
                        'style'   , 'text',...
                        'parent'  , fighandle,...
                        'string'  , [tags{ii}, ':'],...
                        'horizontalalignment', 'left',...
                        'position', [text_offset_X,control_offset_Y-4, ...
                        text_width, control_height]);

                    % Popup, edit box or table?
                    style = 'edit';
                    draw_table = false;
                    if iscell(values{ii})
                        style = 'popup';
                        if all(cellfun('isclass', values{ii}, 'cell'))
                            draw_table = true; end
                    end

                    % Draw appropriate control
                    if ~draw_table
                        controls(ii) = uicontrol(...
                            'enable'  , enable,...
                            'style'   , style,...
                            'Background', edit_bgcolor,...
                            'parent'  , fighandle,...
                            'string'  , values{ii},...
                            'position', [text_width,control_offset_Y,...
                            control_width, control_height]);
                    else
                        % TODO
                        % ...table? ...radio buttons? How to do this?
                        warning(...
                            'settingsdlg:not_yet_implemented',...
                            'Treatment of cells is not yet implemented.');

                    end
                end

                % increase loop index
                ii = ii + 1;
            end

            % end loop?
            if ii > numel(tags)
                break, end

            % Decrease offset
            control_offset_Y = control_offset_Y - 1.25*control_height;
        end
    end

    % Draw cancel button
    uicontrol(...
        'style'   , 'pushbutton',...
        'parent'  , fighandle,...
        'string'  , 'Cancel',...
        'position', [separator_offset_X,2, total_width/2.5,control_height*1.5],...
        'Callback', @Cancel)

    % Draw OK button
    uicontrol(...
        'style'   , 'pushbutton',...
        'parent'  , fighandle,...
        'string'  , 'OK',...
        'position', [total_width*(1-1/2.5)-separator_offset_X,2, ...
                     total_width/2.5,control_height*1.5],...
        'Callback', @OK)

    % move to center of screen and make visible
    movegui(fighandle, window_position);
    set(fighandle, 'Visible', 'on');

    % WAIT until OK/Cancel is pressed
    uiwait(fighandle);



    %% Helper funcitons

    % Get a value from the values array:
    % - if it does not exist, return the default value
    % - if it exists, assign it and delete the appropriate entries from the
    %   data arrays
    function val = getValue(default, tag)
        index = strcmpi(fields, tag);
        if any(index)
            val = values{index};
            values(index) = [];
            fields(index) = [];
            tags(index)   = [];
        else
            val = default;
        end
    end

    %% callback functions

    % Enable/disable controls associated with (some) checkboxes
    function EnableDisable(which, varargin) %#ok<VANUS>

        % find proper range of controls to switch
        if (num_separators > 1)
             last_control = separators(find(separators > which,1)) - 1;
             if isempty(last_control); last_control = numel(controls); end
             
             range = (which+1):(last_control);
        else
            range = (which+1):numel(controls);
        end

        % enable/disable these controls
        if strcmpi(get(controls(range(1)), 'enable'), 'off')
            set(controls(range), 'enable', 'on')
        else
            set(controls(range), 'enable', 'off')
        end
    end

    % OK button:
    % - update fields in [settings]
    % - assign [button] output argument ('ok')
    % - kill window
    function OK(varargin) %#ok<VANUS>

        % button pressed
        button = 'OK';

        % fill settings
        for i = 1:numel(controls)

            % extract current control's string, value & type
            str   = get(controls(i), 'string');
            val   = get(controls(i), 'value');
            style = get(controls(i), 'style');

            % popups/edits
            if ~strcmpi(style, 'checkbox')
                % extract correct string (popups only)
                if strcmpi(style, 'popupmenu'), str = str{val}; end
                % try to convert string to double
                val = str2double(str);
                % insert this double in [settings]. If it was not a
                % double, insert string instead
                if ~isnan(val), settings.(fields{i}) = val;
                else            settings.(fields{i}) = str;
                end

            % checkboxes
            else
                % we can insert value immediately
                settings.(fields{i}) = val;
            end
        end

        %  kill window
        delete(fighandle);
    end

    % Cancel button:
    % - assign [button] output argument ('cancel')
    % - delete figure (so: return default settings)
    function Cancel(varargin) %#ok<VANUS>
        button = 'cancel';
        delete(fighandle);
    end

end


================================================
FILE: leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/+mixin/Container.m
================================================
classdef Container < handle
    %uix.mixin.Container  Container mixin
    %
    %  uix.mixin.Container is a mixin class used by containers to provide
    %  various properties and template methods.
    
    %  Copyright 2009-2016 The MathWorks, Inc.
    %  $Revision: 1358 $ $Date: 2016-09-14 11:34:17 +0100 (Wed, 14 Sep 2016) $
    
    properties( Dependent, Access = public )
        Contents % contents in layout order
    end
    
    properties( Access = public, Dependent, AbortSet )
        Padding % space around contents, in pixels
    end
    
    properties( Access = protected )
        Contents_ = gobjects( [0 1] ) % backing for Contents
        Padding_ = 0 % backing for Padding
    end
    
    properties( Dependent, Access = protected )
        Dirty % needs redraw
    end
    
    properties( Access = private )
        Dirty_ = false % backing for Dirty
        FigureObserver % observer
        FigureListener % listener
        ChildObserver % observer
        ChildAddedListener % listener
        ChildRemovedListener % listener
        SizeChangedListener % listener
        ActivePositionPropertyListeners = cell( [0 1] ) % listeners
    end
    
    methods
        
        function obj = Container()
            %uix.mixin.Container  Initialize
            %
            %  c@uix.mixin.Container() initializes the container c.
            
            % Create observers and listeners
            figureObserver = uix.FigureObserver( obj );
            figureListener = event.listener( figureObserver, ...
                'FigureChanged', @obj.onFigureChanged );
            childObserver = uix.ChildObserver( obj );
            childAddedListener = event.listener( ...
                childObserver, 'ChildAdded', @obj.onChildAdded );
            childRemovedListener = event.listener( ...
                childObserver, 'ChildRemoved', @obj.onChildRemoved );
            sizeChangedListener = event.listener( ...
                obj, 'SizeChanged', @obj.onSizeChanged );
            
            % Store observers and listeners
            obj.FigureObserver = figureObserver;
            obj.FigureListener = figureListener;
            obj.ChildObserver = childObserver;
            obj.ChildAddedListener = childAddedListener;
            obj.ChildRemovedListener = childRemovedListener;
            obj.SizeChangedListener = sizeChangedListener;
            
            % Track usage
            obj.track()
            
        end % constructor
        
    end % structors
    
    methods
        
        function value = get.Contents( obj )
            
            value = obj.Contents_;
            
        end % get.Contents
        
        function set.Contents( obj, value )
            
            % For those who can't tell a column from a row...
            if isrow( value )
                value = transpose( value );
            end
            
            % Check
            [tf, indices] = ismember( value, obj.Contents_ );
            assert( isequal( size( obj.Contents_ ), size( value ) ) && ...
                numel( value ) == numel( unique( value ) ) && all( tf ), ...
                'uix:InvalidOperation', ...
                'Property ''Contents'' may only be set to a permutation of itself.' )
            
            % Call reorder
            obj.reorder( indices )
            
        end % set.Contents
        
        function value = get.Padding( obj )
            
            value = obj.Padding_;
            
        end % get.Padding
        
        function set.Padding( obj, value )
            
            % Check
            assert( isa( value, 'double' ) && isscalar( value ) && ...
                isreal( value ) && ~isinf( value ) && ...
                ~isnan( value ) && value >= 0, ...
                'uix:InvalidPropertyValue', ...
                'Property ''Padding'' must be a non-negative scalar.' )
            
            % Set
            obj.Padding_ = value;
            
            % Mark as dirty
            obj.Dirty = true;
            
        end % set.Padding
        
        function value = get.Dirty( obj )
            
            value = obj.Dirty_;
            
        end % get.Dirty
        
        function set.Dirty( obj, value )
            
            if value
                if obj.isDrawable() % drawable
                    obj.redraw() % redraw now
                else % not drawable
                    obj.Dirty_ = true; % flag for future redraw
                end
            end
            
        end % set.Dirty
        
    end % accessors
    
    methods( Access = private, Sealed )
        
        function onFigureChanged( obj, ~, eventData )
            %onFigureChanged  Event handler
            
            % Call template method
            obj.reparent( eventData.OldFigure, eventData.NewFigure )
            
            % Redraw if possible and if dirty
            if obj.Dirty_ && obj.isDrawable()
                obj.redraw()
                obj.Dirty_ = false;
            end
            
        end % onFigureChanged
        
        function onChildAdded( obj, ~, eventData )
            %onChildAdded  Event handler
            
            % Call template method
            obj.addChild( eventData.Child )
            
        end % onChildAdded
        
        function onChildRemoved( obj, ~, eventData )
            %onChildRemoved  Event handler
            
            % Do nothing if container is being deleted
            if strcmp( obj.BeingDeleted, 'on' ), return, end
            
            % Call template method
            obj.removeChild( eventData.Child )
            
        end % onChildRemoved
        
        function onSizeChanged( obj, ~, ~ )
            %onSizeChanged  Event handler
            
            % Mark as dirty
            obj.Dirty = true;
            
        end % onSizeChanged
        
        function onActivePositionPropertyChanged( obj, ~, ~ )
            %onActivePositionPropertyChanged  Event handler
            
            % Mark as dirty
            obj.Dirty = true;
            
        end % onActivePositionPropertyChanged
        
    end % event handlers
    
    methods( Abstract, Access = protected )
        
        redraw( obj )
        
    end % abstract template methods
    
    methods( Access = protected )
        
        function addChild( obj, child )
            %addChild  Add child
            %
            %  c.addChild(d) adds the child d to the container c.
            
            % Add to contents
            obj.Contents_(end+1,:) = child;
            
            % Add listeners
            if isa( child, 'matlab.graphics.axis.Axes' )
                obj.ActivePositionPropertyListeners{end+1,:} = ...
                    event.proplistener( child, ...
                    findprop( child, 'ActivePositionProperty' ), ...
                    'PostSet', @obj.onActivePositionPropertyChanged );
            else
                obj.ActivePositionPropertyListeners{end+1,:} = [];
            end
            
            % Mark as dirty
            obj.Dirty = true;
            
        end % addChild
        
        function removeChild( obj, child )
            %removeChild  Remove child
            %
            %  c.removeChild(d) removes the child d from the container c.
            
            % Remove from contents
            contents = obj.Contents_;
            tf = contents == child;
            obj.Contents_(tf,:) = [];
            
            % Remove listeners
            obj.ActivePositionPropertyListeners(tf,:) = [];
            
            % Mark as dirty
            obj.Dirty = true;
            
        end % removeChild
        
        function reparent( obj, oldFigure, newFigure ) %#ok<INUSD>
            %reparent  Reparent container
            %
            %  c.reparent(a,b) reparents the container c from the figure a
            %  to the figure b.
            
        end % reparent
        
        function reorder( obj, indices )
            %reorder  Reorder contents
            %
            %  c.reorder(i) reorders the container contents using indices
            %  i, c.Contents = c.Contents(i).
            
            % Reorder contents
            obj.Contents_ = obj.Contents_(indices,:);
            
            % Reorder listeners
            obj.ActivePositionPropertyListeners = ...
                obj.ActivePositionPropertyListeners(indices,:);
            
            % Mark as dirty
            obj.Dirty = true;
            
        end % reorder
        
        function tf = isDrawable( obj )
            %isDrawable  Test for drawability
            %
            %  c.isDrawable() is true if the container c is drawable, and
            %  false otherwise.  To be drawable, a container must be
            %  rooted.
            
            tf = ~isempty( obj.FigureObserver.Figure );
            
        end % isDrawable
        
        function track( obj )
            %track  Track usage
            
            persistent TRACKED % single shot
            if isempty( TRACKED )
                v = ver( 'layout' );
                try %#ok<TRYNC>
                    uix.tracking( 'UA-82270656-2', v(1).Version, class( obj ) )
                end
                TRACKED = true;
            end
            
        end % track
        
    end % template methods
    
end % classdef

================================================
FILE: leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/+mixin/Flex.m
================================================
classdef Flex < handle
    %uix.mixin.Flex  Flex mixin
    %
    %  uix.mixin.Flex is a mixin class used by flex containers to provide
    %  various properties and helper methods.
    
    %  Copyright 2016 The MathWorks, Inc.
    %  $Revision: 1435 $ $Date: 2016-11-17 17:50:34 +0000 (Thu, 17 Nov 2016) $
    
    properties( GetAccess = protected, SetAccess = private )
        Pointer = 'unset' % mouse pointer
    end
    
    properties( Access = private )
        Figure = gobjects( 0 ); % mouse pointer figure
        Token = -1 % mouse pointer token
    end
    
    methods
        
        function delete( obj )
            %delete  Destructor
            
            % Clean up
            if ~strcmp( obj.Pointer, 'unset' )
                obj.unsetPointer()
            end
            
        end % destructor
        
    end % structors
    
    methods( Access = protected )
        
        function setPointer( obj, figure, pointer )
            %setPointer  Set pointer
            %
            %  c.setPointer(f,p) sets the pointer for the figure f to p.
            
            % If set, unset
            if obj.Token ~= -1
                obj.unsetPointer()
            end
            
            % Set
            obj.Token = uix.PointerManager.setPointer( figure, pointer );
            obj.Figure = figure;
            obj.Pointer = pointer;
            
        end % setPointer
        
        function unsetPointer( obj )
            %unsetPointer  Unset pointer
            %
            %  c.unsetPointer() undoes the previous pointer set.
            
            % Check
            assert( obj.Token ~= -1, 'uix:InvalidOperation', ...
                'Pointer is already unset.' )
            
            % Unset
            uix.PointerManager.unsetPointer( obj.Figure, obj.Token );
            obj.Figure = gobjects( 0 );
            obj.Pointer = 'unset';
            obj.Token = -1;
            
        end % unsetPointer
        
    end % helper methods
    
end % classdef

================================================
FILE: leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/+mixin/Panel.m
================================================
classdef Panel < uix.mixin.Container
    %uix.mixin.Panel  Panel mixin
    %
    %  uix.mixin.Panel is a mixin class used by panels to provide various
    %  properties and template methods.
    
    %  Copyright 2009-2015 The MathWorks, Inc.
    %  $Revision: 1435 $ $Date: 2016-11-17 17:50:34 +0000 (Thu, 17 Nov 2016) $
    
    properties( Access = public, Dependent, AbortSet )
        Selection % selected contents
    end
    
    properties( Access = protected )
        Selection_ = 0 % backing for Selection
    end
    
    properties( Access = protected )
        G1218142 = false % bug flag
    end
    
    events( NotifyAccess = protected )
        SelectionChanged % selection changed
    end
    
    methods
        
        function value = get.Selection( obj )
            
            value = obj.Selection_;
            
        end % get.Selection
        
        function set.Selection( obj, value )
            
            % Check
            assert( isa( value, 'double' ), 'uix:InvalidPropertyValue', ...
                'Property ''Selection'' must be of type double.' )
            assert( isequal( size( value ), [1 1] ), ...
                'uix:InvalidPropertyValue', ...
                'Property ''Selection'' must be scalar.' )
            assert( isreal( value ) && rem( value, 1 ) == 0, ...
                'uix:InvalidPropertyValue', ...
                'Property ''Selection'' must be an integer.' )
            n = numel( obj.Contents_ );
            if n == 0
                assert( value == 0, 'uix:InvalidPropertyValue', ...
                    'Property ''Selection'' must be 0 for a container with no children.' )
            else
                assert( value >= 1 && value <= n, 'uix:InvalidPropertyValue', ...
                    'Property ''Selection'' must be between 1 and the number of children.' )
            end
            
            % Set
            oldSelection = obj.Selection_;
            newSelection = value;
            obj.Selection_ = newSelection;
            
            % Show selected child
            obj.showSelection()
            
            % Mark as dirty
            obj.Dirty = true;
            
            % Raise event
            notify( obj, 'SelectionChanged', ...
                uix.SelectionData( oldSelection, newSelection ) )
            
        end % set.Selection
        
    end % accessors
    
    methods( Access = protected )
        
        function addChild( obj, child )
            
            % Check for bug
            if verLessThan( 'MATLAB', '8.5' ) && strcmp( child.Visible, 'off' )
                obj.G1218142 = true;
            end
            
            % Select new content
            oldSelection = obj.Selection_;
            newSelection = numel( obj.Contents_ ) + 1;
            obj.Selection_ = newSelection;
            
            % Call superclass method
            addChild@uix.mixin.Container( obj, child )
            
            % Show selected child
            obj.showSelection()
            
            % Notify selection change
            obj.notify( 'SelectionChanged', ...
                uix.SelectionData( oldSelection, newSelection ) )
            
        end % addChild
        
        function removeChild( obj, child )
            
            % Adjust selection if required
            contents = obj.Contents_;
            index = find( contents == child );
            oldSelection = obj.Selection_;
            if index < oldSelection
                newSelection = oldSelection - 1;
            elseif index == oldSelection
                newSelection = min( oldSelection, numel( contents ) - 1 );
            else % index > oldSelection
                newSelection = oldSelection;
            end
            obj.Selection_ = newSelection;
            
            % Call superclass method
            removeChild@uix.mixin.Container( obj, child )
            
            % Show selected child
            obj.showSelection()
            
            % Notify selection change
            if oldSelection ~= newSelection
                obj.notify( 'SelectionChanged', ...
                    uix.SelectionData( oldSelection, newSelection ) )
            end
            
        end % removeChild
        
        function reorder( obj, indices )
            %reorder  Reorder contents
            %
            %  c.reorder(i) reorders the container contents using indices
            %  i, c.Contents = c.Contents(i).
            
            % Reorder
            selection = obj.Selection_;
            if selection ~= 0
                obj.Selection_ = find( indices == selection );
            end
            
            % Call superclass method
            reorder@uix.mixin.Container( obj, indices )
            
        end % reorder
        
        function showSelection( obj )
            %showSelection  Show selected child, hide the others
            %
            %  c.showSelection() shows the selected child of the container
            %  c, and hides the others.
            
            % Set positions and visibility
            selection = obj.Selection_;
            children = obj.Contents_;
            for ii = 1:numel( children )
                child = children(ii);
                if ii == selection
                    if obj.G1218142
                        warning( 'uix:G1218142', ...
                            'Selected child of %s is not visible due to bug G1218142.  The child will become visible at the next redraw.', ...
                            class( obj ) )
                        obj.G1218142 = false;
                    else
                        child.Visible = 'on';
                    end
                    if isa( child, 'matlab.graphics.axis.Axes' )
                        child.ContentsVisible = 'on';
                    end
                else
                    child.Visible = 'off';
                    if isa( child, 'matlab.graphics.axis.Axes' )
                        child.ContentsVisible = 'off';
                    end
                    % As a remedy for g1100294, move off-screen too
                    margin = 1000;
                    if isa( child, 'matlab.graphics.axis.Axes' ) ...
                            && strcmp(child.ActivePositionProperty, 'outerposition' )
                        child.OuterPosition(1) = -child.OuterPosition(3)-margin;
                    else
                        child.Position(1) = -child.Position(3)-margin;
                    end
                end
            end
            
        end % showSelection
        
    end % template methods
    
end % classdef

================================================
FILE: leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/Box.m
================================================
classdef Box < uix.Container & uix.mixin.Container
    %uix.Box  Box and grid base class
    %
    %  uix.Box is a base class for containers with spacing between
    %  contents.
    
    %  Copyright 2009-2015 The MathWorks, Inc.
    %  $Revision: 1165 $ $Date: 2015-12-06 03:09:17 -0500 (Sun, 06 Dec 2015) $
    
    properties( Access = public, Dependent, AbortSet )
        Spacing = 0 % space between contents, in pixels
    end
    
    properties( Access = protected )
        Spacing_ = 0 % backing for Spacing
    end
    
    methods
        
        function value = get.Spacing( obj )
            
            value = obj.Spacing_;
            
        end % get.Spacing
        
        function set.Spacing( obj, value )
            
            % Check
            assert( isa( value, 'double' ) && isscalar( value ) && ...
                isreal( value ) && ~isinf( value ) && ...
                ~isnan( value ) && value >= 0, ...
                'uix:InvalidPropertyValue', ...
                'Property ''Spacing'' must be a non-negative scalar.' )
            
            % Set
            obj.Spacing_ = value;
            
            % Mark as dirty
            obj.Dirty = true;
            
        end % set.Spacing
        
    end % accessors
    
end % classdef

================================================
FILE: leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/ChildEvent.m
================================================
classdef( Hidden, Sealed ) ChildEvent < event.EventData
    %uix.ChildEvent  Event data for child event
    %
    %  e = uix.ChildEvent(c) creates event data including the child c.
    
    %  Copyright 2009-2015 The MathWorks, Inc.
    %  $Revision: 1165 $ $Date: 2015-12-06 03:09:17 -0500 (Sun, 06 Dec 2015) $
    
    properties( SetAccess = private )
        Child % child
    end
    
    methods
        
        function obj = ChildEvent( child )
            %uix.ChildEvent  Event data for child event
            %
            %  e = uix.ChildEvent(c) creates event data including the child
            %  c.
            
            % Set properties
            obj.Child = child;
            
        end % constructor
        
    end % structors
    
end % classdef

================================================
FILE: leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/ChildObserver.m
================================================
classdef ( Hidden, Sealed ) ChildObserver < handle
    %uix.ChildObserver  Child observer
    %
    %  co = uix.ChildObserver(o) creates a child observer for the graphics
    %  object o.  A child observer raises events when objects are added to
    %  and removed from the property Children of o.
    %
    %  See also: uix.Node
    
    %  Copyright 2009-2016 The MathWorks, Inc.
    %  $Revision: 1436 $ $Date: 2016-11-17 17:53:29 +0000 (Thu, 17 Nov 2016) $
    
    properties( Access = private )
        Root % root node
    end
    
    events( NotifyAccess = private )
        ChildAdded % child added
        ChildRemoved % child removed
    end
    
    methods
        
        function obj = ChildObserver( oRoot )
            %uix.ChildObserver  Child observer
            %
            %  co = uix.ChildObserver(o) creates a child observer for the
            %  graphics object o.  A child observer raises events when
            %  objects are added to and removed from the property Children
            %  of o.
            
            % Check
            assert( iscontent( oRoot ) && ...
                isequal( size( oRoot ), [1 1] ), 'uix.InvalidArgument', ...
                'Object must be a graphics object.' )
            
            % Create root node
            nRoot = uix.Node( oRoot );
            childAddedListener = event.listener( oRoot, ...
                'ObjectChildAdded', ...
                @(~,e)obj.addChild(nRoot,e.Child) );
            childAddedListener.Recursive = true;
            nRoot.addprop( 'ChildAddedListener' );
            nRoot.ChildAddedListener = childAddedListener;
            childRemovedListener = event.listener( oRoot, ...
                'ObjectChildRemoved', ...
                @(~,e)obj.removeChild(nRoot,e.Child) );
            childRemovedListener.Recursive = true;
            nRoot.addprop( 'ChildRemovedListener' );
            nRoot.ChildRemovedListener = childRemovedListener;
            
            % Add children
            oChildren = hgGetTrueChildren( oRoot );
            for ii = 1:numel( oChildren )
                obj.addChild( nRoot, oChildren(ii) )
            end
            
            % Store properties
            obj.Root = nRoot;
            
        end % constructor
        
    end % structors
    
    methods( Access = private )
        
        function addChild( obj, nParent, oChild )
            %addChild  Add child object to parent node
            %
            %  co.addChild(np,oc) adds the child object oc to the parent
            %  node np, either as part of construction of the child
            %  observer co, or in response to an ObjectChildAdded event on
            %  an object of interest to co.  This may lead to ChildAdded
            %  events being raised on co.
            
            % Create child node
            nChild = uix.Node( oChild );
            nParent.addChild( nChild )
            if iscontent( oChild )
                % Add Internal PreSet property listener
                internalPreSetListener = event.proplistener( oChild, ...
                    findprop( oChild, 'Internal' ), 'PreSet', ...
                    @(~,~)obj.preSetInternal(nChild) );
                nChild.addprop( 'InternalPreSetListener' );
                nChild.InternalPreSetListener = internalPreSetListener;
                % Add Internal PostSet property listener
                internalPostSetListener = event.proplistener( oChild, ...
                    findprop( oChild, 'Internal' ), 'PostSet', ...
                    @(~,~)obj.postSetInternal(nChild) );
                nChild.addprop( 'InternalPostSetListener' );
                nChild.InternalPostSetListener = internalPostSetListener;
            else
                % Add ObjectChildAdded listener
                childAddedListener = event.listener( oChild, ...
                    'ObjectChildAdded', ...
                    @(~,e)obj.addChild(nChild,e.Child) );
                nChild.addprop( 'ChildAddedListener' );
                nChild.ChildAddedListener = childAddedListener;
                % Add ObjectChildRemoved listener
                childRemovedListener = event.listener( oChild, ...
                    'ObjectChildRemoved', ...
                    @(~,e)obj.removeChild(nChild,e.Child) );
                nChild.addprop( 'ChildRemovedListener' );
                nChild.ChildRemovedListener = childRemovedListener;
            end
            
            % Raise ChildAdded event
            if iscontent( oChild ) && oChild.Internal == false
                notify( obj, 'ChildAdded', uix.ChildEvent( oChild ) )
            end
            
            % Add grandchildren
            if ~iscontent( oChild )
                oGrandchildren = hgGetTrueChildren( oChild );
                for ii = 1:numel( oGrandchildren )
                    obj.addChild( nChild, oGrandchildren(ii) )
                end
            end
            
        end % addChild
        
        function removeChild( obj, nParent, oChild )
            %removeChild  Remove child object from parent node
            %
            %  co.removeChild(np,oc) removes the child object oc from the
            %  parent node np, in response to an ObjectChildRemoved event
            %  on an object of interest to co.  This may lead to
            %  ChildRemoved events being raised on co.
            
            % Get child node
            nChildren = nParent.Children;
            tf = oChild == [nChildren.Object];
            nChild = nChildren(tf);
            
            % Raise ChildRemoved event(s)
            notifyChildRemoved( nChild )
            
            % Delete child node
            delete( nChild )
            
            function notifyChildRemoved( nc )
                
                % Process child nodes
                ngc = nc.Children;
                for ii = 1:numel( ngc )
                    notifyChildRemoved( ngc(ii) )
                end
                
                % Process this node
                oc = nc.Object;
                if iscontent( oc ) && oc.Internal == false
                    notify( obj, 'ChildRemoved', uix.ChildEvent( oc ) )
                end
                
            end % notifyChildRemoved
            
        end % removeChild
        
        function preSetInternal( ~, nChild )
            %preSetInternal  Perform property PreSet tasks
            %
            %  co.preSetInternal(n) caches the previous value of the
            %  property Internal of the object referenced by the node n, to
            %  enable PostSet tasks to identify whether the value changed.
            %  This is necessary since Internal AbortSet is false.
            
            oldInternal = nChild.Object.Internal;
            nChild.addprop( 'OldInternal' );
            nChild.OldInternal = oldInternal;
            
        end % preSetInternal
        
        function postSetInternal( obj, nChild )
            %postSetInternal  Perform property PostSet tasks
            %
            %  co.postSetInternal(n) raises a ChildAdded or ChildRemoved
            %  event on the child observer co in response to a change of
            %  the value of the property Internal of the object referenced
            %  by the node n.
            
            % Retrieve old and new values
            oChild = nChild.Object;
            newInternal = oChild.Internal;
            oldInternal = nChild.OldInternal;
            
            % Clean up node
            delete( findprop( nChild, 'OldInternal' ) )
            
            % Raise event
            switch newInternal
                case oldInternal % no change
                    % no event
                case true % false to true
                    notify( obj, 'ChildRemoved', uix.ChildEvent( oChild ) )
                case false % true to false
                    notify( obj, 'ChildAdded', uix.ChildEvent( oChild ) )
            end
            
        end % postSetInternal
        
    end % event handlers
    
end % classdef

function tf = iscontent( o )
%iscontent  True for graphics that can be Contents (and can be Children)
%
%  uix.ChildObserver needs to determine which objects can be Contents,
%  which is equivalent to can be Children if HandleVisibility is 'on' and
%  Internal is false.  Prior to R2016a, this condition could be checked
%  using isgraphics.  From R2016a, isgraphics returns true for a wider
%  range of objects, including some that can never by Contents, e.g.,
%  JavaCanvas.  Therefore this function checks whether an object is of type
%  matlab.graphics.internal.GraphicsBaseFunctions, which is what isgraphics
%  did prior to R2016a.

tf = isa( o, 'matlab.graphics.internal.GraphicsBaseFunctions' ) &&...
     isprop( o, 'Position' );

end % iscontent

================================================
FILE: leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/Container.m
================================================
classdef Container < matlab.ui.container.internal.UIContainer
    %uix.Container  Container base class
    %
    %  uix.Container is base class for containers that extend uicontainer.
    
    %  Copyright 2009-2015 The MathWorks, Inc.
    %  $Revision: 1137 $ $Date: 2015-05-29 21:48:21 +0100 (Fri, 29 May 2015) $
    
end % classdef

================================================
FILE: leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/Divider.m
================================================
classdef Divider < matlab.mixin.SetGet
    %uix.Divider  Draggable divider
    %
    %  d = uix.Divider() creates a divider.
    %
    %  d = uix.Divider(p1,v1,p2,v2,...) creates a divider and sets
    %  specified property p1 to value v1, etc.
    
    %  Copyright 2009-2016 The MathWorks, Inc.
    %  $Revision: 1436 $ $Date: 2016-11-17 17:53:29 +0000 (Thu, 17 Nov 2016) $
    
    properties( Dependent )
        Parent % parent
        Units % units [inches|centimeters|characters|normalized|points|pixels]
        Position % position
        Visible % visible [on|off]
        BackgroundColor % background color [RGB]
        HighlightColor % border highlight color [RGB]
        ShadowColor % border shadow color [RGB]
        Orientation % orientation [vertical|horizontal]
        Markings % markings [pixels]
    end
    
    properties( Access = private )
        Control % uicontrol
        BackgroundColor_ = get( 0, 'DefaultUicontrolBackgroundColor' ) % backing for BackgroundColor
        HighlightColor_ = [1 1 1] % backing for HighlightColor
        ShadowColor_ = [0.7 0.7 0.7] % backing for ShadowColor
        Orientation_ = 'vertical' % backing for Orientation
        Markings_ = zeros( [0 1] ) % backing for Markings
        SizeChangedListener % listener
    end
    
    methods
        
        function obj = Divider( varargin )
            %uix.Divider  Draggable divider
         
Download .txt
gitextract_ix4wtaz2/

├── .gitignore
├── LICENSE
├── analysis/
│   └── gait_analysis/
│       ├── Cluster_Velocity_Distributions.mat
│       ├── GaitVectors3.mat
│       ├── Gait_Densities.mat
│       ├── Gait_Speed_Distributions.mat
│       ├── Swing_Velocity_Over_Time.mat
│       ├── Swing_and_Stance_versus_Velocity.mat
│       ├── TetrapodExample.mat
│       ├── TripodExample.mat
│       ├── compute_gait_densities.m
│       ├── gait_analysis_computation.m
│       ├── gait_analysis_plotting.m
│       └── plot_gait_densities.m
├── data/
│   └── readme.md
├── examples/
│   ├── batch_process_video.ipynb
│   ├── hdf5tovid.m
│   └── vidtohdf5.m
├── install_leap.m
├── leap/
│   ├── __init__.py
│   ├── compute_errors.m
│   ├── confmaps2pts.m
│   ├── generate_training_set.m
│   ├── graph2paf.m
│   ├── guis/
│   │   ├── cluster_sample.mlapp
│   │   ├── create_skeleton.mlapp
│   │   └── label_joints.m
│   ├── hpc/
│   │   └── python_gpu.sh
│   ├── image_augmentation.py
│   ├── layers.py
│   ├── models.py
│   ├── plot_joints_single.m
│   ├── predict_box.m
│   ├── predict_box.py
│   ├── pts2confmaps.m
│   ├── test_leap.m
│   ├── toolbox/
│   │   ├── aliases/
│   │   │   ├── alims.m
│   │   │   ├── ff.m
│   │   │   ├── h5file.m
│   │   │   ├── imgsc.m
│   │   │   └── repext.m
│   │   ├── graphics/
│   │   │   ├── FEX-settingsdlg/
│   │   │   │   └── settingsdlg.m
│   │   │   ├── GUI Layout Toolbox/
│   │   │   │   └── layout/
│   │   │   │       └── +uix/
│   │   │   │           ├── +mixin/
│   │   │   │           │   ├── Container.m
│   │   │   │           │   ├── Flex.m
│   │   │   │           │   └── Panel.m
│   │   │   │           ├── Box.m
│   │   │   │           ├── ChildEvent.m
│   │   │   │           ├── ChildObserver.m
│   │   │   │           ├── Container.m
│   │   │   │           ├── Divider.m
│   │   │   │           ├── FigureData.m
│   │   │   │           ├── FigureObserver.m
│   │   │   │           ├── Grid.m
│   │   │   │           ├── GridFlex.m
│   │   │   │           ├── HBox.m
│   │   │   │           ├── Node.m
│   │   │   │           ├── Panel.m
│   │   │   │           ├── PointerManager.m
│   │   │   │           ├── SelectionData.m
│   │   │   │           ├── VBox.m
│   │   │   │           ├── calcPixelSizes.m
│   │   │   │           ├── setPosition.m
│   │   │   │           └── tracking.m
│   │   │   ├── distributionPlot/
│   │   │   │   └── colorCode2rgb.m
│   │   │   ├── draggable/
│   │   │   │   └── draggable.m
│   │   │   ├── figclosekey.m
│   │   │   ├── figsize.m
│   │   │   ├── fontsize.m
│   │   │   ├── hline.m
│   │   │   ├── isax.m
│   │   │   ├── isfig.m
│   │   │   ├── noticks.m
│   │   │   ├── pareto2.m
│   │   │   ├── plotExplainedVar.m
│   │   │   ├── plotpts.m
│   │   │   ├── redblue.m
│   │   │   ├── sc/
│   │   │   │   ├── gray.m
│   │   │   │   ├── private/
│   │   │   │   │   ├── colormap_helper.m
│   │   │   │   │   └── rescale.m
│   │   │   │   └── real2rgb.m
│   │   │   └── shortticks.m
│   │   ├── hdf5/
│   │   │   ├── h5att2struct.m
│   │   │   ├── h5getdatasets.m
│   │   │   ├── h5readframes.m
│   │   │   ├── h5readgroup.m
│   │   │   ├── h5save.m
│   │   │   ├── h5savegroup.m
│   │   │   ├── h5size.m
│   │   │   ├── h5struct2att.m
│   │   │   └── hdf5prop/
│   │   │       ├── h5datacreate.m
│   │   │       └── hdf5prop.m
│   │   ├── imageproc/
│   │   │   └── ind2im.m
│   │   ├── inputParsing/
│   │   │   ├── get_caller_name.m
│   │   │   ├── nameval2struct.m
│   │   │   ├── parse_params.m
│   │   │   └── struct2nameval.m
│   │   ├── io/
│   │   │   ├── GetFullPath.m
│   │   │   ├── dir_ext.m
│   │   │   ├── dir_paths.m
│   │   │   ├── dir_regex.m
│   │   │   ├── exists.m
│   │   │   ├── ext2filter_spec.m
│   │   │   ├── extrep.m
│   │   │   ├── funpath.m
│   │   │   ├── get_ext.m
│   │   │   ├── get_filename.m
│   │   │   ├── get_filesize.m
│   │   │   ├── get_new_filename.m
│   │   │   ├── lastdir.m
│   │   │   ├── mkdirto.m
│   │   │   └── uibrowse.m
│   │   ├── ml/
│   │   │   └── ezpca.m
│   │   ├── strings/
│   │   │   ├── bytes2str.m
│   │   │   ├── instr.m
│   │   │   ├── printf.m
│   │   │   ├── secs2hms.m
│   │   │   └── secsf.m
│   │   ├── utilities/
│   │   │   ├── af.m
│   │   │   ├── arange.m
│   │   │   ├── areempty.m
│   │   │   ├── argmin.m
│   │   │   ├── arr2cell.m
│   │   │   ├── cell1.m
│   │   │   ├── cellcat.m
│   │   │   ├── cf.m
│   │   │   ├── clip.m
│   │   │   ├── functional_programming/
│   │   │   │   └── wrap.m
│   │   │   ├── get_new_string.m
│   │   │   ├── grp2cell.m
│   │   │   ├── horz.m
│   │   │   ├── iseven.m
│   │   │   ├── loadvar.m
│   │   │   ├── nunique.m
│   │   │   ├── rownorm.m
│   │   │   ├── stacks/
│   │   │   │   ├── imtile.m
│   │   │   │   ├── stack2cell.m
│   │   │   │   ├── stack2vecs.m
│   │   │   │   └── vecs2stack.m
│   │   │   ├── swap.m
│   │   │   ├── time/
│   │   │   │   ├── GetSystemTimePreciseAsFileTime.m
│   │   │   │   ├── GetSystemTimePreciseAsFileTime.mexw64
│   │   │   │   ├── stic.m
│   │   │   │   ├── stoc.m
│   │   │   │   ├── stocf.m
│   │   │   │   └── systime.m
│   │   │   ├── varsize.m
│   │   │   ├── varstruct.m
│   │   │   ├── vert.m
│   │   │   └── vplay.m
│   │   └── video/
│   │       ├── validate_stack.m
│   │       └── vplayer.m
│   ├── training.py
│   ├── utils.py
│   └── viz.py
├── readme.md
├── setup.py
└── uninstall_leap.m
Download .txt
SYMBOL INDEX (45 symbols across 7 files)

FILE: leap/image_augmentation.py
  function transform_imgs (line 6) | def transform_imgs(X, theta=(-180,180), scale=1.0):
  class PairedImageAugmenter (line 47) | class PairedImageAugmenter(Sequence):
    method __init__ (line 48) | def __init__(self, X, Y, batch_size=32, shuffle=False, theta=(-180,180...
    method __len__ (line 62) | def __len__(self):
    method __getitem__ (line 65) | def __getitem__(self, batch_idx):
  class MultiInputOutputPairedImageAugmenter (line 75) | class MultiInputOutputPairedImageAugmenter(PairedImageAugmenter):
    method __init__ (line 76) | def __init__(self, input_names, output_names, *args, **kwargs):
    method __getitem__ (line 85) | def __getitem__(self, batch_idx):

FILE: leap/layers.py
  function resize_images (line 43) | def resize_images(x, height_factor, width_factor, interpolation, data_fo...
  class UpSampling2D (line 86) | class UpSampling2D(Layer):
    method __init__ (line 120) | def __init__(self, size=(2, 2), data_format=None, interpolation='neare...
    method compute_output_shape (line 132) | def compute_output_shape(self, input_shape):
    method call (line 148) | def call(self, inputs):
    method get_config (line 152) | def get_config(self):
  function _find_maxima (line 159) | def _find_maxima(x):
  function find_maxima (line 180) | def find_maxima(x, data_format):
  class Maxima2D (line 202) | class Maxima2D(Layer):
    method __init__ (line 232) | def __init__(self, data_format=None, **kwargs):
    method compute_output_shape (line 241) | def compute_output_shape(self, input_shape):
    method call (line 251) | def call(self, inputs):
    method get_config (line 254) | def get_config(self):
  function residual_bottleneck_module (line 260) | def residual_bottleneck_module(x_in, output_filters=32, bottleneck_facto...

FILE: leap/models.py
  function leap_cnn (line 9) | def leap_cnn(img_size, output_channels, filters=64, upsampling_layers=Fa...
  function hourglass (line 59) | def hourglass(img_size, output_channels, filters=64, upsampling_layers=F...
  function stacked_hourglass (line 131) | def stacked_hourglass(img_size, output_channels, filters=64, upsampling_...

FILE: leap/predict_box.py
  function tf_find_peaks (line 15) | def tf_find_peaks(x):
  function convert_to_peak_outputs (line 48) | def convert_to_peak_outputs(model, include_confmaps=False):
  function predict_box (line 62) | def predict_box(box_path, model_path, out_path, *, box_dset="/box", epoc...

FILE: leap/training.py
  function train_val_split (line 19) | def train_val_split(X, Y, val_size=0.15, shuffle=True):
  function create_run_folders (line 35) | def create_run_folders(run_name, base_path="models", clean=False):
  class LossHistory (line 65) | class LossHistory(keras.callbacks.Callback):
    method __init__ (line 66) | def __init__(self, run_path):
    method on_train_begin (line 70) | def on_train_begin(self, logs={}):
    method on_epoch_end (line 73) | def on_epoch_end(self, batch, logs={}):
  function create_model (line 85) | def create_model(net_name, img_size, output_channels, **kwargs):
  function train (line 99) | def train(data_path, *,

FILE: leap/utils.py
  function versions (line 7) | def versions(list_devices=False):
  function find_weights (line 25) | def find_weights(model_path):
  function find_best_weights (line 38) | def find_best_weights(model_path):
  function load_dataset (line 48) | def load_dataset(data_path, X_dset="box", Y_dset="confmaps", permute=(0,...
  function preprocess (line 66) | def preprocess(X, permute=(0,3,2,1)):

FILE: leap/viz.py
  function show_pred (line 6) | def show_pred(net, X, Y, joint_idx=0, alpha_pred=0.7, save_path=None, sh...
  function gallery (line 78) | def gallery(array, ncols=4):
  function show_confmap_grid (line 94) | def show_confmap_grid(net, X, Y, plot=True, save_path=None, show_figure=...
  function plot_history (line 133) | def plot_history(history, save_path=None, show_figure=False):
Condensed preview — 157 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (475K chars).
[
  {
    "path": ".gitignore",
    "chars": 1294,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "analysis/gait_analysis/compute_gait_densities.m",
    "chars": 1286,
    "preview": "%% Pathing\nembed_path = 'Z:\\code\\2018-05-05_joints_tsne_FlyAging_talmo-labels\\results\\FlyAging-DiegoCNN_v1.0_filters=64_"
  },
  {
    "path": "analysis/gait_analysis/gait_analysis_computation.m",
    "chars": 11182,
    "preview": "clear all;\n%% Pathing\n% addpath(genpath('deps'));\njoints_dir = 'Z:\\data\\JointTracker\\2018-02_FlyAging_boxes\\expts\\preds\\"
  },
  {
    "path": "analysis/gait_analysis/gait_analysis_plotting.m",
    "chars": 4418,
    "preview": "clear all;\n%% Look at particular section\n% Pick the gait example to observe\n% load('TetrapodExample');\nload('TripodExamp"
  },
  {
    "path": "analysis/gait_analysis/plot_gait_densities.m",
    "chars": 1109,
    "preview": "clear all;\n%% Plot all three densities overlayed ontop of one another\nload('Gait_Densities');\nfigure; hold on;\nh = image"
  },
  {
    "path": "data/readme.md",
    "chars": 1142,
    "preview": "The full fly dataset and all trained networks used in our paper can be downloaded from: [`http://arks.princeton.edu/ark:"
  },
  {
    "path": "examples/batch_process_video.ipynb",
    "chars": 12990,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Example: Predict body part positi"
  },
  {
    "path": "examples/hdf5tovid.m",
    "chars": 1763,
    "preview": "% Clean start!\nclear all, clc\n\n%% Parameters\n% Path to input file\ndataPath = '../data/examples/072212_163153.clip.h5';\n\n"
  },
  {
    "path": "examples/vidtohdf5.m",
    "chars": 2042,
    "preview": "% Clean start!\nclear all, clc\n\n%% Parameters\n% Path to input file\nvideoPath = '..\\..\\leap\\data\\examples\\072212_163153.cl"
  },
  {
    "path": "install_leap.m",
    "chars": 540,
    "preview": "function install_leap()\n%INSTALL_LEAP Installs the Python package and adds MATLAB scripts to path.\n% Usage:\n%   install_"
  },
  {
    "path": "leap/__init__.py",
    "chars": 161,
    "preview": "from . import image_augmentation\nfrom . import layers\nfrom . import models\nfrom . import predict_box\nfrom . import train"
  },
  {
    "path": "leap/compute_errors.m",
    "chars": 1086,
    "preview": "function err = compute_errors(pos_pred, pos_gt)\n%COMPUTE_ERRORS Computes error rates given predicted and ground truth po"
  },
  {
    "path": "leap/confmaps2pts.m",
    "chars": 432,
    "preview": "function [pts, confvals] = confmaps2pts(C)\n%CONFMAPS2PTS Convert a set of confidence maps into a set of points.\n% Usage:"
  },
  {
    "path": "leap/generate_training_set.m",
    "chars": 7827,
    "preview": "function savePath = generate_training_set(boxPath, varargin)\n%GENERATE_TRAINING_SET Creates a dataset for training.\n% Us"
  },
  {
    "path": "leap/graph2paf.m",
    "chars": 1558,
    "preview": "function paf = graph2paf(nodes, edges, sz, channelsOnly, sigma)\n%GRAPH2PAF Converts a set of edges into part affinity fi"
  },
  {
    "path": "leap/guis/label_joints.m",
    "chars": 45151,
    "preview": "function label_joints(boxPath, skeletonPath)\n%LABEL_JOINTS GUI to click on images to yield a graph.\n% Usage:\n%   label_j"
  },
  {
    "path": "leap/hpc/python_gpu.sh",
    "chars": 210,
    "preview": "#!/bin/bash\n#SBATCH --time=4:00:00\n#SBATCH --mem=128000\n#SBATCH -N 1\n#SBATCH --cpus-per-task=4\n#SBATCH --ntasks-per-node"
  },
  {
    "path": "leap/image_augmentation.py",
    "chars": 2795,
    "preview": "import cv2\nimport numpy as np\nimport keras\nfrom keras.utils import Sequence\n\ndef transform_imgs(X, theta=(-180,180), sca"
  },
  {
    "path": "leap/layers.py",
    "chars": 11274,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"\n   Copyright 2018 Jacob M. Graving <jgraving@gmail.com>\n\n   Licensed under the Apache Licens"
  },
  {
    "path": "leap/models.py",
    "chars": 12299,
    "preview": "import keras\nfrom keras import backend as K\nfrom keras.models import Model\nfrom keras.layers import Input, Conv2D, Conv2"
  },
  {
    "path": "leap/plot_joints_single.m",
    "chars": 842,
    "preview": "function h = plot_joints_single(pts, segments, markerSize, lineWidth)\n%PLOT_JOINTS_SINGLE Plot joints on a single frame."
  },
  {
    "path": "leap/predict_box.m",
    "chars": 2675,
    "preview": "function preds = predict_box(box, modelPath, saveConfmaps)\n%PREDICT_BOX Evaluates model predictions on a stack of frames"
  },
  {
    "path": "leap/predict_box.py",
    "chars": 6978,
    "preview": "import h5py\nimport numpy as np\nimport os\nfrom time import time\nimport keras\nimport keras.models\nfrom keras.layers import"
  },
  {
    "path": "leap/pts2confmaps.m",
    "chars": 1073,
    "preview": "function confmaps = pts2confmaps(pts, sz, sigma, normalize)\n%PTS2CONFMAPS Generate confidence maps centered at specified"
  },
  {
    "path": "leap/test_leap.m",
    "chars": 854,
    "preview": "function works = test_leap()\n%TEST_LEAP Checks whether LEAP is properly installed.\n% Usage:\n%   test_leap\n%   works = te"
  },
  {
    "path": "leap/toolbox/aliases/alims.m",
    "chars": 195,
    "preview": "function varargout = alims(X)\n%ALIMS Alias for arange.\n% Usage:\n%   R = alims(X)\n%   [min_val, max_val] = alims(X)\n%\n% S"
  },
  {
    "path": "leap/toolbox/aliases/ff.m",
    "chars": 103,
    "preview": "function varargout = ff(varargin)\n    N = max(nargout,1);\n\tvarargout{1:N} = fullfile(varargin{:});\nend\n"
  },
  {
    "path": "leap/toolbox/aliases/h5file.m",
    "chars": 89,
    "preview": "function varargout = h5file(varargin)\n\tvarargout{1:nargout} = hdf5prop(varargin{:});\nend\n"
  },
  {
    "path": "leap/toolbox/aliases/imgsc.m",
    "chars": 360,
    "preview": "function h = imgsc(I, varargin)\n%IMGSC Alias for imagesc for images.\n% Usage:\n%   imgsc(I)\n%   imgsc(I, ...)\n%\n% See als"
  },
  {
    "path": "leap/toolbox/aliases/repext.m",
    "chars": 94,
    "preview": "function varargout = repext(varargin)\n\tvarargout{1:max(nargout,1)} = extrep(varargin{:});\nend\n"
  },
  {
    "path": "leap/toolbox/graphics/FEX-settingsdlg/settingsdlg.m",
    "chars": 18089,
    "preview": "function [settings, button] = settingsdlg(varargin)\n% SETTINGSDLG             Default dialog to produce a settings-struc"
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/+mixin/Container.m",
    "chars": 9443,
    "preview": "classdef Container < handle\n    %uix.mixin.Container  Container mixin\n    %\n    %  uix.mixin.Container is a mixin class "
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/+mixin/Flex.m",
    "chars": 2024,
    "preview": "classdef Flex < handle\n    %uix.mixin.Flex  Flex mixin\n    %\n    %  uix.mixin.Flex is a mixin class used by flex contain"
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/+mixin/Panel.m",
    "chars": 6663,
    "preview": "classdef Panel < uix.mixin.Container\n    %uix.mixin.Panel  Panel mixin\n    %\n    %  uix.mixin.Panel is a mixin class use"
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/Box.m",
    "chars": 1291,
    "preview": "classdef Box < uix.Container & uix.mixin.Container\n    %uix.Box  Box and grid base class\n    %\n    %  uix.Box is a base "
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/ChildEvent.m",
    "chars": 778,
    "preview": "classdef( Hidden, Sealed ) ChildEvent < event.EventData\n    %uix.ChildEvent  Event data for child event\n    %\n    %  e ="
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/ChildObserver.m",
    "chars": 8831,
    "preview": "classdef ( Hidden, Sealed ) ChildObserver < handle\n    %uix.ChildObserver  Child observer\n    %\n    %  co = uix.ChildObs"
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/Container.m",
    "chars": 334,
    "preview": "classdef Container < matlab.ui.container.internal.UIContainer\n    %uix.Container  Container base class\n    %\n    %  uix."
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/Divider.m",
    "chars": 10925,
    "preview": "classdef Divider < matlab.mixin.SetGet\n    %uix.Divider  Draggable divider\n    %\n    %  d = uix.Divider() creates a divi"
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/FigureData.m",
    "chars": 781,
    "preview": "classdef ( Hidden, Sealed ) FigureData < event.EventData\n    %uix.FigureData  Event data for FigureChanged on uix.Figure"
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/FigureObserver.m",
    "chars": 3091,
    "preview": "classdef ( Hidden, Sealed ) FigureObserver < handle\n    %uix.FigureObserver  Figure observer\n    %\n    %  A figure obser"
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/Grid.m",
    "chars": 11948,
    "preview": "classdef Grid < uix.Box\n    %uix.Grid  Grid\n    %\n    %  b = uix.Grid(p1,v1,p2,v2,...) constructs a grid and sets parame"
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/GridFlex.m",
    "chars": 19877,
    "preview": "classdef GridFlex < uix.Grid & uix.mixin.Flex\n    %uix.GridFlex  Flexible grid\n    %\n    %  b = uix.GridFlex(p1,v1,p2,v2"
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/HBox.m",
    "chars": 6649,
    "preview": "classdef HBox < uix.Box\n    %uix.HBox  Horizontal box\n    %\n    %  b = uix.HBox(p1,v1,p2,v2,...) constructs a horizontal"
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/Node.m",
    "chars": 2881,
    "preview": "classdef ( Hidden ) Node < dynamicprops\n    %uix.Node  Node\n    %\n    %  n = uix.Node(o) creates a node for the handle o"
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/Panel.m",
    "chars": 2095,
    "preview": "classdef Panel < matlab.ui.container.Panel & uix.mixin.Panel\n    %uix.Panel  Standard panel\n    %\n    %  b = uix.Panel(p"
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/PointerManager.m",
    "chars": 4743,
    "preview": "classdef ( Hidden, Sealed ) PointerManager < handle\n    %uix.PointerManager  Pointer manager\n    \n    %  Copyright 2016 "
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/SelectionData.m",
    "chars": 954,
    "preview": "classdef( Hidden, Sealed ) SelectionData < event.EventData\n    %uix.SelectionData  Event data for selection event\n    %\n"
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/VBox.m",
    "chars": 6681,
    "preview": "classdef VBox < uix.Box\n    %uix.VBox  Vertical box\n    %\n    %  b = uix.VBox(p1,v1,p2,v2,...) constructs a vertical box"
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/calcPixelSizes.m",
    "chars": 1574,
    "preview": "function pSizes = calcPixelSizes( pTotal, mSizes, pMinima, pPadding, pSpacing )\n%calcPixelSizes  Calculate child sizes i"
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/setPosition.m",
    "chars": 877,
    "preview": "function setPosition( o, p, u )\n%setPosition  Set position of graphics object\n%\n%  setPosition(o,p,u) sets the position "
  },
  {
    "path": "leap/toolbox/graphics/GUI Layout Toolbox/layout/+uix/tracking.m",
    "chars": 5760,
    "preview": "function varargout = tracking( varargin )\n%tracking  Track anonymized usage data\n%\n%  tracking(p,v,id) tracks usage to t"
  },
  {
    "path": "leap/toolbox/graphics/distributionPlot/colorCode2rgb.m",
    "chars": 939,
    "preview": "function rgbVec = colorCode2rgb(c)\n%COLORCODE2RGB converts a color code to an rgb vector\n%\n\n% SYNOPSIS rgbVec = colorCod"
  },
  {
    "path": "leap/toolbox/graphics/draggable/draggable.m",
    "chars": 21955,
    "preview": "function draggable(h,varargin)\n% DRAGGABLE - Make it so that a graphics object can be dragged in a figure.\n%   This func"
  },
  {
    "path": "leap/toolbox/graphics/figclosekey.m",
    "chars": 624,
    "preview": "function figclosekey(h, key)\n%FIGCLOSEKEY Add a hotkey for closing the figure.\n% Usage:\n%   figclosekey(h, key)\n% \n% Arg"
  },
  {
    "path": "leap/toolbox/graphics/figsize.m",
    "chars": 1679,
    "preview": "function sz = figsize(h, width, height)\n%FIGSIZE Resizes the specified figure while keeping it on screen.\n% Usage:\n%   f"
  },
  {
    "path": "leap/toolbox/graphics/fontsize.m",
    "chars": 285,
    "preview": "function fontsize(size, h)\n%FONTSIZE Sets the fontsize across a graphics object.\n% Usage:\n%   fontsize(size) % default: "
  },
  {
    "path": "leap/toolbox/graphics/hline.m",
    "chars": 807,
    "preview": "function h = hline(y, varargin)\n%HLINE Easy plotting of a horizonal line.\n% Usage:\n%   hline\n%   hline(y)\n%   hline(ax, "
  },
  {
    "path": "leap/toolbox/graphics/isax.m",
    "chars": 187,
    "preview": "function TF = isax(h)\n%ISAX Check if input is an axes handle.\n% Usage:\n%   TF = isax(h)\n% \n% Args:\n%   h: \n% \n% See also"
  },
  {
    "path": "leap/toolbox/graphics/isfig.m",
    "chars": 320,
    "preview": "function TF = isfig(h)\n%ISFIG Checks whether the handle(s) specified are existing figures.\n% Usage:\n%   TF = isfig(h)\n%\n"
  },
  {
    "path": "leap/toolbox/graphics/noticks.m",
    "chars": 852,
    "preview": "function noticks(ax, whichAxes)\n%NOTICKS Hides ticks in the axes.\n% Usage:\n%   noticks\n%   noticks('x') % hides only x-t"
  },
  {
    "path": "leap/toolbox/graphics/pareto2.m",
    "chars": 598,
    "preview": "function [ax, b, p] = pareto2(X, leftYLabel, rightYLabel)\n%PARETO2 Prettier pareto function.\n% Usage:\n%   pareto2(X)\n%  "
  },
  {
    "path": "leap/toolbox/graphics/plotExplainedVar.m",
    "chars": 677,
    "preview": "function [ax, b, p] = plotExplainedVar(explained)\n%PLOTEXPLAINEDVAR Pretty plot PCA explained variance.\n% Usage:\n%   plo"
  },
  {
    "path": "leap/toolbox/graphics/plotpts.m",
    "chars": 482,
    "preview": "function h = plotpts(pts, varargin)\n%PLOTPTS Convenience wrapper for plotting scatters. Same syntax as plot().\n% Usage:\n"
  },
  {
    "path": "leap/toolbox/graphics/redblue.m",
    "chars": 1044,
    "preview": "function map = redblue(N, dark)\n%REDBLUE Red and blue colormap going from blue to white to red.\n% Usage:\n%   map = redbl"
  },
  {
    "path": "leap/toolbox/graphics/sc/gray.m",
    "chars": 1016,
    "preview": "%GRAY  Black-white colormap\n%\n% Examples:\n%   map = gray;\n%   map = gray(len);\n%   B = gray(A);\n%   B = gray(A, lims);\n%"
  },
  {
    "path": "leap/toolbox/graphics/sc/private/colormap_helper.m",
    "chars": 1693,
    "preview": "function map = colormap_helper(map, len, lims)\n%COLORMAP_HELPER  Helper function for colormaps\n%\n% Examples:\n%   map = c"
  },
  {
    "path": "leap/toolbox/graphics/sc/private/rescale.m",
    "chars": 1632,
    "preview": "function [B, lims] = rescale(A, lims, out_lims)\n%RESCALE  Linearly rescale values in an array\n%\n% Examples:\n%   B = resc"
  },
  {
    "path": "leap/toolbox/graphics/sc/real2rgb.m",
    "chars": 4936,
    "preview": "%REAL2RGB  Converts a real-valued matrix into a truecolor image\n%\n% Examples:\n%   B = real2rgb(A, cmap);\n%   B = real2rg"
  },
  {
    "path": "leap/toolbox/graphics/shortticks.m",
    "chars": 226,
    "preview": "function shortticks(ax)\n%SHORTTICKS Make the axis ticks short.\n% Usage:\n%   shortticks\n%   shortticks(ax)\n% \n% See also:"
  },
  {
    "path": "leap/toolbox/hdf5/h5att2struct.m",
    "chars": 806,
    "preview": "function S = h5att2struct(filename, location)\n%H5ATT2STRUCT Reads a set of HDF5 attributes into a named structure.\n% Usa"
  },
  {
    "path": "leap/toolbox/hdf5/h5getdatasets.m",
    "chars": 1019,
    "preview": "function datasets = h5getdatasets(filepath, grp, recurse)\n%H5GETDATASETS Returns a list of all datasets in an HDF5 file."
  },
  {
    "path": "leap/toolbox/hdf5/h5readframes.m",
    "chars": 1954,
    "preview": "function [frames, numFrames] = h5readframes(filepath, dataset, idx)\n%H5READFRAMES Reads video frames from an HDF5 file.\n"
  },
  {
    "path": "leap/toolbox/hdf5/h5readgroup.m",
    "chars": 1248,
    "preview": "function data = h5readgroup(filepath, group)\n%H5READGROUP Reads a group into a structure containing all the data in the "
  },
  {
    "path": "leap/toolbox/hdf5/h5save.m",
    "chars": 2428,
    "preview": "function h5save(filepath, X, dset, varargin)\n%H5SAVE Create and save a variable to an HDF5 file.\n% Usage:\n%   h5save(fil"
  },
  {
    "path": "leap/toolbox/hdf5/h5savegroup.m",
    "chars": 1067,
    "preview": "function h5savegroup(filepath, S, grp, varargin)\n%H5SAVEGROUP Saves a struct as a HDF5 group.\n% Usage:\n%   h5savegroup(f"
  },
  {
    "path": "leap/toolbox/hdf5/h5size.m",
    "chars": 561,
    "preview": "function [sz, maxSize] = h5size(filepath, dataset, dim)\n%H5SIZE Returns the size of the specified dataset.\n% Usage:\n%   "
  },
  {
    "path": "leap/toolbox/hdf5/h5struct2att.m",
    "chars": 726,
    "preview": "function h5struct2att(filepath, location, S)\n%H5STRUCT2ATT Writes attributes to an HDF5 file from a scalar structure.\n% "
  },
  {
    "path": "leap/toolbox/hdf5/hdf5prop/h5datacreate.m",
    "chars": 7610,
    "preview": "function datasetId = h5datacreate(h5file,varname,varargin)\n%H5datacreate  Create HDF5 dataset.\n%\n%   H5DATACREATE(HFILE,"
  },
  {
    "path": "leap/toolbox/hdf5/hdf5prop/hdf5prop.m",
    "chars": 15122,
    "preview": "classdef hdf5prop < handle\n% HDF5PROP Class for transparent file data access.\n% Matlab class to create and access HDF5 d"
  },
  {
    "path": "leap/toolbox/imageproc/ind2im.m",
    "chars": 1307,
    "preview": "function I = ind2im(ind, sz, vals, fillval)\n%IND2IM Create image from a set of linear indices.\n% Usage:\n%   I = ind2im(i"
  },
  {
    "path": "leap/toolbox/inputParsing/get_caller_name.m",
    "chars": 1045,
    "preview": "function caller = get_caller_name(varargin)\n%GET_CALLER_NAME Returns the name of the caller function.\n% Usage:\n%   calle"
  },
  {
    "path": "leap/toolbox/inputParsing/nameval2struct.m",
    "chars": 709,
    "preview": "function S = nameval2struct(C)\n%NAMEVAL2STRUCT Converts a cell array of name-value pairs to a struct.\n% Usage:\n%   struc"
  },
  {
    "path": "leap/toolbox/inputParsing/parse_params.m",
    "chars": 1890,
    "preview": "function [results, unmatched] = parse_params(args, defaults, varargin)\n%PARSE_PARAMS Parses a set of name-value pairs.\n%"
  },
  {
    "path": "leap/toolbox/inputParsing/struct2nameval.m",
    "chars": 392,
    "preview": "function C = struct2nameval(S)\n%STRUCT2NAMEVAL Converts a structure to a cell array of name-value pairs.\n% Usage:\n%   na"
  },
  {
    "path": "leap/toolbox/io/GetFullPath.m",
    "chars": 11855,
    "preview": "function File = GetFullPath(File, Style)\n% GetFullPath - Get absolute canonical path of a file or folder\n% Absolute path"
  },
  {
    "path": "leap/toolbox/io/dir_ext.m",
    "chars": 866,
    "preview": "function matches = dir_ext(path, extensions, return_paths)\n%DIR_EXT Returns files in a directory with the matching exten"
  },
  {
    "path": "leap/toolbox/io/dir_paths.m",
    "chars": 1318,
    "preview": "function [paths, base_path] = dir_paths(path, type)\n%DIR_PATHS Returns the full paths of a directory listing.\n% Usage:\n%"
  },
  {
    "path": "leap/toolbox/io/dir_regex.m",
    "chars": 678,
    "preview": "function matches = dir_regex(path, expression, return_paths)\n%DIR_REGEX Returns contents in the path matching a regular "
  },
  {
    "path": "leap/toolbox/io/exists.m",
    "chars": 1497,
    "preview": "function TF = exists(path, dir_only)\n%EXISTS Returns true if the specified path exists in the filesystem.\n% Usage:\n%   T"
  },
  {
    "path": "leap/toolbox/io/ext2filter_spec.m",
    "chars": 563,
    "preview": "function filter_spec = ext2filter_spec(exts)\n%EXT2FILTER_SPEC Generates a filter specification from a list of file exten"
  },
  {
    "path": "leap/toolbox/io/extrep.m",
    "chars": 469,
    "preview": "function new_path = extrep(filepath, new_ext)\n%EXTREP Replace the extension of a file path.\n% Usage:\n%   newpath = extre"
  },
  {
    "path": "leap/toolbox/io/funpath.m",
    "chars": 402,
    "preview": "function path = funpath(~)\n%FUNPATH Returns the path to the calling function.\n% Usage:\n%   path = funpath()\n%   path = f"
  },
  {
    "path": "leap/toolbox/io/get_ext.m",
    "chars": 637,
    "preview": "function ext = get_ext(path, no_dot)\n%GET_EXT Returns the extension in a path. The path need not exist.\n% Usage:\n%   ext"
  },
  {
    "path": "leap/toolbox/io/get_filename.m",
    "chars": 576,
    "preview": "function filename = get_filename(path, no_ext)\n%GET_FILENAME Returns the filename in a path. The path need not exist.\n% "
  },
  {
    "path": "leap/toolbox/io/get_filesize.m",
    "chars": 387,
    "preview": "function bytes = get_filesize(file_path)\n%GET_FILESIZE Returns the size of the specified file in bytes.\n% This is a wrap"
  },
  {
    "path": "leap/toolbox/io/get_new_filename.m",
    "chars": 701,
    "preview": "function new_filename = get_new_filename(filename, noSpaces)\n%GET_NEW_FILENAME Returns a new filename based on the one s"
  },
  {
    "path": "leap/toolbox/io/lastdir.m",
    "chars": 1764,
    "preview": "function dir_path = lastdir(new_path)\n%LASTDIR Remembers the last directory used for use in UI browsing dialogs.\n% Usage"
  },
  {
    "path": "leap/toolbox/io/mkdirto.m",
    "chars": 414,
    "preview": "function TF = mkdirto(path)\n%MKDIRTO Quietly makes all directories to path that do not currently exist.\n% Usage:\n%   mkd"
  },
  {
    "path": "leap/toolbox/io/uibrowse.m",
    "chars": 2503,
    "preview": "function [path, filter_idx]  = uibrowse(filter_spec, start_path, dialog_title, type)\n%UIBROWSE Displays a file or folder"
  },
  {
    "path": "leap/toolbox/ml/ezpca.m",
    "chars": 972,
    "preview": "function pcs = ezpca(X, varargin)\n%EZPCA PCA -- quick and easy!\n% Usage:\n%   pcs = ezpca(X)\n% \n% Args:\n%   X: N x D data"
  },
  {
    "path": "leap/toolbox/strings/bytes2str.m",
    "chars": 659,
    "preview": "function [str, x_bytes, unit] = bytes2str(bytes, precision)\n%BYTES2STR Returns the number of bytes in a more readable fo"
  },
  {
    "path": "leap/toolbox/strings/instr.m",
    "chars": 2975,
    "preview": "function TF = instr(needle, haystack, flags)\n%INSTR Returns true if (any) needle is in (any) haystack.\n% Usage:\n%   TF ="
  },
  {
    "path": "leap/toolbox/strings/printf.m",
    "chars": 1800,
    "preview": "function formatted = printf(str, varargin)\n%PRINTF Prints formatted output.\n% Usage:\n%   printf(str, ...)\n%   formatted "
  },
  {
    "path": "leap/toolbox/strings/secs2hms.m",
    "chars": 293,
    "preview": "function [h, m, s] = secs2hms(numSecs)\n%SECS2HMS Converts a number of seconds to hours, minutes and fractional seconds.\n"
  },
  {
    "path": "leap/toolbox/strings/secsf.m",
    "chars": 958,
    "preview": "function str = secsf(format, numSecs)\n%SECSF Yet another seconds formatting function.\n% Usage:\n%   str = secsf(format, n"
  },
  {
    "path": "leap/toolbox/utilities/af.m",
    "chars": 256,
    "preview": "function out = af(func,varargin)\n%AF Convenience wrapper for arrayfun with non-uniform output.\n% Usage:\n%   out = af(fun"
  },
  {
    "path": "leap/toolbox/utilities/arange.m",
    "chars": 296,
    "preview": "function [min_val, max_val] = arange(X)\n%ARANGE Returns the range (min and max) of an entire array.\n% Usage:\n%   R = ara"
  },
  {
    "path": "leap/toolbox/utilities/areempty.m",
    "chars": 247,
    "preview": "function empties = areempty(cellarr)\n%AREEMPTY Returns a logical array of the size of the cell array indicating whether "
  },
  {
    "path": "leap/toolbox/utilities/argmin.m",
    "chars": 255,
    "preview": "function idx = argmin(X, dim)\n%ARGMAX Returns the index at which the min is found.\n% Usage:\n%   idx = argmin(X)\n%   idx "
  },
  {
    "path": "leap/toolbox/utilities/arr2cell.m",
    "chars": 384,
    "preview": "function C = arr2cell(X, dim)\n%ARR2CELL Splits an array into a cell across the specified dimension.\n% Usage:\n%   C = arr"
  },
  {
    "path": "leap/toolbox/utilities/cell1.m",
    "chars": 569,
    "preview": "function varargout = cell1(sz, dim)\n%CELL1 Creates a 1d empty cell array. Convenience for cell(sz, 1).\n% Usage:\n%   C = "
  },
  {
    "path": "leap/toolbox/utilities/cellcat.m",
    "chars": 865,
    "preview": "function [X,idx] = cellcat(C, dim)\n%CELLCAT Unpacks and concatenates a cell array. Shorthand for: cat(dim, C{:})\n% Usage"
  },
  {
    "path": "leap/toolbox/utilities/cf.m",
    "chars": 253,
    "preview": "function out = cf(func,varargin)\n%CF Convenience wrapper for cellfun with non-uniform output.\n% Usage:\n%   out = cf(func"
  },
  {
    "path": "leap/toolbox/utilities/clip.m",
    "chars": 415,
    "preview": "function X2 = clip(X, bounds, varargin)\n%CLIP Clip values in an array to lower and upper bounds.\n% Usage:\n%   X2 = clip("
  },
  {
    "path": "leap/toolbox/utilities/functional_programming/wrap.m",
    "chars": 874,
    "preview": "function out = wrap(f, out_indices)\n\n% out = wrap(f, out_indices)\n% \n% If you've ever needed multiple outputs from a fun"
  },
  {
    "path": "leap/toolbox/utilities/get_new_string.m",
    "chars": 629,
    "preview": "function new_str = get_new_string(str, strings, format)\n%GET_NEW_STRING Returns a new string that does not exist in a pr"
  },
  {
    "path": "leap/toolbox/utilities/grp2cell.m",
    "chars": 1206,
    "preview": "function [C, grps, G] = grp2cell(X, G, dim)\n%GRP2CELL Splits an array into a cell using a grouping variable.\n% Usage:\n% "
  },
  {
    "path": "leap/toolbox/utilities/horz.m",
    "chars": 138,
    "preview": "function V = horz(X)\n%HORZ Returns the array as a horizontal vector.\n% Usage:\n%   V = horz(X) % size(V) = [1, numel(X)]\n"
  },
  {
    "path": "leap/toolbox/utilities/iseven.m",
    "chars": 182,
    "preview": "function TF = iseven(X)\n%ISEVEN Returns true if the input is divisible by 2, false otherwise.\n% Usage:\n%   TF = iseven(X"
  },
  {
    "path": "leap/toolbox/utilities/loadvar.m",
    "chars": 1109,
    "preview": "function varargout = loadvar(mat_file, var_name, varargin)\n%LOADVAR Loads one or more variables from a MAT file.\n% Usage"
  },
  {
    "path": "leap/toolbox/utilities/nunique.m",
    "chars": 161,
    "preview": "function N = nunique(X)\n%NUNIQUE Returns the number of unique elements in an array.\n% Usage:\n%   N = nunique(X)\n%\n% See "
  },
  {
    "path": "leap/toolbox/utilities/rownorm.m",
    "chars": 336,
    "preview": "function n = rownorm(X, p)\n%ROWNORM Returns the row-wise norm of a matrix X.\n% Usage:\n%   n = rownorm(X)\n%   n = rownorm"
  },
  {
    "path": "leap/toolbox/utilities/stacks/imtile.m",
    "chars": 936,
    "preview": "function T = imtile(I, varargin)\n%IMTILE Tiles a stack or set of images.\n% Usage:\n%   T = imtile(stack)\n%   T = imtile(I"
  },
  {
    "path": "leap/toolbox/utilities/stacks/stack2cell.m",
    "chars": 244,
    "preview": "function C = stack2cell(S)\n%STACK2CELL Returns a cell array with one slice of the stack in each cell.\n%   C = stack2cell"
  },
  {
    "path": "leap/toolbox/utilities/stacks/stack2vecs.m",
    "chars": 335,
    "preview": "function X = stack2vecs(S)\n%STACK2VECS Convert stack to observation-by-features matrix.\n% Usage:\n%   X = stack2vecs(S)\n%"
  },
  {
    "path": "leap/toolbox/utilities/stacks/vecs2stack.m",
    "chars": 384,
    "preview": "function S = vecs2stack(X,sz)\n%VECS2STACK Convert observation-by-features matrix to stack.\n% Usage:\n%   X = stack2vecs(S"
  },
  {
    "path": "leap/toolbox/utilities/swap.m",
    "chars": 121,
    "preview": "function [B, A] = swap(A, B)\n%SWAP Swap two variables without an intermediate copy.\n% Usage:\n%   [B, A] = swap(A, B)\n\nen"
  },
  {
    "path": "leap/toolbox/utilities/time/GetSystemTimePreciseAsFileTime.m",
    "chars": 557,
    "preview": "%GETSYSTEMTIMEPRECISEASFILETIME Returns system time with high precision.\n% Usage:\n%   t = GetSystemTimePreciseAsFileTime"
  },
  {
    "path": "leap/toolbox/utilities/time/stic.m",
    "chars": 273,
    "preview": "function timer_id = stic\n%STIC TIC equivalent using precise system time.\n% Usage:\n%   stic\n%   timer_id = tic\n%\n% See al"
  },
  {
    "path": "leap/toolbox/utilities/time/stoc.m",
    "chars": 709,
    "preview": "function [dt, timer_id] = stoc(timer_id)\n%STIC TOC equivalent using precise system time.\n% Usage:\n%   stoc\n%   stoc(time"
  },
  {
    "path": "leap/toolbox/utilities/time/stocf.m",
    "chars": 1044,
    "preview": "function [dt, timer_id] = stocf(timer_id, str, varargin)\n%STOCF Report elapsed time with print formatting.\n% Usage:\n%   "
  },
  {
    "path": "leap/toolbox/utilities/time/systime.m",
    "chars": 488,
    "preview": "function secs = systime\n%SYSTIME Returns precise system time in seconds.\n% Usage:\n%   secs = systime\n%\n% Note: This is j"
  },
  {
    "path": "leap/toolbox/utilities/varsize.m",
    "chars": 833,
    "preview": "function size = varsize(X, units)\n%VARSIZE Returns the size of a variable in bytes.\n% Usage:\n%   size = varsize(X)\n%   s"
  },
  {
    "path": "leap/toolbox/utilities/varstruct.m",
    "chars": 2089,
    "preview": "function S = varstruct(var1, var2, varargin)\n%VARSTRUCT Creates a structure out of a set of variables. Fieldnames are in"
  },
  {
    "path": "leap/toolbox/utilities/vert.m",
    "chars": 141,
    "preview": "function V = vert(X)\n%VERT Returns the matrix as a vertical vector.\n% Usage:\n%   V = vert(X)\n%\n% See also: mat2vec, resh"
  },
  {
    "path": "leap/toolbox/utilities/vplay.m",
    "chars": 328,
    "preview": "function vp = vplay(mov, varargin)\n%VPLAY Play a movie using the vplayer class.\n% Usage:\n%   vplay(mov)\n%   vplay(mov, ."
  },
  {
    "path": "leap/toolbox/video/validate_stack.m",
    "chars": 838,
    "preview": "function [images, numFrames] = validate_stack(images, no_error)\n%VALIDATE_STACK Validates a stack of images and converts"
  },
  {
    "path": "leap/toolbox/video/vplayer.m",
    "chars": 25560,
    "preview": "classdef vplayer < matlab.mixin.SetGet\n    %VPLAYER Image video/stack player class.\n    % Usage:\n    %   vplayer(S)\n    "
  },
  {
    "path": "leap/training.py",
    "chars": 11480,
    "preview": "import numpy as np\nimport h5py\nimport os\nfrom time import time\nfrom scipy.io import loadmat, savemat\nimport re\nimport sh"
  },
  {
    "path": "leap/utils.py",
    "chars": 2425,
    "preview": "import os\nimport numpy as np\nimport re\nfrom time import time\nimport h5py\n\ndef versions(list_devices=False):\n    \"\"\" Prin"
  },
  {
    "path": "leap/viz.py",
    "chars": 4553,
    "preview": "import numpy as np\nimport matplotlib.pyplot as plt\nplt.switch_backend('agg')\n\n\ndef show_pred(net, X, Y, joint_idx=0, alp"
  },
  {
    "path": "readme.md",
    "chars": 6130,
    "preview": "# SLEAP: Social LEAP Estimates Animal Poses\n![Social LEAP Estimates Animal Poses](https://raw.githubusercontent.com/talm"
  },
  {
    "path": "setup.py",
    "chars": 319,
    "preview": "from setuptools import setup\n\nsetup(\n    name=\"leap\",\n    version=\"0.0.1\",\n    author=\"Talmo Pereira\",\n    author_email="
  },
  {
    "path": "uninstall_leap.m",
    "chars": 517,
    "preview": "function uninstall_leap()\n%UNINSTALL_LEAP Removes LEAP code from the MATLAB path and Python environment.\n% Usage:\n%   un"
  }
]

// ... and 11 more files (download for full content)

About this extraction

This page contains the full source code of the talmo/leap GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 157 files (36.8 MB), approximately 121.1k tokens, and a symbol index with 45 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!