Full Code of bioFAM/MOFA2 for AI

master 207dcb69c267 cached
143 files
721.6 KB
220.9k tokens
1 symbols
1 requests
Download .txt
Showing preview only (763K chars total). Download the full file or copy to clipboard to get everything.
Repository: bioFAM/MOFA2
Branch: master
Commit: 207dcb69c267
Files: 143
Total size: 721.6 KB

Directory structure:
gitextract_72sduoul/

├── .Rbuildignore
├── .gitattributes
├── .gitignore
├── .gitmodules
├── DESCRIPTION
├── Dockerfile
├── LICENSE
├── NAMESPACE
├── R/
│   ├── AllClasses.R
│   ├── AllGenerics.R
│   ├── QC.R
│   ├── basilisk.R
│   ├── calculate_variance_explained.R
│   ├── cluster_samples.R
│   ├── compare_models.R
│   ├── contribution_scores.R
│   ├── correlate_covariates.R
│   ├── create_mofa.R
│   ├── dimensionality_reduction.R
│   ├── enrichment.R
│   ├── get_methods.R
│   ├── imports.R
│   ├── impute.R
│   ├── load_model.R
│   ├── make_example_data.R
│   ├── mefisto.R
│   ├── plot_data.R
│   ├── plot_factors.R
│   ├── plot_weights.R
│   ├── predict.R
│   ├── prepare_mofa.R
│   ├── run_mofa.R
│   ├── set_methods.R
│   ├── subset.R
│   └── utils.R
├── README.md
├── configure
├── configure.win
├── inst/
│   ├── CITATION
│   ├── extdata/
│   │   └── test_data.RData
│   └── scripts/
│       ├── template_script.R
│       ├── template_script.py
│       ├── template_script_dataframe.py
│       └── template_script_matrix.py
├── man/
│   ├── .Rapp.history
│   ├── MOFA.Rd
│   ├── add_mofa_factors_to_seurat.Rd
│   ├── calculate_contribution_scores.Rd
│   ├── calculate_variance_explained.Rd
│   ├── calculate_variance_explained_per_sample.Rd
│   ├── cluster_samples.Rd
│   ├── compare_elbo.Rd
│   ├── compare_factors.Rd
│   ├── correlate_factors_with_covariates.Rd
│   ├── covariates_names.Rd
│   ├── create_mofa.Rd
│   ├── create_mofa_from_MultiAssayExperiment.Rd
│   ├── create_mofa_from_Seurat.Rd
│   ├── create_mofa_from_SingleCellExperiment.Rd
│   ├── create_mofa_from_df.Rd
│   ├── create_mofa_from_matrix.Rd
│   ├── factors_names.Rd
│   ├── features_metadata.Rd
│   ├── features_names.Rd
│   ├── get_covariates.Rd
│   ├── get_data.Rd
│   ├── get_default_data_options.Rd
│   ├── get_default_mefisto_options.Rd
│   ├── get_default_model_options.Rd
│   ├── get_default_stochastic_options.Rd
│   ├── get_default_training_options.Rd
│   ├── get_dimensions.Rd
│   ├── get_elbo.Rd
│   ├── get_expectations.Rd
│   ├── get_factors.Rd
│   ├── get_group_kernel.Rd
│   ├── get_imputed_data.Rd
│   ├── get_interpolated_factors.Rd
│   ├── get_lengthscales.Rd
│   ├── get_scales.Rd
│   ├── get_variance_explained.Rd
│   ├── get_weights.Rd
│   ├── groups_names.Rd
│   ├── impute.Rd
│   ├── interpolate_factors.Rd
│   ├── load_model.Rd
│   ├── make_example_data.Rd
│   ├── pipe.Rd
│   ├── plot_alignment.Rd
│   ├── plot_ascii_data.Rd
│   ├── plot_data_heatmap.Rd
│   ├── plot_data_overview.Rd
│   ├── plot_data_scatter.Rd
│   ├── plot_data_vs_cov.Rd
│   ├── plot_dimred.Rd
│   ├── plot_enrichment.Rd
│   ├── plot_enrichment_detailed.Rd
│   ├── plot_enrichment_heatmap.Rd
│   ├── plot_factor.Rd
│   ├── plot_factor_cor.Rd
│   ├── plot_factors.Rd
│   ├── plot_factors_vs_cov.Rd
│   ├── plot_group_kernel.Rd
│   ├── plot_interpolation_vs_covariate.Rd
│   ├── plot_sharedness.Rd
│   ├── plot_smoothness.Rd
│   ├── plot_top_weights.Rd
│   ├── plot_variance_explained.Rd
│   ├── plot_variance_explained_by_covariates.Rd
│   ├── plot_variance_explained_per_feature.Rd
│   ├── plot_weights.Rd
│   ├── plot_weights_heatmap.Rd
│   ├── plot_weights_scatter.Rd
│   ├── predict.Rd
│   ├── prepare_mofa.Rd
│   ├── run_enrichment.Rd
│   ├── run_mofa.Rd
│   ├── run_tsne.Rd
│   ├── run_umap.Rd
│   ├── samples_metadata.Rd
│   ├── samples_names.Rd
│   ├── select_model.Rd
│   ├── set_covariates.Rd
│   ├── subset_factors.Rd
│   ├── subset_features.Rd
│   ├── subset_groups.Rd
│   ├── subset_samples.Rd
│   ├── subset_views.Rd
│   ├── summarise_factors.Rd
│   └── views_names.Rd
├── setup.py
├── tests/
│   ├── testthat/
│   │   ├── barcodes.tsv
│   │   ├── genes.tsv
│   │   ├── matrix.csv
│   │   ├── matrix.mtx
│   │   ├── test_create_model.R
│   │   ├── test_load_model.R
│   │   ├── test_plot.R
│   │   └── test_prepare_model.R
│   └── testthat.R
└── vignettes/
    ├── MEFISTO_temporal.Rmd
    ├── downstream_analysis.Rmd
    └── getting_started_R.Rmd

================================================
FILE CONTENTS
================================================

================================================
FILE: .Rbuildignore
================================================
^.*\.Rproj$
^\.Rproj\.user$
mofapy2
Dockerfile
setup.py


================================================
FILE: .gitattributes
================================================
*.sh text eol=lf


================================================
FILE: .gitignore
================================================
# Resilio Sync
.sync

# MAC
*Icon*
.DS_Store

# Rstudio projects
*.Rproj
.Rhistory

*_site/
# Pycharm
.idea

# HTML
# *.html

# Models outputs
*.hdf5

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.pyc
*.egg-info/
.installed.cfg
*.egg
.Rproj.user
*.Rcheck
.Rproj.user
.Rhistory
.RData
*.ipynb_checkpoints

*_cache/
*_files/

*.tar.gz

================================================
FILE: .gitmodules
================================================
[submodule "mofapy2"]
	path = mofapy2
	url = git@github.com:bioFAM/mofapy2


================================================
FILE: DESCRIPTION
================================================
Package: MOFA2
Type: Package
Title: Multi-Omics Factor Analysis v2
Version: 1.21.3
Maintainer: Ricard Argelaguet <ricard.argelaguet@gmail.com>
Authors@R: c(person("Ricard", "Argelaguet", role = c("aut", "cre"),
                     email = "ricard.argelaguet@gmail.com",
                     comment = c(ORCID = "http://orcid.org/0000-0003-3199-3722")),
              person("Damien", "Arnol", role = "aut",
                     email = "damien.arnol@gmail.com",
                     comment = c(ORCID = "http://orcid.org/0000-0003-2462-534X")),
              person("Danila", "Bredikhin", role = "aut",
                     email = "danila.bredikhin@embl.de",
                     comment = c(ORCID = "https://orcid.org/0000-0001-8089-6983")),                     
              person("Britta", "Velten", role = "aut",
              		email = "britta.velten@gmail.com",
              		comment = c(ORCID = "http://orcid.org/0000-0002-8397-3515"))
              )
Date: 2023-01-12
License: file LICENSE
Description: The MOFA2 package contains a collection of tools for training and analysing multi-omic factor analysis (MOFA). MOFA is a probabilistic factor model that aims to identify principal axes of variation from data sets that can comprise multiple omic layers and/or groups of samples. Additional time or space information on the samples can be incorporated using the MEFISTO framework, which is part of MOFA2. Downstream analysis functions to inspect molecular features underlying each factor, visualisation, imputation etc are available.
Encoding: UTF-8
Depends: R (>= 4.0)
Imports: rhdf5, dplyr, tidyr, reshape2, pheatmap, ggplot2, methods, RColorBrewer, cowplot, ggrepel, reticulate, HDF5Array, grDevices, stats, magrittr, forcats, utils, corrplot, DelayedArray, Rtsne, uwot, basilisk, stringi
Suggests: knitr, testthat, Seurat, SeuratObject, ggpubr, foreach, psych, MultiAssayExperiment, SummarizedExperiment, SingleCellExperiment, ggrastr, mvtnorm, GGally, rmarkdown, data.table, tidyverse, BiocStyle, Matrix, markdown
biocViews: DimensionReduction, Bayesian, Visualization
URL: https://biofam.github.io/MOFA2/index.html
BugReports: https://github.com/bioFAM/MOFA2
VignetteBuilder: knitr
LazyData: false
StagedInstall: no
NeedsCompilation: yes
RoxygenNote: 7.3.3
SystemRequirements: Python (>=3), numpy, pandas, h5py, scipy, argparse, sklearn, mofapy2


================================================
FILE: Dockerfile
================================================
FROM r-base:4.0.2

WORKDIR /mofa2
ADD . /mofa2

RUN apt-get update && apt-get install -f && apt-get install -y python3 python3-setuptools python3-dev python3-pip
RUN apt-get install -y libcurl4-openssl-dev 
RUN apt-get install -y libcairo2-dev libfreetype6-dev libpng-dev libtiff5-dev libjpeg-dev libxt-dev libharfbuzz-dev libfribidi-dev

# Install mofapy2
RUN python3 -m pip install 'https://github.com/bioFAM/mofapy2/tarball/master'

# Install bioconductor dependencies
RUN R --vanilla -e "\
  if (!requireNamespace('BiocManager', quietly = TRUE)) install.packages('BiocManager', repos = 'https://cran.r-project.org'); \
  sapply(c('rhdf5', 'dplyr', 'tidyr', 'reshape2', 'pheatmap', 'corrplot', \
           'ggplot2', 'ggbeeswarm', 'scales', 'GGally', 'doParallel', 'RColorBrewer', \
           'cowplot', 'ggrepel', 'foreach', 'reticulate', 'HDF5Array', 'DelayedArray', \
           'ggpubr', 'forcats', 'Rtsne', 'uwot', \
           'systemfonts', 'ragg', 'Cairo', 'ggrastr', 'basilisk', 'mvtnorm'), \ 
         BiocManager::install)"
RUN R CMD INSTALL --build .

CMD []


================================================
FILE: LICENSE
================================================
                   GNU LESSER GENERAL PUBLIC LICENSE
                       Version 3, 29 June 2007

 Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
 Everyone is permitted to copy and distribute verbatim copies
 of this license document, but changing it is not allowed.


  This version of the GNU Lesser General Public License incorporates
the terms and conditions of version 3 of the GNU General Public
License, supplemented by the additional permissions listed below.

  0. Additional Definitions.

  As used herein, "this License" refers to version 3 of the GNU Lesser
General Public License, and the "GNU GPL" refers to version 3 of the GNU
General Public License.

  "The Library" refers to a covered work governed by this License,
other than an Application or a Combined Work as defined below.

  An "Application" is any work that makes use of an interface provided
by the Library, but which is not otherwise based on the Library.
Defining a subclass of a class defined by the Library is deemed a mode
of using an interface provided by the Library.

  A "Combined Work" is a work produced by combining or linking an
Application with the Library.  The particular version of the Library
with which the Combined Work was made is also called the "Linked
Version".

  The "Minimal Corresponding Source" for a Combined Work means the
Corresponding Source for the Combined Work, excluding any source code
for portions of the Combined Work that, considered in isolation, are
based on the Application, and not on the Linked Version.

  The "Corresponding Application Code" for a Combined Work means the
object code and/or source code for the Application, including any data
and utility programs needed for reproducing the Combined Work from the
Application, but excluding the System Libraries of the Combined Work.

  1. Exception to Section 3 of the GNU GPL.

  You may convey a covered work under sections 3 and 4 of this License
without being bound by section 3 of the GNU GPL.

  2. Conveying Modified Versions.

  If you modify a copy of the Library, and, in your modifications, a
facility refers to a function or data to be supplied by an Application
that uses the facility (other than as an argument passed when the
facility is invoked), then you may convey a copy of the modified
version:

   a) under this License, provided that you make a good faith effort to
   ensure that, in the event an Application does not supply the
   function or data, the facility still operates, and performs
   whatever part of its purpose remains meaningful, or

   b) under the GNU GPL, with none of the additional permissions of
   this License applicable to that copy.

  3. Object Code Incorporating Material from Library Header Files.

  The object code form of an Application may incorporate material from
a header file that is part of the Library.  You may convey such object
code under terms of your choice, provided that, if the incorporated
material is not limited to numerical parameters, data structure
layouts and accessors, or small macros, inline functions and templates
(ten or fewer lines in length), you do both of the following:

   a) Give prominent notice with each copy of the object code that the
   Library is used in it and that the Library and its use are
   covered by this License.

   b) Accompany the object code with a copy of the GNU GPL and this license
   document.

  4. Combined Works.

  You may convey a Combined Work under terms of your choice that,
taken together, effectively do not restrict modification of the
portions of the Library contained in the Combined Work and reverse
engineering for debugging such modifications, if you also do each of
the following:

   a) Give prominent notice with each copy of the Combined Work that
   the Library is used in it and that the Library and its use are
   covered by this License.

   b) Accompany the Combined Work with a copy of the GNU GPL and this license
   document.

   c) For a Combined Work that displays copyright notices during
   execution, include the copyright notice for the Library among
   these notices, as well as a reference directing the user to the
   copies of the GNU GPL and this license document.

   d) Do one of the following:

       0) Convey the Minimal Corresponding Source under the terms of this
       License, and the Corresponding Application Code in a form
       suitable for, and under terms that permit, the user to
       recombine or relink the Application with a modified version of
       the Linked Version to produce a modified Combined Work, in the
       manner specified by section 6 of the GNU GPL for conveying
       Corresponding Source.

       1) Use a suitable shared library mechanism for linking with the
       Library.  A suitable mechanism is one that (a) uses at run time
       a copy of the Library already present on the user's computer
       system, and (b) will operate properly with a modified version
       of the Library that is interface-compatible with the Linked
       Version.

   e) Provide Installation Information, but only if you would otherwise
   be required to provide such information under section 6 of the
   GNU GPL, and only to the extent that such information is
   necessary to install and execute a modified version of the
   Combined Work produced by recombining or relinking the
   Application with a modified version of the Linked Version. (If
   you use option 4d0, the Installation Information must accompany
   the Minimal Corresponding Source and Corresponding Application
   Code. If you use option 4d1, you must provide the Installation
   Information in the manner specified by section 6 of the GNU GPL
   for conveying Corresponding Source.)

  5. Combined Libraries.

  You may place library facilities that are a work based on the
Library side by side in a single library together with other library
facilities that are not Applications and are not covered by this
License, and convey such a combined library under terms of your
choice, if you do both of the following:

   a) Accompany the combined library with a copy of the same work based
   on the Library, uncombined with any other library facilities,
   conveyed under the terms of this License.

   b) Give prominent notice with the combined library that part of it
   is a work based on the Library, and explaining where to find the
   accompanying uncombined form of the same work.

  6. Revised Versions of the GNU Lesser General Public License.

  The Free Software Foundation may publish revised and/or new versions
of the GNU Lesser General Public License from time to time. Such new
versions will be similar in spirit to the present version, but may
differ in detail to address new problems or concerns.

  Each version is given a distinguishing version number. If the
Library as you received it specifies that a certain numbered version
of the GNU Lesser General Public License "or any later version"
applies to it, you have the option of following the terms and
conditions either of that published version or of any later version
published by the Free Software Foundation. If the Library as you
received it does not specify a version number of the GNU Lesser
General Public License, you may choose any version of the GNU Lesser
General Public License ever published by the Free Software Foundation.

  If the Library as you received it specifies that a proxy can decide
whether future versions of the GNU Lesser General Public License shall
apply, that proxy's public statement of acceptance of any version is
permanent authorization for you to choose that version for the
Library.


================================================
FILE: NAMESPACE
================================================
# Generated by roxygen2: do not edit by hand

export("%>%")
export("covariates_names<-")
export("factors_names<-")
export("features_metadata<-")
export("features_names<-")
export("groups_names<-")
export("samples_metadata<-")
export("samples_names<-")
export("views_names<-")
export(add_mofa_factors_to_seurat)
export(calculate_contribution_scores)
export(calculate_variance_explained)
export(calculate_variance_explained_per_sample)
export(cluster_samples)
export(compare_elbo)
export(compare_factors)
export(correlate_factors_with_covariates)
export(covariates_names)
export(create_mofa)
export(create_mofa_from_MultiAssayExperiment)
export(create_mofa_from_Seurat)
export(create_mofa_from_SingleCellExperiment)
export(create_mofa_from_df)
export(create_mofa_from_matrix)
export(factors_names)
export(features_metadata)
export(features_names)
export(get_covariates)
export(get_data)
export(get_default_data_options)
export(get_default_mefisto_options)
export(get_default_model_options)
export(get_default_stochastic_options)
export(get_default_training_options)
export(get_dimensions)
export(get_elbo)
export(get_expectations)
export(get_factors)
export(get_group_kernel)
export(get_imputed_data)
export(get_interpolated_factors)
export(get_lengthscales)
export(get_scales)
export(get_variance_explained)
export(get_weights)
export(groups_names)
export(impute)
export(interpolate_factors)
export(load_model)
export(make_example_data)
export(plot_alignment)
export(plot_ascii_data)
export(plot_data_heatmap)
export(plot_data_overview)
export(plot_data_scatter)
export(plot_data_vs_cov)
export(plot_dimred)
export(plot_enrichment)
export(plot_enrichment_detailed)
export(plot_enrichment_heatmap)
export(plot_factor)
export(plot_factor_cor)
export(plot_factors)
export(plot_factors_vs_cov)
export(plot_group_kernel)
export(plot_interpolation_vs_covariate)
export(plot_sharedness)
export(plot_smoothness)
export(plot_top_weights)
export(plot_variance_explained)
export(plot_variance_explained_by_covariates)
export(plot_variance_explained_per_feature)
export(plot_weights)
export(plot_weights_heatmap)
export(plot_weights_scatter)
export(predict)
export(prepare_mofa)
export(run_enrichment)
export(run_mofa)
export(run_tsne)
export(run_umap)
export(samples_metadata)
export(samples_names)
export(select_model)
export(set_covariates)
export(subset_factors)
export(subset_features)
export(subset_groups)
export(subset_samples)
export(subset_views)
export(summarise_factors)
export(views_names)
exportClasses(MOFA)
exportMethods("covariates_names<-")
exportMethods("factors_names<-")
exportMethods("features_metadata<-")
exportMethods("features_names<-")
exportMethods("groups_names<-")
exportMethods("samples_metadata<-")
exportMethods("samples_names<-")
exportMethods("views_names<-")
exportMethods(covariates_names)
exportMethods(factors_names)
exportMethods(features_metadata)
exportMethods(features_names)
exportMethods(groups_names)
exportMethods(samples_metadata)
exportMethods(samples_names)
exportMethods(views_names)
import(basilisk)
import(cowplot)
import(dplyr)
import(ggplot2)
import(grDevices)
import(methods)
import(pheatmap)
import(reshape2)
import(reticulate)
import(tidyr)
importFrom(DelayedArray,DelayedArray)
importFrom(HDF5Array,HDF5ArraySeed)
importFrom(RColorBrewer,brewer.pal)
importFrom(Rtsne,Rtsne)
importFrom(basilisk,BasiliskEnvironment)
importFrom(corrplot,corrplot)
importFrom(cowplot,plot_grid)
importFrom(dplyr,bind_rows)
importFrom(dplyr,desc)
importFrom(dplyr,filter)
importFrom(dplyr,group_by)
importFrom(dplyr,left_join)
importFrom(dplyr,mutate)
importFrom(dplyr,summarise)
importFrom(dplyr,top_n)
importFrom(forcats,fct_na_value_to_level)
importFrom(ggrepel,geom_text_repel)
importFrom(grDevices,colorRampPalette)
importFrom(magrittr,"%>%")
importFrom(magrittr,set_colnames)
importFrom(pheatmap,pheatmap)
importFrom(reshape2,melt)
importFrom(rhdf5,h5ls)
importFrom(rhdf5,h5read)
importFrom(stats,as.formula)
importFrom(stats,complete.cases)
importFrom(stats,cor)
importFrom(stats,dist)
importFrom(stats,kmeans)
importFrom(stats,median)
importFrom(stats,p.adjust)
importFrom(stats,p.adjust.methods)
importFrom(stats,pnorm)
importFrom(stats,pt)
importFrom(stats,quantile)
importFrom(stats,rbinom)
importFrom(stats,rnorm)
importFrom(stats,rpois)
importFrom(stats,sd)
importFrom(stats,var)
importFrom(stats,wilcox.test)
importFrom(stringi,stri_enc_mark)
importFrom(tidyr,gather)
importFrom(tidyr,spread)
importFrom(utils,as.relistable)
importFrom(utils,head)
importFrom(utils,modifyList)
importFrom(utils,relist)
importFrom(utils,tail)
importFrom(uwot,umap)


================================================
FILE: R/AllClasses.R
================================================

##########################################################
## Define a general class to store a MOFA trained model ##
##########################################################

#' @title Class to store a mofa model
#' @description
#' The \code{MOFA} is an S4 class used to store all relevant data to analyse a MOFA model
#' @slot data The input data
#' @slot intercepts Feature intercepts
#' @slot samples_metadata Samples metadata
#' @slot features_metadata Features metadata.
#' @slot imputed_data The imputed data.
#' @slot expectations expected values of the factors and the loadings.
#' @slot dim_red non-linear dimensionality reduction manifolds.
#' @slot training_stats model training statistics.
#' @slot data_options Data processing options.
#' @slot training_options Model training options.
#' @slot stochastic_options Stochastic variational inference options.
#' @slot model_options Model options.
#' @slot mefisto_options  Options for the use of MEFISO
#' @slot dimensions Dimensionalities of the model: 
#'    M for the number of views, 
#'    G for the number of groups,
#'    N for the number of samples (per group),
#'    C for the number of covariates per sample,
#'    D for the number of features (per view),
#'    K for the number of factors.
#' @slot on_disk Logical indicating whether data is loaded from disk.
#' @slot cache Cache.
#' @slot status Auxiliary variable indicating whether the model has been trained.
#' @slot covariates optional slot to store sample covariate for training in MEFISTO
#' @slot covariates_warped optional slot to store warped sample covariate for training in MEFISTO
#' @slot interpolated_Z optional slot to store interpolated factor values (used only with MEFISTO)
#' @name MOFA
#' @rdname MOFA
#' @aliases MOFA-class
#' @exportClass MOFA

setClassUnion("listOrNULL",members = c("list","NULL"))
setClass("MOFA", 
        slots=c(
            data                = "list",
            covariates          = "listOrNULL",
            covariates_warped   = "listOrNULL",
            intercepts          = "list",
            imputed_data        = "list",
            interpolated_Z      = "list",
            samples_metadata    = "list",
            features_metadata   = "list",
            expectations        = "list", 
            training_stats      = "list",
            data_options        = "list",
            model_options       = "list",
            training_options    = "list",
            stochastic_options  = "list",
            mefisto_options      = "list",
            dimensions          = "list",
            on_disk             = "logical",
            dim_red             = "list",
            cache               = "list",
            status              = "character"
        )
)

# Printing method
setMethod("show", "MOFA", function(object) {
  
  if (!.hasSlot(object, "dimensions") || length(object@dimensions) == 0)
    stop("Error: dimensions not defined")
  if (!.hasSlot(object, "status") || length(object@status) == 0)
    stop("Error: status not defined")
  
  if (object@status == "trained") {
    nfactors <- object@dimensions[["K"]]
    if(!.hasSlot(object, "covariates") || is.null(object@covariates)) {
      cat(sprintf("Trained MOFA with the following characteristics: \n Number of views: %d \n Views names: %s \n Number of features (per view): %s \n Number of groups: %d \n Groups names: %s \n Number of samples (per group): %s \n Number of factors: %d \n",
                  object@dimensions[["M"]], paste(views_names(object),  collapse=" "), paste(as.character(object@dimensions[["D"]]), collapse=" "),
                  object@dimensions[["G"]], paste(groups_names(object), collapse=" "), paste(as.character(object@dimensions[["N"]]), collapse=" "),
                  nfactors))
    } else {
      cat(sprintf("Trained MEFISTO with the following characteristics: \n Number of views: %d \n Views names: %s \n Number of features (per view): %s \n Number of groups: %d \n Groups names: %s \n Number of samples (per group): %s \n Number of covariates per sample: %d \n Number of factors: %d \n",
                  object@dimensions[["M"]], paste(views_names(object),  collapse=" "), paste(as.character(object@dimensions[["D"]]), collapse=" "),
                  object@dimensions[["G"]], paste(groups_names(object), collapse=" "), paste(as.character(object@dimensions[["N"]]), collapse=" "),
                  object@dimensions[["C"]], nfactors))
    }
  } else {
    if(!.hasSlot(object, "covariates") || is.null(object@covariates)) {
      cat(sprintf("Untrained MOFA model with the following characteristics: \n Number of views: %d \n Views names: %s \n Number of features (per view): %s \n Number of groups: %d \n Groups names: %s \n Number of samples (per group): %s \n ",
                  object@dimensions[["M"]], paste(views_names(object),  collapse=" "), paste(as.character(object@dimensions[["D"]]), collapse=" "),
                  object@dimensions[["G"]], paste(groups_names(object), collapse=" "), paste(as.character(object@dimensions[["N"]]), collapse=" ")))
    } else {
      cat(sprintf("Untrained MEFISTO model with the following characteristics: \n Number of views: %d \n Views names: %s \n Number of features (per view): %s \n Number of groups: %d \n Groups names: %s \n Number of samples (per group): %s \n Number of covariates per sample: %d \n ",
                  object@dimensions[["M"]], paste(views_names(object),  collapse=" "), paste(as.character(object@dimensions[["D"]]), collapse=" "),
                  object@dimensions[["G"]], paste(groups_names(object), collapse=" "), paste(as.character(object@dimensions[["N"]]), collapse=" "),
                  object@dimensions[["C"]]))
    }
  }
  cat("\n")
})




================================================
FILE: R/AllGenerics.R
================================================

##################
## Factor Names ##
##################

#' @title factors_names: set and retrieve factor names
#' @name factors_names
#' @rdname factors_names
#' @export
setGeneric("factors_names", function(object) { standardGeneric("factors_names") })

#' @name factors_names
#' @rdname factors_names
#' @aliases factors_names<-
#' @export
setGeneric("factors_names<-", function(object, value) { standardGeneric("factors_names<-") })

#####################
## Covariate Names ##
#####################

#' @title covariates_names: set and retrieve covariate names
#' @name covariates_names
#' @rdname covariates_names
#' @export
setGeneric("covariates_names", function(object) { standardGeneric("covariates_names") })

#' @name covariates_names
#' @rdname covariates_names
#' @aliases covariates_names<-
#' @export
setGeneric("covariates_names<-", function(object, value) { standardGeneric("covariates_names<-") })


##################
## Sample Names ##
##################

#' @title samples_names: set and retrieve sample names
#' @name samples_names
#' @rdname samples_names
#' @export
setGeneric("samples_names", function(object) { standardGeneric("samples_names") })

#' @name samples_names
#' @rdname samples_names
#' @aliases samples_names<-
#' @export
setGeneric("samples_names<-", function(object, value) { standardGeneric("samples_names<-") })

#####################
## Sample Metadata ##
#####################

#' @title samples_metadata: retrieve sample metadata
#' @name samples_metadata
#' @rdname samples_metadata
#' @export
setGeneric("samples_metadata", function(object) { standardGeneric("samples_metadata") })

#' @name samples_metadata
#' @rdname samples_metadata
#' @aliases samples_metadata<-
#' @export
setGeneric("samples_metadata<-", function(object, value) { standardGeneric("samples_metadata<-") })

###################
## Feature Names ##
###################

#' @title features_names: set and retrieve feature names
#' @name features_names
#' @rdname features_names
#' @export
setGeneric("features_names", function(object) { standardGeneric("features_names") })

#' @name features_names
#' @rdname features_names
#' @aliases features_names<-
#' @export
setGeneric("features_names<-", function(object, value) { standardGeneric("features_names<-") })

######################
## Feature Metadata ##
######################

#' @title features_metadata: set and retrieve feature metadata
#' @name features_metadata
#' @rdname features_metadata
#' @export
setGeneric("features_metadata", function(object) { standardGeneric("features_metadata") })

#' @name features_metadata
#' @rdname features_metadata
#' @aliases features_metadata<-
#' @export
setGeneric("features_metadata<-", function(object, value) { standardGeneric("features_metadata<-") })

################
## View Names ##
################

#' @title views_names: set and retrieve view names
#' @name views_names
#' @rdname views_names
#' @export
setGeneric("views_names", function(object) { standardGeneric("views_names") })

#' @name views_names
#' @rdname views_names
#' @aliases views_names<-
#' @export
setGeneric("views_names<-", function(object, value) { standardGeneric("views_names<-") })

################
## group Names ##
################

#' @title groups_names: set and retrieve group names
#' @name groups_names
#' @rdname groups_names
#' @export
setGeneric("groups_names", function(object) { standardGeneric("groups_names") })

#' @name groups_names
#' @rdname groups_names
#' @aliases groups_names<-
#' @export
setGeneric("groups_names<-", function(object, value) { standardGeneric("groups_names<-") })


================================================
FILE: R/QC.R
================================================
#' @importFrom stringi stri_enc_mark
.quality_control <- function(object, verbose = FALSE) {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  
  # Check views names
  if (verbose == TRUE) message("Checking views names...")
  stopifnot(!is.null(views_names(object)))
  stopifnot(!duplicated(views_names(object)))
  if (any(grepl("/", views_names(object)))) {
    stop("Some of the views names contain `/` symbol, which is not supported.
  This can be fixed e.g. with:
    views_names(object) <- gsub(\"/\", \"-\", views_names(object))")
  }
  
  # Check groups names
  if (verbose == TRUE) message("Checking groups names...")
  if (any(grepl("/", groups_names(object)))) {
    stop("Some of the groups names contain `/` symbol, which is not supported.
    This can be fixed e.g. with:
    groups_names(object) <- gsub(\"/\", \"-\", groups_names(object))")
  }
  stopifnot(!is.null(groups_names(object)))
  stopifnot(!duplicated(groups_names(object)))
  
  # Check samples names
  if (verbose == TRUE) message("Checking samples names...")
  stopifnot(!is.null(samples_names(object)))
  stopifnot(!duplicated(unlist(samples_names(object))))
  enc <- stringi::stri_enc_mark(unlist(samples_names(object)))
  if (any(enc!="ASCII")) {
    tmp <- unname(unlist(samples_names(object))[enc!="ASCII"])
    stop(sprintf("non-ascii characters detected in the following samples names, please rename them and run again create_mofa():\n- %s ", paste(tmp, collapse="\n- ")))
    print()
  }
  
  # Check features names
  if (verbose == TRUE) message("Checking features names...")
  stopifnot(!is.null(features_names(object)))
  stopifnot(!duplicated(unlist(features_names(object))))
  enc <- stringi::stri_enc_mark(unlist(features_names(object)))
  if (any(enc!="ASCII")) {
    tmp <- unname(unlist(features_names(object))[enc!="ASCII"])
    stop(sprintf("non-ascii characters detected in the following features names, please rename them and run again create_mofa():\n- %s ", paste(tmp, collapse="\n- ")))
    print()
  }
  
  # Check dimensionalities in the input data
  if (verbose == TRUE) message("Checking dimensions...")
  N <- object@dimensions$N
  D <- object@dimensions$D
  for (i in views_names(object)) {
    for (j in groups_names(object)) {
      stopifnot(ncol(object@data[[i]][[j]]) == N[[j]])
      stopifnot(nrow(object@data[[i]][[j]]) == D[[i]])
      stopifnot(length(colnames(object@data[[i]][[j]])) == N[[j]])
      stopifnot(length(rownames(object@data[[i]][[j]])) == D[[i]])
    }
  }
  
  # Check that there are no features with complete missing values (across all groups)
  if (object@status == "untrained" || object@data_options[["loaded"]]) {
      if (verbose == TRUE) message("Checking there are no features with complete missing values...")
      for (i in views_names(object)) {
        if (!(is(object@data[[i]][[1]], "dgCMatrix") || is(object@data[[i]][[1]], "dgTMatrix"))) {
          tmp <- as.data.frame(sapply(object@data[[i]], function(x) rowMeans(is.na(x)), simplify = TRUE))
          if (any(unlist(apply(tmp, 1, function(x) mean(x==1)))==1))
            warning("You have features which do not contain a single observation in any group, consider removing them...")
        }
      }
    }
    
  # check dimensionalities of sample_covariates 
  if (verbose == TRUE) message("Checking sample covariates...")
  if(.hasSlot(object, "covariates") && !is.null(object@covariates)){
    stopifnot(ncol(object@covariates) == sum(object@dimensions$N))
    stopifnot(nrow(object@covariates) == object@dimensions$C)
    stopifnot(all(unlist(samples_names(object)) == colnames(object@covariates)))
  }
  
  # Sanity checks that are exclusive for an untrained model  
  if (object@status == "untrained") {
    
    # Check features names
    if (verbose == TRUE) message("Checking features names...")
    tmp <- lapply(object@data, function(x) unique(lapply(x,rownames)))
    for (x in tmp) stopifnot(length(x)==1)
    for (x in tmp) if (any(duplicated(x[[1]]))) stop("There are duplicated features names within the same view. Please rename")
    all_names <- unname(unlist(tmp))
    duplicated_names <- unique(all_names[duplicated(all_names)])
    if (length(duplicated_names)>0) 
      warning("There are duplicated features names across different views. We will add the suffix *_view* only for those features 
            Example: if you have both TP53 in mRNA and mutation data it will be renamed to TP53_mRNA, TP53_mutation")
    for (i in names(object@data)) {
      for (j in names(object@data[[i]])) {
        tmp <- which(rownames(object@data[[i]][[j]]) %in% duplicated_names)
        if (length(tmp)>0) {
          rownames(object@data[[i]][[j]])[tmp] <- paste(rownames(object@data[[i]][[j]])[tmp], i, sep="_")
        }
      }
    }
    
  # Sanity checks that are exclusive for a trained model  
  } else if (object@status == "trained") {
    # Check expectations
    if (verbose == TRUE) message("Checking expectations...")
    stopifnot(all(c("W", "Z") %in% names(object@expectations)))
    # if(.hasSlot(object, "covariates") && !is.null(object@covariates)) stopifnot("Sigma" %in% names(object@expectations))
    stopifnot(all(sapply(object@expectations$W, is.matrix)))
    stopifnot(all(sapply(object@expectations$Z, is.matrix)))
    
    # Check for intercept factors
    if (object@data_options[["loaded"]]) { 
      if (verbose == TRUE) message("Checking for intercept factors...")
      if (!is.null(object@data)) {
        factors <- do.call("rbind",get_factors(object))
        r <- suppressWarnings( t(do.call('rbind', lapply(object@data, function(x) 
          abs(cor(colMeans(do.call("cbind",x),na.rm=TRUE),factors, use="pairwise.complete.obs"))
        ))) )
        intercept_factors <- which(rowSums(r>0.75)>0)
        if (length(intercept_factors)) {
            warning(sprintf("Factor(s) %s are strongly correlated with the average expression of features for at least one of your omics. Such factors appear when there are differences in the total 'levels' between your samples, *sometimes* because of poor normalisation in the preprocessing steps.\n",paste(intercept_factors,collapse=", ")))
        }
      }
    }
  
    # Check for correlated factors
    if (verbose == TRUE) message("Checking for highly correlated factors...")
    Z <- do.call("rbind",get_factors(object))
    op <- options(warn=-1) # suppress warnings
    
    noise <- matrix(rnorm(n=length(Z), mean=0, sd=1e-10), nrow(Z), ncol(Z))
    tmp <- cor(Z+noise); diag(tmp) <- NA
    options(op) # activate warnings again
    if (max(tmp,na.rm=TRUE)>0.5) {
      warning("The model contains highly correlated factors (see `plot_factor_cor(MOFAobject)`). \nWe recommend that you train the model with less factors and that you let it train for a longer time.\n")
    }
  
  }
  
  return(object)  
}


================================================
FILE: R/basilisk.R
================================================
# .mofapy2_dependencies <- c(
#     "h5py==3.1.0",
#     "pandas==1.2.1",
#     "scikit-learn==0.24.1",
#     "dtw-python==1.1.10"
# )

.mofapy2_dependencies <- c(
    "python=3.12.12",
    "numpy=1.26.4",
    "scipy=1.12.0",
    "pandas=2.2.1",
    "h5py=3.10.0",
    "scikit-learn=1.4.0",
    "dtw-python=1.3.1"
)

.mofapy2_version <- "0.7.3"

#' @importFrom basilisk BasiliskEnvironment
mofa_env <- BasiliskEnvironment("mofa_env", pkgname="MOFA2", packages=.mofapy2_dependencies, pip = paste0("mofapy2==",.mofapy2_version))

================================================
FILE: R/calculate_variance_explained.R
================================================
#' @title Calculate variance explained by the model
#' @description  This function takes a trained MOFA model as input and calculates the proportion of variance explained 
#' (i.e. the coefficient of determinations (R^2)) by the MOFA factors across the different views.
#' @name calculate_variance_explained
#' @param object a \code{\link{MOFA}} object.
#' @param views character vector with the view names, or numeric vector with view indexes. Default is 'all'
#' @param groups character vector with the group names, or numeric vector with group indexes. Default is 'all'
#' @param factors character vector with the factor names, or numeric vector with the factor indexes. Default is 'all'
#' @return a list with matrices with the amount of variation explained per factor and view.
#' @importFrom utils relist as.relistable
#' @export
#' @examples
#' # Using an existing trained model on simulated data
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' 
#' # Calculate variance explained (R2)
#' r2 <- calculate_variance_explained(model)
#' 
#' # Plot variance explained values (view as x-axis, and factor as y-axis)
#' plot_variance_explained(model, x="view", y="factor")
#' 
#' # Plot variance explained values (view as x-axis, and group as y-axis)
#' plot_variance_explained(model, x="view", y="group")
#' 
#' # Plot variance explained values for factors 1 to 3
#' plot_variance_explained(model, x="view", y="group", factors=1:3)
#' 
#' # Scale R2 values
#' plot_variance_explained(model, max_r2 = 0.25)
calculate_variance_explained <- function(object, views = "all", groups = "all", factors = "all") {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  if (any(object@model_options$likelihoods!="gaussian"))
    stop("Not possible to recompute the variance explained estimates when using non-gaussian likelihoods.")
  if (any(object@model_options$likelihoods!="gaussian"))
    if (isFALSE(object@data_options$loaded)) stop("Data is not loaded, cannot compute variance explained.")
  
  # Define factors, views and groups
  views  <- .check_and_get_views(object, views)
  groups <- .check_and_get_groups(object, groups)
  factors <- .check_and_get_factors(object, factors)
  K <- length(factors)
  
  # Collect relevant expectations
  W <- get_weights(object, views=views, factors=factors)
  Z <- get_factors(object, groups=groups, factors=factors)
  Y <- lapply(get_data(object, add_intercept = FALSE)[views], function(view) view[groups])
  Y <- lapply(Y, function(x) lapply(x,t))
  
  # Replace masked values on Z by 0 (so that they do not contribute to predictions)
  for (g in groups) {
    Z[[g]][is.na(Z[[g]])] <- 0
  }
  
  # Calculate coefficient of determination per group and view
  r2_m <- tryCatch({
    lapply(groups, function(g) sapply(views, function(m) {
      a <- sum((as.matrix(Y[[m]][[g]]) - tcrossprod(Z[[g]], W[[m]]))**2, na.rm = TRUE)
      b <- sum(Y[[m]][[g]]**2, na.rm = TRUE)
      return(1 - a/b)
    })
    )}, error = function(err) {
      stop(paste0("Calculating explained variance doesn't work with the current version of DelayedArray.\n",
                  "  Do not sort factors if you're trying to load the model (sort_factors = FALSE),\n",
                  "  or load the full dataset into memory (on_disk = FALSE)."))
      return(err)
    })
  r2_m <- .name_views_and_groups(r2_m, groups, views)
  
  # Lower bound is zero
  r2_m = lapply(r2_m, function(x){
    x[x < 0] = 0
    return(x)
  })
  
  # Calculate coefficient of determination per group, factor and view
  r2_mk <- lapply(groups, function(g) {
    tmp <- sapply(views, function(m) { sapply(factors, function(k) {
      a <- sum((as.matrix(Y[[m]][[g]]) - tcrossprod(Z[[g]][,k], W[[m]][,k]))**2, na.rm = TRUE)
      b <- sum(Y[[m]][[g]]**2, na.rm = TRUE)
      return(1 - a/b)
    })
    })
    tmp <- matrix(tmp, ncol = length(views), nrow = length(factors))
    colnames(tmp) <- views
    rownames(tmp) <- factors
    return(tmp)
  })
  names(r2_mk) <- groups
  
  # Lower bound is 0
  r2_mk = lapply(r2_mk, function(x){
    x[x < 0] = 0
    return(x)
  })
  
  # Transform from fraction to percentage
  r2_mk = utils::relist(unlist(utils::as.relistable(r2_mk)) * 100 ) 
  r2_m = utils::relist(unlist(utils::as.relistable(r2_m)) * 100 )
  
  # Store results
  r2_list <- list(r2_total = r2_m, r2_per_factor = r2_mk)
  
  return(r2_list)
}



#' @title Calculate variance explained by the MOFA factors for each sample
#' @description  This function takes a trained MOFA model as input and calculates, **for each sample** the proportion of variance explained 
#' (i.e. the coefficient of determinations (R^2)) by the MOFA factors across the different views.
#' @name calculate_variance_explained_per_sample
#' @param object a \code{\link{MOFA}} object.
#' @param views character vector with the view names, or numeric vector with view indexes. Default is 'all'
#' @param groups character vector with the group names, or numeric vector with group indexes. Default is 'all'
#' @param factors character vector with the factor names, or numeric vector with the factor indexes. Default is 'all'
#' @return a list with matrices with the amount of variation explained per sample and view.
#' @export
#' @examples
#' # Using an existing trained model on simulated data
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' 
#' # Calculate variance explained (R2)
#' r2 <- calculate_variance_explained_per_sample(model)
#'
calculate_variance_explained_per_sample <- function(object, views = "all", groups = "all", factors = "all") {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  if (any(object@model_options$likelihoods!="gaussian"))
    stop("Not possible to recompute the variance explained estimates when using non-gaussian likelihoods.")
  if (any(object@model_options$likelihoods!="gaussian"))
    if (isFALSE(object@data_options$loaded)) stop("Data is not loaded, cannot compute variance explained.")
  
  # Define factors, views and groups
  views  <- .check_and_get_views(object, views)
  groups <- .check_and_get_groups(object, groups)
  factors <- .check_and_get_factors(object, factors)
  
  # Collect relevant expectations
  W <- get_weights(object, views=views, factors=factors)
  Z <- get_factors(object, groups=groups, factors=factors)
  Y <- lapply(get_data(object, add_intercept = FALSE)[views], function(view) view[groups])
  Y <- lapply(Y, function(x) lapply(x,t))
  
  # Replace masked values on Z by 0 (so that they do not contribute to predictions)
  for (g in groups) { Z[[g]][is.na(Z[[g]])] <- 0 }
  
  # samples <- unlist(samples_names(object)[groups])
  samples <- samples_names(object)[groups]
  
  # Calculate coefficient of determination per sample and view
  r2 <- lapply(groups, function(g) {
    tmp <- sapply(views, function(m) {
      a <- rowSums((Y[[m]][[g]] - tcrossprod(Z[[g]],W[[m]]))**2, na.rm=TRUE)
      b <- rowSums(Y[[m]][[g]]**2, na.rm = TRUE)
      return(100*(1-a/b))
    })
    tmp <- matrix(tmp, ncol = length(views), nrow = length(samples[[g]]))
    tmp[tmp<0] <- 0
    colnames(tmp) <- views
    rownames(tmp) <- samples[[g]]
    return(tmp)
  }); names(r2) <- groups
  
  return(r2)
}








#' @title Plot variance explained by the model
#' @description plots the variance explained by the MOFA factors across different views and groups, as specified by the user.
#' Consider using cowplot::plot_grid(plotlist = ...) to combine the multiple plots that this function generates.
#' @name plot_variance_explained
#' @param object a \code{\link{MOFA}} object
#' @param x character specifying the dimension for the x-axis ("view", "factor", or "group").
#' @param y character specifying the dimension for the y-axis ("view", "factor", or "group").
#' @param split_by character specifying the dimension to be faceted ("view", "factor", or "group").
#' @param factors character vector with a factor name(s), or numeric vector with the index(es) of the factor(s). Default is "all".
#' @param plot_total logical value to indicate if to plot the total variance explained (for the variable in the x-axis)
#' @param min_r2 minimum variance explained for the color scheme (default is 0).
#' @param max_r2 maximum variance explained for the color scheme.
#' @param legend logical indicating whether to add a legend to the plot  (default is TRUE).
#' @param use_cache logical indicating whether to use cache (default is TRUE)
#' @param ... extra arguments to be passed to \code{\link{calculate_variance_explained}}
#' @import ggplot2
#' @importFrom cowplot plot_grid
#' @importFrom stats as.formula
#' @importFrom reshape2 melt
#' @return A list of \code{\link{ggplot}} objects (if \code{plot_total} is TRUE) or a single \code{\link{ggplot}} object
#' @export
#' @examples
#' # Using an existing trained model on simulated data
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' 
#' # Calculate variance explained (R2)
#' r2 <- calculate_variance_explained(model)
#' 
#' # Plot variance explained values (view as x-axis, and factor as y-axis)
#' plot_variance_explained(model, x="view", y="factor")
#' 
#' # Plot variance explained values (view as x-axis, and group as y-axis)
#' plot_variance_explained(model, x="view", y="group")
#' 
#' # Plot variance explained values for factors 1 to 3
#' plot_variance_explained(model, x="view", y="group", factors=1:3)
#' 
#' # Scale R2 values
#' plot_variance_explained(model, max_r2=0.25)
plot_variance_explained <- function(object, x = "view", y = "factor", split_by = NA, plot_total = FALSE, 
                                    factors = "all", min_r2 = 0, max_r2 = NULL, legend = TRUE, use_cache = TRUE, ...) {
  
  # Sanity checks 
  if (length(unique(c(x, y, split_by))) != 3) { 
    stop(paste0("Please ensure x, y, and split_by arguments are different.\n",
                "  Possible values are `view`, `group`, and `factor`."))
  }
  
  # Automatically fill split_by in
  if (is.na(split_by)) split_by <- setdiff(c("view", "factor", "group"), c(x, y, split_by))
  
  # Calculate variance explained
  if ((use_cache) & .hasSlot(object, "cache") && ("variance_explained" %in% names(object@cache))) {
    r2_list <- object@cache$variance_explained
  } else {
    r2_list <- calculate_variance_explained(object, factors = factors, ...)
  }
  
  r2_mk <- r2_list$r2_per_factor
  
  # convert matrix to long data frame for ggplot2
  r2_mk_df <- melt(
    lapply(r2_mk, function(x)
      melt(as.matrix(x), varnames = c("factor", "view"))
    ), id.vars=c("factor", "view", "value")
  )
  colnames(r2_mk_df)[ncol(r2_mk_df)] <- "group"
  
  # Subset factors for plotting
  if ((length(factors) == 1) && (factors[1] == "all")) {
    factors <- factors_names(object)
  } else {
    if (is.numeric(factors)) {
      factors <- factors_names(object)[factors]
    } else { 
      stopifnot(all(factors %in% factors_names(object)))
    }
    r2_mk_df <- r2_mk_df[r2_mk_df$factor %in% factors,]
  }
  
  r2_mk_df$factor <- factor(r2_mk_df$factor, levels = factors)
  r2_mk_df$group <- factor(r2_mk_df$group, levels = groups_names(object))
  r2_mk_df$view <- factor(r2_mk_df$view, levels = views_names(object))
  
  # Detect whether to split by group or by view
  groups <- names(r2_list$r2_total)
  views <- colnames(r2_list$r2_per_factor[[1]])
  
  # Set R2 limits
  if (!is.null(min_r2)) r2_mk_df$value[r2_mk_df$value<min_r2] <- 0.001
  min_r2 = 0
  
  if (!is.null(max_r2)) {
    r2_mk_df$value[r2_mk_df$value>max_r2] <- max_r2
  } else {
    max_r2 = max(r2_mk_df$value)
  }
  
  
  # Grid plot with the variance explained per factor and view/group
  p1 <- ggplot(r2_mk_df, aes(x=.data[[x]], y=.data[[y]])) + 
    geom_tile(aes(fill=.data$value), color="black") +
    facet_wrap(as.formula(sprintf('~%s',split_by)), nrow=1) +
    labs(x="", y="", title="") +
    scale_fill_gradientn(colors=c("gray97","darkblue"), guide="colorbar", limits=c(min_r2,max_r2)) +
    guides(fill=guide_colorbar("Var. (%)")) +
    theme(
      axis.text.x = element_text(size=rel(1.0), color="black"),
      axis.text.y = element_text(size=rel(1.1), color="black"),
      axis.line = element_blank(),
      axis.ticks =  element_blank(),
      panel.background = element_blank(),
      strip.background = element_blank(),
      strip.text = element_text(size=rel(1.0))
    )
  
  if (isFALSE(legend)) p1 <- p1 + theme(legend.position = "none")
  
  # remove facet title
  if (length(unique(r2_mk_df[,split_by]))==1) p1 <- p1 + theme(strip.text = element_blank())
  
  # Add total variance explained bar plots
  if (plot_total) {
    
    r2_m_df <- melt(lapply(r2_list$r2_total, function(x) lapply(x, function(z) z)),
                    varnames=c("view", "group"), value.name="R2")
    colnames(r2_m_df)[(ncol(r2_m_df)-1):ncol(r2_m_df)] <- c("view", "group")
    
    r2_m_df$group <- factor(r2_m_df$group, levels = MOFA2::groups_names(object))
    r2_m_df$view <- factor(r2_m_df$view, levels = views_names(object))
    
    # Barplots for total variance explained
    min_lim_bplt <- min(0, r2_m_df$R2)
    max_lim_bplt <- max(r2_m_df$R2)
    
    # Barplot with variance explained per view/group (across all factors)
    p2 <- ggplot(r2_m_df, aes(x=.data[[x]], y=.data$R2)) + 
      # ggtitle(sprintf("%s\nTotal variance explained per %s", i, x)) +
      geom_bar(stat="identity", fill="deepskyblue4", color="black", width=0.9) +
      facet_wrap(as.formula(sprintf('~%s',split_by)), nrow=1) +
      xlab("") + ylab("Variance explained (%)") +
      scale_y_continuous(limits=c(min_lim_bplt, max_lim_bplt), expand=c(0.005, 0.005)) +
      theme(
        axis.ticks.x = element_blank(),
        axis.text.x = element_text(color="black"),
        axis.text.y = element_text(color="black"),
        axis.title.y = element_text(color="black"),
        axis.line = element_line(color="black"),
        panel.background = element_blank(),
        strip.background = element_blank(),
        strip.text = element_text()
      )
    
    # remove facet title
    if (length(unique(r2_m_df[,split_by]))==1) p2 <- p2 + theme(strip.text = element_blank())
    
    # Bind plots      
    plot_list <- list(p1,p2)
    
  } else {
    plot_list <- p1
  }
  
  return(plot_list)
}


#' @title Plot variance explained by the model for a set of features
#' 
#' @description Returns a tile plot with a group on the X axis and a feature along the Y axis
#' 
#' @name plot_variance_explained_per_feature
#' @param object a \code{\link{MOFA}} object.
#' @param view a view name or index.
#' @param features a vector with indices or names for features from the respective view, 
#' or number of top features to be fetched by their loadings across specified factors. 
#' "all" to plot all features.
#' @param split_by_factor logical indicating whether to split R2 per factor or plot R2 jointly
#' @param group_features_by column name of features metadata to group features by
#' @param groups a vector with indices or names for sample groups (default is all)
#' @param factors a vector with indices or names for factors (default is all)
#' @param min_r2 minimum variance explained for the color scheme (default is 0).
#' @param max_r2 maximum variance explained for the color scheme.
#' @param legend logical indicating whether to add a legend to the plot  (default is TRUE).
#' @param return_data logical indicating whether to return the data frame to plot instead of plotting
#' @param ... extra arguments to be passed to \code{\link{calculate_variance_explained}}
#' @return ggplot object
#' @import ggplot2
#' @importFrom cowplot plot_grid
#' @importFrom stats as.formula
#' @importFrom reshape2 melt
#' @export
#' @examples 
#' # Using an existing trained model
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' plot_variance_explained_per_feature(model, view = 1)

plot_variance_explained_per_feature <- function(object, view, features = 10,
                                                split_by_factor = FALSE, group_features_by = NULL,
                                                groups = "all", factors = "all",
                                                min_r2 = 0, max_r2 = NULL, legend = TRUE,
                                                return_data = FALSE, ...) {
  
  # Check that one view is requested
  view  <- .check_and_get_views(object, view)
  if (length(view) != 1)
    stop("Please choose a single view to plot features from")
  
  # Fetch loadings, factors, and data  
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  
  # Fetch relevant features)
  if (is.numeric(features) && (length(features) == 1)) {
    features <- as.integer(features)
    features <- .get_top_features_by_loading(object, view = view, factors = factors, nfeatures = features)
  } else if (is.character(features)) {
    if (features[1]=="all") features <- 1:object@dimensions$D[[view]]
  }
  features <- .check_and_get_features_from_view(object, view = view, features)
  
  # Collect relevant expectations
  groups <- .check_and_get_groups(object, groups)
  factors <- .check_and_get_factors(object, factors)
  # 1. Loadings: choose a view, one or multiple factors, and subset chosen features
  W <- get_weights(object, views = view, factors = factors)
  W <- lapply(W, function(W_m) W_m[rownames(W_m) %in% features,,drop=FALSE])
  # 2. Factor values: choose one or multiple groups and factors
  Z <- get_factors(object, groups = groups, factors = factors)
  # 3. Data: Choose a view, one or multiple groups, and subset chosen features
  # Y <- lapply(get_expectations(object, "Y")[view], function(Y_m) lapply(Y_m[groups], t))
  Y <- lapply(get_data(object, add_intercept = FALSE)[view], function(Y_m) lapply(Y_m[groups], t))
  Y <- lapply(Y, function(Y_m) lapply(Y_m, function(Y_mg) Y_mg[,colnames(Y_mg) %in% features,drop=FALSE]))
  
  # Replace masked values on Z by 0 (so that they do not contribute to predictions)
  for (g in groups) {
    Z[[g]][is.na(Z[[g]])] <- 0
  }
  
  m <- view  # Use shorter notation when calculating R2
  
  if (split_by_factor) {
    
    # Calculate coefficient of determination per group, factor and feature
    r2_gdk <- lapply(groups, function(g) {
      r2_g <- sapply(features, function(d) { 
        sapply(factors, function(k) {
        a <- sum((as.matrix(Y[[m]][[g]][,d,drop=FALSE]) - tcrossprod(Z[[g]][,k,drop=FALSE], W[[m]][d,k,drop=FALSE]))**2, na.rm = TRUE)
        b <- sum(Y[[m]][[g]][,d,drop=FALSE]**2, na.rm = TRUE)
        return(1 - a/b)
      })
      })
      r2_g <- matrix(r2_g, ncol = length(features), nrow = length(factors))
      colnames(r2_g) <- features
      rownames(r2_g) <- factors
      # Lower bound is zero
      r2_g[r2_g < 0] <- 0
      r2_g
    })
    names(r2_gdk) <- groups
    
    # Convert matrix to long data frame for ggplot2
    r2_gdk_df <- do.call(rbind, r2_gdk)
    r2_gdk_df <- data.frame(r2_gdk_df, 
                            "group" = rep(groups, lapply(r2_gdk, nrow)),
                            "factor" = rownames(r2_gdk_df))
    r2_gdk_df <- melt(r2_gdk_df, id.vars = c("group", "factor"))
    colnames(r2_gdk_df) <- c("group", "factor", "feature", "value")
    
    r2_gdk_df$group <- factor(r2_gdk_df$group, levels = unique(r2_gdk_df$group))
    
    r2_df <- r2_gdk_df
    
  } else {
    
    # Calculate coefficient of determination per group and feature
    r2_gd <- lapply(groups, function(g) {
      r2_g <- lapply(features, function(d) {
        a <- sum((as.matrix(Y[[m]][[g]][,d,drop=FALSE]) - tcrossprod(Z[[g]], W[[m]][d,,drop=FALSE]))**2, na.rm = TRUE)
        b <- sum(Y[[m]][[g]][,d,drop=FALSE]**2, na.rm = TRUE)
        return(1 - a/b)
      })
      names(r2_g) <- features
      # Lower bound is zero
      r2_g[r2_g < 0] <- 0
      r2_g
    })
    names(r2_gd) <- groups
    
    # Convert matrix to long data frame for ggplot2
    tmp <- as.matrix(data.frame(lapply(r2_gd, unlist))) 
    colnames(tmp) <- groups
    r2_gd_df <- melt(tmp)
    colnames(r2_gd_df) <- c("feature", "group", "value")
    
    r2_gd_df$group <- factor(r2_gd_df$group, levels = unique(r2_gd_df$group))
    
    r2_df <- r2_gd_df
    
  }
  
  # Transform from fraction to percentage
  r2_df$value <- 100*r2_df$value
  
  # Calculate minimum R2 to display
  if (!is.null(min_r2)) {
    r2_df$value[r2_df$value<min_r2] <- 0.001
  }
  min_r2 <- 0
  
  # Calculate maximum R2 to display
  if (!is.null(max_r2)) {
    r2_df$value[r2_df$value>max_r2] <- max_r2
  } else {
    max_r2 <- max(r2_df$value)
  }
  
  # Group features
  if (!is.null(group_features_by)) {
    features_indices <- match(r2_df$feature, features_metadata(object)$feature)
    features_grouped <- features_metadata(object)[,group_features_by,drop=FALSE][features_indices,,drop=FALSE]
    # If features grouped using multiple variables, concatenate them
    if (length(group_features_by) > 1) {
      features_grouped <- apply(features_grouped, 1, function(row) paste0(row, collapse="_"))
    } else {
      features_grouped <- features_grouped[,group_features_by,drop=TRUE]
    }
    r2_df["feature_group"] <- features_grouped
  }
  
  if (return_data)
    return(r2_df)
  
  if (split_by_factor) {
    r2_df$factor <- factor(r2_df$factor, levels = factors_names(object))
  }
  
  # Grid plot with the variance explained per feature in every group
  p <- ggplot(r2_df, aes(x = .data$group, y = .data$feature)) + 
    geom_tile(aes(fill = .data$value), color = "black") +
    guides(fill = guide_colorbar("R2 (%)")) +
    labs(x = "", y = "", title = "") +
    scale_fill_gradientn(colors=c("gray97","darkblue"), guide="colorbar", limits=c(min_r2, max_r2)) +
    theme_classic() +
    theme(
      axis.text = element_text(size = 12),
      axis.line = element_blank(),
      axis.ticks =  element_blank(),
      strip.text = element_text(size = 12),
    )
  
  if (!is.null(group_features_by) && split_by_factor) {
    p <- p + facet_grid(feature_group ~ factor, scales = "free_y")
  } else if (split_by_factor) {
    p <- p + facet_wrap(~factor, nrow = 1)
  } else if (!is.null(group_features_by)) {
    p <- p + facet_wrap(~feature_group, ncol = 1, scales = "free")
  }
  
  if (!legend)
    p <- p + theme(legend.position = "none")
  
  return(p)
}


================================================
FILE: R/cluster_samples.R
================================================

##########################################################
## Functions to cluster samples based on latent factors ##
##########################################################

#' @title K-means clustering on samples based on latent factors
#' @name cluster_samples
#' @description MOFA factors are continuous in nature but they can be used to predict discrete clusters of samples. \cr
#' The clustering can be performed in a single factor, which is equivalent to setting a manual threshold.
#' More interestingly, it can be done using multiple factors, where multiple sources of variation are aggregated. \cr
#' Importantly, this type of clustering is not weighted and does not take into account the different importance of the latent factors. 
#' @param object a trained \code{\link{MOFA}} object.
#' @param k number of clusters (integer).
#' @param factors character vector with the factor name(s), or numeric vector with the index of the factor(s) to use. 
#' Default is 'all'
#' @param ... extra arguments  passed to \code{\link{kmeans}}
#' @details In some cases, due to model technicalities, samples can have missing values in the latent factor space. 
#' In such a case, these samples are currently ignored in the clustering procedure.
#' @return output from \code{\link{kmeans}} function
#' @importFrom stats kmeans
#' @export 
#' @examples
#' # Using an existing trained model on simulated data
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' 
#' # Cluster samples in the factor space using factors 1 to 3 and K=2 clusters 
#' clusters <- cluster_samples(model, k=2, factors=1:3)
cluster_samples <- function(object, k, factors = "all", ...) {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  
  # Define factors
  factors <- .check_and_get_factors(object, factors)
  
  # Collect relevant data
  Z <- get_factors(object, factors=factors)
  if (is(Z, "list")) Z <- do.call(rbind, Z)
  N <- nrow(Z)
  
  # For now remove sample with missing values on factors
  # (TO-DO) incorporate a clustering function that is able to cope with missing values
  haveAllZ <- apply(Z, 1, function(x) all(!is.na(x)))
  if(!all(haveAllZ)) warning(paste("Removing", sum(!haveAllZ), "samples with missing values on at least one factor"))
  Z <- Z[haveAllZ,]
  
  # Perform k-means clustering
  kmeans.out <- kmeans(Z, centers=k,  ...)

  return(kmeans.out)  

}


================================================
FILE: R/compare_models.R
================================================

################################################
## Functions to compare different MOFA models ##
################################################


#' @title Plot the correlation of factors between different models
#' @name compare_factors
#' @description Different \code{\link{MOFA}} objects are compared in terms of correlation between their factors.
#' @param models a list with \code{\link{MOFA}} objects.
#' @param ... extra arguments passed to pheatmap
#' @details If assessing model robustness across trials, the output should look like a block diagonal matrix, 
#' suggesting that all factors are robustly detected in all model instances.
#' @return Plots a heatmap of the Pearson correlation between latent factors across all input models.
#' @importFrom stats cor
#' @importFrom pheatmap pheatmap
#' @importFrom grDevices colorRampPalette
#' @export
#' @examples
#' # Using an existing trained model on simulated data
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model1 <- load_model(file)
#' model2 <- load_model(file)
#' 
#' # Compare factors between models
#' compare_factors(list(model1,model2))
compare_factors <- function(models, ...) {
  
  # Sanity checks
  if(!is.list(models))
    stop("'models' has to be a list")
  if (!all(sapply(models, function (l) is(l, "MOFA"))))
    stop("Each element of the the list 'models' has to be an instance of MOFA")

  # Give generic names if no names present
  if(is.null(names(models))) names(models) <- paste("model", seq_len(length(models)), sep="")

  # Get latent factors
  LFs <- lapply(seq_along(models), function(i){
    do.call(rbind, get_factors(models[[i]]))
  })
  
  # Sanity checks
  if (is.null(Reduce(intersect,lapply(LFs, rownames))))
    stop("No common samples in all models for comparison")

  # Align samples between models
  samples_names <- Reduce(intersect, lapply(LFs, rownames))
  LFs <- lapply(LFs, function(z) {
    z[samples_names,,drop=FALSE]
  })
  
  # Rename factors
  for (i in seq_along(LFs))
    colnames(LFs[[i]]) <- paste(names(models)[i], colnames(LFs[[i]]), sep="_")

  # calculate correlation between factors across models
  corLFs <- cor(Reduce(cbind, LFs), use="complete.obs")
  corLFs[is.na(corLFs)] <- 0
  corLFs <- abs(corLFs)

  # Plot heatmap
  breaksList <- seq(0,1, by=0.01)
  colors <- colorRampPalette(c("white",RColorBrewer::brewer.pal(9,name="YlOrRd")))(length(breaksList))
  pheatmap(corLFs, color = colors, breaks = breaksList, ...)
}



#' @title Compare different trained \code{\link{MOFA}} objects in terms of the final value of the ELBO statistics and number of inferred factors
#' @name compare_elbo
#' @description Different objects of \code{\link{MOFA}} are compared in terms of the final value of the ELBO statistics.
#' For model selection the model with the highest ELBO value is selected.
#' @param models a list containing \code{\link{MOFA}} objects.
#' @param log logical indicating whether to plot the log of the ELBO.
#' @param return_data logical indicating whether to return a data.frame with the ELBO values per model
#' @return A \code{\link{ggplot}} object or the underlying data.frame if return_data is TRUE
#' @export
#' @examples
#' # Using an existing trained model on simulated data
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model1 <- load_model(file)
#' model2 <- load_model(file)
#' 
#' # Compare ELBO between models
#' \dontrun{compare_elbo(list(model1,model2))}
compare_elbo <- function(models, log = FALSE, return_data = FALSE) {
  
  # Sanity checks
  if(!is.list(models))
    stop("'models' has to be a list")
  if (!all(sapply(models, function (l) is(l, "MOFA"))))
    stop("Each element of the the list 'models' has to be an instance of MOFA")
  
  # Give generic names if no names present
  if (is.null(names(models))) names(models) <- paste0("model_", seq_along(models))
  
  # Get ELBO values
  elbo_vals <- sapply(models, get_elbo)
  
  # Generate plot
  df <- data.frame(
    ELBO = elbo_vals, 
    model = names(models)
  )
  
  
  # take the log
  if (log) {
    message("Plotting the log2 of the negative of the ELBO (the higher the better)")
    df$ELBO <- log2(-df$ELBO)
  }
  
  if (all(df$ELBO<0)) {
    df$ELBO <- abs(df$ELBO)
    message("Plotting the absolute value of the ELBO for every model (the smaller the better)")
} else {
    message("Plotting the ELBO for every model (the higher the better)")
  }
  
  # return data
  if (return_data) return(df)
  
  gg <- ggplot(df, aes(x=.data$model, y=.data$ELBO)) + 
    geom_bar(stat="identity", color="black", fill="grey70") +
    labs(x="", y="Evidence Lower Bound (ELBO)") +
    theme_classic()
  
  return(gg)
}



#' @title Select a model from a list of trained \code{\link{MOFA}} objects based on the best ELBO value
#' @name select_model
#' @description Different objects of \code{\link{MOFA}} are compared in terms of the final value of the ELBO statistics
#' and the model with the highest ELBO value is selected.
#' @param models a list containing \code{\link{MOFA}} objects.
#' @param plot boolean indicating whether to show a plot of the ELBO for each model instance
#' @return A \code{\link{MOFA}} object
#' @export
select_model <- function(models, plot = FALSE) {
  # Sanity checks
  if(!is.list(models))
    stop("'models' has to be a list")
  if (!all(sapply(models, function (l) is(l, "MOFA"))))
    stop("Each element of the the list 'models' has to be an instance of MOFA")

  elbo_vals <- sapply(models, get_elbo)
  if(plot) compare_elbo(models)
  models[[which.max(elbo_vals)]]
}


================================================
FILE: R/contribution_scores.R
================================================
#' @title Calculate contribution scores for each view in each sample
#' @description This function calculates, *for each sample* how much each view contributes to its location in the latent manifold, what we call \emph{contribution scores}
#' @name calculate_contribution_scores
#' @param object a trained \code{\link{MOFA}} object.
#' @param views character vector with the view names, or numeric vector with view indexes. Default is 'all'
#' @param groups character vector with the group names, or numeric vector with group indexes. Default is 'all'
#' @param factors character vector with the factor names, or numeric vector with the factor indexes. Default is 'all'
#' @param scale logical indicating whether to scale the sample-wise variance explained values by the total amount of variance explained per view. 
#' This effectively normalises each view by its total variance explained. It is important when different amounts of variance is explained for each view (check with \code{plot_variance_explained(..., plot_total=TRUE)})
#' @details Contribution scores are calculated in three steps:
#' \itemize{
#'  \item{\strong{Step 1}: calculate variance explained for each cell i and each view m (\eqn{R_{im}}), using all factors}
#'  \item{\strong{Step 2} (optional): scale values by the total variance explained for each view}
#'  \item{\strong{Step 3}: calculate contribution score (\eqn{C_{im}}) for cell i and view m as: \deqn{C_{im} = \frac{R2_{im}}{\sum_{m} R2_{im}} } }
#' }
#' Note that contribution scores can be calculated using any number of data modalities, but it is easier to interpret when you specify two. \cr
#' Please note that this functionality is still experimental, contact the authors if you have questions.
#' @return adds the contribution scores to the metadata slot (\code{samples_metadata(MOFAobject)}) and to the \code{MOFAobject@cache} slot
#' @export
#' @examples
#' # Using an existing trained model on simulated data
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' model <- calculate_contribution_scores(model)
#'
calculate_contribution_scores <- function(object, views = "all", groups = "all", factors = "all", scale = TRUE) {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  if (any(object@model_options$likelihoods!="gaussian"))
    stop("Not possible to compute contribution scores when using non-gaussian likelihoods.")

  # Define factors, views and groups
  views  <- .check_and_get_views(object, views)
  if (length(views)<2) stop("contribution scores only make sense when having at least 2 views")
  groups <- .check_and_get_groups(object, groups)
  factors <- .check_and_get_factors(object, factors)
  if (length(factors)<2) stop("contribution scores only make sense when having at least 2 factors")
  
  # fetch variance explained values
  r2.per.sample <- calculate_variance_explained_per_sample(object, factors=factors, views = views, groups = groups)
  
  # scale the variance explained values to the total amount of variance explained per view
  if (scale) {
    r2.per.view <- get_variance_explained(object, factors=factors, views = views, groups = groups)[["r2_total"]]
    r2.per.sample <- lapply(1:length(groups), function(g) sweep(r2.per.sample[[g]], 2, r2.per.view[[g]],"/"))
  }
  
  # concatenate groups
  r2.per.sample <- do.call("rbind",r2.per.sample)
  
  # Calculate the fraction of (relative) variance explained for each view in each cell -> the contribution score
  contribution_scores <- r2.per.sample / rowSums(r2.per.sample)
  
  # Add contribution scores to the sample metadata
  for (i in colnames(contribution_scores)) {
    object <- .add_column_to_metadata(object, contribution_scores[,i], paste0(i,"_contribution"))
  }
  # Add contribution scores to the cache
  object@cache[["contribution_scores"]] <- contribution_scores
  
  
  return(object)
  
}


get_contribution_scores <- function(object, groups = "all", views = "all", factors = "all", 
                                   as.data.frame = FALSE) {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  
  # Get factors and groups
  groups <- .check_and_get_groups(object, groups)
  views <- .check_and_get_views(object, views)
  factors <- .check_and_get_factors(object, factors)
  
  # Fetch
  if (.hasSlot(object, "cache") && ("contribution_scores" %in% names(object@cache))) {
    scores_list <- object@cache[["contribution_scores"]]
  } else {
    scores_list <- calculate_contribution_scores(object, factors = factors, views = views, groups = groups)
  }
  
  # Convert to data.frame format
  if (as.data.frame) {
    scores <- reshape2::melt( do.call("rbind",scores_list) )
    colnames(scores) <- c("sample", "view", "value")
  } else {
    scores <- scores_list
  }
  
  return(scores)
  
}

plot_contribution_scores <- function(object, samples = "all", group_by = NULL, return_data = FALSE, ...) {

  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  
  # (TO-DO) get samples
  
  # get contribution scores
  scores <- get_contribution_scores(object, as.data.frame = TRUE, ...)
  
  # TO-DO: CHECK THAT GROUP IS A CHARACTER/FACTOR
  # individual samples
  if (is.null(group_by)) {
    
    to.plot <- scores
    if (return_data) return(to.plot)
    p <- ggplot(to.plot, aes(x=.data$view, y=.data$value)) +
      geom_bar(aes(fill=view), stat="identity", color="black") +
      facet_wrap(~sample) +
      labs(x="", y="Contribution score") +
      theme_classic() +
      theme(
        axis.text.x = element_blank(),
        axis.ticks.x = element_blank(),
        legend.position = "top",
        legend.title = element_blank()
      )
    return(p)
    
  # group samples
  } else {
    
    to.plot <- merge(scores, object@samples_metadata[,c("sample",group_by)], by="sample")
    if (return_data) return(to.plot)
    p <- ggplot(to.plot, aes(x=.data$view, y=.data$value)) +
      geom_boxplot(aes(fill=view)) +
      facet_wrap(as.formula(paste("~", group_by))) +
      labs(x="", y="Contribution score") +
      theme_classic() +
      theme(
        axis.text.x = element_blank(),
        axis.ticks.x = element_blank(),
        legend.position = "top",
        legend.title = element_blank()
      )
    return(p)
  }
}

================================================
FILE: R/correlate_covariates.R
================================================
#' @title Plot correlation of factors with external covariates
#' @name correlate_factors_with_covariates
#' @description Function to correlate factor values with external covariates.
#' @param object a trained \code{\link{MOFA}} object.
#' @param covariates 
#' \itemize{
#'   \item{\strong{data.frame}: a data.frame where the samples are stored in the rows and the covariates are stored in the columns. 
#'   Use row names for sample names and column names for covariate names. Columns values must be numeric. }
#'   \item{\strong{character vector}: character vector with names of columns that are present in the sample metadata (\code{samples_metadata(model)}}
#' }
#' @param factors character vector with the factor name(s), or numeric vector with the index of the factor(s) to use. Default is 'all'.
#' @param groups character vector with the groups names, or numeric vector with the indices of the groups of samples to use, or "all" to use samples from all groups.
#' @param abs logical indicating whether to take the absolute value of the correlation coefficient (default is \code{TRUE}).
#' @param plot character indicating whether to plot Pearson correlation coefficients (\code{plot="r"}) or log10 adjusted p-values (\code{plot="log_pval"}).
#' @param return_data logical indicating whether to return the correlation results instead of plotting
#' @param transpose logical indicating whether to transpose the plot
#' @param alpha p-value threshold
#' @param ... extra arguments passed to \code{\link[corrplot]{corrplot}} (if \code{plot=="r"}) or \code{\link[pheatmap]{pheatmap}} (if \code{plot=="log_pval"}).
#' @importFrom pheatmap pheatmap
#' @importFrom corrplot corrplot
#' @return A \code{\link[corrplot]{corrplot}} (if \code{plot=="r"}) or \code{\link[pheatmap]{pheatmap}} (if \code{plot=="log_pval"}) or the underlying data.frame if return_data is TRUE
#' @export
correlate_factors_with_covariates <- function(object, covariates, factors = "all", groups = "all", 
                                              abs = FALSE, plot = c("log_pval","r"), 
                                              alpha = 0.05, return_data = FALSE, transpose = FALSE, ...) {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  groups <- .check_and_get_groups(object,groups)
  
  # Get covariates
  metadata <- samples_metadata(object)
  metadata <- metadata[metadata$group%in%groups,]
  if (is.character(covariates)) {
    stopifnot(all(covariates %in% colnames(metadata)))
    covariates <- metadata[,covariates,drop=FALSE]
  } else if (is.data.frame(covariates)) {
    samples <- metadata$sample
    if (is.null(rownames(covariates))) stop("The 'covariates' data.frame does not have samples names")
    stopifnot(all(rownames(covariates) %in% samples))
    covariates <- metadata[match(rownames(covariates), metadata$sample),]
  } else {
    stop("covariates argument not recognised. Please read the documentation: ?correlate_factors_with_covariates")
  }
  
  # convert character columns to factors
  cols <- which(sapply(covariates, is.character))
  if (length(cols>=1)) {
    covariates[cols] <- lapply(covariates[cols], as.factor)
  }
  
  # convert all columns to numeric
  cols <- which(!sapply(covariates,class)%in%c("numeric","integer"))
  if (length(cols>=1)) {
    cols.factor <- which(sapply(covariates,class)=="factor")
    covariates[cols] <- lapply(covariates[cols], as.numeric)
    warning("There are non-numeric values in the covariates data.frame, converting to numeric...")
    covariates[cols] <- lapply(covariates[cols], as.numeric)
  }
  stopifnot(all(sapply(covariates,class)%in%c("numeric","integer")))
  
  # Get factors
  factors <- .check_and_get_factors(object, factors)
  Z <- get_factors(object, factors = factors, groups = groups, as.data.frame=FALSE)
  Z <- do.call(rbind, Z)
  
  # correlation
  cor <- psych::corr.test(Z, covariates, method = "pearson", adjust = "BH")
  
  # plot  
  plot <- match.arg(plot)
  
  if (plot=="r") {
    stat <- cor$r
    if (abs) stat <- abs(stat)
    if (transpose) stat <- t(stat)
    if (return_data) return(stat)
    corrplot(stat, tl.col = "black", title="Pearson correlation coefficient", ...)
    
  } else if (plot=="log_pval") {
    stat <- cor$p
    stat[stat>alpha] <- 1.0
    if (all(stat==1.0)) stop("All p-values are 1.0, nothing to plot")
    stat <- -log10(stat)
    stat[is.infinite(stat)] <- 1000
    if (transpose) stat <- t(stat)
    if (return_data) return(stat)
    col <- colorRampPalette(c("lightgrey", "red"))(n=100)
    pheatmap::pheatmap(stat, main="log10 adjusted p-values", cluster_rows = FALSE, color=col, ...)
    
  } else {
    stop("'plot' argument not recognised. Please read the documentation: ?correlate_factors_with_covariates")
  }
  
}



#' @title Summarise factor values using external groups
#' @name summarise_factors
#' @description Function to summarise factor values using a discrete grouping of samples.
#' @param object a trained \code{\link{MOFA}} object.
#' @param df a data.frame with the columns "sample" and "level", where level is a factor with discrete group assignments for each sample.
#' @param factors character vector with the factor name(s), or numeric vector with the index of the factor(s) to use. Default is 'all'.
#' @param groups character vector with the groups names, or numeric vector with the indices of the groups of samples to use, or "all" to use samples from all groups.
#' @param abs logical indicating whether to take the absolute value of the factors (default is \code{FALSE}).
#' @param return_data logical indicating whether to return the fa instead of plotting
#' @import ggplot2
#' @importFrom dplyr group_by summarise mutate
#' @importFrom stats median
#' @importFrom magrittr %>%
#' @return A \code{\link{ggplot}} object or a \code{data.frame} if return_data is TRUE
#' @export
summarise_factors <- function(object, df, factors = "all", groups = "all", abs = FALSE, return_data = FALSE) {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  stopifnot(is.data.frame(df))
  stopifnot((c("sample","level")%in%colnames(df)))
  stopifnot(df$sample %in% unlist(samples_names(object)))
  stopifnot(length(df$level)>1)
  df$level <- as.factor(df$level)
  
  # Get factors
  factors <- .check_and_get_factors(object, factors)
  groups <- .check_and_get_groups(object, groups)
  factors_df <- get_factors(object, factors = factors, groups = groups, as.data.frame=TRUE) %>% 
    group_by(factor) %>% mutate(value=value/max(abs(value),na.rm=TRUE)) # Scale factor values
  
  # Merge data.frames
  to.plot <- merge(factors_df, df, by="sample") %>% 
    group_by(level,factor,group) %>%
    summarise(value=median(value,na.rm=TRUE))
  
  if (abs) {
    to.plot$value <- abs(to.plot$value)
  }
  
  
  # Plot
  if (length(unique(factors_df$group))>1) {
    to.plot$group <- factor(to.plot$group, levels=groups)
    p <- ggplot(to.plot, aes(x=.data$group, y=.data$level, fill=.data$value)) +
      facet_wrap(~factor)
  } else {
    p <- ggplot(to.plot, aes(x=.data$factor, y=.data$level, fill=.data$value))
  }
  
  p <- p +
    geom_tile() +
    theme_classic() +
    labs(x="", y="", fill="Score") +
    theme(
      axis.text.x = element_text(color="black", angle=30, hjust=1),
      axis.text.y = element_text(color="black")
    )

  if (abs) {
    p <- p + scale_fill_gradient2(low = "white", high = "red")
  } else {
    # center the color scheme at 0
    p <- p + scale_fill_distiller(type = "div", limit = max(abs(to.plot$value),na.rm=TRUE)*c(-1,1))
  } 
  
  # Return data or plot
  if (return_data) {
    return(to.plot)
  } else {
    return(p)
  }
}




================================================
FILE: R/create_mofa.R
================================================

#' @title create a MOFA object
#' @name create_mofa
#' @description Method to create a \code{\link{MOFA}} object. Depending on the input data format, this method calls one of the following functions:
#' \itemize{
#'   \item{\strong{long data.frame}: \code{\link{create_mofa_from_df}}}
#'   \item{\strong{List of matrices}: \code{\link{create_mofa_from_matrix}}}
#'   \item{\strong{MultiAssayExperiment}: \code{\link{create_mofa_from_MultiAssayExperiment}}}
#'   \item{\strong{Seurat}: \code{\link{create_mofa_from_Seurat}}}
#'   \item{\strong{SingleCellExperiment}: \code{\link{create_mofa_from_SingleCellExperiment}}}
#'   }
#'  Please read the documentation of the corresponding function for more details on your specific data format.
#' @param data one of the formats above
#' @param groups group information, only relevant when using the multi-group framework. 
#' @param extract_metadata logical indicating whether to incorporate the sample metadata from the input object into the MOFA object (
#' not relevant when the input is a list of matrices). Default is \code{TRUE}.
#' @param ... further arguments that can be passed to the function depending on the input data format.
#' See the documentation of above functions for details.
#' @return Returns an untrained \code{\link{MOFA}} object
#' @export
#' @examples
#' # Using an existing simulated data with two groups and two views
#' file <- system.file("extdata", "test_data.RData", package = "MOFA2")
#' 
#' # Load data (in long data.frame format)
#' load(file) 
#' MOFAmodel <- create_mofa(dt)
create_mofa <- function(data, groups = NULL, extract_metadata = TRUE, ...) {
  
  # Creating MOFA object from a Seurat object
  if (is(data, "Seurat")) {
    
    message("Creating MOFA object from a Seurat object...")
    object <- create_mofa_from_Seurat(data, groups, extract_metadata = extract_metadata, ...)
    
    # Creating MOFA object from a SingleCellExperiment object
  } else if (is(data, "SingleCellExperiment")) {
    
    message("Creating MOFA object from a SingleCellExperiment object...")
    object <- create_mofa_from_SingleCellExperiment(data, groups, extract_metadata = extract_metadata, ...)
    
    
    # Creating MOFA object from a data.frame object
  } else if (is(data, "data.frame")) {
    
    message("Creating MOFA object from a data.frame...")
    object <- create_mofa_from_df(data, extract_metadata = extract_metadata)
    
    # Creating MOFA object from a (sparse) matrix object
  } else if (is(data, "list") && (length(data) > 0) && 
             (all(sapply(data, function(x) is(x, "matrix"))) || 
              all(sapply(data, function(x) is(x, "dgCMatrix"))) || 
              all(sapply(data, function(x) is(x, "dgTMatrix"))))) {
    
    message("Creating MOFA object from a list of matrices (features as rows, sample as columns)...\n")
    object <- create_mofa_from_matrix(data, groups)
    
    # Creating MOFA object from MultiAssayExperiment object
  } else if(is(data, "MultiAssayExperiment")){
    
    object <- create_mofa_from_MultiAssayExperiment(data, groups, extract_metadata = extract_metadata, ...)
    
  } else {
    stop("Error: input data has to be provided as a list of matrices, a data frame or a Seurat object. Please read the documentation for more details.")
  }
  
  return(object)
}

#' @title create a MOFA object from a MultiAssayExperiment object
#' @name create_mofa_from_MultiAssayExperiment
#' @description Method to create a \code{\link{MOFA}} object from a MultiAssayExperiment object
#' @param mae a MultiAssayExperiment object
#' @param groups a string specifying column name of the colData to use it as a group variable. 
#' Alternatively, a character vector with group assignment for every sample.
#' Default is \code{NULL} (no group structure).
#' @param extract_metadata logical indicating whether to incorporate the metadata from the MultiAssayExperiment object into the MOFA object
#' @return Returns an untrained \code{\link{MOFA}} object
#' @export
create_mofa_from_MultiAssayExperiment <- function(mae, groups = NULL, extract_metadata = FALSE) {
  
  # Sanity check
  if(!requireNamespace("MultiAssayExperiment", quietly = TRUE)){
    stop("Package \"MultiAssayExperiment\" is required but is not installed.", call. = FALSE)
  } else {
    
    # Re-arrange data for training in MOFA to matrices, fill in NAs
    data_list <- lapply(names(mae), function(m) {
      
      # Extract general sample names
      primary <- unique(MultiAssayExperiment::sampleMap(mae)[,"primary"])
      
      # Extract view
      subdata <- as.matrix(MultiAssayExperiment::assays(mae)[[m]])
      
      # Rename view-specific sample IDs with the general sample names
      stopifnot(colnames(subdata)==MultiAssayExperiment::sampleMap(mae)[MultiAssayExperiment::sampleMap(mae)[,"assay"]==m,"colname"])
      colnames(subdata) <- MultiAssayExperiment::sampleMap(mae)[MultiAssayExperiment::sampleMap(mae)[,"assay"]==m,"primary"]
      
      # Fill subdata with NAs
      subdata_filled <- .subset_augment(subdata, primary)
      return(subdata_filled)
    })
    
    # Define groups
    if (is(groups, 'character') && (length(groups) == 1)) {
      if (!(groups %in% colnames(MultiAssayExperiment::colData(mae))))
        stop(paste0(groups, " is not found in the colData of the MultiAssayExperiment.\n",
                    "If you want to use groups information from MultiAssayExperiment,\n",
                    "please ensure to provide a column name that exists. The columns of colData are:\n",
                    paste0(colnames(MultiAssayExperiment::colData(mae)), collapse = ", ")))
      groups <- MultiAssayExperiment::colData(mae)[,groups]
    }
    
    # If no groups provided, treat all samples as coming from one group
    if (is.null(groups)) {
      # message("No groups provided as argument, we assume that all samples belong to the same group.\n")
      groups <- rep("group1",  length(unique(MultiAssayExperiment::sampleMap(mae)[,"primary"])))
    }
    
    # Initialise MOFA object
    object <- new("MOFA")
    object@status <- "untrained"
    object@data <- .split_data_into_groups(data_list, groups)
    
    # groups_nms <- unique(as.character(groups))
    groups_nms <- names(object@data[[1]])
    
    # Set dimensionalities
    object@dimensions[["M"]] <- length(data_list)
    object@dimensions[["G"]] <- length(groups_nms)
    object@dimensions[["D"]] <- sapply(data_list, nrow)
    object@dimensions[["N"]] <- sapply(groups_nms, function(x) sum(groups == x))
    object@dimensions[["K"]] <- 0
    
    # Set view names
    views_names(object) <- names(mae)
    
    # Set samples group names
    groups_names(object) <- groups_nms
    
    # Extract metadata
    if (extract_metadata) {
      if (ncol(MultiAssayExperiment::colData(mae)) > 0) {
        object@samples_metadata <- data.frame(MultiAssayExperiment::colData(mae))
      }
    }

    # Create sample metadata
    object <- .create_samples_metadata(object)

    # Create features metadata
    object <- .create_features_metadata(object)

    # Rename duplicated features
    object <- .rename_duplicated_features(object)

    # Do quality control
    object <- .quality_control(object)
    
    return(object)
  }
}


#' @title create a MOFA object from a data.frame object
#' @name create_mofa_from_df
#' @description Method to create a \code{\link{MOFA}} object from a data.frame object
#' @param df \code{data.frame} object with at most 5 columns: \code{sample}, \code{group}, \code{feature}, \code{view}, \code{value}. 
#'   The \code{group} column (optional) indicates the group of each sample when using the multi-group framework.
#'   The \code{view} column (optional) indicates the view of each feature when having multi-view data.
#' @param extract_metadata  logical indicating whether to incorporate the extra columns as sample metadata into the MOFA object
#' @return Returns an untrained \code{\link{MOFA}} object
#' @export
#' @examples
#' # Using an existing simulated data with two groups and two views
#' file <- system.file("extdata", "test_data.RData", package = "MOFA2")
#' 
#' # Load data (in long data.frame format)
#' load(file) 
#' MOFAmodel <- create_mofa_from_df(dt)
create_mofa_from_df <- function(df, extract_metadata = TRUE) {
  
  # Quality controls
  df <- as.data.frame(df)
  if (!"group" %in% colnames(df)) {
    # message('No "group" column found in the data.frame, we will assume a common group for all samples')
    df$group <- "single_group"
  }
  if (!"view" %in% colnames(df)) {
    # message('No "view" column found in the data.frame, we will assume a common view for all features')
    df$view <- "single_view"
  }
  stopifnot(all(c("sample","feature","value") %in% colnames(df)))
  # stopifnot(all(colnames(df) %in% (c("sample","feature","value","group","view"))))
  stopifnot(all(is.numeric(df$value)))
  
  # Convert 'sample' and 'feature' columns to factors
  if (!is.factor(df$sample))
    df$sample <- as.factor(df$sample)
  if (!is.factor(df$feature))
    df$feature <- as.factor(df$feature)
  
  # Convert 'group' columns to factors
  if (!"group" %in% colnames(df)) {
    df$group <- factor("group1")
  } else {
    df$group <- factor(df$group)
  }
  
  # Convert 'view' columns to factors
  if (!"view" %in% colnames(df)) {
    df$view <- factor("view1")
  } else {
    df$view <- factor(df$view)
  }
  
  data_matrix <- list()
  for (m in levels(df$view)) {
    data_matrix[[m]] <- list()
    features <- as.character( unique( df[df$view==m,"feature",drop=TRUE] ) )
    for (g in levels(df$group)) {
      samples <- as.character( unique( df[df$group==g,"sample",drop=TRUE] ) )
      Y <- df[df$view==m & df$group==g,]
      Y$sample <- factor(Y$sample, levels=samples)
      Y$feature <- factor(Y$feature, levels=features)
      if (nrow(Y)==0) {
        data_matrix[[m]][[g]] <- matrix(as.numeric(NA), ncol=length(samples), nrow=length(features))
        rownames(data_matrix[[m]][[g]]) <- features
        colnames(data_matrix[[m]][[g]]) <- samples
      } else {
        data_matrix[[m]][[g]] <- .df_to_matrix( reshape2::dcast(Y, feature~sample, value.var="value", fill=NA, drop=FALSE) )
      }
    }
  }
  
  # Create MOFA object
  object <- new("MOFA")
  object@status <- "untrained"
  object@data <- data_matrix
  
  # Set dimensionalities
  object@dimensions[["M"]] <- length(levels(df$view))
  object@dimensions[["D"]] <- sapply(levels(df$view), function(m) length(unique(df[df$view==m,]$feature)))
  object@dimensions[["G"]] <- length(levels(df$group))
  object@dimensions[["N"]] <- sapply(levels(df$group), function(g) length(unique(df[df$group==g,]$sample)))
  object@dimensions[["K"]] <- 0
  
  # Set view names
  views_names(object) <- levels(df$view)
  
  # Set group names
  groups_names(object) <- levels(df$group)
  
  # save other sample-level columns to samples metadata (e.g. covariates)
  if(extract_metadata && !all(colnames(df) %in% (c("sample","feature","value","group","view")))) {
    cols2keep <- df %>% group_by(sample) %>% select(-c("view", "feature", "value", "group", "value")) %>%
      summarise(across(!starts_with("sample"), function(x) length(unique(x)),
                       .names = "{col}")) 
    cols2keep <- colnames(cols2keep)[apply(cols2keep, 2, function(x) all(x  == 1))]
    if (length(cols2keep) > 0){
      df_meta <- df[, c("sample",cols2keep)] %>% distinct()
      object@samples_metadata <- df_meta %>% select(-sample)
      rownames(object@samples_metadata) <- df_meta$sample
    }
  }

    # Create sample metadata
    object <- .create_samples_metadata(object)

    # Create features metadata
    object <- .create_features_metadata(object)

    # Rename duplicated features
    object <- .rename_duplicated_features(object)

    # Do quality control
    object <- .quality_control(object)

  return(object)
}


#' @title create a MOFA object from a SingleCellExperiment object
#' @name create_mofa_from_SingleCellExperiment
#' @description Method to create a \code{\link{MOFA}} object from a SingleCellExperiment object
#' @param sce SingleCellExperiment object
#' @param groups a string specifying column name of the colData to use it as a group variable. 
#' Alternatively, a character vector with group assignment for every sample.
#' Default is \code{NULL} (no group structure).
#' @param assay assay to use, default is \code{logcounts}.
#' @param extract_metadata logical indicating whether to incorporate the metadata from the SingleCellExperiment object into the MOFA object
#' @return Returns an untrained \code{\link{MOFA}} object
#' @export
create_mofa_from_SingleCellExperiment <- function(sce, groups = NULL, assay = "logcounts", extract_metadata = FALSE) {
  
  # Check is SingleCellExperiment is installed
  if (!requireNamespace("SingleCellExperiment", quietly = TRUE)) {
    stop("Package \"SingleCellExperiment\" is required but is not installed.", call. = FALSE)
  }
  else if(!requireNamespace("SummarizedExperiment", quietly = TRUE)){
    stop("Package \"SummarizedExperiment\" is required but is not installed.", call. = FALSE)
  } else {
    stopifnot(assay%in%names(SummarizedExperiment::assays(sce)))
    
    # Define groups of cells
    if (is.null(groups)) {
      # message("No groups provided as argument... we assume that all samples are coming from the same group.\n")
      groups <- rep("group1", dim(sce)[2])
    } else {
      if (is(groups,'character')) {
        if (length(groups) == 1) {
          stopifnot(groups %in% colnames(colData(sce)))
          groups <- colData(sce)[,groups]
        } else {
          stopifnot(length(groups) == ncol(sce))
        }
      } else {
        stop("groups wrongly specified. Please see the documentation and the examples")
      }
    }
    
    # Extract data matrices
    data_matrices <- list( .split_sce_into_groups(sce, groups, assay) )
    names(data_matrices) <- assay
    
    # Create MOFA object
    object <- new("MOFA")
    object@status <- "untrained"
    object@data <- data_matrices
    
    # Define dimensions
    object@dimensions[["M"]] <- length(assay)
    object@dimensions[["D"]] <- vapply(data_matrices, function(m) nrow(m[[1]]), 1L)
    object@dimensions[["G"]] <- length(data_matrices[[1]])
    object@dimensions[["N"]] <- vapply(data_matrices[[1]], function(g) ncol(g), 1L)
    object@dimensions[["K"]] <- 0
    
    # Set views & groups names
    groups_names(object) <- as.character(names(data_matrices[[1]]))
    views_names(object)  <- assay
    
    # Set metadata
    if (extract_metadata) {
      object@samples_metadata <- as.data.frame(colData(sce))
      # object@features_metadata <- as.data.frame(rowData(sce))
    }
    
    # Create sample metadata
    object <- .create_samples_metadata(object)

    # Create features metadata
    object <- .create_features_metadata(object)

    # Rename duplicated features
    object <- .rename_duplicated_features(object)

    # Do quality control
    object <- .quality_control(object)

    return(object)
  }
}

#' @title create a MOFA object from a Seurat object
#' @name create_mofa_from_Seurat
#' @description Method to create a \code{\link{MOFA}} object from a Seurat object
#' @param seurat Seurat object
#' @param groups a string specifying column name of the samples metadata to use it as a group variable. 
#' Alternatively, a character vector with group assignment for every sample.
#' Default is \code{NULL} (no group structure).
#' @param assays assays to use, default is \code{NULL}, it fetched all assays available
#' @param layer layer to be used (default is data).
#' @param features a list with vectors, which are used to subset features, with names corresponding to assays; a vector can be provided when only one assay is used
#' @param extract_metadata logical indicating whether to incorporate the metadata from the Seurat object into the MOFA object
#' @return Returns an untrained \code{\link{MOFA}} object
#' @export
create_mofa_from_Seurat <- function(seurat, groups = NULL, assays = NULL, layer = "data", features = NULL, extract_metadata = FALSE) {
  
  # Check is Seurat is installed
  if (!requireNamespace("Seurat", quietly = TRUE)) {
    stop("Package \"Seurat\" is required but is not installed.", call. = FALSE)
  } else {
    
    # Check Seurat version
    if (SeuratObject::Version(seurat)$major != 5) stop("Please install Seurat v5")
    
    # Define assays
    if (is.null(assays)) {
      assays <- SeuratObject::Assays(seurat)
      message(paste0("No assays specified, using all assays by default: ", paste(assays,collapse=" ")))
    } else {
      stopifnot(assays%in%Seurat::Assays(seurat))
    }
    
    # Define groups of cells
    if (is(groups, 'character') && (length(groups) == 1)) {
      if (!(groups %in% colnames(seurat@meta.data)))
        stop(paste0(groups, " is not found in the Seurat@meta.data.\n",
                    "please ensure to provide a column name that exists. The columns of meta data are:\n",
                    paste0(colnames(seurat@meta.data), sep = ", ")))
      groups <- seurat@meta.data[,groups]
    }
    
    # If features to subset are provided,
    # make sure they are a list with respective views (assays) names.
    # A vector is accepted if there's one assay to be used
    if (is(features, "list")) {
      if (!is.null(features) && !all(names(features) %in% assays)) {
        stop("Please make sure all the names of the features list correspond to views (assays) names being used for the model")
      }
    } else {
      # By default select highly variable features if present in the Seurat object
      if (is.null(features)) {
        message("No features specified, using variable features from the Seurat object...")
        features <- lapply(assays, function(i) seurat@assays[[i]]@var.features)
        names(features) <- assays
        if (any(sapply(features,length)==0)) stop("No list of features provided and variable features not detected in the Seurat object")
      } else if (all(is(features, "character"))) {
        features <- list(features)
        names(features) <- assays
      } else {
        stop("Features not recognised. Please either provide a list of features (per assay) or calculate variable features in the Seurat object")
      }
    }
    
    # If no groups provided, treat all samples as coming from one group
    if (is.null(groups)) {
      # message("No groups provided as argument... we assume that all samples are coming from the same group.\n")
      groups <- rep("group1", dim(seurat)[2])
    }
    
    # Extract data matrices
    data_matrices <- lapply(assays, function(i) 
      .split_seurat_into_groups(seurat, groups = groups, assay = i, layer = layer, features = features[[i]]))
    names(data_matrices) <- assays
    
    # Create MOFA object
    object <- new("MOFA")
    object@status <- "untrained"
    object@data <- data_matrices
    
    # Define dimensions
    object@dimensions[["M"]] <- length(assays)
    object@dimensions[["D"]] <- vapply(data_matrices, function(m) nrow(m[[1]]), 1L)
    object@dimensions[["G"]] <- length(data_matrices[[1]])
    object@dimensions[["N"]] <- vapply(data_matrices[[1]], function(g) ncol(g), 1L)
    object@dimensions[["K"]] <- 0
    
    # Set views & groups names
    groups_names(object) <- as.character(names(data_matrices[[1]]))
    views_names(object)  <- assays
    
    # Set metadata
    if (extract_metadata) {
      object@samples_metadata <- seurat@meta.data
      # object@features_metadata <- do.call(rbind, lapply(assays, function(a) seurat@assays[[a]]@meta.features))
    }

    # Create sample metadata
    object <- .create_samples_metadata(object)

    # Create features metadata
    object <- .create_features_metadata(object)

    # Rename duplicated features
    object <- .rename_duplicated_features(object)

    # Do quality control
    object <- .quality_control(object)
    
    return(object)
  }
}


#' @title create a MOFA object from a a list of matrices
#' @name create_mofa_from_matrix
#' @description Method to create a \code{\link{MOFA}} object from a list of matrices
#' @param data A list of matrices, where each entry corresponds to one view.
#'   Samples are stored in columns and features in rows.
#'   Missing values must be filled in prior to creating the MOFA object (see for example the CLL tutorial)
#' @param groups A character vector with group assignment for every sample. Default is \code{NULL}, no group structure.
#' @return Returns an untrained \code{\link{MOFA}} object
#' @export
#' @examples 
#' m <- make_example_data()
#' create_mofa_from_matrix(m$data)

create_mofa_from_matrix <- function(data, groups = NULL) {
  
  # Quality control: check that the matrices are all numeric
  stopifnot(all(sapply(data, function(g) all(is.numeric(g)))) || all(sapply(data, function(x) class(x) %in% c("dgTMatrix", "dgCMatrix"))))
  
  # Quality control: check that all matrices have the same samples
  tmp <- lapply(data, function(m) colnames(m))
  if(length(unique(sapply(tmp,length)))>1)
    stop("Views have different number of samples (columns)... please make sure that all views contain the same samples in the same order (see documentation)")
  if (length(unique(tmp))>1) 
    stop("Views have different sample names (columns)... please make sure that all views contain the same samples in the same order (see documentation)")
  
  # Make a dgCMatrix out of dgTMatrix
  if (all(sapply(data, function(x) is(x, "dgTMatrix")))) {
    data <- lapply(data, function(m) as(m, "dgCMatrix"))
  }
  
  # Set groups names
  if (is.null(groups)) {
    # message("No groups provided as argument... we assume that all samples are coming from the same group.\n")
    groups <- rep("group1", ncol(data[[1]]))
  }
  
  # Set views names
  if (is.null(names(data))) {
    default_views <- paste0("view_", seq_len(length(data)))
    message(paste0("View names are not specified in the data, using default: ", paste(default_views, collapse=", "), "\n"))
    names(data) <- default_views
  }
  views_names <- as.character(names(data))
  
  # Initialise MOFA object
  object <- new("MOFA")
  object@status <- "untrained"
  object@data <- .split_data_into_groups(data, groups)
  
  # groups_names <- as.character(unique(groups))
  groups_names <- names(object@data[[1]])
  
  # Set dimensionalities
  object@dimensions[["M"]] <- length(data)
  object@dimensions[["G"]] <- length(groups_names)
  object@dimensions[["D"]] <- sapply(data, nrow)
  object@dimensions[["N"]] <- sapply(groups_names, function(x) sum(groups == x))
  object@dimensions[["K"]] <- 0
  
  # Set features names
  for (m in seq_len(length(data))) {
    if (is.null(rownames(data[[m]]))) {
      warning(sprintf("Feature names are not specified for view %d, using default: feature1_v%d, feature2_v%d...", m, m, m))
      for (g in seq_len(length(object@data[[m]]))) {
        rownames(object@data[[m]][[g]]) <- paste0("feature_", seq_len(nrow(object@data[[m]][[g]])), "_v", m)
      }
    }
  }
  
  # Set samples names
  for (g in seq_len(object@dimensions[["G"]])) {
    if (is.null(colnames(object@data[[1]][[g]]))) {
      warning(sprintf("Sample names for group %d are not specified, using default: sample1_g%d, sample2_g%d,...", g, g, g))
      for (m in seq_len(object@dimensions[["M"]])) {
        colnames(object@data[[m]][[g]]) <- paste0("sample_", seq_len(ncol(object@data[[m]][[g]])), "_g", g)
      }
    }
  }
  
  # Set view names
  views_names(object) <- views_names
  
  # Set samples group names
  groups_names(object) <- groups_names

  # Create sample metadata
  object <- .create_samples_metadata(object)

  # Create features metadata
  object <- .create_features_metadata(object)

  # Rename duplicated features
  object <- .rename_duplicated_features(object)

  # Do quality control
  object <- .quality_control(object)

  return(object)
}


# (Hidden) function to split a list of matrices into a nested list of matrices
.split_data_into_groups <- function(data, groups) {
  group_indices <- split(seq_along(groups), factor(groups, exclude = character(0))) # factor call avoids dropping NA
  lapply(data, function(x) {
    lapply(group_indices, function(idx) {
      x[, idx, drop = FALSE]
    })
  })
}

# (Hidden) function to split data in Seurat object into a list of matrices
.split_seurat_into_groups <- function(seurat, groups, assay = "RNA", layer = "data", features = NULL) {
  data <- SeuratObject::GetAssayData(object = seurat, assay = assay, layer = layer)
  if(is.null(data) | any(dim(data) == 0)){
    stop(paste("No data present in the layer",layer, "of the assay",assay ,"in the Seurat object."))
  }
  if (!is.null(features)) data <- data[features, , drop=FALSE]
  .split_data_into_groups(list(data), groups)[[1]]
}

# (Hidden) function to split data in a SingleCellExperiment object into a list of matrices
.split_sce_into_groups <- function(sce, groups, assay) {
  
  if(!requireNamespace("SummarizedExperiment", quietly = TRUE)){
    stop("Package \"SummarizedExperiment\" is required but is not installed.", call. = FALSE)
  } else {
    
    data <- SummarizedExperiment::assay(sce, i = assay)
    .split_data_into_groups(list(data), groups)[[1]]
  }
}

# (Hidden) function to fill NAs for missing samples
.subset_augment<-function(mat, samp) {
  samp <- unique(samp)
  mat <- t(mat)
  aug_mat<-matrix(NA, ncol=ncol(mat), nrow=length(samp))
  aug_mat<-mat[match(samp,rownames(mat)),,drop=FALSE]
  rownames(aug_mat)<-samp
  colnames(aug_mat)<-colnames(mat)
  return(t(aug_mat))
}

.df_to_matrix <- function(x) {
  m <- as.matrix(x[,-1])
  rownames(m) <- x[[1]]
  if (ncol(m) == 1)
    colnames(m) <- colnames(x)[2:ncol(x)]
  m
}

.create_samples_metadata <- function(object) {
  # TO-DO: CHECK SAMPLE AND GROUP COLUMN IN PROVIDED METADATA
  foo <- lapply(object@data[[1]], colnames)
  tmp <- data.frame(
    sample = unname(unlist(foo)),
    group = unlist(lapply(names(foo), function(x) rep(x, length(foo[[x]])) )),
    stringsAsFactors = FALSE
  )
  if (.hasSlot(object, "samples_metadata") && (length(object@samples_metadata) > 0)) {
    object@samples_metadata <- cbind(tmp, object@samples_metadata[match(tmp$sample, rownames(object@samples_metadata)),, drop = FALSE])
  } else {
    object@samples_metadata <- tmp
  }
  return(object)
}

.create_features_metadata <- function(object) {
  tmp <- data.frame(
    feature = unname(unlist(lapply(object@data, function(x) rownames(x[[1]])))),
    view = unlist(lapply(seq_len(object@dimensions$M), function(x) rep(views_names(object)[[x]], object@dimensions$D[[x]]) )),
    stringsAsFactors = FALSE
  )
  if (.hasSlot(object, "features_metadata") && (length(object@features_metadata) > 0)) {
    object@features_metadata <- cbind(tmp, object@features_metadata[match(tmp$feature, rownames(object@features_metadata)),])
  } else {
    object@features_metadata <- tmp
  }
  return(object)
}

.rename_duplicated_features <- function(object) {
  feature_names <- unname(unlist(lapply(object@data, function(x) rownames(x[[1]]))))
  duplicated_names <- unique(feature_names[duplicated(feature_names)])
  if (length(duplicated_names)>0) 
    warning("There are duplicated features names across different views. We will add the suffix *_view* only for those features 
            Example: if you have both TP53 in mRNA and mutation data it will be renamed to TP53_mRNA, TP53_mutation")
  # Rename data matrices
  for (m in names(object@data)) {
    for (g in names(object@data[[m]])) {
      tmp <- which(rownames(object@data[[m]][[g]]) %in% duplicated_names)
      if (length(tmp)>0) {
        rownames(object@data[[m]][[g]])[tmp] <- paste(rownames(object@data[[m]][[g]])[tmp], m, sep="_")
      }
    }
  }
  
  # Rename features metadata
  tmp <- object@features_metadata[["feature"]] %in% duplicated_names
  object@features_metadata[tmp,"feature"] <- paste(object@features_metadata[tmp,"feature"], object@features_metadata[tmp,"view"], sep="_")
  return(object)
}


================================================
FILE: R/dimensionality_reduction.R
================================================

##################################################################
## Functions to do dimensionality reduction on the MOFA factors ##
##################################################################

#' @title Run t-SNE on the MOFA factors
#' @name run_tsne
#' @param object a trained \code{\link{MOFA}} object.
#' @param factors character vector with the factor names, or numeric vector with the indices of the factors to use, or "all" to use all factors (default).
#' @param groups character vector with the groups names, or numeric vector with the indices of the groups of samples to use, or "all" to use all groups (default).
#' @param ... arguments passed to \code{\link{Rtsne}}
#' @details This function calls \code{\link[Rtsne]{Rtsne}} to calculate a TSNE representation from the MOFA factors.
#' Subsequently, you can plot the TSNE representation with \code{\link{plot_dimred}} or fetch the coordinates using \code{plot_dimred(..., method="TSNE", return_data=TRUE)}. 
#' Remember to use set.seed before the function call to get reproducible results. 
#' @return Returns a \code{\link{MOFA}} object with the \code{MOFAobject@dim_red} slot filled with the t-SNE output
#' @importFrom Rtsne Rtsne
#' @export
#' @examples
#' # Using an existing trained model on simulated data
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' 
#' # Run
#' \dontrun{ model <- run_tsne(model, perplexity = 15) }
#' 
#' # Plot
#' \dontrun{ model <- plot_dimred(model, method="TSNE") }
#' 
#' # Fetch data
#' \dontrun{ tsne.df <- plot_dimred(model, method="TSNE", return_data=TRUE) }
#' 
run_tsne <- function(object, factors = "all", groups = "all", ...) {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  
  # Get factor values
  Z <- get_factors(object, factors=factors, groups=groups)
  
  # Concatenate groups
  Z <- do.call(rbind, Z)
  
  # Replace missing values by zero
  Z[is.na(Z)] <- 0
  
  # Run t-SNE
  tsne_embedding <- Rtsne(Z, check_duplicates = FALSE, pca = FALSE, ...)

  # Add sample names and enumerate latent dimensions (e.g. TSNE1 and TSNE2)
  object@dim_red$TSNE <- data.frame(rownames(Z), tsne_embedding$Y)
  colnames(object@dim_red$TSNE) <- c("sample", paste0("TSNE", 1:ncol(tsne_embedding$Y)))
  
  return(object)
  
}



#' @title Run UMAP on the MOFA factors
#' @name run_umap
#' @param object a trained \code{\link{MOFA}} object.
#' @param factors character vector with the factor names, or numeric vector with the indices of the factors to use, or "all" to use all factors (default).
#' @param groups character vector with the groups names, or numeric vector with the indices of the groups of samples to use, or "all" to use all groups (default).
#' @param n_neighbors number of neighbouring points used in local approximations of manifold structure. Larger values will result in more global structure being preserved at the loss of detailed local structure. In general this parameter should often be in the range 5 to 50.
#' @param min_dist  This controls how tightly the embedding is allowed compress points together. Larger values ensure embedded points are more evenly distributed, while smaller values allow the algorithm to optimise more accurately with regard to local structure. Sensible values are in the range 0.01 to 0.5
#' @param metric choice of metric used to measure distance in the input space
#' @param ... arguments passed to \code{\link[uwot]{umap}}
#' @details This function calls \code{\link[uwot]{umap}} to calculate a UMAP representation from the MOFA factors
#' For details on the hyperparameters of UMAP see the documentation of \code{\link[uwot]{umap}}.
#' Subsequently, you can plot the UMAP representation with \code{\link{plot_dimred}} or fetch the coordinates using \code{plot_dimred(..., method="UMAP", return_data=TRUE)}. 
#' Remember to use set.seed before the function call to get reproducible results. 
#' @return Returns a \code{\link{MOFA}} object with the \code{MOFAobject@dim_red} slot filled with the UMAP output
#' @importFrom uwot umap
#' @export
#' @examples
#' # Using an existing trained model on simulated data
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' 
#' # Change hyperparameters passed to umap
#' \dontrun{ model <- run_umap(model, min_dist = 0.01, n_neighbors = 10) }

#' # Plot
#' \dontrun{ model <- plot_dimred(model, method="UMAP") }
#' 
#' # Fetch data
#' \dontrun{ umap.df <- plot_dimred(model, method="UMAP", return_data=TRUE) }
#' 
run_umap <- function(object, factors = "all", groups = "all", n_neighbors = 30, min_dist = 0.3, metric = "cosine", ...) {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  
  # Get factor values
  Z <- get_factors(object, factors = factors, groups = groups)
  
  # Concatenate groups
  Z <- do.call(rbind, Z)
  
  # Replace missing values by zero
  Z[is.na(Z)] <- 0
  
  # Run UMAP
  umap_embedding <- umap(Z, n_neighbors=n_neighbors, min_dist=min_dist, metric=metric, ...)

  # Add sample names and enumerate latent dimensions (e.g. UMAP1 and UMAP2)
  object@dim_red$UMAP <- data.frame(rownames(Z), umap_embedding)
  colnames(object@dim_red$UMAP) <- c("sample", paste0("UMAP", 1:ncol(umap_embedding)))
  
  return(object)
  
}



#' @title Plot dimensionality reduction based on MOFA factors
#' @name plot_dimred
#' @param object a trained \code{\link{MOFA}} object.
#' @param method string indicating which method has been used for non-linear dimensionality reduction (either 'umap' or 'tsne')
#' @param groups character vector with the groups names, or numeric vector with the indices of the groups of samples to use, or "all" to use samples from all groups.
#' @param show_missing logical indicating whether to include samples for which \code{shape_by} or \code{color_by} is missing
#' @param color_by specifies groups or values used to color the samples. This can be either:
#' (1) a character giving the name of a feature present in the training data.
#' (2) a character giving the same of a column present in the sample metadata.
#' (3) a vector of the same length as the number of samples specifying discrete groups or continuous numeric values.
#' @param shape_by specifies groups or values used to shape the samples. This can be either:
#' (1) a character giving the name of a feature present in the training data, 
#' (2) a character giving the same of a column present in the sample metadata.
#' (3) a vector of the same length as the number of samples specifying discrete groups.
#' @param color_name name for color legend.
#' @param shape_name name for shape legend.
#' @param label logical indicating whether to label the medians of the clusters. Only if color_by is specified
#' @param dot_size numeric indicating dot size.
#' @param stroke numeric indicating the stroke size (the black border around the dots, default is NULL, inferred automatically).
#' @param alpha_missing numeric indicating dot transparency of missing data.
#' @param legend logical indicating whether to add legend.
#' @param return_data logical indicating whether to return the long data frame to plot instead of plotting
#' @param rasterize logical indicating whether to rasterize plot using \code{\link[ggrastr]{geom_point_rast}}
#' @param ... extra arguments passed to \code{\link{run_umap}} or \code{\link{run_tsne}}.
#' @details This function plots dimensionality reduction projections that are stored in the \code{dim_red} slot.
#' Typically this contains UMAP or t-SNE projections computed using \code{\link{run_tsne}} or \code{\link{run_umap}}, respectively.
#' @return Returns a \code{ggplot2} object or a long data.frame (if return_data is TRUE)
#' @import ggplot2
#' @importFrom dplyr filter
#' @importFrom stats complete.cases
#' @importFrom tidyr spread gather
#' @importFrom magrittr %>% set_colnames
#' @importFrom ggrepel geom_text_repel
#' @export
#' @examples
#' # Using an existing trained model on simulated data
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' 
#' # Run UMAP
#' model <- run_umap(model)
#' 
#' # Plot UMAP
#' plot_dimred(model, method = "UMAP")
#' 
#' # Plot UMAP, colour by Factor 1 values
#' plot_dimred(model, method = "UMAP", color_by = "Factor1")
#' 
#' # Plot UMAP, colour by the values of a specific feature
#' plot_dimred(model, method = "UMAP", color_by = "feature_0_view_0")
#' 
plot_dimred <- function(object, method = c("UMAP", "TSNE"), groups = "all", show_missing = TRUE,
                        color_by = NULL, shape_by = NULL, color_name = NULL, shape_name = NULL, label = FALSE,
                        dot_size = 1.5, stroke = NULL, alpha_missing = 1, legend = TRUE, rasterize = FALSE, return_data = FALSE, ...) {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")

  # If UMAP or TSNE is requested but were not computed, compute the requested embedding
  if ((method %in% c("UMAP", "TSNE")) && (!.hasSlot(object, "dim_red") || !(method %in% names(object@dim_red)))) {
    message(paste0(method, " embedding was not computed. Running run_", tolower(method), "()..."))
    if (method == "UMAP") {
      object <- run_umap(object, ...)
    } else if (method == "TSNE") {
      object <- run_tsne(object, ...)
    }
  }
  
  # make sure the slot for the requested method exists
  method <- match.arg(method, names(object@dim_red))  
  
  # Plotting multiple features
  if (length(color_by)>1) {
    .args <- as.list(match.call()[-1])
    plist <- lapply(color_by, function(i) {
      .args[["color_by"]] <- i
      do.call(plot_dimred, .args)
    })
    p <- cowplot::plot_grid(plotlist=plist)
    return(p)
  }
  
  # Remember color_name and shape_name if not provided
  if (!is.null(color_by) && (length(color_by) == 1) && is.null(color_name))
    color_name <- color_by
  if (!is.null(shape_by) && (length(shape_by) == 1) && is.null(shape_name))
    shape_name <- shape_by
  
  # Fetch latent manifold
  Z <- object@dim_red[[method]]
  latent_dimensions_names <- colnames(Z)[-1]
  Z <- gather(Z, -sample, key="latent_dimension", value="value")
  
  # Subset groups
  groups <- .check_and_get_groups(object, groups)
  Z <- Z[Z$sample%in%unlist(samples_names(object)[groups]),]
  
  # Set color and shape
  color_by <- .set_colorby(object, color_by)
  shape_by <- .set_shapeby(object, shape_by)
  
  # Merge factor values with color and shape information
  df <- merge(Z, color_by, by="sample")
  df <- merge(df, shape_by, by="sample")
  df$shape_by <- as.character(df$shape_by)
  
  # Remove missing values
  if(!show_missing) df <- filter(df, !is.na(color_by) & !is.na(shape_by))
  df$observed <- as.factor(!is.na(df$color_by))
  
  # spread over latent dimensions
  df <- spread(df, key="latent_dimension", value="value")
  df <- set_colnames(df, c(colnames(df)[seq_len(4)], "x", "y"))
  
  # Return data if requested instead of plotting
  if (return_data) return(df)

  # Set stroke
  if (is.null(stroke)) if (length(unique(df$sample))<1000) { stroke <- 0.5 } else { stroke <- 0 }
  
  # Generate plot
  p <- ggplot(df, aes(x = .data$x, y = .data$y)) + 
    labs(x = latent_dimensions_names[1], y = latent_dimensions_names[2]) +
    theme_classic() +
    theme(
      axis.text = element_blank(), 
      axis.title = element_blank(), 
      axis.line = element_line(color = "black", linewidth = 0.5), 
      axis.ticks = element_blank()
    )
  
  # Add dots  
  if (rasterize) {
    message("for rasterizing the plot we use ggrastr::geom_point_rast()")
    p <- p + ggrastr::geom_point_rast(aes(fill = .data$color_by, shape = .data$shape_by, alpha = .data$observed), size = dot_size, stroke = stroke)
  } else {
    p <- p + geom_point(aes(fill = .data$color_by, shape = .data$shape_by, alpha = .data$observed), size = dot_size, stroke = stroke)
    
  }      
  
  # Add legend for alpha
  if (length(unique(df$observed))>1) { 
    p <- p + scale_alpha_manual(values = c("TRUE"=1.0, "FALSE"=alpha_missing))
  } else { 
    p <- p + scale_alpha_manual(values = 1.0)
  }
  p <- p + guides(alpha="none")
    
  # Label clusters
  if (label && length(unique(df$color_by)) > 1 && length(unique(df$color_by))<50) {
    groups <- unique(df$color_by)
    labels.loc <- lapply(
      X = groups,
      FUN = function(i) {
        data.use <- df[df[,"color_by"] == i, , drop = FALSE]
        data.medians <- as.data.frame(x = t(x = apply(X = data.use[, c("x","y"), drop = FALSE], MARGIN = 2, FUN = median, na.rm = TRUE)))
        data.medians[, "color_by"] <- i
        return(data.medians)
      }
    ) %>% do.call("rbind",.)
    p <- p + geom_text_repel(aes(label=.data$color_by), data=labels.loc)
  }
  
  
  # Add legend
  p <- .add_legend(p, df, legend, color_name, shape_name)
  
  return(p)
}


================================================
FILE: R/enrichment.R
================================================
##########################################################
## Functions to perform Feature Set Enrichment Analysis ##
##########################################################

#' @title Run feature set Enrichment Analysis
#' @name run_enrichment 
#' @description Method to perform feature set enrichment analysis. Here we use a slightly modified version of the \code{\link[PCGSE]{pcgse}} function.
#' @param object a \code{\link{MOFA}} object.
#' @param view a character with the view name, or a numeric vector with the index of the view to use.
#' @param feature.sets data structure that holds feature set membership information. 
#' Must be a binary membership matrix (rows are feature sets and columns are features). See details below for some pre-built gene set matrices.
#' @param factors character vector with the factor names, or numeric vector with the index of the factors for which to perform the enrichment.
#' @param set.statistic the set statistic computed from the feature statistics. Must be one of the following: "mean.diff" (default) or "rank.sum".
#' @param statistical.test the statistical test used to compute the significance of the feature set statistics under a competitive null hypothesis.
#' Must be one of the following: "parametric" (default), "cor.adj.parametric", "permutation".
#' @param sign use only "positive" or "negative" weights. Default is "all".
#' @param min.size Minimum size of a feature set (default is 10).
#' @param nperm number of permutations. Only relevant if statistical.test is set to "permutation". Default is 1000
#' @param p.adj.method Method to adjust p-values factor-wise for multiple testing. Can be any method in p.adjust.methods(). Default uses Benjamini-Hochberg procedure.
#' @param alpha FDR threshold to generate lists of significant pathways. Default is 0.1
#' @param verbose boolean indicating whether to print messages on progress 
#' @details 
#'  The aim of this function is to relate each factor to pre-defined biological pathways by performing a gene set enrichment analysis on the feature weights. \cr
#'  This function is particularly useful when a factor is difficult to characterise based only on the genes with the highest weight. \cr
#'  We provide a few pre-built gene set matrices in the MOFAdata package. See \code{https://github.com/bioFAM/MOFAdata} for details. \cr
#'  The function we implemented is based on the \code{\link[PCGSE]{pcgse}} function with some modifications. 
#'  Please read this paper https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4543476 for details on the math.
#' @return a list with five elements: 
#' \item{\strong{pval}:}{ matrices with nominal p-values. }
#' \item{\strong{pval.adj}:}{ matrices with FDR-adjusted p-values. }
#' \item{\strong{feature.statistics}:}{ matrices with the local (feature-wise) statistics.  }
#' \item{\strong{set.statistics}:}{ matrices with the global (gene set-wise) statistics.  }
#' \item{\strong{sigPathways}}{ list with significant pathways per factor. }
#' @importFrom stats p.adjust var p.adjust.methods
#' @export

run_enrichment <- function(object, view, feature.sets, factors = "all",
                           set.statistic = c("mean.diff", "rank.sum"),
                           statistical.test = c("parametric", "cor.adj.parametric", "permutation"), sign = c("all","positive","negative"),
                           min.size = 10, nperm = 1000, p.adj.method = "BH", alpha = 0.1, verbose = TRUE) {
  
  # Quality control
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  if (!(is(feature.sets, "matrix") & all(feature.sets %in% c(0,1)))) stop("feature.sets has to be a list or a binary matrix.")
  
  # Define views
  view  <- .check_and_get_views(object, view)
  
  # Define factors
  factors  <- .check_and_get_factors(object, factors)
  
  # Parse inputs
  sign <- match.arg(sign)
  set.statistic <- match.arg(set.statistic)
  statistical.test <- match.arg(statistical.test)
  
  # Collect observed data
  data <- get_data(object, views = view, as.data.frame = FALSE)[[1]]
  if(is(data, "list")) data <- Reduce(cbind, data) # concatenate groups
  data <- t(data)
  
  # Collect relevant expectations
  W <- get_weights(object, views=view, factors=factors, scale = TRUE)[[1]]
  Z <- get_factors(object, factors=factors)
  if(is(Z, "list")) Z <- Reduce(rbind, Z)
  stopifnot(rownames(data) == rownames(Z))
  
  # Remove features with no variance
  # if (statistical.test %in% c("cor.adj.parametric")) {
  idx <- apply(data,2, function(x) var(x,na.rm=TRUE))==0
  if (sum(idx)>=1) {
    warning(sprintf("%d features were removed because they had no variance in the data.\n",sum(idx)))
    data <- data[,!idx]
    W <- W[!idx,]
  }
  
  # Check if some features do not intersect between the feature sets and the observed data and remove them
  features <- intersect(colnames(data),colnames(feature.sets))
  if(length(features)== 0) stop("Feature names in feature.sets do not match feature names in model.")
  if(verbose) {
    message(sprintf("Intersecting features names in the model and the gene set annotation results in a total of %d features.",length(features)))
  }
  data <- data[,features]
  W <- W[features,,drop=FALSE]
  feature.sets <- feature.sets[,features]
  
  # Filter feature sets with small number of features
  feature.sets <- feature.sets[rowSums(feature.sets)>=min.size,]
  
  # Subset weights by sign
  if (sign=="positive") {
    W[W<0] <- 0
    # W[W<0] <- NA
  } else if (sign=="negative") {
    W[W>0] <- 0
    # W[W>0] <- NA
    W <- abs(W)
  }
  
  # Print options
  if(verbose) {
    message("\nRunning feature set Enrichment Analysis with the following options...\n",
            sprintf("View: %s \n", view),
            sprintf("Number of feature sets: %d \n", nrow(feature.sets)),
            sprintf("Set statistic: %s \n", set.statistic),
            sprintf("Statistical test: %s \n", statistical.test)
    )
    if (sign%in%c("positive","negative"))
      message(sprintf("Subsetting weights with %s sign",sign))
    if (statistical.test=="permutation") {
      message(sprintf("Number of permutations: %d", nperm))
    }
    message("\n")
  }
  
  if (nperm<100) 
    warning("A large number of permutations (at least 1000) is required for the permutation approach!\n")
  
  # Non-parametric permutation test
  if (statistical.test == "permutation") {

        null_dist_tmp <- lapply(seq_len(nperm), function(i) {
      print(sprintf("Running permutation %d/%d...",i,nperm))
      perm <- sample(ncol(data))
      
      # Permute rows of the weight matrix to obtain a null distribution
      W_null <- W[perm,]
      rownames(W_null) <- rownames(W)
      colnames(W_null) <- colnames(W)
      
      # Permute columns of the data matrix correspondingly (only matters for cor.adjusted test)
      data_null <- data[,perm]
      rownames(data_null) <- rownames(data)
      
      # Compute null (or background) statistic
      s.background <- .pcgse(
        data = data_null, 
        prcomp.output = list(rotation=W_null, x=Z),
        pc.indexes = seq_along(factors), 
        feature.sets = feature.sets,
        set.statistic = set.statistic,
        set.test = "parametric")$statistic
      return(abs(s.background))
    })
    null_dist <- do.call("rbind", null_dist_tmp)
    colnames(null_dist) <- factors
    
    # Compute foreground statistics
    results <- .pcgse(
      data = data, 
      prcomp.output = list(rotation=W, x=Z),
      pc.indexes = seq_along(factors), 
      feature.sets = feature.sets,
      set.statistic = set.statistic,
      set.test = "parametric")
    s.foreground <- results$statistic
    
    # Calculate p-values based on fraction true statistic per factor and feature set is larger than permuted
    xx <- array(unlist(null_dist_tmp), dim = c(nrow(null_dist_tmp[[1]]), ncol(null_dist_tmp[[1]]), length(null_dist_tmp)))
    ll <- lapply(seq_len(nperm), function(i) xx[,,i] > abs(s.foreground))
    results$p.values <- Reduce("+",ll)/nperm
    
    # Parametric test
  } else {
    results <- .pcgse(
      data = data,
      prcomp.output = list(rotation=W, x=Z),
      pc.indexes = seq_along(factors),
      feature.sets = feature.sets,
      set.statistic = set.statistic,
      set.test = statistical.test
    )
  }
  
  # Parse results
  pathways <- rownames(feature.sets)
  colnames(results$p.values) <- colnames(results$statistics) <- colnames(results$feature.statistics) <- factors
  rownames(results$p.values) <- rownames(results$statistics) <- pathways
  rownames(results$feature.statistics) <- colnames(data)
  
  # adjust for multiple testing
  if(!p.adj.method %in%  p.adjust.methods) 
    stop("p.adj.method needs to be an element of p.adjust.methods")
  adj.p.values <- apply(results$p.values, 2,function(lfw) p.adjust(lfw, method = p.adj.method))

  # If we specify a direction, we are only interested in overrepresented pathways in the selected direction
  if (sign%in%c("positive","negative")) {
    results$p.values[results$statistics<0] <- 1.0
    adj.p.values[results$statistics<0] <- 1.0
    results$statistics[results$statistics<0] <- 0
  }
  
  
  # If we specify a direction, we are only interested in overrepresented pathways in the selected direction
  if (sign%in%c("positive","negative")) {
    results$p.values[results$statistics<0] <- 1.0
    adj.p.values[results$statistics<0] <- 1.0
    results$statistics[results$statistics<0] <- 0
  }
  
  
  # obtain list of significant pathways
  sigPathways <- lapply(factors, function(j) rownames(adj.p.values)[adj.p.values[,j] <= alpha])
  
  # prepare output
  output <- list(
    feature.sets = feature.sets, 
    pval = results$p.values, 
    pval.adj = adj.p.values, 
    feature.statistics = results$feature.statistics,
    set.statistics = results$statistics,
    sigPathways = sigPathways
  )
  return(output)
}


########################
## Plotting functions ##
########################


#' @title Plot output of gene set Enrichment Analysis
#' @name plot_enrichment
#' @description Method to plot the results of the gene set Enrichment Analysis
#' @param enrichment.results output of \link{run_enrichment} function
#' @param factor a string with the factor name or an integer with the factor index
#' @param alpha p.value threshold to filter out gene sets
#' @param max.pathways maximum number of enriched pathways to display
#' @param text_size text size
#' @param dot_size dot size
#' @details it requires \code{\link{run_enrichment}} to be run beforehand.
#' @return a \code{ggplot2} object
#' @import ggplot2
#' @importFrom utils head
#' @export
plot_enrichment <- function(enrichment.results, factor, alpha = 0.1, max.pathways = 25,
                            text_size = 1.0, dot_size = 5.0) {
  
  # Sanity checks
  stopifnot(is.numeric(alpha)) 
  stopifnot(length(factor)==1) 
  if (is.numeric(factor)) factor <- colnames(enrichment.results$pval.adj)[factor]
  if(!factor %in% colnames(enrichment.results$pval)) 
    stop(paste0("No gene set enrichment calculated for factor ", factor))
  
  # get p-values
  p.values <- enrichment.results$pval.adj
  
  # Get data  
  tmp <- data.frame(
    pvalues = p.values[,factor, drop=TRUE], 
    pathway = rownames(p.values)
  )
  
  # Filter out pathways
  tmp <- tmp[tmp$pvalue<=alpha,,drop=FALSE]
  if (nrow(tmp)==0) stop("No significant pathways at the specified alpha threshold")
  
  # If there are too many pathways enriched, just keep the 'max_pathways' more significant
  if (nrow(tmp)>max.pathways) tmp <- head(tmp[order(tmp$pvalue),],n=max.pathways)
  
  # Convert pvalues to log scale
  tmp$logp <- -log10(tmp$pvalue+1e-100)
  
  #order according to significance
  tmp$pathway <- factor(tmp$pathway <- rownames(tmp), levels = tmp$pathway[order(tmp$pvalue, decreasing = TRUE)])
  tmp$start <- 0
  
  p <- ggplot(tmp, aes(x=.data$pathway, y=.data$logp)) +
    geom_point(size=dot_size) +
    geom_hline(yintercept=-log10(alpha), linetype="longdash") +
    scale_color_manual(values=c("black","red")) +
    geom_segment(aes(xend=.data$pathway, yend=.data$start)) +
    ylab("-log pvalue") +
    coord_flip() +
    theme(
      axis.text.y = element_text(size=rel(text_size), hjust=1, color='black'),
      axis.text.x = element_text(size=rel(1.2), vjust=0.5, color='black'),
      axis.title.y=element_blank(),
      legend.position='none',
      panel.background = element_blank()
    )
  
  return(p)
}

#' @title Heatmap of Feature Set Enrichment Analysis results
#' @name plot_enrichment_heatmap
#' @description This method generates a heatmap with the adjusted p.values that
#'  result from the the feature set enrichment analysis. Rows are feature sets and columns are factors.
#' @param enrichment.results output of \link{run_enrichment} function
#' @param alpha FDR threshold to filter out unsignificant feature sets which are
#'  not represented in the heatmap. Default is 0.10.
#' @param cap cap p-values below this threshold
#' @param log_scale logical indicating whether to plot the -log of the p.values.
#' @param ... extra arguments to be passed to the \link{pheatmap} function
#' @return produces a heatmap
#' @importFrom pheatmap pheatmap
#' @importFrom grDevices colorRampPalette
#' @export
plot_enrichment_heatmap <- function(enrichment.results, alpha = 0.1, cap = 1e-50, log_scale = TRUE, ...) {
  
  # get p-values
  p.values <- enrichment.results$pval.adj
  
  # remove factors that are full of NAs
  p.values <- p.values[,colMeans(is.na(p.values))<1]
  
  # cap p-values 
  p.values[p.values<cap] <- cap
  
  # Apply Log transform
  if (log_scale) {
    p.values <- -log10(p.values+1e-50)
    alpha <- -log10(alpha)
    col <- colorRampPalette(c("lightgrey","red"))(n=100)
  } else {
    col <- colorRampPalette(c("red","lightgrey"))(n=100)
  }
  
  # Generate heatmap
  pheatmap(p.values, color = col, cluster_cols = FALSE, show_rownames = FALSE, ...)
}


#' @title Plot detailed output of the Feature Set Enrichment Analysis
#' @name plot_enrichment_detailed
#' @description Method to plot a detailed output of the Feature Set Enrichment Analysis (FSEA). \cr
#' Each row corresponds to a significant pathway, sorted by statistical significance, and each dot corresponds to a gene. \cr
#' For each pathway, we display the top genes of the pathway sorted by the corresponding feature statistic (by default, the absolute value of the weight)
#' The top genes with the highest statistic (max.genes argument) are displayed and labelled in black. The remaining genes are colored in grey.
#' @param enrichment.results output of \link{run_enrichment} function
#' @param factor string with factor name or numeric with factor index
#' @param alpha p.value threshold to filter out feature sets
#' @param max.pathways maximum number of enriched pathways to display
#' @param max.genes maximum number of genes to display, per pathway
#' @param text_size size of the text to label the top genes
#' @return a \code{ggplot2} object
#' @import ggplot2
#' @importFrom reshape2 melt
#' @importFrom dplyr top_n
#' @importFrom ggrepel geom_text_repel
#' @export
plot_enrichment_detailed <- function(enrichment.results, factor, 
                                     alpha = 0.1, max.genes = 5, max.pathways = 10, text_size = 3) {
  
  # Sanity checks
  stopifnot(is.list(enrichment.results))
  stopifnot(length(factor)==1) 
  if (!is.numeric(factor)) {
    if(!factor %in% colnames(enrichment.results$pval)) 
      stop(paste0("No feature set enrichment calculated for ", factor))
  }
  
  # Fetch and prepare data  
  
  # foo
  foo <- reshape2::melt(enrichment.results$feature.statistics[,factor], na.rm=TRUE, value.name="feature.statistic")
  foo$feature <- rownames(foo)
  
  # bar
  feature.sets <- enrichment.results$feature.sets
  feature.sets[feature.sets==0] <- NA
  bar <- reshape2::melt(feature.sets, na.rm=TRUE)[,c(1,2)]
  colnames(bar) <- c("pathway","feature")
  bar$pathway <- as.character(bar$pathway)
  bar$feature <- as.character(bar$feature)
  
  # baz
  baz <- reshape2::melt(enrichment.results$pval.adj[,factor], value.name="pvalue", na.rm=TRUE)
  baz$pathway <- rownames(baz)
  
  # Filter out pathways by p-values
  baz <- baz[baz$pvalue<=alpha,,drop=FALSE]
  if(nrow(baz)==0) {
    stop("No significant pathways at the specified alpha threshold. \n
         For an overview use plot_enrichment_heatmap().")
  } else {
    if (nrow(baz)>max.pathways)
      baz <- head(baz[order(baz$pvalue),],n=max.pathways)
  }
  
  # order pathways according to significance
  baz$pathway <- factor(baz$pathway, levels = baz$pathway[order(baz$pvalue, decreasing = TRUE)])
  
  # Merge
  foobar <- merge(foo, bar, by="feature")
  tmp <- merge(foobar, baz, by="pathway")
  
  # Select the top N features with the largest feature.statistic (per pathway)
  tmp_filt <- top_n(group_by(tmp, pathway), n=max.genes, abs(feature.statistic))
  
  # Add number of features and p-value per pathway
  pathways <- unique(tmp_filt$pathway)
  
  # Add Ngenes and p-values to the pathway name
  df <- data.frame(pathway=pathways, nfeatures=rowSums(feature.sets,na.rm=TRUE)[pathways])
  df <- merge(df, baz, by="pathway")
  df$pathway_long_name <- sprintf("%s\n (Ngenes = %d) \n (p-val = %0.2g)",df$pathway, df$nfeatures, df$pvalue)
  tmp <- merge(tmp, df[,c("pathway","pathway_long_name")], by="pathway")
  tmp_filt <- merge(tmp_filt, df[,c("pathway","pathway_long_name")], by="pathway")
  
  # sort pathways by p-value
  order_pathways <- df$pathway_long_name[order(df$pvalue,decreasing=TRUE) ]
  tmp$pathway_long_name <- factor(tmp$pathway_long_name, levels=order_pathways)
  tmp_filt$pathway_long_name <- factor(tmp_filt$pathway_long_name, levels=order_pathways)
  
  p <- ggplot(tmp, aes(x=.data[["pathway_long_name"]], y=.data[["feature.statistic"]])) +
    geom_text_repel(aes(x=.data[["pathway_long_name"]], y=.data[["feature.statistic"]], label=.data$feature), size=text_size, color="black", force=1, data=tmp_filt) +
    geom_point(size=0.5, color="lightgrey") +
    geom_point(aes(x=.data[["pathway_long_name"]], y=.data[["feature.statistic"]]), size=1, color="black", data=tmp_filt) +
    labs(x="", y="Weight (scaled)", title="") +
    coord_flip() +
    theme(
      axis.line = element_line(color="black"),
      axis.text.y = element_text(size=rel(0.75), hjust=1, color='black'),
      axis.text.x = element_text(size=rel(1.0), vjust=0.5, color='black'),
      axis.title.y=element_blank(),
      legend.position='none',
      panel.background = element_blank()
    )
  
  return(p)
}



#############################################################
## Internal methods for enrichment analysis (not exported) ##
#############################################################

# This is a modified version of the PCGSE module
.pcgse = function(data, prcomp.output, feature.sets, pc.indexes, 
                  set.statistic, set.test) {
  
  # Sanity checks
  if (is.null(feature.sets))
    stop("'feature.sets' must be specified!")
  if (!(set.statistic %in% c("mean.diff", "rank.sum")))
    stop("set.statistic must be 'mean.diff' or 'rank.sum'")
  if (!(set.test %in% c("parametric", "cor.adj.parametric", "permutation")))
    stop("set.test must be one of 'parametric', 'cor.adj.parametric', 'permutation'")
  
  
  # Turn the feature set matrix into list form
  set.indexes <- feature.sets  
  if (is.matrix(feature.sets)) {
    set.indexes <- .createVarGroupList(var.groups=feature.sets)  
  }
  
  # Compute the feature statistics.
  feature.statistics <- matrix(0, nrow=ncol(data), ncol=length(pc.indexes))
  for (i in seq_along(pc.indexes)) {
    feature.statistics[,i] <- .compute_feature_statistics(
      data = data,
      prcomp.output = prcomp.output,
      pc.index = pc.indexes[i]
    )
  }
  
  # Compute the set statistics.
  if (set.test == "parametric" || set.test == "cor.adj.parametric") {
    if (set.statistic == "mean.diff") {
      results <- .pcgse_ttest(
        data = data, 
        prcomp.output = prcomp.output,
        pc.indexes = pc.indexes,
        set.indexes = set.indexes,
        feature.statistics = feature.statistics,
        cor.adjustment = (set.test == "cor.adj.parametric")
      )
    } else if (set.statistic == "rank.sum") {
      results <- .pcgse_wmw(
        data = data, 
        prcomp.output = prcomp.output,
        pc.indexes = pc.indexes,
        set.indexes = set.indexes,
        feature.statistics = feature.statistics,
        cor.adjustment = (set.test == "cor.adj.parametric")
      )
    }
  }
  
  # Add feature.statistics to the results
  results$feature.statistics <- feature.statistics
  
  return (results) 
}




# Turn the annotation matrix into a list of var group indexes for the valid sized var groups
.createVarGroupList <- function(var.groups) {
  var.group.indexes <- list()  
  for (i in seq_len(nrow(var.groups))) {
    member.indexes <- which(var.groups[i,]==1)
    var.group.indexes[[i]] <- member.indexes    
  }
  names(var.group.indexes) <- rownames(var.groups)    
  return (var.group.indexes)
}

# Computes the feature-level statistics
.compute_feature_statistics <- function(data, prcomp.output, pc.index) {
  feature.statistics <- prcomp.output$rotation[,pc.index]
  feature.statistics <- vapply(feature.statistics, abs, numeric(1))
  return (feature.statistics)
}

# Compute enrichment via t-test
#' @importFrom stats pt var
.pcgse_ttest <- function(data, prcomp.output, pc.indexes,
                         set.indexes, feature.statistics, cor.adjustment) {
  
  num.feature.sets <- length(set.indexes)
  
  # Create matrix for p-values
  p.values <- matrix(0, nrow=num.feature.sets, ncol=length(pc.indexes))  
  rownames(p.values) <- names(set.indexes)
  
  # Create matrix for set statistics
  set.statistics <- matrix(TRUE, nrow=num.feature.sets, ncol=length(pc.indexes))    
  rownames(set.statistics) <- names(set.indexes)    
  
  for (i in seq_len(num.feature.sets)) {
    indexes.for.feature.set <- set.indexes[[i]]
    m1 <- length(indexes.for.feature.set)
    not.set.indexes <- which(!(seq_len(ncol(data)) %in% indexes.for.feature.set))
    m2 <- length(not.set.indexes)
    
    if (cor.adjustment) {      
      # compute sample correlation matrix for members of feature set
      cor.mat <- cor(data[,indexes.for.feature.set], use = "complete.obs")
      # compute the mean pair-wise correlation 
      mean.cor <- (sum(cor.mat) - m1)/(m1*(m1-1))    
      # compute the VIF, using CAMERA formula from Wu et al., based on Barry et al.
      vif <- 1 + (m1 -1)*mean.cor
    }
    
    for (j in seq_along(pc.indexes)) {
      # get the feature statistics for this PC
      pc.feature.stats <- feature.statistics[,j]
      # compute the mean difference of the feature-level statistics
      mean.diff <- mean(pc.feature.stats[indexes.for.feature.set],na.rm=TRUE) - mean(pc.feature.stats[not.set.indexes], na.rm=TRUE)
      # compute the pooled standard deviation
      pooled.sd <- sqrt(((m1-1)*var(pc.feature.stats[indexes.for.feature.set], na.rm=TRUE) + (m2-1)*var(pc.feature.stats[not.set.indexes], na.rm=TRUE))/(m1+m2-2))

      # compute the t-statistic
      if (cor.adjustment) {
        t.stat <- mean.diff/(pooled.sd*sqrt(vif/m1 + 1/m2))
        df <- nrow(data)-2
      } else {
        t.stat <- mean.diff/(pooled.sd*sqrt(1/m1 + 1/m2))
        df <- m1+m2-2
      }
      set.statistics[i,j] <- t.stat      
      # compute the p-value via a two-sided test
      lower.p <- pt(t.stat, df=df, lower.tail=TRUE)
      upper.p <- pt(t.stat, df=df, lower.tail=FALSE)        
      p.values[i,j] <- 2*min(lower.p, upper.p)      
    }
  } 
  
  # Build the result list
  results <- list()
  results$p.values <- p.values
  results$statistics <- set.statistics  
  
  return (results)
}

# Compute enrichment via Wilcoxon Mann Whitney 
#' @importFrom stats wilcox.test pnorm
.pcgse_wmw <- function(data, prcomp.output, pc.indexes,
                       set.indexes, feature.statistics, cor.adjustment) {
  
  num.feature.sets <- length(set.indexes)
  
  # Create matrix for p-values
  p.values <- matrix(0, nrow=num.feature.sets, ncol=length(pc.indexes))  
  rownames(p.values) <- names(set.indexes)
  
  # Create matrix for set statistics
  set.statistics <- matrix(TRUE, nrow=num.feature.sets, ncol=length(pc.indexes))    
  rownames(set.statistics) <- names(set.indexes)    
  
  for (i in seq_len(num.feature.sets)) {
    indexes.for.feature.set <- set.indexes[[i]]
    m1 <- length(indexes.for.feature.set)
    not.set.indexes <- which(!(seq_len(ncol(data)) %in% indexes.for.feature.set))
    m2 <- length(not.set.indexes)
    
    if (cor.adjustment) {            
      # compute sample correlation matrix for members of feature set
      cor.mat <- cor(data[,indexes.for.feature.set], use="complete.obs")
      # compute the mean pair-wise correlation 
      mean.cor <- (sum(cor.mat) - m1)/(m1*(m1-1))    
    }
    
    for (j in seq_along(pc.indexes)) {
      # get the feature-level statistics for this PC
      pc.feature.stats <- feature.statistics[,j]
      # compute the rank sum statistic feature-level statistics
      wilcox.results <- wilcox.test(x=pc.feature.stats[indexes.for.feature.set],
                                    y=pc.feature.stats[not.set.indexes],
                                    alternative="two.sided", exact=FALSE, correct=FALSE)
      rank.sum = wilcox.results$statistic                
      if (cor.adjustment) {
        # Using correlation-adjusted formula from Wu et al.
        var.rank.sum <- ((m1*m2)/(2*pi))*
          (asin(1) + (m2 - 1)*asin(.5) + (m1-1)*(m2-1)*asin(mean.cor/2) +(m1-1)*asin((mean.cor+1)/2))
      } else {        
        var.rank.sum <- m1*m2*(m1+m2+1)/12
      }
      z.stat <- (rank.sum - (m1*m2)/2)/sqrt(var.rank.sum)
      set.statistics[i,j] <- z.stat
      
      # compute the p-value via a two-sided z-test
      lower.p <- pnorm(z.stat, lower.tail=TRUE)
      upper.p <- pnorm(z.stat, lower.tail=FALSE)        
      p.values[i,j] <- 2*min(lower.p, upper.p)
    }
  } 
  
  # Build the result list
  results <- list()
  results$p.values <- p.values
  results$statistics <- set.statistics  
  
  return (results)
}


================================================
FILE: R/get_methods.R
================================================

################################################
## Get functions to fetch data from the model ##
################################################

#' @title Get dimensions
#' @name get_dimensions
#' @description Extract dimensionalities from the model. 
#' @details K indicates the number of factors, M indicates the number of views, D indicates the number of features (per view), 
#' N indicates the number of samples (per group) and C indicates the number of covariates.
#' @param object a \code{\link{MOFA}} object.
#' @return list containing the dimensionalities of the model
#' @export
#' @examples
#' # Using an existing trained model
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' dims <- get_dimensions(model)

get_dimensions <- function(object) {
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  return(object@dimensions)
}

#' @title Get ELBO
#' @name get_elbo
#' @description Extract the value of the ELBO statistics after model training. This can be useful for model selection.
#' @details This can be useful for model selection.
#' @param object a \code{\link{MOFA}} object.
#' @return Value of the ELBO
#' @export
#' @examples
#' # Using an existing trained model
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' elbo <- get_elbo(model)

get_elbo <- function(object) {
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  return(max(object@training_stats$elbo, na.rm=TRUE))
}

#' @title Get lengthscales
#' @name get_lengthscales
#' @description Extract the inferred lengthscale for each factor after model training. 
#' @details This can be used only if covariates are passed to the MOFAobject upon creation and GP_factors is set to True.
#' @param object a \code{\link{MOFA}} object.
#' @return A numeric vector containing the lengthscale for each factor.
#' @export
#' @examples 
#' # Using an existing trained model
#' file <- system.file("extdata", "MEFISTO_model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' ls <- get_lengthscales(model)
get_lengthscales <- function(object) {
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  if(!.hasSlot(object, "covariates") || is.null(object@covariates)) stop("No covariates specified in 'object'")
  if(is.null(object@training_stats$length_scales)) stop("No lengthscales saved in 'object' \n Make sure you specify the covariates and train setting the option 'GP_factors' to TRUE.")
  tmp <- object@training_stats$length_scales
  return(tmp)
}


#' @title Get scales
#' @name get_scales
#' @description Extract the inferred scale for each factor after model training. 
#' @details This can be used only if covariates are passed to the MOFAobject upon creation and GP_factors is set to True.
#' @param object a \code{\link{MOFA}} object.
#' @return A numeric vector containing the scale for each factor.
#' @export
#' @examples 
#' # Using an existing trained model
#' file <- system.file("extdata", "MEFISTO_model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' s <- get_scales(model)
get_scales <- function(object) {
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  if(!.hasSlot(object, "covariates") || is.null(object@covariates)) stop("No covariates specified in 'object'")
  if(is.null(object@training_stats$scales)) stop("No scales saved in 'object' \n Make sure you specify the covariates and train setting the option 'GP_factors' to TRUE.")
  tmp <- object@training_stats$scales
  return(tmp)
}

#' @title Get group covariance matrix
#' @name get_group_kernel
#' @description Extract the inferred group-group covariance matrix per factor
#' @details This can be used only if covariates are passed to the MOFAobject upon creation and GP_factors is set to True.
#' @param object a \code{\link{MOFA}} object.
#' @return A list of group-group correlation matrices per factor
#' @export
get_group_kernel <- function(object) {
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  if(is.null(object@covariates)) stop("No covariates specified in 'object'")
  if (is.null(object@mefisto_options)) stop("'object' does have MEFISTO training options.")
  
  if(!object@mefisto_options$model_groups || object@dimensions$G == 1) {
    tmp <- lapply(seq_len(dim(object@training_stats$Kg)[3]), function(x) {
      mat <- matrix(1, nrow = object@dimensions$G, ncol = object@dimensions$G)
      rownames(mat) <- colnames(mat) <- groups_names(object)
      mat
    })
  } else {
  if(is.null(object@training_stats$Kg)) stop("No group kernel saved in 'object' \n Make sure you specify the covariates and train setting the option 'model_groups' to TRUE.")
  tmp <- lapply(seq_len(dim(object@training_stats$Kg)[3]), function(x) {
    mat <- object@training_stats$Kg[ , , x]
    rownames(mat) <- colnames(mat) <- groups_names(object)
    mat
    })
  }
  names(tmp) <- factors_names(object)
  return(tmp)
}

#' @title Get interpolated factor values
#' @name get_interpolated_factors
#' @description Extract the interpolated factor values
#' @details This can be used only if covariates are passed to the object upon creation, GP_factors is set to True and new covariates were passed for interpolation.
#' @param object a \code{\link{MOFA}} object
#' @param as.data.frame logical indicating whether to return data as a data.frame
#' @param only_mean logical indicating whether include only mean or also uncertainties
#' @return By default, a nested list containing for each group a list with a matrix with the interpolated factor values ("mean"),
#'  their variance ("variance") and the values of the covariate at which interpolation took place ("new_values"). 
#' Alternatively, if \code{as.data.frame} is \code{TRUE}, returns a long-formatted data frame with columns containing the covariates 
#' and (factor, group, mean and variance).
#' @import dplyr
#' @import reshape2
#' @export
get_interpolated_factors <- function(object, as.data.frame = FALSE, only_mean = FALSE) {
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  if(is.null(object@interpolated_Z)) stop("No interpolated factors present in 'object'")
  if(length(object@interpolated_Z) == 0) stop("No interpolated factors present in 'object'")
  
  if(!as.data.frame){
    return(object@interpolated_Z)
  } else {
    type <- NULL
    preds <- lapply(object@interpolated_Z, function(l) l[names(l)[names(l) != "new_values"]])
    df_interpol <- reshape2::melt(preds, varnames = c("factor", "sample_id"))
    df_interpol <- dplyr::rename(df_interpol, group = L1, type = L2)
    if(only_mean){
      df_interpol <- filter(df_interpol, type == "mean")
    }
    
    if("new_values" %in% names(object@interpolated_Z[[1]])) {
      new_vals <- lapply(object@interpolated_Z, function(l) l[names(l)[names(l) == "new_values"]])
      new_vals <- reshape2::melt(new_vals, varnames = c("covariate","sample_id"))
      new_vals <- mutate(new_vals, covariate = covariates_names(object))
      new_vals <- rename(new_vals, group = L1, covariate_value = value)
      new_vals <- spread(new_vals, key = covariate, value = covariate_value)
      new_vals <- select(new_vals, -L2)
      df_interpol <- left_join(df_interpol, new_vals, by = c("group", "sample_id"))
      df_interpol <- select(df_interpol, -sample_id)
    } else { # compatibility to older objects
      df_interpol <- rename(df_interpol, covariate_value = sample_id)
      df_interpol <- mutate(df_interpol, covariate = covariates_names(object))
    }
    df_interpol <- mutate(df_interpol, factor = factors_names(object)[factor])
    df_interpol <- spread(df_interpol, key = type, value = value)
    return(df_interpol)
  }
}


#' @title Get factors
#' @name get_factors
#' @description Extract the latent factors from the model.
#' @param object a trained \code{\link{MOFA}} object.
#' @param factors character vector with the factor name(s), or numeric vector with the factor index(es).
#' Default is "all".
#' @param groups character vector with the group name(s), or numeric vector with the group index(es).
#' Default is "all".
#' @param scale logical indicating whether to scale factor values.
#' @param as.data.frame logical indicating whether to return a long data frame instead of a matrix.
#' Default is \code{FALSE}.
#' @return By default it returns the latent factor matrix of dimensionality (N,K), where N is number of samples and K is number of factors. \cr
#' Alternatively, if \code{as.data.frame} is \code{TRUE}, returns a long-formatted data frame with columns (sample,factor,value).
#' @export
#' 
#' @examples
#' # Using an existing trained model on simulated data
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#'
#' # Fetch factors in matrix format (a list, one matrix per group)
#' factors <- get_factors(model)
#'
#' # Concatenate groups
#' factors <- do.call("rbind",factors)
#'
#' # Fetch factors in data.frame format instead of matrix format
#' factors <- get_factors(model, as.data.frame = TRUE)
get_factors <- function(object, groups = "all", factors = "all", scale = FALSE, as.data.frame = FALSE) {

  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  
  # Get groups
  groups <- .check_and_get_groups(object, groups)
  # Get factors
  factors <- .check_and_get_factors(object, factors)
  
  # Collect factors
  Z <- get_expectations(object, "Z", as.data.frame)
  if (as.data.frame) {
    Z <- Z[Z$factor%in%factors & Z$group%in%groups,]
    if (scale) Z$value <- Z$value/max(abs(Z$value),na.rm=TRUE)
  } else {
    Z <- lapply(Z[groups], function(z) z[,factors, drop=FALSE])
    if (scale) Z <- lapply(Z, function(x) x/max(abs(x)) )
    names(Z) <- groups
  }

  return(Z)
}


#' @title Get weights
#' @name get_weights
#' @description Extract the weights from the model.
#' @param object a trained \code{\link{MOFA}} object.
#' @param views character vector with the view name(s), or numeric vector with the view index(es). 
#' Default is "all".
#' @param factors character vector with the factor name(s) or numeric vector with the factor index(es). \cr
#' Default is "all".
#' @param abs logical indicating whether to take the absolute value of the weights.
#' @param scale logical indicating whether to scale all weights from -1 to 1 (or from 0 to 1 if \code{abs=TRUE}).
#' @param as.data.frame logical indicating whether to return a long data frame instead of a list of matrices. 
#' Default is \code{FALSE}.
#' @return By default it returns a list where each element is a loading matrix with dimensionality (D,K), 
#' where D is the number of features and K is the number of factors. \cr
#' Alternatively, if \code{as.data.frame} is \code{TRUE}, returns a long-formatted data frame with columns (view,feature,factor,value).
#' @export
#' 
#' @examples
#' # Using an existing trained model on simulated data
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#'
#' # Fetch weights in matrix format (a list, one matrix per view)
#' weights <- get_weights(model)
#'
#' # Fetch weights for factor 1 and 2 and view 1
#' weights <- get_weights(model, views = 1, factors = c(1,2))
#'
#' # Fetch weights in data.frame format
#' weights <- get_weights(model, as.data.frame = TRUE)

get_weights <- function(object, views = "all", factors = "all", abs = FALSE, scale = FALSE, as.data.frame = FALSE) {

  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  
  # Get views
  views <- .check_and_get_views(object, views)
  factors <- .check_and_get_factors(object, factors)
  
  # Fetch weights
  weights <- get_expectations(object, "W", as.data.frame)
  
  if (as.data.frame) {
    weights <- weights[weights$view %in% views & weights$factor %in% factors, ]
    if (abs) weights$value <- abs(weights$value)
    if (scale) weights$value <- weights$value/max(abs(weights$value))
  } else {
    weights <- lapply(weights[views], function(x) x[,factors,drop=FALSE])
    if (abs) weights <- lapply(weights, abs)
    if (scale) weights <- lapply(weights, function(x) x/max(abs(x)) )
    names(weights) <- views
  }
  
  return(weights)
}


#' @title Get data
#' @name get_data
#' @description Fetch the input data
#' @param object a \code{\link{MOFA}} object.
#' @param views character vector with the view name(s), or numeric vector with the view index(es). 
#' Default is "all".
#' @param groups character vector with the group name(s), or numeric vector with the group index(es). 
#' Default is "all".
#' @param features a *named* list of character vectors. Example: list("view1"=c("feature_1","feature_2"), "view2"=c("feature_3","feature_4"))
#' Default is "all".
#' @param as.data.frame logical indicating whether to return a long data frame instead of a list of matrices. Default is \code{FALSE}.
#' @param add_intercept logical indicating whether to add feature intercepts to the data. Default is \code{TRUE}.
#' @param denoise logical indicating whether to return the denoised data (i.e. the model predictions). Default is \code{FALSE}.
#' @param na.rm remove NAs from the data.frame (only if as.data.frame is \code{TRUE}).
#' @details By default this function returns a list where each element is a data matrix with dimensionality (D,N) 
#' where D is the number of features and N is the number of samples. \cr
#' Alternatively, if \code{as.data.frame} is \code{TRUE}, the function returns a long-formatted data frame with columns (view,feature,sample,value).
#' Missing values are not included in the the long data.frame format by default. To include them use the argument \code{na.rm=FALSE}.
#' @return A  list of data matrices with dimensionality (D,N) or a \code{data.frame} (if \code{as.data.frame} is TRUE)
#' @export
#' @examples
#' # Using an existing trained model on simulated data
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#'
#' # Fetch data
#' data <- get_data(model)
#'
#' # Fetch a specific view
#' data <- get_data(model, views = "view_0")
#'
#' # Fetch data in data.frame format instead of matrix format
#' data <- get_data(model, as.data.frame = TRUE)
#'
#' # Fetch centered data (do not add the feature intercepts)
#' data <- get_data(model, as.data.frame = FALSE)
#' 
#' # Fetch denoised data (do not add the feature intercepts)
#' data <- get_data(model, denoise = TRUE)
get_data <- function(object, views = "all", groups = "all", features = "all", as.data.frame = FALSE, add_intercept = TRUE, denoise = FALSE, na.rm = TRUE) {

  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  
  # Get views and groups
  views  <- .check_and_get_views(object, views)
  groups <- .check_and_get_groups(object, groups)
  
  # Get features
  if (is(features, "list")) {
    if (is.null(names(features))) stop("features has to be a *named* list of character vectors. Please see the documentation")
    if (!(names(features)%in%views_names(object))) stop("Views not recognised")
    if (!all(sapply(names(features), function(i) all(features[[i]] %in% features_names(object)[[i]]) ))) stop("features not recognised")
    if (any(sapply(features,length)<1)) stop("features not recognised, please read the documentation")
    views <- names(features)
  } else {
    if (paste0(features, collapse="") == "all") { 
      features <- features_names(object)[views]
    } else {
      stop("features not recognised, please read the documentation")
    }
  }

  # Fetch data
  if (denoise) {
    data <- predict(object, views=views, groups=groups)
  } else {
    data <- lapply(object@data[views], function(x) x[groups])
  }
  data <- lapply(views, function(m) lapply(seq_len(length(data[[1]])), function(p) data[[m]][[p]][as.character(features[[m]]),,drop=FALSE]))
  data <- .name_views_and_groups(data, views, groups)
  
  # Add feature intercepts (only for gaussian likelihoods)
  tryCatch( {
    
    if (add_intercept & length(object@intercepts[[1]])>0) {
      intercepts <- lapply(object@intercepts[views], function(x) x[groups])
      intercepts <- lapply(seq_len(length(intercepts)), function(m) lapply(seq_len(length(intercepts[[1]])), function(p) intercepts[[m]][[p]][as.character(features[[m]])]))
      intercepts <- .name_views_and_groups(intercepts, views, groups)
      
      for (m in names(data)) {
        if (object@model_options$likelihoods[[m]]=="gaussian") {
          for (g in names(data[[m]])) {
            data[[m]][[g]] <- data[[m]][[g]] + intercepts[[m]][[g]][as.character(features[[m]])]
          }
        }
      }
    } }, error = function(e) { NULL })

  # Convert to long data frame
  if (as.data.frame) {
    tmp <- lapply(views, function(m) { 
      lapply(groups, function(p) { 
        tmp <- reshape2::melt(data[[m]][[p]], na.rm=na.rm)
        if(nrow(tmp) >0 & !is.null(tmp)) {
        colnames(tmp) <- c("feature", "sample", "value")
        tmp <- cbind(view = m, group = p, tmp)
        return(tmp) 
        } 
      })
    })
    data <- do.call(rbind, do.call(rbind, tmp))
    factor.cols <- c("view","group","feature","sample")
    data[factor.cols] <- lapply(data[factor.cols], factor)
    
  }
  
  return(data)
}


#' @title Get imputed data
#' @name get_imputed_data
#' @description Function to get the imputed data. It requires the previous use of the \code{\link{impute}} method.
#' @param object a trained \code{\link{MOFA}} object.
#' @param views character vector with the view name(s), or numeric vector with the view index(es). 
#' Default is "all".
#' @param groups character vector with the group name(s), or numeric vector with the group index(es).
#' Default is "all".
#' @param features list of character vectors with the feature names or list of numeric vectors with the feature indices. 
#' Default is "all".
#' @param as.data.frame logical indicating whether to return a long-formatted data frame instead of a list of matrices. 
#' Default is \code{FALSE}.
#' @details Data is imputed from the generative model of MOFA.
#' @return A list containing the imputed valued or a data.frame if as.data.frame is TRUE
#' @export
#' @examples
#' # Using an existing trained model
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' model <- impute(model)
#' imputed <- get_imputed_data(model)

get_imputed_data <- function(object, views = "all", groups = "all", features = "all", as.data.frame = FALSE) {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  if (length(object@imputed_data)==0) stop("imputed data not found, did you run: 'object <- impute(object)'?")
  
  # Get views and groups
  views <- .check_and_get_views(object, views)
  groups <- .check_and_get_groups(object, groups)

  # Get features
  if (is(features, "list")) {
    stopifnot(all(sapply(seq_len(length(features)), function(i) all(features[[i]] %in% features_names(object)[[views[i]]]))))
    stopifnot(length(features)==length(views))
    if (is.null(names(features))) names(features) <- views
  } else {
    if (paste0(features, collapse="") == "all") { 
      features <- features_names(object)[views]
    } else {
      stop("features not recognised, please read the documentation")
    }
  }
  
  # Fetch mean
  imputed_data <- lapply(object@imputed_data[views], function(x) x[groups] )
  imputed_data <- lapply(seq_len(length(imputed_data)), function(m) lapply(seq_len(length(imputed_data[[1]])), function(p) imputed_data[[m]][[p]][as.character(features[[m]]),,drop=FALSE]))
  imputed_data <- .name_views_and_groups(imputed_data, views, groups)
  
# Add feature intercepts
# tryCatch( {
#
#   if (add_intercept & length(object@intercepts[[1]])>0) {
#     intercepts <- lapply(object@intercepts[views], function(x) x[groups])
#     intercepts <- .name_views_and_groups(intercepts, views, groups)
#
#     for (m in names(imputed_data)) {
#       for (g in names(imputed_data[[m]])) {
#         imputed_data[[m]][[g]] <- imputed_data[[m]][[g]] + intercepts[[m]][[g]][as.character(features[[m]])]
#       }
#     }
#   } }, error = function(e) { NULL })

  # Convert to long data frame
  if (isTRUE(as.data.frame)) {
    
    imputed_data <- lapply(views, function(m) { 
      lapply(groups, function(g) { 
        tmp <- reshape2::melt(imputed_data[[m]][[g]])
        colnames(tmp) <- c("feature", "sample", "value")
        tmp <- cbind(view = m, group = g, tmp)
        return(tmp) 
      })
    })
    imputed_data <- do.call(rbind, do.call(rbind, imputed_data))
    

    factor.cols <- c("view","group","feature","sample")
    imputed_data[factor.cols] <- lapply(imputed_data[factor.cols], factor)
  }
  return(imputed_data)
}


#' @title Get expectations
#' @name get_expectations
#' @description Function to extract the expectations from the (variational) posterior distributions of a trained \code{\link{MOFA}} object.
#' @param object a trained \code{\link{MOFA}} object.
#' @param variable variable name: 'Z' for factors and 'W' for weights.
#' @param as.data.frame logical indicating whether to output the result as a long data frame, default is \code{FALSE}.
#' @details Technical note: MOFA is a Bayesian model where each variable has a prior distribution and a posterior distribution. 
#' In particular, to achieve scalability we used the variational inference framework, thus true posterior distributions are replaced by approximated variational distributions.
#' This function extracts the expectations of the variational distributions, which can be used as final point estimates to analyse the results of the model. \cr 
#' The priors and variational distributions of each variable are extensively described in the supplementary methods of the original paper.
#' @return the output varies depending on the variable of interest: \cr
#' \itemize{
#'  \item{\strong{"Z"}: a matrix with dimensions (samples,factors). If \code{as.data.frame} is \code{TRUE}, a long-formatted data frame with columns (sample,factor,value)}
#'  \item{\strong{"W"}: a list of length (views) where each element is a matrix with dimensions (features,factors). If \code{as.data.frame} is \code{TRUE}, a long-formatted data frame with columns (view,feature,factor,value)}
#' }
#' @export
#' @examples
#' # Using an existing trained model
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' factors <- get_expectations(model, "Z")
#' weights <- get_expectations(model, "W")

get_expectations <- function(object, variable, as.data.frame = FALSE) {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  stopifnot(variable %in% names(object@expectations))
  
  # Get expectations in single matrix or list of matrices (for multi-view nodes)
  exp <- object@expectations[[variable]]

  # unlist single view nodes - single Sigma node across all groups using time warping
  if(variable == "Sigma")
    exp <- exp[[1]]
  
  # For memory and space efficiency, Y expectations are not saved to the model file when using only gaussian likelihoods.
  if (variable == "Y") {
    if ((length(object@expectations$Y) == 0) && all(object@model_options$likelihood == "gaussian")) {
      # message("Using training data slot as Y expectations since all the likelihoods are gaussian.")
      exp <- object@data
    }
  }
  
  # Convert to long data frame
  if (as.data.frame) {
    
    # Z node
    if (variable=="Z") {
      tmp <- reshape2::melt(exp, na.rm=TRUE)
      colnames(tmp) <- c("sample", "factor", "value", "group")
      tmp$sample <- as.character(tmp$sample)
      factor.cols <- c("sample", "factor", "group")
      factor.cols[factor.cols] <- lapply(factor.cols[factor.cols], factor)
    }
    
    # W node
    else if (variable=="W") {
      tmp <- lapply(names(exp), function(m) { 
        tmp <- reshape2::melt(exp[[m]], na.rm=TRUE)
        colnames(tmp) <- c("feature","factor","value")
        tmp$view <- m
        factor.cols <- c("view", "feature", "factor")
        tmp[factor.cols] <- lapply(tmp[factor.cols], factor)
        return(tmp)
      })
      tmp <- do.call(rbind.data.frame,tmp)
    }
    
    # Y node
    else if (variable=="Y") {
      tmp <- lapply(names(exp), function(m) {
        tmp <- lapply(names(exp[[m]]), function(g) {
          tmp <- reshape2::melt(exp[[m]][[g]], na.rm=TRUE)
          colnames(tmp) <- c("sample", "feature", "value")
          tmp$view <- m
          tmp$group <- g
          factor.cols <- c("view", "group", "feature", "factor")
          tmp[factor.cols] <- lapply(tmp[factor.cols], factor)
          return(tmp)
        })
      })
      tmp <- do.call(rbind, tmp)
    }
    
    exp <- tmp
  }
  return(exp)
}


#' @title Get variance explained values
#' @name get_variance_explained
#' @description Extract the latent factors from the model.
#' @param object a trained \code{\link{MOFA}} object.
#' @param factors character vector with the factor name(s), or numeric vector with the factor index(es).
#' Default is "all".
#' @param groups character vector with the group name(s), or numeric vector with the group index(es).
#' Default is "all".
#' @param views character vector with the view name(s), or numeric vector with the view index(es).
#' Default is "all".
#' @param as.data.frame logical indicating whether to return a long data frame instead of a matrix.
#' Default is \code{FALSE}.
#' @return A list of data matrices with variance explained per group or a \code{data.frame} (if \code{as.data.frame} is TRUE)
#' @export
#'
#' @examples
#' # Using an existing trained model
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#'
#' # Fetch variance explained values (in matrix format)
#' r2 <- get_variance_explained(model)
#'
#' # Fetch variance explained values (in data.frame format)
#' r2 <- get_variance_explained(model, as.data.frame = TRUE)
#'
get_variance_explained <- function(object, groups = "all", views = "all", factors = "all", 
                                   as.data.frame = FALSE) {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  
  # Get factors and groups
  groups <- .check_and_get_groups(object, groups)
  views <- .check_and_get_views(object, views)
  factors <- .check_and_get_factors(object, factors)
  
  # Fetch R2
  if (.hasSlot(object, "cache") && ("variance_explained" %in% names(object@cache))) {
    r2_list <- object@cache$variance_explained
  } else {
    r2_list <- calculate_variance_explained(object, factors = factors, views = views, groups = groups)
  }
  
  # Convert to data.frame format
  if (as.data.frame) {
    
    # total R2
    r2_total <- reshape2::melt( do.call("rbind",r2_list[["r2_total"]]) )
    colnames(r2_total) <- c("group", "view", "value")
    
    # R2 per factor
    r2_per_factor <- lapply(names(r2_list[["r2_per_factor"]]), function(g) {
      x <- reshape2::melt( r2_list[["r2_per_factor"]][[g]] )
      colnames(x) <- c("factor", "view", "value")
      x$factor <- as.factor(x$factor)
      x$group <- g
      return(x)
    })
    r2_per_factor <- do.call("rbind",r2_per_factor)[,c("group","view","factor","value")]
    r2 <- list("r2_per_factor"=r2_per_factor, "r2_total"=r2_total)
    
  } else {
    r2 <- r2_list
  }
  
  return(r2)
}

================================================
FILE: R/imports.R
================================================
#' Re-exporting the pipe operator
#' See \code{magrittr::\link[magrittr]{\%>\%}} for details.
#'
#' @name %>%
#' @rdname pipe
#' @param lhs see \code{magrittr::\link[magrittr]{\%>\%}}
#' @param rhs see \code{magrittr::\link[magrittr]{\%>\%}}
#' @export
#' @importFrom magrittr %>%
#' @usage lhs \%>\% rhs
#' @return depending on lhs and rhs
NULL

================================================
FILE: R/impute.R
================================================

#######################################################
## Functions to perform imputation of missing values ##
#######################################################

#' @title Impute missing values from a fitted MOFA
#' @name impute
#' @description This function uses the latent factors and the loadings to impute missing values.
#' @param object a \code{\link{MOFA}} object.
#' @param views character vector with the view name(s), or numeric vector with view index(es).
#' @param groups character vector with the group name(s), or numeric vector with group index(es).
#' @param factors character vector with the factor names, or numeric vector with the factor index(es).
#' @param add_intercept add feature intercepts to the imputation (default is TRUE).
#' @details MOFA generates a denoised and condensed low-dimensional representation of the data that captures the main sources of heterogeneity of the data.
#' This representation can be used to reconstruct the data, simply using the equation \code{Y = WX}. 
#' For more details read the supplementary methods of the manuscript. \cr
#' Note that with \code{\link{impute}} you can only generate the point estimates (the means of the posterior distributions). 
#' If you want to add uncertainty estimates (the variance) you need to set \code{impute=TRUE} in the training options.
#' See \code{\link{get_default_training_options}}.
#' @return This method fills the \code{imputed_data} slot by replacing the missing values in the input data with the model predictions.
#' @export
#' @examples
#' # Using an existing trained model on simulated data
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' 
#' # Impute missing values in all data modalities
#' imputed_data <- impute(model, views = "all")
#' 
#' # Impute missing values in all data modalities using factors 1:3
#' imputed_data <- impute(model, views = "all", factors = 1:3)
impute <- function(object, views = "all", groups = "all", factors = "all", 
                  add_intercept = TRUE) {

  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  if (length(object@imputed_data)>0) warning("imputed_data slot is already filled. It will be replaced and the variance estimates will be lost...")
  
  # Get views and groups
  views  <- .check_and_get_views(object, views, non_gaussian=FALSE)
  groups <- .check_and_get_groups(object, groups)


  # Do predictions
  pred <- predict(object, views=views, factors=factors, add_intercept=add_intercept)

  # Replace NAs with predicted values
  imputed <- get_data(object, views=views, groups=groups, add_intercept = add_intercept)
  for (m in views) {
    for (g in groups) {
      imputed[[m]][[g]] <- imputed[[m]][[g]]
      non_observed <- is.na(imputed[[m]][[g]])
      imputed[[m]][[g]][non_observed] <- pred[[m]][[g]][non_observed]
    }
  }
  
  # Save imputed data in the corresponding slot
  object@imputed_data <- imputed

  return(object)
}



================================================
FILE: R/load_model.R
================================================

############################################
## Functions to load a trained MOFA model ##
############################################

#' @title Load a trained MOFA
#' @name load_model
#' @description Method to load a trained MOFA \cr
#' The training of mofa is done using a Python framework, and the model output is saved as an .hdf5 file, which has to be loaded in the R package.
#' @param file an hdf5 file saved by the mofa Python framework
#' @param sort_factors logical indicating whether factors should be sorted by variance explained (default is TRUE)
#' @param on_disk logical indicating whether to work from memory (FALSE) or disk (TRUE). \cr
#' This should be set to TRUE when the training data is so big that cannot fit into memory. \cr
#' On-disk operations are performed using the \code{\link{HDF5Array}} and \code{\link{DelayedArray}} framework.
#' @param load_data logical indicating whether to load the training data (default is TRUE, it can be memory expensive)
#' @param remove_outliers logical indicating whether to mask outlier values.
#' @param remove_inactive_factors logical indicating whether to remove inactive factors from the model.
# #' @param remove_intercept_factors logical indicating whether to remove intercept factors for non-Gaussian views.
#' @param verbose logical indicating whether to print verbose output (default is FALSE)
#' @param load_interpol_Z (MEFISTO) logical indicating whether to load predictions for factor values based on latent processed (only
#'  relevant for models trained with covariates and Gaussian processes, where prediction was enabled)
#' @return a \code{\link{MOFA}} model
#' @importFrom rhdf5 h5read h5ls
#' @importFrom HDF5Array HDF5ArraySeed
#' @importFrom DelayedArray DelayedArray
#' @importFrom dplyr bind_rows
#' @export
#' @examples
#' #' # Using an existing trained model on simulated data
#' file <- system.file("extdata", "model.hdf5", package = "MOFA2")
#' model <- load_model(file)

load_model <- function(file, sort_factors = TRUE, on_disk = FALSE, load_data = TRUE,
                       remove_outliers = FALSE, remove_inactive_factors = TRUE, verbose = FALSE,
                       load_interpol_Z = FALSE) {

  # Create new MOFAodel object
  object <- new("MOFA")
  object@status <- "trained"
  
  # Set on_disk option
  if (on_disk) { 
    object@on_disk <- TRUE 
  } else { 
      object@on_disk <- FALSE 
  }
  
  # Get groups and data set names from the hdf5 file object
  h5ls.out <- h5ls(file, datasetinfo = FALSE)
  
  ########################
  ## Load training data ##
  ########################

  # Load names
  if ("views" %in% h5ls.out$name) {
    view_names <- as.character( h5read(file, "views")[[1]] )
    group_names <- as.character( h5read(file, "groups")[[1]] )
    feature_names <- h5read(file, "features")[view_names]
    sample_names  <- h5read(file, "samples")[group_names] 
  } else {  # for old models
    feature_names <- h5read(file, "features")
    sample_names  <- h5read(file, "samples")
    view_names <- names(feature_names)
    group_names <- names(sample_names)
    h5ls.out <- h5ls.out[grep("variance_explained", h5ls.out$name, invert = TRUE),]
  }
  if("covariates" %in%  h5ls.out$name){
    covariate_names <- as.character( h5read(file, "covariates")[[1]])
  } else {
    covariate_names <- NULL
  }

  # Load training data (as nested list of matrices)
  data <- list(); intercepts <- list()
  if (load_data && "data"%in%h5ls.out$name) {
    
    object@data_options[["loaded"]] <- TRUE
    if (verbose) message("Loading data...")
    
    for (m in view_names) {
      data[[m]] <- list()
      intercepts[[m]] <- list()
      for (g in group_names) {
        if (on_disk) {
          # as DelayedArrays
          data[[m]][[g]] <- DelayedArray::DelayedArray( HDF5ArraySeed(file, name = sprintf("data/%s/%s", m, g) ) )
        } else {
          # as matrices
          data[[m]][[g]] <- h5read(file, sprintf("data/%s/%s", m, g) )
          tryCatch(intercepts[[m]][[g]] <- as.numeric( h5read(file, sprintf("intercepts/%s/%s", m, g) ) ), error = function(e) { NULL })
        }
        # Replace NaN by NA
        data[[m]][[g]][is.nan(data[[m]][[g]])] <- NA # this realised into memory, TO FIX
      }
    }
    
  # Create empty training data (as nested list of empty matrices, with the correct dimensions)
  } else {
    
    object@data_options[["loaded"]] <- FALSE
    
    for (m in view_names) {
      data[[m]] <- list()
      for (g in group_names) {
        data[[m]][[g]] <- .create_matrix_placeholder(rownames = feature_names[[m]], colnames = sample_names[[g]])
      }
    }
  }

  object@data <- data
  object@intercepts <- intercepts


  # Load metadata if any
  if ("samples_metadata" %in% h5ls.out$name) {
    object@samples_metadata <- bind_rows(lapply(group_names, function(g) as.data.frame(h5read(file, sprintf("samples_metadata/%s", g)))))
  }
  if ("features_metadata" %in% h5ls.out$name) {
    object@features_metadata <- bind_rows(lapply(view_names, function(m) as.data.frame(h5read(file, sprintf("features_metadata/%s", m)))))
  }
  
  ############################
  ## Load sample covariates ##
  ############################
  
  if (any(grepl("cov_samples", h5ls.out$group))){
    covariates <- list()
    for (g in group_names) {
      if (on_disk) {
        # as DelayedArrays
        covariates[[g]] <- DelayedArray::DelayedArray( HDF5ArraySeed(file, name = sprintf("cov_samples/%s", g) ) )
      } else {
        # as matrices
        covariates[[g]] <- h5read(file, sprintf("cov_samples/%s", g) )
      }    
    }
  } else covariates <- NULL
  object@covariates <- covariates

  if (any(grepl("cov_samples_transformed", h5ls.out$group))){
    covariates_warped <- list()
    for (g in group_names) {
      if (on_disk) {
        # as DelayedArrays
        covariates_warped[[g]] <- DelayedArray::DelayedArray( HDF5ArraySeed(file, name = sprintf("cov_samples_transformed/%s", g) ) )
      } else {
        # as matrices
        covariates_warped[[g]] <- h5read(file, sprintf("cov_samples_transformed/%s", g) )
      }    
    }
  } else covariates_warped <- NULL
  object@covariates_warped <- covariates_warped
  
  #######################
  ## Load interpolated factor values ##
  #######################
  
  interpolated_Z <- list()
  if (isTRUE(load_interpol_Z)) {
    
    if (isTRUE(verbose)) message("Loading interpolated factor values...")
    
    for (g in group_names) {
      interpolated_Z[[g]] <- list()
      if (on_disk) {
        # as DelayedArrays
        # interpolated_Z[[g]] <- DelayedArray::DelayedArray( HDF5ArraySeed(file, name = sprintf("Z_predictions/%s", g) ) )
      } else {
        # as matrices
        tryCatch( {
          interpolated_Z[[g]][["mean"]] <- h5read(file, sprintf("Z_predictions/%s/mean", g) )
        }, error = function(x) { print("Predictions of Z not found, not loading it...") })
        tryCatch( {
          interpolated_Z[[g]][["variance"]] <- h5read(file, sprintf("Z_predictions/%s/variance", g) )
        }, error = function(x) { print("Variance of predictions of Z not found, not loading it...") })
        tryCatch( {
          interpolated_Z[[g]][["new_values"]] <- h5read(file, "Z_predictions/new_values")
        }, error = function(x) { print("New values of Z not found, not loading it...") })
      }
    }
  }
  object@interpolated_Z <- interpolated_Z
  
  #######################
  ## Load expectations ##
  #######################

  expectations <- list()
  node_names <- h5ls.out[h5ls.out$group=="/expectations","name"]

  if (verbose) message(paste0("Loading expectations for ", length(node_names), " nodes..."))

  if ("AlphaW" %in% node_names)
    expectations[["AlphaW"]] <- h5read(file, "expectations/AlphaW")[view_names]
  if ("AlphaZ" %in% node_names)
    expectations[["AlphaZ"]] <- h5read(file, "expectations/AlphaZ")[group_names]
  if ("Sigma" %in% node_names)
    expectations[["Sigma"]] <- h5read(file, "expectations/Sigma")
  if ("Z" %in% node_names)
    expectations[["Z"]] <- h5read(file, "expectations/Z")[group_names]
  if ("W" %in% node_names)
    expectations[["W"]] <- h5read(file, "expectations/W")[view_names]
  if ("ThetaW" %in% node_names)
    expectations[["ThetaW"]] <- h5read(file, "expectations/ThetaW")[view_names]
  if ("ThetaZ" %in% node_names)
    expectations[["ThetaZ"]] <- h5read(file, "expectations/ThetaZ")[group_names]
  # if ("Tau" %in% node_names)
  #   expectations[["Tau"]] <- h5read(file, "expectations/Tau")
  
  object@expectations <- expectations

  
  ########################
  ## Load model options ##
  ########################

  if (verbose) message("Loading model options...")

  tryCatch( {
    object@model_options <- as.list(h5read(file, 'model_options', read.attributes = TRUE))
  }, error = function(x) { print("Model options not found, not loading it...") })

  # Convert True/False strings to logical values
  for (i in names(object@model_options)) {
    if (object@model_options[i] == "False" || object@model_options[i] == "True") {
      object@model_options[i] <- as.logical(object@model_options[i])
    } else {
      object@model_options[i] <- object@model_options[i]
    }
  }

  ##########################################
  ## Load training options and statistics ##
  ##########################################

  if (verbose) message("Loading training options and statistics...")

  # Load training options
  if (length(object@training_options) == 0) {
    tryCatch( {
      object@training_options <- as.list(h5read(file, 'training_opts', read.attributes = TRUE))
    }, error = function(x) { print("Training opts not found, not loading it...") })
  }

  # Load training statistics
  tryCatch( {
    object@training_stats <- h5read(file, 'training_stats', read.attributes = TRUE)
    object@training_stats <- h5read(file, 'training_stats', read.attributes = TRUE)
  }, error = function(x) { print("Training stats not found, not loading it...") })

  #############################
  ## Load covariates options ##
  #############################
  
  if (any(grepl("cov_samples", h5ls.out$group))) { 
    if (isTRUE(verbose)) message("Loading covariates options...")
    tryCatch( {
      object@mefisto_options <- as.list(h5read(file, 'smooth_opts', read.attributes = TRUE))
    }, error = function(x) { print("Covariates options not found, not loading it...") })
    
    # Convert True/False strings to logical values
    for (i in names(object@mefisto_options)) {
      if (object@mefisto_options[i] == "False" | object@mefisto_options[i] == "True") {
        object@mefisto_options[i] <- as.logical(object@mefisto_options[i])
      } else {
        object@mefisto_options[i] <- object@mefisto_options[i]
      }
    }
    
  }
  
  
    
  #######################################
  ## Load variance explained estimates ##
  #######################################
  
  if ("variance_explained" %in% h5ls.out$name) {
    r2_list <- list(
      r2_total = h5read(file, "variance_explained/r2_total")[group_names],
      r2_per_factor = h5read(file, "variance_explained/r2_per_factor")[group_names]
    )
    object@cache[["variance_explained"]] <- r2_list
  }
  
  # Hack to fix the problems where variance explained values range from 0 to 1 (%)
  if (max(sapply(object@cache$variance_explained$r2_total,max,na.rm=TRUE),na.rm=TRUE)<1) {
    for (m in 1:length(view_names)) {
      for (g in 1:length(group_names)) {
        object@cache$variance_explained$r2_total[[g]][[m]] <- 100 * object@cache$variance_explained$r2_total[[g]][[m]]
        object@cache$variance_explained$r2_per_factor[[g]][,m] <- 100 * object@cache$variance_explained$r2_per_factor[[g]][,m]
      }
    }
  }
  
  ##############################
  ## Specify dimensionalities ##
  ##############################
  
  # Specify dimensionality of the data
  object@dimensions[["M"]] <- length(data)                            # number of views
  object@dimensions[["G"]] <- length(data[[1]])                       # number of groups
  object@dimensions[["N"]] <- sapply(data[[1]], ncol)                 # number of samples (per group)
  object@dimensions[["D"]] <- sapply(data, function(e) nrow(e[[1]]))  # number of features (per view)
  object@dimensions[["C"]] <- nrow(covariates[[1]])                        # number of covariates
  object@dimensions[["K"]] <- ncol(object@expectations$Z[[1]])        # number of factors
  
  # Assign sample and feature names (slow for large matrices)
  if (verbose) message("Assigning names to the different dimensions...")

  # Create default features names if they are null
  if (is.null(feature_names)) {
    print("Features names not found, generating default: feature1_view1, ..., featureD_viewM")
    feature_names <- lapply(seq_len(object@dimensions[["M"]]),
                            function(m) sprintf("feature%d_view_&d", as.character(seq_len(object@dimensions[["D"]][m])), m))
  } else {
    # Check duplicated features names
    all_names <- unname(unlist(feature_names))
    duplicated_names <- unique(all_names[duplicated(all_names)])
    if (length(duplicated_names)>0) 
      warning("There are duplicated features names across different views. We will add the suffix *_view* only for those features 
            Example: if you have both TP53 in mRNA and mutation data it will be renamed to TP53_mRNA, TP53_mutation")
    for (m in names(feature_names)) {
      tmp <- which(feature_names[[m]] %in% duplicated_names)
      if (length(tmp)>0) feature_names[[m]][tmp] <- paste(feature_names[[m]][tmp], m, sep="_")
    }
  }
  features_names(object) <- feature_names
  
  # Create default samples names if they are null
  if (is.null(sample_names)) {
    print("Samples names not found, generating default: sample1, ..., sampleN")
    sample_names <- lapply(object@dimensions[["N"]], function(n) paste0("sample", as.character(seq_len(n))))
  }
  samples_names(object) <- sample_names

  # Add covariates names
  if(!is.null(object@covariates)){
    # Create default covariates names if they are null
    if (is.null(covariate_names)) {
      print("Covariate names not found, generating default: covariate1, ..., covariateC")
      covariate_names <- paste0("sample", as.character(seq_len(object@dimensions[["C"]])))
    }
    covariates_names(object) <- covariate_names
  }
  
  # Set views names
  if (is.null(names(object@data))) {
    print("Views names not found, generating default: view1, ..., viewM")
    view_names <- paste0("view", as.character(seq_len(object@dimensions[["M"]])))
  }
  views_names(object) <- view_names
  
  # Set groups names
  if (is.null(names(object@data[[1]]))) {
    print("Groups names not found, generating default: group1, ..., groupG")
    group_names <- paste0("group", as.character(seq_len(object@dimensions[["G"]])))
  }
  groups_names(object) <- group_names
  
  # Set factors names
  factors_names(object)  <- paste0("Factor", as.character(seq_len(object@dimensions[["K"]])))
  
  ###################
  ## Parse factors ##
  ###################
  
  # Calculate variance explained estimates per factor
  if (is.null(object@cache[["variance_explained"]])) {
    object@cache[["variance_explained"]] <- calculate_variance_explained(object)
  } 
  
  # Remove inactive factors
  if (remove_inactive_factors) {
    r2 <- rowSums(do.call('cbind', lapply(object@cache[["variance_explained"]]$r2_per_factor, rowSums, na.rm=TRUE)))
    var.threshold <- 0.0001
    if (all(r2 < var.threshold)) {
      warning(sprintf("All %s factors were found to explain little or no variance so remove_inactive_factors option has been disabled.", length(r2)))
    } else if (any(r2 < var.threshold)) {
      object <- subset_factors(object, which(r2>=var.threshold), recalculate_variance_explained=FALSE)
      message(sprintf("%s factors were found to explain no variance and they were removed for downstream analysis. You can disable this option by setting load_model(..., remove_inactive_factors = FALSE)", sum(r2 < var.threshold)))
    }
  }
  
  # [Done in mofapy2] Sort factors by total variance explained
  if (sort_factors && object@dimensions$K>1) {

    # Sanity checks
    if (verbose) message("Re-ordering factors by their variance explained...")

    # Calculate variance explained per factor across all views
    r2 <- rowSums(sapply(object@cache[["variance_explained"]]$r2_per_factor, function(e) rowSums(e, na.rm = TRUE)))
    order_factors <- c(names(r2)[order(r2, decreasing = TRUE)])

    # re-order factors
    object <- subset_factors(object, order_factors)
  }

  # Mask outliers
  if (remove_outliers) {
    if (verbose) message("Removing outliers...")
    object <- .detect_outliers(object)
  }
  
  # Mask intercepts for non-Gaussian data
  if (any(object@model_options$likelihoods!="gaussian")) {
    for (m in names(which(object@model_options$likelihoods!="gaussian"))) {
      for (g in names(object@intercepts[[m]])) {
        object@intercepts[[m]][[g]] <- NA
      }
    }
  }

  ######################
  ## Quality controls ##
  ######################

  if (verbose) message("Doing quality control...")
  object <- .quality_control(object, verbose = verbose)
  
  return(object)
}


================================================
FILE: R/make_example_data.R
================================================

#' @title Simulate a data set using the generative model of MOFA
#' @name make_example_data
#' @description Function to simulate an example multi-view multi-group data set according to the generative model of MOFA2.
#' @param n_views number of views
#' @param n_features number of features in each view 
#' @param n_samples number of samples in each group
#' @param n_groups number of groups
#' @param n_factors number of factors
#' @param likelihood likelihood for each view, one of "gaussian" (default), "bernoulli", "poisson",
#'  or a character vector of length n_views
#' @param lscales vector of lengthscales, needs to be of length n_factors (default is 0 - no smooth factors)
#' @param sample_cov (only for use with MEFISTO) matrix of sample covariates for one group with covariates in rows and samples in columns 
#' or "equidistant" for sequential ordering, default is NULL (no smooth factors)
#' @param as.data.frame return data and covariates as long dataframe 
#' @return Returns a list containing the simulated data and simulation parameters.
#' @importFrom stats rnorm rbinom rpois
#' @importFrom dplyr left_join
#' @importFrom stats dist
#' @export
#' @examples
#' # Generate a simulated data set
#' MOFAexample <- make_example_data()


make_example_data <- function(n_views=3, n_features=100, n_samples = 50, n_groups = 1,
                            n_factors = 5, likelihood = "gaussian",
                            lscales = 1, sample_cov = NULL, as.data.frame = FALSE) {
  
  # Sanity checks
  if (!all(likelihood %in% c("gaussian", "bernoulli", "poisson")))
    stop("Likelihood not implemented: Use either gaussian, bernoulli or poisson")
  
  if(length(lscales) == 1)
    lscales = rep(lscales, n_factors)
  if(!length(lscales) == n_factors)
    stop("Lengthscales lscales need to be of length n_factors")
  if(all(lscales == 0)){
    sample_cov <- NULL
  }
  
  if (length(likelihood)==1) likelihood <- rep(likelihood, n_views) 
  if (!length(likelihood) == n_views) 
    stop("Likelihood needs to be a single string or matching the number of views!")
  
  if(!is.null(sample_cov)){
    if(sample_cov[1] == "equidistant") {
      sample_cov <- seq_len(n_samples)
    }
    if(is.null(dim(sample_cov))) sample_cov <- matrix(sample_cov, nrow = 1)
    if(ncol(sample_cov) != n_samples){
      stop("Number of columns in sample_cov must match number of samples n_samples.")
    }
  
    # Simulate covariance for factors
    Sigma = lapply(lscales, function(ls) {
      if(ls == 0) diag(1, n_samples)
      else (1) * exp(-as.matrix(stats::dist(t(sample_cov)))^2/(2*ls^2))
      # else (1-0.001) * exp(-as.matrix(stats::dist(t(sample_cov)))^2/(2*ls^2)) + diag(0.001, n_samples)
    })
  
    # simulate factors
    alpha_z <- NULL
    S_z <- lapply(seq_len(n_groups), function(vw) matrix(1, nrow=n_samples, ncol=n_factors))
    Z <-  vapply(seq_len(n_factors), function(fc) mvtnorm::rmvnorm(1, rep(0, n_samples), Sigma[[fc]]), numeric(n_samples))
    colnames(Z) <- paste0("simulated_factor_", 1:ncol(Z))
    Z <- lapply(seq_len(n_groups), function(gr) Z)
    sample_cov <- Reduce(cbind, lapply(seq_len(n_groups), function(gr) sample_cov))
  } else {
    # set sparsity for factors
    theta_z <- 0.5
    
    # set ARD prior for factors, each factor being active in at least one group
    alpha_z <- vapply(seq_len(n_factors), function(fc) {
      active_gw <- sample(seq_len(n_groups), 1)
      alpha_fc <- sample(c(1, 1000), n_groups, replace = TRUE)
      if(all(alpha_fc==1000)) alpha_fc[active_gw] <- 1
      alpha_fc
    }, numeric(n_groups))
    alpha_z <- matrix(alpha_z, nrow=n_factors, ncol=n_groups, byrow=TRUE)
    
    # simulate factors 
    S_z <- lapply(seq_len(n_groups), function(vw) matrix(rbinom(n_samples * n_factors, 1, theta_z),
                                                         nrow=n_samples, ncol=n_factors))
    Z <- lapply(seq_len(n_groups), function(vw) vapply(seq_len(n_factors), function(fc) rnorm(n_samples, 0, sqrt(1/alpha_z[fc,vw])), numeric(n_samples)))
  }
  
  # set sparsity for weights
  theta_w <- 0.5
  
  # set ARD prior, each factor being active in at least one view
  alpha_w <- vapply(seq_len(n_factors), function(fc) {
    active_vw <- sample(seq_len(n_views), 1)
    alpha_fc <- sample(c(1, 1000), n_views, replace = TRUE)
    if(all(alpha_fc==1000)) alpha_fc[active_vw] <- 1
    alpha_fc
  }, numeric(n_views))
  alpha_w <- matrix(alpha_w, nrow=n_factors, ncol=n_views, byrow=TRUE)
  
  # simulate weights 
  S_w <- lapply(seq_len(n_views), function(vw) matrix(rbinom(n_features*n_factors, 1, theta_w),
                                             nrow=n_features, ncol=n_factors))
  W <- lapply(seq_len(n_views), function(vw) vapply(seq_len(n_factors), function(fc) rnorm(n_features, 0, sqrt(1/alpha_w[f
Download .txt
gitextract_72sduoul/

├── .Rbuildignore
├── .gitattributes
├── .gitignore
├── .gitmodules
├── DESCRIPTION
├── Dockerfile
├── LICENSE
├── NAMESPACE
├── R/
│   ├── AllClasses.R
│   ├── AllGenerics.R
│   ├── QC.R
│   ├── basilisk.R
│   ├── calculate_variance_explained.R
│   ├── cluster_samples.R
│   ├── compare_models.R
│   ├── contribution_scores.R
│   ├── correlate_covariates.R
│   ├── create_mofa.R
│   ├── dimensionality_reduction.R
│   ├── enrichment.R
│   ├── get_methods.R
│   ├── imports.R
│   ├── impute.R
│   ├── load_model.R
│   ├── make_example_data.R
│   ├── mefisto.R
│   ├── plot_data.R
│   ├── plot_factors.R
│   ├── plot_weights.R
│   ├── predict.R
│   ├── prepare_mofa.R
│   ├── run_mofa.R
│   ├── set_methods.R
│   ├── subset.R
│   └── utils.R
├── README.md
├── configure
├── configure.win
├── inst/
│   ├── CITATION
│   ├── extdata/
│   │   └── test_data.RData
│   └── scripts/
│       ├── template_script.R
│       ├── template_script.py
│       ├── template_script_dataframe.py
│       └── template_script_matrix.py
├── man/
│   ├── .Rapp.history
│   ├── MOFA.Rd
│   ├── add_mofa_factors_to_seurat.Rd
│   ├── calculate_contribution_scores.Rd
│   ├── calculate_variance_explained.Rd
│   ├── calculate_variance_explained_per_sample.Rd
│   ├── cluster_samples.Rd
│   ├── compare_elbo.Rd
│   ├── compare_factors.Rd
│   ├── correlate_factors_with_covariates.Rd
│   ├── covariates_names.Rd
│   ├── create_mofa.Rd
│   ├── create_mofa_from_MultiAssayExperiment.Rd
│   ├── create_mofa_from_Seurat.Rd
│   ├── create_mofa_from_SingleCellExperiment.Rd
│   ├── create_mofa_from_df.Rd
│   ├── create_mofa_from_matrix.Rd
│   ├── factors_names.Rd
│   ├── features_metadata.Rd
│   ├── features_names.Rd
│   ├── get_covariates.Rd
│   ├── get_data.Rd
│   ├── get_default_data_options.Rd
│   ├── get_default_mefisto_options.Rd
│   ├── get_default_model_options.Rd
│   ├── get_default_stochastic_options.Rd
│   ├── get_default_training_options.Rd
│   ├── get_dimensions.Rd
│   ├── get_elbo.Rd
│   ├── get_expectations.Rd
│   ├── get_factors.Rd
│   ├── get_group_kernel.Rd
│   ├── get_imputed_data.Rd
│   ├── get_interpolated_factors.Rd
│   ├── get_lengthscales.Rd
│   ├── get_scales.Rd
│   ├── get_variance_explained.Rd
│   ├── get_weights.Rd
│   ├── groups_names.Rd
│   ├── impute.Rd
│   ├── interpolate_factors.Rd
│   ├── load_model.Rd
│   ├── make_example_data.Rd
│   ├── pipe.Rd
│   ├── plot_alignment.Rd
│   ├── plot_ascii_data.Rd
│   ├── plot_data_heatmap.Rd
│   ├── plot_data_overview.Rd
│   ├── plot_data_scatter.Rd
│   ├── plot_data_vs_cov.Rd
│   ├── plot_dimred.Rd
│   ├── plot_enrichment.Rd
│   ├── plot_enrichment_detailed.Rd
│   ├── plot_enrichment_heatmap.Rd
│   ├── plot_factor.Rd
│   ├── plot_factor_cor.Rd
│   ├── plot_factors.Rd
│   ├── plot_factors_vs_cov.Rd
│   ├── plot_group_kernel.Rd
│   ├── plot_interpolation_vs_covariate.Rd
│   ├── plot_sharedness.Rd
│   ├── plot_smoothness.Rd
│   ├── plot_top_weights.Rd
│   ├── plot_variance_explained.Rd
│   ├── plot_variance_explained_by_covariates.Rd
│   ├── plot_variance_explained_per_feature.Rd
│   ├── plot_weights.Rd
│   ├── plot_weights_heatmap.Rd
│   ├── plot_weights_scatter.Rd
│   ├── predict.Rd
│   ├── prepare_mofa.Rd
│   ├── run_enrichment.Rd
│   ├── run_mofa.Rd
│   ├── run_tsne.Rd
│   ├── run_umap.Rd
│   ├── samples_metadata.Rd
│   ├── samples_names.Rd
│   ├── select_model.Rd
│   ├── set_covariates.Rd
│   ├── subset_factors.Rd
│   ├── subset_features.Rd
│   ├── subset_groups.Rd
│   ├── subset_samples.Rd
│   ├── subset_views.Rd
│   ├── summarise_factors.Rd
│   └── views_names.Rd
├── setup.py
├── tests/
│   ├── testthat/
│   │   ├── barcodes.tsv
│   │   ├── genes.tsv
│   │   ├── matrix.csv
│   │   ├── matrix.mtx
│   │   ├── test_create_model.R
│   │   ├── test_load_model.R
│   │   ├── test_plot.R
│   │   └── test_prepare_model.R
│   └── testthat.R
└── vignettes/
    ├── MEFISTO_temporal.Rmd
    ├── downstream_analysis.Rmd
    └── getting_started_R.Rmd
Download .txt
SYMBOL INDEX (1 symbols across 1 files)

FILE: setup.py
  function setup_package (line 8) | def setup_package():
Condensed preview — 143 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (773K chars).
[
  {
    "path": ".Rbuildignore",
    "chars": 56,
    "preview": "^.*\\.Rproj$\n^\\.Rproj\\.user$\nmofapy2\nDockerfile\nsetup.py\n"
  },
  {
    "path": ".gitattributes",
    "chars": 17,
    "preview": "*.sh text eol=lf\n"
  },
  {
    "path": ".gitignore",
    "chars": 410,
    "preview": "# Resilio Sync\n.sync\n\n# MAC\n*Icon*\n.DS_Store\n\n# Rstudio projects\n*.Rproj\n.Rhistory\n\n*_site/\n# Pycharm\n.idea\n\n# HTML\n# *."
  },
  {
    "path": ".gitmodules",
    "chars": 75,
    "preview": "[submodule \"mofapy2\"]\n\tpath = mofapy2\n\turl = git@github.com:bioFAM/mofapy2\n"
  },
  {
    "path": "DESCRIPTION",
    "chars": 2367,
    "preview": "Package: MOFA2\nType: Package\nTitle: Multi-Omics Factor Analysis v2\nVersion: 1.21.3\nMaintainer: Ricard Argelaguet <ricard"
  },
  {
    "path": "Dockerfile",
    "chars": 1076,
    "preview": "FROM r-base:4.0.2\n\nWORKDIR /mofa2\nADD . /mofa2\n\nRUN apt-get update && apt-get install -f && apt-get install -y python3 p"
  },
  {
    "path": "LICENSE",
    "chars": 7652,
    "preview": "                   GNU LESSER GENERAL PUBLIC LICENSE\n                       Version 3, 29 June 2007\n\n Copyright (C) 2007"
  },
  {
    "path": "NAMESPACE",
    "chars": 4586,
    "preview": "# Generated by roxygen2: do not edit by hand\n\nexport(\"%>%\")\nexport(\"covariates_names<-\")\nexport(\"factors_names<-\")\nexpor"
  },
  {
    "path": "R/AllClasses.R",
    "chars": 5735,
    "preview": "\n##########################################################\n## Define a general class to store a MOFA trained model ##\n#"
  },
  {
    "path": "R/AllGenerics.R",
    "chars": 3607,
    "preview": "\n##################\n## Factor Names ##\n##################\n\n#' @title factors_names: set and retrieve factor names\n#' @na"
  },
  {
    "path": "R/QC.R",
    "chars": 6858,
    "preview": "#' @importFrom stringi stri_enc_mark\n.quality_control <- function(object, verbose = FALSE) {\n  \n  # Sanity checks\n  if ("
  },
  {
    "path": "R/basilisk.R",
    "chars": 526,
    "preview": "# .mofapy2_dependencies <- c(\n#     \"h5py==3.1.0\",\n#     \"pandas==1.2.1\",\n#     \"scikit-learn==0.24.1\",\n#     \"dtw-pytho"
  },
  {
    "path": "R/calculate_variance_explained.R",
    "chars": 22376,
    "preview": "#' @title Calculate variance explained by the model\n#' @description  This function takes a trained MOFA model as input a"
  },
  {
    "path": "R/cluster_samples.R",
    "chars": 2451,
    "preview": "\n##########################################################\n## Functions to cluster samples based on latent factors ##\n#"
  },
  {
    "path": "R/compare_models.R",
    "chars": 5566,
    "preview": "\n################################################\n## Functions to compare different MOFA models ##\n#####################"
  },
  {
    "path": "R/contribution_scores.R",
    "chars": 6353,
    "preview": "#' @title Calculate contribution scores for each view in each sample\n#' @description This function calculates, *for each"
  },
  {
    "path": "R/correlate_covariates.R",
    "chars": 7689,
    "preview": "#' @title Plot correlation of factors with external covariates\n#' @name correlate_factors_with_covariates\n#' @descriptio"
  },
  {
    "path": "R/create_mofa.R",
    "chars": 27926,
    "preview": "\n#' @title create a MOFA object\n#' @name create_mofa\n#' @description Method to create a \\code{\\link{MOFA}} object. Depen"
  },
  {
    "path": "R/dimensionality_reduction.R",
    "chars": 12842,
    "preview": "\n##################################################################\n## Functions to do dimensionality reduction on the M"
  },
  {
    "path": "R/enrichment.R",
    "chars": 26108,
    "preview": "##########################################################\n## Functions to perform Feature Set Enrichment Analysis ##\n##"
  },
  {
    "path": "R/get_methods.R",
    "chars": 27396,
    "preview": "\n################################################\n## Get functions to fetch data from the model ##\n#####################"
  },
  {
    "path": "R/imports.R",
    "chars": 345,
    "preview": "#' Re-exporting the pipe operator\n#' See \\code{magrittr::\\link[magrittr]{\\%>\\%}} for details.\n#'\n#' @name %>%\n#' @rdname"
  },
  {
    "path": "R/impute.R",
    "chars": 2997,
    "preview": "\n#######################################################\n## Functions to perform imputation of missing values ##\n#######"
  },
  {
    "path": "R/load_model.R",
    "chars": 17155,
    "preview": "\n############################################\n## Functions to load a trained MOFA model ##\n#############################"
  },
  {
    "path": "R/make_example_data.R",
    "chars": 6625,
    "preview": "\n#' @title Simulate a data set using the generative model of MOFA\n#' @name make_example_data\n#' @description Function to"
  },
  {
    "path": "R/mefisto.R",
    "chars": 50133,
    "preview": "##########################################################################\n## Functions to use continuous covariates, as"
  },
  {
    "path": "R/plot_data.R",
    "chars": 25846,
    "preview": "###########################################\n## Functions to visualise the input data ##\n################################"
  },
  {
    "path": "R/plot_factors.R",
    "chars": 21842,
    "preview": "\n###########################################\n## Functions to visualise latent factors ##\n###############################"
  },
  {
    "path": "R/plot_weights.R",
    "chars": 24609,
    "preview": "########################################\n## Functions to visualise the weights ##\n######################################"
  },
  {
    "path": "R/predict.R",
    "chars": 3180,
    "preview": "\n######################################\n## Functions to perform predictions ##\n######################################\n\n#"
  },
  {
    "path": "R/prepare_mofa.R",
    "chars": 26193,
    "preview": "\n#######################################################\n## Functions to prepare a MOFA object for training ##\n#########"
  },
  {
    "path": "R/run_mofa.R",
    "chars": 11972,
    "preview": "#######################################\n## Functions to train a MOFA model ##\n#######################################\n\n#"
  },
  {
    "path": "R/set_methods.R",
    "chars": 34281,
    "preview": "\n\n####################################\n## Set and retrieve factors names ##\n####################################\n\n#' @rd"
  },
  {
    "path": "R/subset.R",
    "chars": 16340,
    "preview": "\n################################\n## Functions to do subsetting ##\n################################\n\n#' @title Subset gr"
  },
  {
    "path": "R/utils.R",
    "chars": 24845,
    "preview": "\n# Function to find \"intercept\" factors\n# .detectInterceptFactors <- function(object, cor_threshold = 0.75) {\n#   \n#   #"
  },
  {
    "path": "README.md",
    "chars": 289,
    "preview": "\n# Multi-Omics Factor Analysis\n\nMOFA is a factor analysis model that provides a general framework for the integration of"
  },
  {
    "path": "configure",
    "chars": 70,
    "preview": "#!/bin/sh\n\n${R_HOME}/bin/Rscript -e \"basilisk::configureBasiliskEnv()\""
  },
  {
    "path": "configure.win",
    "chars": 87,
    "preview": "#!/bin/sh\n\n${R_HOME}/bin${R_ARCH_BIN}/Rscript.exe -e \"basilisk::configureBasiliskEnv()\""
  },
  {
    "path": "inst/CITATION",
    "chars": 2807,
    "preview": "citEntry(entry=\"article\",\n         title = \"Multi‐Omics Factor Analysis—a framework for unsupervised integration of mult"
  },
  {
    "path": "inst/scripts/template_script.R",
    "chars": 2643,
    "preview": "library(MOFA2)\nlibrary(data.table)\n\n# (Optional) set up reticulate connection with Python\n# library(reticulate)\n# reticu"
  },
  {
    "path": "inst/scripts/template_script.py",
    "chars": 5275,
    "preview": "\n######################################################\n## Template script to train a MOFA+ model in Python ##\n#########"
  },
  {
    "path": "inst/scripts/template_script_dataframe.py",
    "chars": 3460,
    "preview": "\n######################################################\n## Template script to train a MOFA+ model in Python ##\n#########"
  },
  {
    "path": "inst/scripts/template_script_matrix.py",
    "chars": 4745,
    "preview": "\n######################################################\n## Template script to train a MOFA+ model in Python ##\n#########"
  },
  {
    "path": "man/.Rapp.history",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "man/MOFA.Rd",
    "chars": 1822,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/AllClasses.R\n\\docType{class}\n\\name{MOFA}\n\\"
  },
  {
    "path": "man/add_mofa_factors_to_seurat.Rd",
    "chars": 1102,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/utils.R\n\\name{add_mofa_factors_to_seurat}\n"
  },
  {
    "path": "man/calculate_contribution_scores.Rd",
    "chars": 2315,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/contribution_scores.R\n\\name{calculate_cont"
  },
  {
    "path": "man/calculate_variance_explained.Rd",
    "chars": 1701,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/calculate_variance_explained.R\n\\name{calcu"
  },
  {
    "path": "man/calculate_variance_explained_per_sample.Rd",
    "chars": 1357,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/calculate_variance_explained.R\n\\name{calcu"
  },
  {
    "path": "man/cluster_samples.Rd",
    "chars": 1573,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/cluster_samples.R\n\\name{cluster_samples}\n\\"
  },
  {
    "path": "man/compare_elbo.Rd",
    "chars": 1132,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/compare_models.R\n\\name{compare_elbo}\n\\alia"
  },
  {
    "path": "man/compare_factors.Rd",
    "chars": 1016,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/compare_models.R\n\\name{compare_factors}\n\\a"
  },
  {
    "path": "man/correlate_factors_with_covariates.Rd",
    "chars": 2147,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/correlate_covariates.R\n\\name{correlate_fac"
  },
  {
    "path": "man/covariates_names.Rd",
    "chars": 947,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/AllGenerics.R, R/set_methods.R\n\\name{covar"
  },
  {
    "path": "man/create_mofa.Rd",
    "chars": 1684,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/create_mofa.R\n\\name{create_mofa}\n\\alias{cr"
  },
  {
    "path": "man/create_mofa_from_MultiAssayExperiment.Rd",
    "chars": 911,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/create_mofa.R\n\\name{create_mofa_from_Multi"
  },
  {
    "path": "man/create_mofa_from_Seurat.Rd",
    "chars": 1181,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/create_mofa.R\n\\name{create_mofa_from_Seura"
  },
  {
    "path": "man/create_mofa_from_SingleCellExperiment.Rd",
    "chars": 990,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/create_mofa.R\n\\name{create_mofa_from_Singl"
  },
  {
    "path": "man/create_mofa_from_df.Rd",
    "chars": 1119,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/create_mofa.R\n\\name{create_mofa_from_df}\n\\"
  },
  {
    "path": "man/create_mofa_from_matrix.Rd",
    "chars": 839,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/create_mofa.R\n\\name{create_mofa_from_matri"
  },
  {
    "path": "man/factors_names.Rd",
    "chars": 860,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/AllGenerics.R, R/set_methods.R\n\\name{facto"
  },
  {
    "path": "man/features_metadata.Rd",
    "chars": 989,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/AllGenerics.R, R/set_methods.R\n\\name{featu"
  },
  {
    "path": "man/features_names.Rd",
    "chars": 923,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/AllGenerics.R, R/set_methods.R\n\\name{featu"
  },
  {
    "path": "man/get_covariates.Rd",
    "chars": 1071,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/mefisto.R\n\\name{get_covariates}\n\\alias{get"
  },
  {
    "path": "man/get_data.Rd",
    "chars": 2414,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/get_methods.R\n\\name{get_data}\n\\alias{get_d"
  },
  {
    "path": "man/get_default_data_options.Rd",
    "chars": 1867,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/prepare_mofa.R\n\\name{get_default_data_opti"
  },
  {
    "path": "man/get_default_mefisto_options.Rd",
    "chars": 2994,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/mefisto.R\n\\name{get_default_mefisto_option"
  },
  {
    "path": "man/get_default_model_options.Rd",
    "chars": 2141,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/prepare_mofa.R\n\\name{get_default_model_opt"
  },
  {
    "path": "man/get_default_stochastic_options.Rd",
    "chars": 2131,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/prepare_mofa.R\n\\name{get_default_stochasti"
  },
  {
    "path": "man/get_default_training_options.Rd",
    "chars": 2735,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/prepare_mofa.R\n\\name{get_default_training_"
  },
  {
    "path": "man/get_dimensions.Rd",
    "chars": 751,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/get_methods.R\n\\name{get_dimensions}\n\\alias"
  },
  {
    "path": "man/get_elbo.Rd",
    "chars": 587,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/get_methods.R\n\\name{get_elbo}\n\\alias{get_e"
  },
  {
    "path": "man/get_expectations.Rd",
    "chars": 1917,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/get_methods.R\n\\name{get_expectations}\n\\ali"
  },
  {
    "path": "man/get_factors.Rd",
    "chars": 1513,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/get_methods.R\n\\name{get_factors}\n\\alias{ge"
  },
  {
    "path": "man/get_group_kernel.Rd",
    "chars": 534,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/get_methods.R\n\\name{get_group_kernel}\n\\ali"
  },
  {
    "path": "man/get_imputed_data.Rd",
    "chars": 1332,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/get_methods.R\n\\name{get_imputed_data}\n\\ali"
  },
  {
    "path": "man/get_interpolated_factors.Rd",
    "chars": 1136,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/get_methods.R\n\\name{get_interpolated_facto"
  },
  {
    "path": "man/get_lengthscales.Rd",
    "chars": 714,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/get_methods.R\n\\name{get_lengthscales}\n\\ali"
  },
  {
    "path": "man/get_scales.Rd",
    "chars": 671,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/get_methods.R\n\\name{get_scales}\n\\alias{get"
  },
  {
    "path": "man/get_variance_explained.Rd",
    "chars": 1390,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/get_methods.R\n\\name{get_variance_explained"
  },
  {
    "path": "man/get_weights.Rd",
    "chars": 1723,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/get_methods.R\n\\name{get_weights}\n\\alias{ge"
  },
  {
    "path": "man/groups_names.Rd",
    "chars": 916,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/AllGenerics.R, R/set_methods.R\n\\name{group"
  },
  {
    "path": "man/impute.Rd",
    "chars": 1925,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/impute.R\n\\name{impute}\n\\alias{impute}\n\\tit"
  },
  {
    "path": "man/interpolate_factors.Rd",
    "chars": 1212,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/mefisto.R\n\\name{interpolate_factors}\n\\alia"
  },
  {
    "path": "man/load_model.Rd",
    "chars": 1842,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/load_model.R\n\\name{load_model}\n\\alias{load"
  },
  {
    "path": "man/make_example_data.Rd",
    "chars": 1454,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/make_example_data.R\n\\name{make_example_dat"
  },
  {
    "path": "man/pipe.Rd",
    "chars": 501,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/imports.R\n\\name{\\%>\\%}\n\\alias{\\%>\\%}\n\\titl"
  },
  {
    "path": "man/plot_alignment.Rd",
    "chars": 661,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/mefisto.R\n\\name{plot_alignment}\n\\alias{plo"
  },
  {
    "path": "man/plot_ascii_data.Rd",
    "chars": 746,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/plot_data.R\n\\name{plot_ascii_data}\n\\alias{"
  },
  {
    "path": "man/plot_data_heatmap.Rd",
    "chars": 3194,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/plot_data.R\n\\name{plot_data_heatmap}\n\\alia"
  },
  {
    "path": "man/plot_data_overview.Rd",
    "chars": 1509,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/plot_data.R\n\\name{plot_data_overview}\n\\ali"
  },
  {
    "path": "man/plot_data_scatter.Rd",
    "chars": 3487,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/plot_data.R\n\\name{plot_data_scatter}\n\\alia"
  },
  {
    "path": "man/plot_data_vs_cov.Rd",
    "chars": 4215,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/mefisto.R\n\\name{plot_data_vs_cov}\n\\alias{p"
  },
  {
    "path": "man/plot_dimred.Rd",
    "chars": 3369,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/dimensionality_reduction.R\n\\name{plot_dimr"
  },
  {
    "path": "man/plot_enrichment.Rd",
    "chars": 853,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/enrichment.R\n\\name{plot_enrichment}\n\\alias"
  },
  {
    "path": "man/plot_enrichment_detailed.Rd",
    "chars": 1317,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/enrichment.R\n\\name{plot_enrichment_detaile"
  },
  {
    "path": "man/plot_enrichment_heatmap.Rd",
    "chars": 950,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/enrichment.R\n\\name{plot_enrichment_heatmap"
  },
  {
    "path": "man/plot_factor.Rd",
    "chars": 4624,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/plot_factors.R\n\\name{plot_factor}\n\\alias{p"
  },
  {
    "path": "man/plot_factor_cor.Rd",
    "chars": 1501,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/plot_factors.R\n\\name{plot_factor_cor}\n\\ali"
  },
  {
    "path": "man/plot_factors.Rd",
    "chars": 3026,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/plot_factors.R\n\\name{plot_factors}\n\\alias{"
  },
  {
    "path": "man/plot_factors_vs_cov.Rd",
    "chars": 3628,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/mefisto.R\n\\name{plot_factors_vs_cov}\n\\alia"
  },
  {
    "path": "man/plot_group_kernel.Rd",
    "chars": 1437,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/mefisto.R\n\\name{plot_group_kernel}\n\\alias{"
  },
  {
    "path": "man/plot_interpolation_vs_covariate.Rd",
    "chars": 1141,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/mefisto.R\n\\name{plot_interpolation_vs_cova"
  },
  {
    "path": "man/plot_sharedness.Rd",
    "chars": 844,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/mefisto.R\n\\name{plot_sharedness}\n\\alias{pl"
  },
  {
    "path": "man/plot_smoothness.Rd",
    "chars": 955,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/mefisto.R\n\\name{plot_smoothness}\n\\alias{pl"
  },
  {
    "path": "man/plot_top_weights.Rd",
    "chars": 1909,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/plot_weights.R\n\\name{plot_top_weights}\n\\al"
  },
  {
    "path": "man/plot_variance_explained.Rd",
    "chars": 2439,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/calculate_variance_explained.R\n\\name{plot_"
  },
  {
    "path": "man/plot_variance_explained_by_covariates.Rd",
    "chars": 2184,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/mefisto.R\n\\name{plot_variance_explained_by"
  },
  {
    "path": "man/plot_variance_explained_per_feature.Rd",
    "chars": 1879,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/calculate_variance_explained.R\n\\name{plot_"
  },
  {
    "path": "man/plot_weights.Rd",
    "chars": 3557,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/plot_weights.R\n\\name{plot_weights}\n\\alias{"
  },
  {
    "path": "man/plot_weights_heatmap.Rd",
    "chars": 1562,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/plot_weights.R\n\\name{plot_weights_heatmap}"
  },
  {
    "path": "man/plot_weights_scatter.Rd",
    "chars": 2584,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/plot_weights.R\n\\name{plot_weights_scatter}"
  },
  {
    "path": "man/predict.Rd",
    "chars": 1537,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/predict.R\n\\name{predict}\n\\alias{predict}\n\\"
  },
  {
    "path": "man/prepare_mofa.Rd",
    "chars": 2386,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/prepare_mofa.R\n\\name{prepare_mofa}\n\\alias{"
  },
  {
    "path": "man/run_enrichment.Rd",
    "chars": 3218,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/enrichment.R\n\\name{run_enrichment}\n\\alias{"
  },
  {
    "path": "man/run_mofa.Rd",
    "chars": 2033,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/run_mofa.R\n\\name{run_mofa}\n\\alias{run_mofa"
  },
  {
    "path": "man/run_tsne.Rd",
    "chars": 1574,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/dimensionality_reduction.R\n\\name{run_tsne}"
  },
  {
    "path": "man/run_umap.Rd",
    "chars": 2446,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/dimensionality_reduction.R\n\\name{run_umap}"
  },
  {
    "path": "man/samples_metadata.Rd",
    "chars": 1031,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/AllGenerics.R, R/set_methods.R\n\\name{sampl"
  },
  {
    "path": "man/samples_names.Rd",
    "chars": 909,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/AllGenerics.R, R/set_methods.R\n\\name{sampl"
  },
  {
    "path": "man/select_model.Rd",
    "chars": 663,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/compare_models.R\n\\name{select_model}\n\\alia"
  },
  {
    "path": "man/set_covariates.Rd",
    "chars": 1326,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/mefisto.R\n\\name{set_covariates}\n\\alias{set"
  },
  {
    "path": "man/subset_factors.Rd",
    "chars": 848,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/subset.R\n\\name{subset_factors}\n\\alias{subs"
  },
  {
    "path": "man/subset_features.Rd",
    "chars": 591,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/subset.R\n\\name{subset_features}\n\\alias{sub"
  },
  {
    "path": "man/subset_groups.Rd",
    "chars": 710,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/subset.R\n\\name{subset_groups}\n\\alias{subse"
  },
  {
    "path": "man/subset_samples.Rd",
    "chars": 662,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/subset.R\n\\name{subset_samples}\n\\alias{subs"
  },
  {
    "path": "man/subset_views.Rd",
    "chars": 698,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/subset.R\n\\name{subset_views}\n\\alias{subset"
  },
  {
    "path": "man/summarise_factors.Rd",
    "chars": 1200,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/correlate_covariates.R\n\\name{summarise_fac"
  },
  {
    "path": "man/views_names.Rd",
    "chars": 898,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/AllGenerics.R, R/set_methods.R\n\\name{views"
  },
  {
    "path": "setup.py",
    "chars": 964,
    "preview": "import sys\nimport os\nfrom setuptools import setup\nfrom setuptools import find_packages\n\nexec(open(os.path.join(os.path.d"
  },
  {
    "path": "tests/testthat/barcodes.tsv",
    "chars": 2414,
    "preview": "CACCGGGACGTGTA-1\nCGTGTAGAGTTCAG-1\nCCTGCAACACGTTG-1\nCCGATAGACCTAAG-1\nGATATAACACGCAT-1\nTACTACTGATGTCG-1\nAGCCTCACTGTCAG-1\nC"
  },
  {
    "path": "tests/testthat/genes.tsv",
    "chars": 5045,
    "preview": "ENSGXXXXXX\tTDRG1\nENSGXXXXXX\tTCTE1\nENSGXXXXXX\tCCDC106\nENSGXXXXXX\tTIGD6\nENSGXXXXXX\tMSANTD3-TMEFF1\nENSGXXXXXX\tQRSL1\nENSGXXX"
  },
  {
    "path": "tests/testthat/matrix.csv",
    "chars": 96074,
    "preview": "V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,V11,V12,V13,V14,V15,V16,V17,V18,V19,V20,V21,V22,V23,V24,V25,V26,V27,V28,V29,V30,V31,V32,V"
  },
  {
    "path": "tests/testthat/matrix.mtx",
    "chars": 9620,
    "preview": "%%MatrixMarket matrix coordinate integer general\n%\n252 142 1059\n42 1 1\n43 1 1\n87 1 1\n92 1 37\n217 1 1\n233 1 1\n92 2 11\n160"
  },
  {
    "path": "tests/testthat/test_create_model.R",
    "chars": 2580,
    "preview": "context(\"Creating the model from different objects\")\nlibrary(MOFA2)\n\ntest_that(\"a model can be created from a list of ma"
  },
  {
    "path": "tests/testthat/test_load_model.R",
    "chars": 220,
    "preview": "context(\"Loading the model\")\nlibrary(MOFA2)\n\ntest_that(\"a pre-trained model can be loaded from disk\", {\n  filepath <- sy"
  },
  {
    "path": "tests/testthat/test_plot.R",
    "chars": 1395,
    "preview": "context(\"Making plots\")\nlibrary(MOFA2)\n\nfilepath <- system.file(\"extdata\", \"model.hdf5\", package = \"MOFA2\")\ntest_mofa2 <"
  },
  {
    "path": "tests/testthat/test_prepare_model.R",
    "chars": 2174,
    "preview": "context(\"Prepare the model from different objects\")\nlibrary(MOFA2)\n\n\ntest_that(\"a MOFA model can be prepared from a list"
  },
  {
    "path": "tests/testthat.R",
    "chars": 111,
    "preview": "library(testthat)\nlibrary(MOFA2)\n\ntest_check(\"MOFA2\")\n\n# setwd(\"/Users/rargelaguet/mofa/MOFA2/tests/testthat\")\n"
  },
  {
    "path": "vignettes/MEFISTO_temporal.Rmd",
    "chars": 7181,
    "preview": "---\ntitle: \"Illustration of MEFISTO on simulated data with a temporal covariate\"\nauthor:\n- name: \"Britta Velten\"\n  affil"
  },
  {
    "path": "vignettes/downstream_analysis.Rmd",
    "chars": 9221,
    "preview": "---\ntitle: \"MOFA+: downstream analysis in R\"\nauthor:\n- name: \"Ricard Argelaguet\"\n  affiliation: \"European Bioinformatics"
  },
  {
    "path": "vignettes/getting_started_R.Rmd",
    "chars": 8169,
    "preview": "---\ntitle: \"MOFA2: training a model in R\"\nauthor:\n- name: \"Ricard Argelaguet\"\n  affiliation: \"European Bioinformatics In"
  }
]

// ... and 1 more files (download for full content)

About this extraction

This page contains the full source code of the bioFAM/MOFA2 GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 143 files (721.6 KB), approximately 220.9k tokens, and a symbol index with 1 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!