Full Code of facebookresearch/CPA for AI

main 50e283a7a1b3 cached
30 files
36.5 MB
755.7k tokens
103 symbols
1 requests
Download .txt
Showing preview only (3,023K chars total). Download the full file or copy to clipboard to get everything.
Repository: facebookresearch/CPA
Branch: main
Commit: 50e283a7a1b3
Files: 30
Total size: 36.5 MB

Directory structure:
gitextract_97iggw1v/

├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── cpa/
│   ├── __init__.py
│   ├── api.py
│   ├── data.py
│   ├── helper.py
│   ├── model.py
│   ├── plotting.py
│   └── train.py
├── datasets/
│   └── .gitkeep
├── notebooks/
│   └── demo.ipynb
├── preprocessing/
│   ├── GSM.ipynb
│   ├── Norman19.ipynb
│   ├── cross_species.ipynb
│   ├── lincs.ipynb
│   ├── pachter.ipynb
│   ├── sciplex3.ipynb
│   └── sciplex3_round_robin.ipynb
├── pretrained_models/
│   ├── .gitattributes
│   └── .gitkeep
├── requirements.txt
├── scripts/
│   ├── .gitkeep
│   ├── run_collect_results.sh
│   ├── run_one_epoch.sh
│   └── run_sweeps.sh
├── setup.py
└── tests/
    └── test.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
__pycache__
*.pyc
*.egg-info/
*.ipynb_checkpoints/
*.pt

================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Code of Conduct

Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
Please read the [full text](https://code.fb.com/codeofconduct/)
so that you can understand what actions will and will not be tolerated.


================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to `CPA` 
We want to make contributing to this project as easy and transparent as
possible.

## Pull Requests
We actively welcome your pull requests.

1. Fork the repo and create your branch from `master`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").

## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.

Complete your CLA here: <https://code.facebook.com/cla>

## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.

Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.

## License
By contributing to `CPA`, you agree that your contributions
will be licensed under the LICENSE file in the root directory of this source
tree.


================================================
FILE: LICENSE
================================================
The MIT License

Copyright (c) Facebook, Inc. and its affiliates.

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


================================================
FILE: README.md
================================================
# CPA - Compositional Perturbation Autoencoder

# This code in not being maintained anymore, please use the new implementation [here](https://github.com/theislab/cpa).


## What is CPA?
![Screenshot](Figure1.png)

`CPA` is a framework to learn effects of perturbations at the single-cell level. CPA encodes and learns phenotypic drug response across different cell types, doses and drug combinations. CPA allows:

* Out-of-distribution predicitons of unseen drug combinations at various doses and among different cell types.
* Learn interpretable drug and cell type latent spaces.
* Estimate dose response curve for each perturbation and their combinations.
* Access the uncertainty of the estimations of the model.

## Package Structure

The repository is centered around the `cpa` module:

* [`cpa.train`](cpa/train.py) contains scripts to train the model.
* [`cpa.api`](cpa/api.py) contains user friendly scripts to interact with the model via scanpy.
* [`cpa.plotting`](cpa/plotting.py) contains scripts to plotting functions.
* [`cpa.model`](cpa/model.py) contains modules of cpa model.
* [`cpa.data`](cpa/data.py) contains data loader, which transforms anndata structure to a class compatible with cpa model.

Additional files and folders:

* [`datasets`](datasets/) contains both versions of the data: raw and pre-processed.
* [`preprocessing`](preprocessing/) contains notebooks to reproduce the datasets pre-processing from raw data.

## Usage

- As a first step, download the contents of `datasets/` and `pretrained_models/` from [this tarball](https://dl.fbaipublicfiles.com/dlp/cpa_binaries.tar).


To learn how to use this repository, check 
[`./notebooks/demo.ipynb`](notebooks/demo.ipynb), and the following scripts:


* Note that hyperparameters in the `demo.ipynb` are set as default but might not work work for new datasets.
## Examples and Reproducibility
you can find more example and  hyperparamters tuning scripts and also reproducbility notebooks for the plots in the paper in the [`reproducibility`](https://github.com/theislab/cpa-reproducibility) repo.

## Curation of your own data to train CPA

* To prepare your data to train CPA, you need to add specific fields to adata object and perfrom data split. Examples on how to add 
necessary fields for multiple datasets used in the paper can be found in [`preprocessing/`](/https://github.com/facebookresearch/CPA/tree/master/preprocessing) folder.

## Training a model

There are two ways to train a cpa model:

* Using the command line, e.g.: `python -m cpa.train --data datasets/GSM_new.h5ad  --save_dir /tmp --max_epochs 1 --doser_type sigm`
* From jupyter notebook: example in [`./notebooks/demo.ipynb`](notebooks/demo.ipynb)


## Documentation

Currently you can access the documentation via `help` function in IPython. For example:

```python
from cpa.api import API

help(API)

from cpa.plotting import CPAVisuals

help(CPAVisuals)

```

A separate page with the documentation is coming soon.

## Support and contribute

If you have a question or noticed a problem, you can post an [`issue`](https://github.com/facebookresearch/CPA/issues/new).

## Reference

Please cite the following publication if you find CPA useful in your research.
```
@article{lotfollahi2023predicting,
  title={Predicting cellular responses to complex perturbations in high-throughput screens},
  author={Lotfollahi, Mohammad and Klimovskaia Susmelj, Anna and De Donno, Carlo and Hetzel, Leon and Ji, Yuge and Ibarra, Ignacio L and Srivatsan, Sanjay R and Naghipourfar, Mohsen and Daza, Riza M and Martin, Beth and others},
  journal={Molecular Systems Biology},
  pages={e11517},
  year={2023}
}
```

The paper titled **Predicting cellular responses to complex perturbations in high-throughput screens** can be found [here](https://www.biorxiv.org/content/10.1101/2021.04.14.439903v2](https://www.embopress.org/doi/full/10.15252/msb.202211517).
## License

This source code is released under the MIT license, included [here](LICENSE).


================================================
FILE: cpa/__init__.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from cpa.api import API
from cpa.plotting import CPAVisuals


================================================
FILE: cpa/api.py
================================================
import copy
import itertools
import os
import pprint
import time
from collections import defaultdict
from typing import Optional, Union, Tuple

import numpy as np
import pandas as pd
import scanpy as sc
import torch
from torch.distributions import (
    NegativeBinomial,
    Normal
)
from cpa.train import evaluate, prepare_cpa
from cpa.helper import _convert_mean_disp_to_counts_logits
from sklearn.metrics import r2_score
from sklearn.metrics.pairwise import cosine_distances, euclidean_distances
from tqdm import tqdm

class API:
    """
    API for CPA model to make it compatible with scanpy.
    """

    def __init__(
        self,
        data,
        perturbation_key="condition",
        covariate_keys=["cell_type"],
        split_key="split",
        dose_key="dose_val",
        control=None,
        doser_type="mlp",
        decoder_activation="linear",
        loss_ae="gauss",
        patience=200,
        seed=0,
        pretrained=None,
        device="cuda",
        save_dir="/tmp/",  # directory to save the model
        hparams={},
        only_parameters=False,
    ):
        """
        Parameters
        ----------
        data : str or `AnnData`
            AnndData object or a full path to the file in the .h5ad format.
        covariate_keys : list (default: ['cell_type'])
            List of names in the .obs of AnnData that should be used as
            covariates.
        split_key : str (default: 'split')
            Name of the column in .obs of AnnData to use for splitting the
            dataset into train, test and validation.
        perturbation_key : str (default: 'condition')
            Name of the column in .obs of AnnData to use for perturbation
            variable.
        dose_key : str (default: 'dose_val')
            Name of the column in .obs of AnnData to use for continious
            covariate.
        doser_type : str (default: 'mlp')
            Type of the nonlinearity in the latent space for the continious
            covariate encoding: sigm, logsigm, mlp.
        decoder_activation : str (default: 'linear')
            Last layer of the decoder.
        loss_ae : str (default: 'gauss')
            Loss (currently only gaussian loss is supported).
        patience : int (default: 200)
            Patience for early stopping.
        seed : int (default: 0)
            Random seed.
        pretrained : str (default: None)
            Full path to the pretrained model.
        only_parameters : bool (default: False)
            Whether to load only arguments or also weights from pretrained model.
        save_dir : str (default: '/tmp/')
            Folder to save the model.
        device : str (default: 'cpu')
            Device for model computations. If None, will try to use CUDA if
            available.
        hparams : dict (default: {})
            Parameters for the architecture of the CPA model.
        control: str
            Obs columns with booleans that identify control. If it is not provided
            the model will look for them in adata.obs["control"]
        """

        args = locals()
        del args["self"]

        if not (pretrained is None):
            state, self.used_args, self.history = torch.load(
                pretrained, map_location=torch.device(device)
            )
            self.args = self.used_args
            self.args["data"] = data
            self.args["covariate_keys"] = covariate_keys
            self.args["device"] = device
            self.args["control"] = control
            if only_parameters:
                state = None
                print(f"Loaded ARGS of the model from:\t{pretrained}")
            else:
                print(f"Loaded pretrained model from:\t{pretrained}")
        else:
            state = None
            self.args = args

        self.model, self.datasets = prepare_cpa(self.args, state_dict=state)
        if not (pretrained is None) and (not only_parameters):
            self.model.history = self.history
        self.args["save_dir"] = save_dir
        self.args["hparams"] = self.model.hparams

        if not (save_dir is None):
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

        dataset = self.datasets["training"]
        self.perturbation_key = dataset.perturbation_key
        self.dose_key = dataset.dose_key
        self.covariate_keys = covariate_keys  # very important, specifies the order of
        # covariates during training
        self.min_dose = dataset.drugs[dataset.drugs > 0].min().item()
        self.max_dose = dataset.drugs[dataset.drugs > 0].max().item()

        self.var_names = dataset.var_names

        self.unique_perts = list(dataset.perts_dict.keys())

        self.unique_covars = {}
        for cov in dataset.covars_dict:
            self.unique_covars[cov] = list(dataset.covars_dict[cov].keys())
        self.num_drugs = dataset.num_drugs

        self.perts_dict = dataset.perts_dict
        self.covars_dict = dataset.covars_dict

        self.drug_ohe = torch.Tensor(list(dataset.perts_dict.values()))

        self.covars_ohe = {}
        for cov in dataset.covars_dict:
            self.covars_ohe[cov] = torch.LongTensor(
                list(dataset.covars_dict[cov].values())
            )

        self.emb_covars = {}
        for cov in dataset.covars_dict:
            self.emb_covars[cov] = None
        self.emb_perts = None
        self.seen_covars_perts = None
        self.comb_emb = None
        self.control_cat = None

        self.seen_covars_perts = {}
        for k in self.datasets.keys():
            self.seen_covars_perts[k] = np.unique(self.datasets[k].pert_categories)

        self.measured_points = {}
        self.num_measured_points = {}
        for k in self.datasets.keys():
            self.measured_points[k] = {}
            self.num_measured_points[k] = {}
            for pert in np.unique(self.datasets[k].pert_categories):
                num_points = len(np.where(self.datasets[k].pert_categories == pert)[0])
                self.num_measured_points[k][pert] = num_points

                *cov_list, drug, dose = pert.split("_")
                cov = "_".join(cov_list)
                if not ("+" in dose):
                    dose = float(dose)
                if cov in self.measured_points[k].keys():
                    if drug in self.measured_points[k][cov].keys():
                        self.measured_points[k][cov][drug].append(dose)
                    else:
                        self.measured_points[k][cov][drug] = [dose]
                else:
                    self.measured_points[k][cov] = {drug: [dose]}

        self.measured_points["all"] = copy.deepcopy(self.measured_points["training"])
        for cov in self.measured_points["ood"].keys():
            for pert in self.measured_points["ood"][cov].keys():
                if pert in self.measured_points["training"][cov].keys():
                    self.measured_points["all"][cov][pert] = (
                        self.measured_points["training"][cov][pert].copy()
                        + self.measured_points["ood"][cov][pert].copy()
                    )
                else:
                    self.measured_points["all"][cov][pert] = self.measured_points[
                        "ood"
                    ][cov][pert].copy()

    def load_from_old(self, pretrained):
        """
        Parameters
        ----------
        pretrained : str
            Full path to the pretrained model.
        """
        print(f"Loaded pretrained model from:\t{pretrained}")
        state, self.used_args, self.history = torch.load(
            pretrained, map_location=torch.device(self.args["device"])
        )
        self.model.load_state_dict(state_dict)
        self.model.history = self.history

    def print_args(self):
        pprint.pprint(self.args)

    def load(self, pretrained):
        """
        Parameters
        ----------
        pretrained : str
            Full path to the pretrained model.
        """  # TODO fix compatibility
        print(f"Loaded pretrained model from:\t{pretrained}")
        state, self.used_args, self.history = torch.load(
            pretrained, map_location=torch.device(self.args["device"])
        )
        self.model.load_state_dict(state_dict)

    def train(
        self,
        max_epochs=1,
        checkpoint_freq=20,
        run_eval=False,
        max_minutes=60,
        filename="model.pt",
        batch_size=None,
        save_dir=None,
        seed=0,
    ):
        """
        Parameters
        ----------
        max_epochs : int (default: 1)
            Maximum number epochs for training.
        checkpoint_freq : int (default: 20)
            Checkoint frequencty to save intermediate results.
        run_eval : bool (default: False)
            Whether or not to run disentanglement and R2 evaluation during training.
        max_minutes : int (default: 60)
            Maximum computation time in minutes.
        filename : str (default: 'model.pt')
            Name of the file without the directoty path to save the model.
            Name should be with .pt extension.
        batch_size : int, optional (default: None)
            Batch size for training. If None, uses default batch size specified
            in hparams.
        save_dir : str, optional (default: None)
            Full path to the folder to save the model. If None, will use from
            the path specified during init.
        seed : int (default: None)
            Random seed. If None, uses default random seed specified during init.
        """
        args = locals()
        del args["self"]

        if batch_size is None:
            batch_size = self.model.hparams["batch_size"]
            args["batch_size"] = batch_size
            self.args["batch_size"] = batch_size

        if save_dir is None:
            save_dir = self.args["save_dir"]
        print("Results will be saved to the folder:", save_dir)

        self.datasets.update(
            {
                "loader_tr": torch.utils.data.DataLoader(
                    self.datasets["training"], batch_size=batch_size, shuffle=True
                )
            }
        )

        self.model.train()

        start_time = time.time()
        pbar = tqdm(range(max_epochs), ncols=80)
        try:
            for epoch in pbar:
                epoch_training_stats = defaultdict(float)

                for data in self.datasets["loader_tr"]:
                    genes, drugs, covariates = data[0], data[1], data[2:]
                    minibatch_training_stats = self.model.update(
                        genes, drugs, covariates
                    )

                    for key, val in minibatch_training_stats.items():
                        epoch_training_stats[key] += val

                for key, val in epoch_training_stats.items():
                    epoch_training_stats[key] = val / len(self.datasets["loader_tr"])
                    if not (key in self.model.history.keys()):
                        self.model.history[key] = []
                    self.model.history[key].append(epoch_training_stats[key])
                self.model.history["epoch"].append(epoch)

                ellapsed_minutes = (time.time() - start_time) / 60
                self.model.history["elapsed_time_min"] = ellapsed_minutes

                # decay learning rate if necessary
                # also check stopping condition: patience ran out OR
                # time ran out OR max epochs achieved
                stop = ellapsed_minutes > max_minutes or (epoch == max_epochs - 1)

                pbar.set_description(
                    f"Rec: {epoch_training_stats['loss_reconstruction']:.4f}, "
                    + f"AdvPert: {epoch_training_stats['loss_adv_drugs']:.2f}, "
                    + f"AdvCov: {epoch_training_stats['loss_adv_covariates']:.2f}"
                )

                if (epoch % checkpoint_freq) == 0 or stop:
                    if run_eval == True:
                        evaluation_stats = evaluate(self.model, self.datasets)
                        for key, val in evaluation_stats.items():
                            if not (key in self.model.history.keys()):
                                self.model.history[key] = []
                            self.model.history[key].append(val)
                        self.model.history["stats_epoch"].append(epoch)
                        stop = stop or self.model.early_stopping(
                            np.mean(evaluation_stats["test"])
                        )
                    else:
                        stop = stop or self.model.early_stopping(
                            np.mean(epoch_training_stats["test"])
                        )
                        evaluation_stats = None

                    if stop:
                        self.save(f"{save_dir}{filename}")
                        pprint.pprint(
                            {
                                "epoch": epoch,
                                "training_stats": epoch_training_stats,
                                "evaluation_stats": evaluation_stats,
                                "ellapsed_minutes": ellapsed_minutes,
                            }
                        )

                        print(f"Stop epoch: {epoch}")
                        break


        except KeyboardInterrupt:
            self.save(f"{save_dir}{filename}")

        self.save(f"{save_dir}{filename}")

    def save(self, filename):
        """
        Parameters
        ----------
        filename : str
            Full path to save pretrained model.
        """
        torch.save((self.model.state_dict(), self.args, self.model.history), filename)
        self.history = self.model.history
        print(f"Model saved to: {filename}")

    def _init_pert_embeddings(self):
        dose = 1.0
        self.emb_perts = (
            self.model.compute_drug_embeddings_(
                dose * self.drug_ohe.to(self.model.device)
            )
            .cpu()
            .clone()
            .detach()
            .numpy()
        )

    def get_drug_embeddings(self, dose=1.0, return_anndata=True):
        """
        Parameters
        ----------
        dose : int (default: 1.0)
            Dose at which to evaluate latent embedding vector.
        return_anndata : bool, optional (default: True)
            Return embedding wrapped into anndata object.

        Returns
        -------
        If return_anndata is True, returns anndata object. Otherwise, doesn't
        return anything. Always saves embeddding in self.emb_perts.
        """
        self._init_pert_embeddings()

        emb_perts = (
            self.model.compute_drug_embeddings_(
                dose * self.drug_ohe.to(self.model.device)
            )
            .cpu()
            .clone()
            .detach()
            .numpy()
        )

        if return_anndata:
            adata = sc.AnnData(emb_perts)
            adata.obs[self.perturbation_key] = self.unique_perts
            return adata

    def _init_covars_embeddings(self):
        combo_list = []
        for covars_key in self.covariate_keys:
            combo_list.append(self.unique_covars[covars_key])
            if self.emb_covars[covars_key] is None:
                i_cov = self.covariate_keys.index(covars_key)
                self.emb_covars[covars_key] = dict(
                    zip(
                        self.unique_covars[covars_key],
                        self.model.covariates_embeddings[i_cov](
                            self.covars_ohe[covars_key].to(self.model.device).argmax(1)
                        )
                        .cpu()
                        .clone()
                        .detach()
                        .numpy(),
                    )
                )
        self.emb_covars_combined = {}
        for combo in list(itertools.product(*combo_list)):
            combo_name = "_".join(combo)
            for i, cov in enumerate(combo):
                covars_key = self.covariate_keys[i]
                if i == 0:
                    emb = self.emb_covars[covars_key][cov]
                else:
                    emb += self.emb_covars[covars_key][cov]
            self.emb_covars_combined[combo_name] = emb

    def get_covars_embeddings_combined(self, return_anndata=True):
        """
        Parameters
        ----------
        return_anndata : bool, optional (default: True)
            Return embedding wrapped into anndata object.

        Returns
        -------
        If return_anndata is True, returns anndata object. Otherwise, doesn't
        return anything. Always saves embeddding in self.emb_covars.
        """
        self._init_covars_embeddings()
        if return_anndata:
            adata = sc.AnnData(np.array(list(self.emb_covars_combined.values())))
            adata.obs["covars"] = self.emb_covars_combined.keys()
            return adata

    def get_covars_embeddings(self, covars_tgt, return_anndata=True):
        """
        Parameters
        ----------
        covars_tgt : str
            Name of covariate for which to return AnnData
        return_anndata : bool, optional (default: True)
            Return embedding wrapped into anndata object.

        Returns
        -------
        If return_anndata is True, returns anndata object. Otherwise, doesn't
        return anything. Always saves embeddding in self.emb_covars.
        """
        self._init_covars_embeddings()

        if return_anndata:
            adata = sc.AnnData(np.array(list(self.emb_covars[covars_tgt].values())))
            adata.obs[covars_tgt] = self.emb_covars[covars_tgt].keys()
            return adata

    def _get_drug_encoding(self, drugs, doses=None):
        """
        Parameters
        ----------
        drugs : str
            Drugs combination as a string, where individual drugs are separated
            with a plus.
        doses : str, optional (default: None)
            Doses corresponding to the drugs combination as a string. Individual
            drugs are separated with a plus.

        Returns
        -------
        One hot encodding for a mixture of drugs.
        """

        drug_mix = np.zeros([1, self.num_drugs])
        atomic_drugs = drugs.split("+")
        doses = str(doses)

        if doses is None:
            doses_list = [1.0] * len(atomic_drugs)
        else:
            doses_list = [float(d) for d in str(doses).split("+")]
        for j, drug in enumerate(atomic_drugs):
            drug_mix += doses_list[j] * self.perts_dict[drug]

        return drug_mix

    def mix_drugs(self, drugs_list, doses_list=None, return_anndata=True):
        """
        Gets a list of drugs combinations to mix, e.g. ['A+B', 'B+C'] and
        corresponding doses.

        Parameters
        ----------
        drugs_list : list
            List of drug combinations, where each drug combination is a string.
            Individual drugs in the combination are separated with a plus.
        doses_list : str, optional (default: None)
            List of corresponding doses, where each dose combination is a string.
            Individual doses in the combination are separated with a plus.
        return_anndata : bool, optional (default: True)
            Return embedding wrapped into anndata object.

        Returns
        -------
        If return_anndata is True, returns anndata structure of the combinations,
        otherwise returns a np.array of corresponding embeddings.
        """

        drug_mix = np.zeros([len(drugs_list), self.num_drugs])
        for i, drug_combo in enumerate(drugs_list):
            drug_mix[i] = self._get_drug_encoding(drug_combo, doses=doses_list[i])

        emb = (
            self.model.compute_drug_embeddings_(
                torch.Tensor(drug_mix).to(self.model.device)
            )
            .cpu()
            .clone()
            .detach()
            .numpy()
        )

        if return_anndata:
            adata = sc.AnnData(emb)
            adata.obs[self.perturbation_key] = drugs_list
            adata.obs[self.dose_key] = doses_list
            return adata
        else:
            return emb

    def latent_dose_response(
        self, perturbations=None, dose=None, contvar_min=0, contvar_max=1, n_points=100
    ):
        """
        Parameters
        ----------
        perturbations : list
            List containing two names for which to return complete pairwise
            dose-response.
        doses : np.array (default: None)
            Doses values. If None, default values will be generated on a grid:
            n_points in range [contvar_min, contvar_max].
        contvar_min : float (default: 0)
            Minimum dose value to generate for default option.
        contvar_max : float (default: 0)
            Maximum dose value to generate for default option.
        n_points : int (default: 100)
            Number of dose points to generate for default option.
        Returns
        -------
        pd.DataFrame
        """
        # dosers work only for atomic drugs. TODO add drug combinations
        self.model.eval()

        if perturbations is None:
            perturbations = self.unique_perts

        if dose is None:
            dose = np.linspace(contvar_min, contvar_max, n_points)
        n_points = len(dose)

        df = pd.DataFrame(columns=[self.perturbation_key, self.dose_key, "response"])
        for drug in perturbations:
            d = np.where(self.perts_dict[drug] == 1)[0][0]
            this_drug = torch.Tensor(dose).to(self.model.device).view(-1, 1)
            if self.model.doser_type == "mlp":
                response = (
                    (self.model.dosers[d](this_drug).sigmoid() * this_drug.gt(0))
                    .cpu()
                    .clone()
                    .detach()
                    .numpy()
                    .reshape(-1)
                )
            else:
                response = (
                    self.model.dosers.one_drug(this_drug.view(-1), d)
                    .cpu()
                    .clone()
                    .detach()
                    .numpy()
                    .reshape(-1)
                )

            df_drug = pd.DataFrame(
                list(zip([drug] * n_points, dose, list(response))),
                columns=[self.perturbation_key, self.dose_key, "response"],
            )
            df = pd.concat([df, df_drug])

        return df

    def latent_dose_response2D(
        self,
        perturbations,
        dose=None,
        contvar_min=0,
        contvar_max=1,
        n_points=100,
    ):
        """
        Parameters
        ----------
        perturbations : list, optional (default: None)
            List of atomic drugs for which to return latent dose response.
            Currently drug combinations are not supported.
        doses : np.array (default: None)
            Doses values. If None, default values will be generated on a grid:
            n_points in range [contvar_min, contvar_max].
        contvar_min : float (default: 0)
            Minimum dose value to generate for default option.
        contvar_max : float (default: 0)
            Maximum dose value to generate for default option.
        n_points : int (default: 100)
            Number of dose points to generate for default option.
        Returns
        -------
        pd.DataFrame
        """
        # dosers work only for atomic drugs. TODO add drug combinations

        assert len(perturbations) == 2, "You should provide a list of 2 perturbations."

        self.model.eval()

        if dose is None:
            dose = np.linspace(contvar_min, contvar_max, n_points)
        n_points = len(dose)

        df = pd.DataFrame(columns=perturbations + ["response"])
        response = {}

        for drug in perturbations:
            d = np.where(self.perts_dict[drug] == 1)[0][0]
            this_drug = torch.Tensor(dose).to(self.model.device).view(-1, 1)
            if self.model.doser_type == "mlp":
                response[drug] = (
                    (self.model.dosers[d](this_drug).sigmoid() * this_drug.gt(0))
                    .cpu()
                    .clone()
                    .detach()
                    .numpy()
                    .reshape(-1)
                )
            else:
                response[drug] = (
                    self.model.dosers.one_drug(this_drug.view(-1), d)
                    .cpu()
                    .clone()
                    .detach()
                    .numpy()
                    .reshape(-1)
                )

        l = 0
        for i in range(len(dose)):
            for j in range(len(dose)):
                df.loc[l] = [
                    dose[i],
                    dose[j],
                    response[perturbations[0]][i] + response[perturbations[1]][j],
                ]
                l += 1

        return df

    def compute_comb_emb(self, thrh=30):
        """
        Generates an AnnData object containing all the latent vectors of the
        cov+dose*pert combinations seen during training.
        Called in api.compute_uncertainty(), stores the AnnData in self.comb_emb.

        Parameters
        ----------
        Returns
        -------
        """
        if self.seen_covars_perts["training"] is None:
            raise ValueError("Need to run parse_training_conditions() first!")

        emb_covars = self.get_covars_embeddings_combined(return_anndata=True)

        # Generate adata with all cov+pert latent vect combinations
        tmp_ad_list = []
        for cov_pert in self.seen_covars_perts["training"]:
            if self.num_measured_points["training"][cov_pert] > thrh:
                *cov_list, pert_loop, dose_loop = cov_pert.split("_")
                cov_loop = "_".join(cov_list)
                emb_perts_loop = []
                if "+" in pert_loop:
                    pert_loop_list = pert_loop.split("+")
                    dose_loop_list = dose_loop.split("+")
                    for _dose in pd.Series(dose_loop_list).unique():
                        tmp_ad = self.get_drug_embeddings(dose=float(_dose))
                        tmp_ad.obs["pert_dose"] = tmp_ad.obs.condition + "_" + _dose
                        emb_perts_loop.append(tmp_ad)

                    emb_perts_loop = emb_perts_loop[0].concatenate(emb_perts_loop[1:])
                    X = emb_covars.X[
                        emb_covars.obs.covars == cov_loop
                    ] + np.expand_dims(
                        emb_perts_loop.X[
                            emb_perts_loop.obs.pert_dose.isin(
                                [
                                    pert_loop_list[i] + "_" + dose_loop_list[i]
                                    for i in range(len(pert_loop_list))
                                ]
                            )
                        ].sum(axis=0),
                        axis=0,
                    )
                    if X.shape[0] > 1:
                        raise ValueError("Error with comb computation")
                else:
                    emb_perts = self.get_drug_embeddings(dose=float(dose_loop))
                    X = (
                        emb_covars.X[emb_covars.obs.covars == cov_loop]
                        + emb_perts.X[emb_perts.obs.condition == pert_loop]
                    )
                tmp_ad = sc.AnnData(X=X)
                tmp_ad.obs["cov_pert"] = "_".join([cov_loop, pert_loop, dose_loop])
            tmp_ad_list.append(tmp_ad)

        self.comb_emb = tmp_ad_list[0].concatenate(tmp_ad_list[1:])

    def compute_uncertainty(self, cov, pert, dose, thrh=30):
        """
        Compute uncertainties for the queried covariate+perturbation combination.
        The distance from the closest condition in the training set is used as a
        proxy for uncertainty.

        Parameters
        ----------
        cov: dict
            Provide a value for each covariate (eg. cell_type) as a dictionaty
            for the queried uncertainty (e.g. cov_dict={'cell_type': 'A549'}).
        pert: string
            Perturbation for the queried uncertainty. In case of combinations the
            format has to be 'pertA+pertB'
        dose: string
            String which contains the dose of the perturbation queried. In case
            of combinations the format has to be 'doseA+doseB'

        Returns
        -------
        min_cos_dist: float
            Minimum cosine distance with the training set.
        min_eucl_dist: float
            Minimum euclidean distance with the training set.
        closest_cond_cos: string
            Closest training condition wrt cosine distances.
        closest_cond_eucl: string
            Closest training condition wrt euclidean distances.
        """

        if self.comb_emb is None:
            self.compute_comb_emb(thrh=30)

        drug_ohe = torch.Tensor(self._get_drug_encoding(pert, doses=dose)).to(
            self.model.device
        )

        pert = drug_ohe.expand([1, self.drug_ohe.shape[1]])

        drug_emb = self.model.compute_drug_embeddings_(pert).detach().cpu().numpy()

        cond_emb = drug_emb
        for cov_key in cov:
            cond_emb += self.emb_covars[cov_key][cov[cov_key]]

        cos_dist = cosine_distances(cond_emb, self.comb_emb.X)[0]
        min_cos_dist = np.min(cos_dist)
        cos_idx = np.argmin(cos_dist)
        closest_cond_cos = self.comb_emb.obs.cov_pert[cos_idx]

        eucl_dist = euclidean_distances(cond_emb, self.comb_emb.X)[0]
        min_eucl_dist = np.min(eucl_dist)
        eucl_idx = np.argmin(eucl_dist)
        closest_cond_eucl = self.comb_emb.obs.cov_pert[eucl_idx]

        return min_cos_dist, min_eucl_dist, closest_cond_cos, closest_cond_eucl

    def predict(
        self,
        genes,
        cov,
        pert,
        dose,
        uncertainty=True,
        return_anndata=True,
        sample=False,
        n_samples=1,
    ):
        """Predict values of control 'genes' conditions specified in df.

        Parameters
        ----------
        genes : np.array
            Control cells.
        cov: dict of lists
            Provide a value for each covariate (eg. cell_type) as a dictionaty
            for the queried uncertainty (e.g. cov_dict={'cell_type': 'A549'}).
        pert: list
            Perturbation for the queried uncertainty. In case of combinations the
            format has to be 'pertA+pertB'
        dose: list
            String which contains the dose of the perturbation queried. In case
            of combinations the format has to be 'doseA+doseB'

        uncertainty: bool (default: True)
            Compute uncertainties for the generated cells.
        return_anndata : bool, optional (default: True)
            Return embedding wrapped into anndata object.
        sample : bool (default: False)
            If sample is True, returns samples from gausssian distribution with
            mean and variance estimated by the model. Otherwise, returns just
            means and variances estimated by the model.
        n_samples : int (default: 10)
            Number of samples to sample if sampling is True.
        Returns
        -------
        If return_anndata is True, returns anndata structure. Otherwise, returns
        np.arrays for gene_means, gene_vars and a data frame for the corresponding
        conditions df_obs.

        """

        assert len(dose) == len(pert), "Check the length of pert, dose"
        for cov_key in cov:
            assert len(cov[cov_key]) == len(pert), "Check the length of covariates"

        df = pd.concat(
            [
                pd.DataFrame({self.perturbation_key: pert, self.dose_key: dose}),
                pd.DataFrame(cov),
            ],
            axis=1,
        )

        self.model.eval()
        num = genes.shape[0]
        dim = genes.shape[1]
        genes = torch.Tensor(genes).to(self.model.device)

        gene_means_list = []
        gene_vars_list = []
        df_list = []

        for i in range(len(df)):
            comb_name = pert[i]
            dose_name = dose[i]
            covar_name = {}
            for cov_key in cov:
                covar_name[cov_key] = cov[cov_key][i]

            drug_ohe = torch.Tensor(
                self._get_drug_encoding(comb_name, doses=dose_name)
            ).to(self.model.device)

            drugs = drug_ohe.expand([num, self.drug_ohe.shape[1]])

            covars = []
            for cov_key in self.covariate_keys:
                covar_ohe = torch.Tensor(
                    self.covars_dict[cov_key][covar_name[cov_key]]
                ).to(self.model.device)
                covars.append(covar_ohe.expand([num, covar_ohe.shape[0]]).clone())

            gene_reconstructions = (
                self.model.predict(genes, drugs, covars).cpu().clone().detach().numpy()
            )

            if sample:
                df_list.append(
                    pd.DataFrame(
                        [df.loc[i].values] * num * n_samples, columns=df.columns
                    )
                )
                if self.args['loss_ae'] == 'gauss':
                    dist = Normal(
                        torch.Tensor(gene_reconstructions[:, :dim]),
                        torch.Tensor(gene_reconstructions[:, dim:]),
                    )
                elif self.args['loss_ae'] == 'nb':
                    counts, logits = _convert_mean_disp_to_counts_logits(
                        torch.clamp(
                            torch.Tensor(gene_reconstructions[:, :dim]),
                            min=1e-8,
                            max=1e8,
                        ),
                        torch.clamp(
                            torch.Tensor(gene_reconstructions[:, dim:]),
                            min=1e-8,
                            max=1e8,
                        )
                    )
                    dist = NegativeBinomial(
                        total_count=counts,
                        logits=logits
                    )
                sampled_gexp = (
                    dist.sample(torch.Size([n_samples]))
                    .cpu()
                    .detach()
                    .numpy()
                    .reshape(-1, dim)
                )
                sampled_gexp[sampled_gexp < 0] = 0 #set negative values to 0, since gexp can't be negative
                gene_means_list.append(sampled_gexp)
            else:
                df_list.append(
                    pd.DataFrame([df.loc[i].values] * num, columns=df.columns)
                )

                gene_means_list.append(gene_reconstructions[:, :dim])

            if uncertainty:
                (
                    cos_dist,
                    eucl_dist,
                    closest_cond_cos,
                    closest_cond_eucl,
                ) = self.compute_uncertainty(
                    cov=covar_name, pert=comb_name, dose=dose_name
                )
                df_list[-1] = df_list[-1].assign(
                    uncertainty_cosine=cos_dist,
                    uncertainty_euclidean=eucl_dist,
                    closest_cond_cosine=closest_cond_cos,
                    closest_cond_euclidean=closest_cond_eucl,
                )

            gene_vars_list.append(gene_reconstructions[:, dim:])

        gene_means = np.concatenate(gene_means_list)
        gene_vars = np.concatenate(gene_vars_list)
        df_obs = pd.concat(df_list)
        del df_list, gene_means_list, gene_vars_list

        if return_anndata:
            adata = sc.AnnData(gene_means)
            adata.var_names = self.var_names
            adata.obs = df_obs
            if not sample:
                adata.layers["variance"] = gene_vars

            adata.obs.index = adata.obs.index.astype(str)  # type fix
            del gene_means, gene_vars, df_obs
            return adata
        else:
            return gene_means, gene_vars, df_obs

    def get_latent(
        self,
        genes,
        cov,
        pert,
        dose,
        return_anndata=True,
    ):
        """Get latent values of control 'genes' with conditions specified in df.

        Parameters
        ----------
        genes : np.array
            Control cells.
        cov: dict of lists
            Provide a value for each covariate (eg. cell_type) as a dictionaty
            for the queried uncertainty (e.g. cov_dict={'cell_type': 'A549'}).
        pert: list
            Perturbation for the queried uncertainty. In case of combinations the
            format has to be 'pertA+pertB'
        dose: list
            String which contains the dose of the perturbation queried. In case
            of combinations the format has to be 'doseA+doseB'
        return_anndata : bool, optional (default: True)
            Return embedding wrapped into anndata object.

        Returns
        -------
        If return_anndata is True, returns anndata structure. Otherwise, returns
        np.arrays for latent and a data frame for the corresponding
        conditions df_obs.

        """

        assert len(dose) == len(pert), "Check the length of pert, dose"
        for cov_key in cov:
            assert len(cov[cov_key]) == len(pert), "Check the length of covariates"

        df = pd.concat(
            [
                pd.DataFrame({self.perturbation_key: pert, self.dose_key: dose}),
                pd.DataFrame(cov),
            ],
            axis=1,
        )

        self.model.eval()
        num = genes.shape[0]
        genes = torch.Tensor(genes).to(self.model.device)

        latent_list = []
        df_list = []

        for i in range(len(df)):
            comb_name = pert[i]
            dose_name = dose[i]
            covar_name = {}
            for cov_key in cov:
                covar_name[cov_key] = cov[cov_key][i]

            drug_ohe = torch.Tensor(
                self._get_drug_encoding(comb_name, doses=dose_name)
            ).to(self.model.device)

            drugs = drug_ohe.expand([num, self.drug_ohe.shape[1]])

            covars = []
            for cov_key in self.covariate_keys:
                covar_ohe = torch.Tensor(
                    self.covars_dict[cov_key][covar_name[cov_key]]
                ).to(self.model.device)
                covars.append(covar_ohe.expand([num, covar_ohe.shape[0]]).clone())

            _, latent_treated = self.model.predict(
                    genes,
                    drugs, 
                    covars,
                    return_latent_treated=True,
            )

            latent_treated = latent_treated.cpu().clone().detach().numpy()

            df_list.append(
                pd.DataFrame([df.loc[i].values] * num, columns=df.columns)
            )

            latent_list.append(latent_treated)

        latent = np.concatenate(latent_list)
        df_obs = pd.concat(df_list)
        del df_list

        if return_anndata:
            adata = sc.AnnData(latent)
            adata.obs = df_obs
            adata.obs.index = adata.obs.index.astype(str)  # type fix
            return adata
        else:
            return latent, df_obs

    def get_response(
        self,
        genes_control=None,
        doses=None,
        contvar_min=None,
        contvar_max=None,
        n_points=10,
        ncells_max=100,
        perturbations=None,
        control_name="test",
    ):
        """Decoded dose response data frame.

        Parameters
        ----------
        genes_control : np.array (deafult: None)
            Genes for which to predict values. If None, take from 'test_control'
            split in datasets.
        doses : np.array (default: None)
            Doses values. If None, default values will be generated on a grid:
            n_points in range [contvar_min, contvar_max].
        contvar_min : float (default: 0)
            Minimum dose value to generate for default option.
        contvar_max : float (default: 0)
            Maximum dose value to generate for default option.
        n_points : int (default: 100)
            Number of dose points to generate for default option.
        perturbations : list (default: None)
            List of perturbations for dose response

        Returns
        -------
        pd.DataFrame
            of decoded response values of genes and average response.
        """

        if genes_control is None:
            genes_control = self.datasets["test"].subset_condition(control=True).genes

        if contvar_min is None:
            contvar_min = 0
        if contvar_max is None:
            contvar_max = self.max_dose

        self.model.eval()
        if doses is None:
            doses = np.linspace(contvar_min, contvar_max, n_points)

        if perturbations is None:
            perturbations = self.unique_perts

        response = pd.DataFrame(
            columns=self.covariate_keys
            + [self.perturbation_key, self.dose_key, "response"]
            + list(self.var_names)
        )

        if ncells_max < len(genes_control):
            ncells_max = min(ncells_max, len(genes_control))
            idx = torch.LongTensor(
                np.random.choice(range(len(genes_control)), ncells_max, replace=False)
            )
            genes_control = genes_control[idx]

        j = 0
        for covar_combo in self.emb_covars_combined:
            cov_dict = {}
            for i, cov_val in enumerate(covar_combo.split("_")):
                cov_dict[self.covariate_keys[i]] = [cov_val]
                print(cov_dict)
                for _, drug in enumerate(perturbations):
                    if not (drug in self.datasets[control_name].subset_condition(control=True).ctrl_name):
                        for dose in doses:
                            # TODO handle covars

                            gene_means, _, _ = self.predict(
                                genes_control,
                                cov=cov_dict,
                                pert=[drug],
                                dose=[dose],
                                return_anndata=False,
                            )
                            predicted_data = np.mean(gene_means, axis=0).reshape(-1)
                            response.loc[j] = (
                                covar_combo.split("_")
                                + [drug, dose, np.linalg.norm(predicted_data)]
                                + list(predicted_data)
                            )
                            j += 1
        return response

    def get_response_reference(self, perturbations=None):

        """Computes reference values of the response.

        Parameters
        ----------
        dataset : CompPertDataset
            The file location of the spreadsheet
        perturbations : list (default: None)
            List of perturbations for dose response

        Returns
        -------
        pd.DataFrame
            of decoded response values of genes and average response.
        """
        if perturbations is None:
            perturbations = self.unique_perts

        reference_response_curve = pd.DataFrame(
            columns=self.covariate_keys
            + [self.perturbation_key, self.dose_key, "split", "num_cells", "response"]
            + list(self.var_names)
        )

        dataset_ctr = self.datasets["training"].subset_condition(control=True)

        i = 0
        for split in ["training", "ood"]:
            if split == 'ood':
                dataset = self.datasets[split]
            else:
                dataset = self.datasets["training"].subset_condition(control=False)
            for pert in self.seen_covars_perts[split]:
                *covars, drug, dose_val = pert.split("_")
                if drug in perturbations:
                    if not ("+" in dose_val):
                        dose = float(dose_val)
                    else:
                        dose = dose_val

                    idx = np.where((dataset.pert_categories == pert))[0]

                    if len(idx):
                        y_true = dataset.genes[idx, :].numpy().mean(axis=0)
                        reference_response_curve.loc[i] = (
                            covars
                            + [drug, dose, split, len(idx), np.linalg.norm(y_true)]
                            + list(y_true)
                        )

                        i += 1

        reference_response_curve = reference_response_curve.replace(
            "training_treated", "train"
        )
        return reference_response_curve

    def get_response2D(
        self,
        perturbations,
        covar,
        genes_control=None,
        doses=None,
        contvar_min=None,
        contvar_max=None,
        n_points=10,
        ncells_max=100,
        #fixed_drugs="",
        #fixed_doses="",
    ):
        """Decoded dose response data frame.

        Parameters
        ----------
        perturbations : list
            List of length 2 of perturbations for dose response.
        covar : dict
            Name of a covariate for which to compute dose-response.
        genes_control : np.array (deafult: None)
            Genes for which to predict values. If None, take from 'test_control'
            split in datasets.
        doses : np.array (default: None)
            Doses values. If None, default values will be generated on a grid:
            n_points in range [contvar_min, contvar_max].
        contvar_min : float (default: 0)
            Minimum dose value to generate for default option.
        contvar_max : float (default: 0)
            Maximum dose value to generate for default option.
        n_points : int (default: 100)
            Number of dose points to generate for default option.

        Returns
        -------
        pd.DataFrame
            of decoded response values of genes and average response.
        """

        assert len(perturbations) == 2, "You should provide a list of 2 perturbations."

        if contvar_min is None:
            contvar_min = self.min_dose

        if contvar_max is None:
            contvar_max = self.max_dose

        self.model.eval()
        # doses = torch.Tensor(np.linspace(contvar_min, contvar_max, n_points))
        if doses is None:
            doses = np.linspace(contvar_min, contvar_max, n_points)

        # genes_control = dataset.genes[dataset.indices['control']]
        if genes_control is None:
            genes_control = self.datasets["test"].subset_condition(control=True).genes

        ncells_max = min(ncells_max, len(genes_control))
        idx = torch.LongTensor(np.random.choice(range(len(genes_control)), ncells_max))
        genes_control = genes_control[idx]

        response = pd.DataFrame(
            columns=perturbations+["response"]+list(self.var_names)
        )

        drug = perturbations[0] + "+" + perturbations[1]

        dose_vals = [f"{d[0]}+{d[1]}" for d in itertools.product(*[doses, doses])]
        dose_comb = [list(d) for d in itertools.product(*[doses, doses])]

        i = 0
        if not (drug in self.datasets['training'].subset_condition(control=True).ctrl_name):
            for dose in dose_vals:
                gene_means, _, _ = self.predict(
                    genes_control,
                    cov=covar,
                    pert=[drug],# + fixed_drugs],
                    dose=[dose],# + fixed_doses],
                    return_anndata=False,
                )
                predicted_data = np.mean(gene_means, axis=0).reshape(-1)
                response.loc[i] = (
                    dose_comb[i]
                    + [np.linalg.norm(predicted_data)]
                    + list(predicted_data)
                )
                i += 1

        # i = 0
        # if not (drug in ["Vehicle", "EGF", "unst", "control", "ctrl"]):
        #     for dose in dose_vals:
        #         gene_means, _, _ = self.predict(
        #             genes_control,
        #             cov=covar,
        #             pert=[drug + fixed_drugs],
        #             dose=[dose + fixed_doses],
        #             return_anndata=False,
        #         )

        #         predicted_data = np.mean(gene_means, axis=0).reshape(-1)

        #         response.loc[i] = (
        #             dose_comb[i]
        #             + [np.linalg.norm(predicted_data)]
        #             + list(predicted_data)
        #         )
        #         i += 1

        return response

    def evaluate_r2(self, dataset, genes_control, adata_random=None):
        """
        Measures different quality metrics about an CPA `autoencoder`, when
        tasked to translate some `genes_control` into each of the drug/cell_type
        combinations described in `dataset`.

        Considered metrics are R2 score about means and variances for all genes, as
        well as R2 score about means and variances about differentially expressed
        (_de) genes.
        """
        self.model.eval()
        scores = pd.DataFrame(
            columns=self.covariate_keys
            + [
                self.perturbation_key,
                self.dose_key,
                "R2_mean",
                "R2_mean_DE",
                "R2_var",
                "R2_var_DE",
                "model",
                "num_cells",
            ]
        )

        num, dim = genes_control.size(0), genes_control.size(1)

        total_cells = len(dataset)

        icond = 0
        for pert_category in np.unique(dataset.pert_categories):
            # pert_category category contains: 'celltype_perturbation_dose' info
            de_idx = np.where(
                dataset.var_names.isin(np.array(dataset.de_genes[pert_category]))
            )[0]

            idx = np.where(dataset.pert_categories == pert_category)[0]
            *covars, pert, dose = pert_category.split("_")
            cov_dict = {}
            for i, cov_key in enumerate(self.covariate_keys):
                cov_dict[cov_key] = [covars[i]]

            if len(idx) > 0:
                mean_predict, var_predict, _ = self.predict(
                    genes_control,
                    cov=cov_dict,
                    pert=[pert],
                    dose=[dose],
                    return_anndata=False,
                    sample=False,
                )

                # estimate metrics only for reasonably-sized drug/cell-type combos
                y_true = dataset.genes[idx, :].numpy()

                # true means and variances
                yt_m = y_true.mean(axis=0)
                yt_v = y_true.var(axis=0)
                # predicted means and variances
                yp_m = mean_predict.mean(0)
                yp_v = var_predict.mean(0)
                #yp_v = np.var(mean_predict, axis=0)

                mean_score = r2_score(yt_m, yp_m)
                var_score = r2_score(yt_v, yp_v)

                mean_score_de = r2_score(yt_m[de_idx], yp_m[de_idx])
                var_score_de = r2_score(yt_v[de_idx], yp_v[de_idx])

                scores.loc[icond] = pert_category.split("_") + [
                    mean_score,
                    mean_score_de,
                    var_score,
                    var_score_de,
                    "cpa",
                    len(idx),
                ]
                icond += 1
                if adata_random is not None:
                    yp_m_bl = np.mean(adata_random, axis=0)
                    yp_v_bl = np.var(adata_random, axis=0)

                    mean_score_bl = r2_score(yt_m, yp_m_bl)
                    var_score_bl = r2_score(yt_v, yp_v_bl)

                    mean_score_de_bl = r2_score(yt_m[de_idx], yp_m_bl[de_idx])
                    var_score_de_bl = r2_score(yt_v[de_idx], yp_v_bl[de_idx])


                    scores.loc[icond] = pert_category.split("_") + [
                        mean_score_bl,
                        mean_score_de_bl,
                        var_score_bl,
                        var_score_de_bl,
                        "baseline",
                        len(idx),
                    ]
                    icond += 1
        return scores

def get_reference_from_combo(perturbations_list, datasets, splits=["training", "ood"]):
    """
    A simple function that produces a pd.DataFrame of individual
    drugs-doses combinations used among the splits (for a fixed covariate).
    """
    df_list = []
    for split_name in splits:
        full_dataset = datasets[split_name]
        ref = {"num_cells": []}
        for pp in perturbations_list:
            ref[pp] = []

        ndrugs = len(perturbations_list)
        for pert_cat in np.unique(full_dataset.pert_categories):
            _, pert, dose = pert_cat.split("_")
            pert_list = pert.split("+")
            if set(pert_list) == set(perturbations_list):
                dose_list = dose.split("+")
                ncells = len(
                    full_dataset.pert_categories[
                        full_dataset.pert_categories == pert_cat
                    ]
                )
                for j in range(ndrugs):
                    ref[pert_list[j]].append(float(dose_list[j]))
                ref["num_cells"].append(ncells)
                print(pert, dose, ncells)
        df = pd.DataFrame.from_dict(ref)
        df["split"] = split_name
        df_list.append(df)

    return pd.concat(df_list)


def linear_interp(y1, y2, x1, x2, x):
    a = (y1 - y2) / (x1 - x2)
    b = y1 - a * x1
    y = a * x + b
    return y


def evaluate_r2_benchmark(cpa_api, datasets, pert_category, pert_category_list):
    scores = pd.DataFrame(
        columns=[
            cpa_api.covars_key,
            cpa_api.perturbation_key,
            cpa_api.dose_key,
            "R2_mean",
            "R2_mean_DE",
            "R2_var",
            "R2_var_DE",
            "num_cells",
            "benchmark",
            "method",
        ]
    )

    de_idx = np.where(
        datasets["ood"].var_names.isin(
            np.array(datasets["ood"].de_genes[pert_category])
        )
    )[0]
    idx = np.where(datasets["ood"].pert_categories == pert_category)[0]
    y_true = datasets["ood"].genes[idx, :].numpy()
    # true means and variances
    yt_m = y_true.mean(axis=0)
    yt_v = y_true.var(axis=0)

    icond = 0
    if len(idx) > 0:
        for pert_category_predict in pert_category_list:
            if "+" in pert_category_predict:
                pert1, pert2 = pert_category_predict.split("+")
                idx_pred1 = np.where(datasets["training"].pert_categories == pert1)[0]
                idx_pred2 = np.where(datasets["training"].pert_categories == pert2)[0]

                y_pred1 = datasets["training"].genes[idx_pred1, :].numpy()
                y_pred2 = datasets["training"].genes[idx_pred2, :].numpy()

                x1 = float(pert1.split("_")[2])
                x2 = float(pert2.split("_")[2])
                x = float(pert_category.split("_")[2])
                yp_m1 = y_pred1.mean(axis=0)
                yp_m2 = y_pred2.mean(axis=0)
                yp_v1 = y_pred1.var(axis=0)
                yp_v2 = y_pred2.var(axis=0)

                yp_m = linear_interp(yp_m1, yp_m2, x1, x2, x)
                yp_v = linear_interp(yp_v1, yp_v2, x1, x2, x)

            #                     yp_m = (y_pred1.mean(axis=0) + y_pred2.mean(axis=0))/2
            #                     yp_v = (y_pred1.var(axis=0) + y_pred2.var(axis=0))/2

            else:
                idx_pred = np.where(
                    datasets["training"].pert_categories == pert_category_predict
                )[0]
                print(pert_category_predict, len(idx_pred))
                y_pred = datasets["training"].genes[idx_pred, :].numpy()
                # predicted means and variances
                yp_m = y_pred.mean(axis=0)
                yp_v = y_pred.var(axis=0)

            mean_score = r2_score(yt_m, yp_m)
            var_score = r2_score(yt_v, yp_v)

            mean_score_de = r2_score(yt_m[de_idx], yp_m[de_idx])
            var_score_de = r2_score(yt_v[de_idx], yp_v[de_idx])
            scores.loc[icond] = pert_category.split("_") + [
                mean_score,
                mean_score_de,
                var_score,
                var_score_de,
                len(idx),
                pert_category_predict,
                "benchmark",
            ]
            icond += 1

    return scores


================================================
FILE: cpa/data.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import warnings

import numpy as np
import torch

warnings.simplefilter(action="ignore", category=FutureWarning)
from typing import Union

import pandas as pd
import scanpy as sc
import scipy
from cpa.helper import rank_genes_groups
from sklearn.preprocessing import OneHotEncoder


def ranks_to_df(data, key="rank_genes_groups"):
    """Converts an `sc.tl.rank_genes_groups` result into a MultiIndex dataframe.

    You can access various levels of the MultiIndex with `df.loc[[category]]`.

    Params
    ------
    data : `AnnData`
    key : str (default: 'rank_genes_groups')
        Field in `.uns` of data where `sc.tl.rank_genes_groups` result is
        stored.
    """
    d = data.uns[key]
    dfs = []
    for k in d.keys():
        if k == "params":
            continue
        series = pd.DataFrame.from_records(d[k]).unstack()
        series.name = k
        dfs.append(series)

    return pd.concat(dfs, axis=1)


def check_adata(adata, special_fields):
    replaced = False
    for sf in special_fields:
        if sf in adata.obs:
            flag = 0
            for el in adata.obs[sf].values:
                if "_" in str(el):
                    flag += 1
            if flag:
                print(
                    f"WARNING. Special characters ('_') were found in: '{sf}'.",
                    "They will be replaced with '-'.",
                    "Be careful, it may lead to errors downstream.",
                )
                adata.obs[sf] = [s.replace("_", "-") for s in adata.obs[sf].values]
                replaced = True

    return adata, replaced


indx = lambda a, i: a[i] if a is not None else None


class Dataset:
    def __init__(
        self,
        data,
        perturbation_key=None,
        dose_key=None,
        covariate_keys=None,
        split_key="split",
        control=None,
    ):
        if type(data) == str:
            data = sc.read(data)
        #Assert that keys are present in the adata object
        assert perturbation_key in data.obs.columns, f"Perturbation {perturbation_key} is missing in the provided adata"
        for key in covariate_keys:
            assert key in data.obs.columns, f"Covariate {key} is missing in the provided adata"
        assert dose_key in data.obs.columns, f"Dose {dose_key} is missing in the provided adata"
        assert split_key in data.obs.columns, f"Split {split_key} is missing in the provided adata"
        assert not (split_key is None), "split_key can not be None"

        #If covariate keys is empty list create dummy covariate
        if len(covariate_keys) == 0:
            print("Adding a dummy covariate...")
            data.obs['dummy_cov'] = 'dummy_cov'
            covariate_keys = ['dummy_cov']

        self.perturbation_key = perturbation_key
        self.dose_key = dose_key

        if scipy.sparse.issparse(data.X):
            self.genes = torch.Tensor(data.X.A)
        else:
            self.genes = torch.Tensor(data.X)

        self.var_names = data.var_names

        if isinstance(covariate_keys, str):
            covariate_keys = [covariate_keys]
        self.covariate_keys = covariate_keys

        data, replaced = check_adata(
            data, [perturbation_key, dose_key] + covariate_keys
        )

        for cov in covariate_keys:
            if not (cov in data.obs):
                data.obs[cov] = "unknown"

        if split_key in data.obs:
            pass
        else:
            print("Performing automatic train-test split with 0.25 ratio.")
            from sklearn.model_selection import train_test_split

            data.obs[split_key] = "train"
            idx = list(range(len(data)))
            idx_train, idx_test = train_test_split(
                data.obs_names, test_size=0.25, random_state=42
            )
            data.obs[split_key].loc[idx_train] = "train"
            data.obs[split_key].loc[idx_test] = "test"

        if "control" in data.obs:
            self.ctrl = data.obs["control"].values
        else:
            print(f"Assigning control values for {control}")
            assert_msg = "Please provide a name for control condition."
            assert not (control is None), assert_msg
            data.obs["control"] = 0
            if dose_key in data.obs:
                pert, dose = control.split("_")
                data.obs.loc[
                    (data.obs[perturbation_key] == pert) & (data.obs[dose_key] == dose),
                    "control",
                ] = 1
            else:
                pert = control
                data.obs.loc[(data.obs[perturbation_key] == pert), "control"] = 1

            self.ctrl = data.obs["control"].values
            assert_msg = "Cells to assign as control not found! Please check the name of control variable."
            assert sum(self.ctrl), assert_msg
            print(f"Assigned {sum(self.ctrl)} control cells")

        if perturbation_key is not None:
            if dose_key is None:
                raise ValueError(
                    f"A 'dose_key' is required when provided a 'perturbation_key'({perturbation_key})."
                )
            if not (dose_key in data.obs):
                print(
                    f"Creating a default entrance for dose_key {dose_key}:",
                    "1.0 per perturbation",
                )
                dose_val = []
                for i in range(len(data)):
                    pert = data.obs[perturbation_key].values[i].split("+")
                    dose_val.append("+".join(["1.0"] * len(pert)))
                data.obs[dose_key] = dose_val

            if not ("cov_drug_dose_name" in data.obs) or replaced:
                print("Creating 'cov_drug_dose_name' field.")
                cov_drug_dose_name = []
                for i in range(len(data)):
                    comb_name = ""
                    for cov_key in self.covariate_keys:
                        comb_name += f"{data.obs[cov_key].values[i]}_"
                    comb_name += f"{data.obs[perturbation_key].values[i]}_{data.obs[dose_key].values[i]}"
                    cov_drug_dose_name.append(comb_name)
                data.obs["cov_drug_dose_name"] = cov_drug_dose_name

            if not ("rank_genes_groups_cov" in data.uns) or replaced:
                print("Ranking genes for DE genes.")
                rank_genes_groups(data, groupby="cov_drug_dose_name")

            self.pert_categories = np.array(data.obs["cov_drug_dose_name"].values)
            self.de_genes = data.uns["rank_genes_groups_cov"]

            self.drugs_names = np.array(data.obs[perturbation_key].values)
            self.dose_names = np.array(data.obs[dose_key].values)

            # get unique drugs
            drugs_names_unique = set()
            for d in self.drugs_names:
                [drugs_names_unique.add(i) for i in d.split("+")]
            self.drugs_names_unique = np.array(list(drugs_names_unique))

            # save encoder for a comparison with Mo's model
            # later we need to remove this part
            encoder_drug = OneHotEncoder(sparse=False)
            encoder_drug.fit(self.drugs_names_unique.reshape(-1, 1))

            # Store as attribute for molecular featurisation
            self.encoder_drug = encoder_drug

            self.perts_dict = dict(
                zip(
                    self.drugs_names_unique,
                    encoder_drug.transform(self.drugs_names_unique.reshape(-1, 1)),
                )
            )

            # get drug combinations
            drugs = []
            for i, comb in enumerate(self.drugs_names):
                drugs_combos = encoder_drug.transform(
                    np.array(comb.split("+")).reshape(-1, 1)
                )
                dose_combos = str(data.obs[dose_key].values[i]).split("+")
                for j, d in enumerate(dose_combos):
                    if j == 0:
                        drug_ohe = float(d) * drugs_combos[j]
                    else:
                        drug_ohe += float(d) * drugs_combos[j]
                drugs.append(drug_ohe)
            self.drugs = torch.Tensor(drugs)

            atomic_ohe = encoder_drug.transform(self.drugs_names_unique.reshape(-1, 1))

            self.drug_dict = {}
            for idrug, drug in enumerate(self.drugs_names_unique):
                i = np.where(atomic_ohe[idrug] == 1)[0][0]
                self.drug_dict[i] = drug
        else:
            self.pert_categories = None
            self.de_genes = None
            self.drugs_names = None
            self.dose_names = None
            self.drugs_names_unique = None
            self.perts_dict = None
            self.drug_dict = None
            self.drugs = None

        if isinstance(covariate_keys, list) and covariate_keys:
            if not len(covariate_keys) == len(set(covariate_keys)):
                raise ValueError(f"Duplicate keys were given in: {covariate_keys}")
            self.covariate_names = {}
            self.covariate_names_unique = {}
            self.covars_dict = {}
            self.covariates = []
            for cov in covariate_keys:
                self.covariate_names[cov] = np.array(data.obs[cov].values)
                self.covariate_names_unique[cov] = np.unique(self.covariate_names[cov])

                names = self.covariate_names_unique[cov]
                encoder_cov = OneHotEncoder(sparse=False)
                encoder_cov.fit(names.reshape(-1, 1))

                self.covars_dict[cov] = dict(
                    zip(list(names), encoder_cov.transform(names.reshape(-1, 1)))
                )

                names = self.covariate_names[cov]
                self.covariates.append(
                    torch.Tensor(encoder_cov.transform(names.reshape(-1, 1))).float()
                )
        else:
            self.covariate_names = None
            self.covariate_names_unique = None
            self.covars_dict = None
            self.covariates = None

        if perturbation_key is not None:
            self.ctrl_name = list(
                np.unique(data[data.obs["control"] == 1].obs[self.perturbation_key])
            )
        else:
            self.ctrl_name = None

        if self.covariates is not None:
            self.num_covariates = [
                len(names) for names in self.covariate_names_unique.values()
            ]
        else:
            self.num_covariates = [0]
        self.num_genes = self.genes.shape[1]
        self.num_drugs = len(self.drugs_names_unique) if self.drugs is not None else 0
        self.is_control = data.obs["control"].values.astype(bool)
        self.indices = {
            "all": list(range(len(self.genes))),
            "control": np.where(data.obs["control"] == 1)[0].tolist(),
            "treated": np.where(data.obs["control"] != 1)[0].tolist(),
            "train": np.where(data.obs[split_key] == "train")[0].tolist(),
            "test": np.where(data.obs[split_key] == "test")[0].tolist(),
            "ood": np.where(data.obs[split_key] == "ood")[0].tolist(),
        }

    def subset(self, split, condition="all"):
        idx = list(set(self.indices[split]) & set(self.indices[condition]))
        return SubDataset(self, idx)

    def __getitem__(self, i):
        return (
            self.genes[i],
            indx(self.drugs, i),
            *[indx(cov, i) for cov in self.covariates],
        )

    def __len__(self):
        return len(self.genes)


class SubDataset:
    """
    Subsets a `Dataset` by selecting the examples given by `indices`.
    """

    def __init__(self, dataset, indices):
        self.perturbation_key = dataset.perturbation_key
        self.dose_key = dataset.dose_key
        self.covariate_keys = dataset.covariate_keys

        self.perts_dict = dataset.perts_dict
        self.covars_dict = dataset.covars_dict

        self.genes = dataset.genes[indices]
        self.drugs = indx(dataset.drugs, indices)
        self.covariates = [indx(cov, indices) for cov in dataset.covariates]

        self.drugs_names = indx(dataset.drugs_names, indices)
        self.pert_categories = indx(dataset.pert_categories, indices)
        self.covariate_names = {}
        for cov in self.covariate_keys:
            self.covariate_names[cov] = indx(dataset.covariate_names[cov], indices)

        self.var_names = dataset.var_names
        self.de_genes = dataset.de_genes
        self.ctrl_name = indx(dataset.ctrl_name, 0)

        self.num_covariates = dataset.num_covariates
        self.num_genes = dataset.num_genes
        self.num_drugs = dataset.num_drugs
        self.is_control = dataset.is_control[indices]

    def __getitem__(self, i):
        return (
            self.genes[i],
            indx(self.drugs, i),
            *[indx(cov, i) for cov in self.covariates],
        )

    def subset_condition(self, control=True):
        idx = np.where(self.is_control == control)[0].tolist()
        return SubDataset(self, idx)

    def __len__(self):
        return len(self.genes)


def load_dataset_splits(
    data: str,
    perturbation_key: Union[str, None],
    dose_key: Union[str, None],
    covariate_keys: Union[list, str, None],
    split_key: str,
    control: Union[str, None],
    return_dataset: bool = False,
):

    dataset = Dataset(
        data, perturbation_key, dose_key, covariate_keys, split_key, control
    )

    splits = {
        "training": dataset.subset("train", "all"),
        "test": dataset.subset("test", "all"),
        "ood": dataset.subset("ood", "all"),
    }

    if return_dataset:
        return splits, dataset
    else:
        return splits


================================================
FILE: cpa/helper.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import warnings

import numpy as np
import pandas as pd
import scanpy as sc
from sklearn.metrics import r2_score
from scipy.sparse import issparse
from scipy.stats import wasserstein_distance
import torch

warnings.filterwarnings("ignore")

import sys

if not sys.warnoptions:
    warnings.simplefilter("ignore")
warnings.simplefilter(action="ignore", category=FutureWarning)

def _convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6):
    r"""NB parameterizations conversion
    Parameters
    ----------
    mu :
        mean of the NB distribution.
    theta :
        inverse overdispersion.
    eps :
        constant used for numerical log stability. (Default value = 1e-6)
    Returns
    -------
    type
        the number of failures until the experiment is stopped
        and the success probability.
    """
    assert (mu is None) == (
        theta is None
    ), "If using the mu/theta NB parameterization, both parameters must be specified"
    logits = (mu + eps).log() - (theta + eps).log()
    total_count = theta
    return total_count, logits


def rank_genes_groups_by_cov(
    adata,
    groupby,
    control_group,
    covariate,
    pool_doses=False,
    n_genes=50,
    rankby_abs=True,
    key_added="rank_genes_groups_cov",
    return_dict=False,
):

    """
    Function that generates a list of differentially expressed genes computed
    separately for each covariate category, and using the respective control
    cells as reference.

    Usage example:

    rank_genes_groups_by_cov(
        adata,
        groupby='cov_product_dose',
        covariate_key='cell_type',
        control_group='Vehicle_0'
    )

    Parameters
    ----------
    adata : AnnData
        AnnData dataset
    groupby : str
        Obs column that defines the groups, should be
        cartesian product of covariate_perturbation_cont_var,
        it is important that this format is followed.
    control_group : str
        String that defines the control group in the groupby obs
    covariate : str
        Obs column that defines the main covariate by which we
        want to separate DEG computation (eg. cell type, species, etc.)
    n_genes : int (default: 50)
        Number of DEGs to include in the lists
    rankby_abs : bool (default: True)
        If True, rank genes by absolute values of the score, thus including
        top downregulated genes in the top N genes. If False, the ranking will
        have only upregulated genes at the top.
    key_added : str (default: 'rank_genes_groups_cov')
        Key used when adding the dictionary to adata.uns
    return_dict : str (default: False)
        Signals whether to return the dictionary or not

    Returns
    -------
    Adds the DEG dictionary to adata.uns

    If return_dict is True returns:
    gene_dict : dict
        Dictionary where groups are stored as keys, and the list of DEGs
        are the corresponding values

    """

    gene_dict = {}
    cov_categories = adata.obs[covariate].unique()
    for cov_cat in cov_categories:
        print(cov_cat)
        # name of the control group in the groupby obs column
        control_group_cov = "_".join([cov_cat, control_group])

        # subset adata to cells belonging to a covariate category
        adata_cov = adata[adata.obs[covariate] == cov_cat]

        # compute DEGs
        sc.tl.rank_genes_groups(
            adata_cov,
            groupby=groupby,
            reference=control_group_cov,
            rankby_abs=rankby_abs,
            n_genes=n_genes,
        )

        # add entries to dictionary of gene sets
        de_genes = pd.DataFrame(adata_cov.uns["rank_genes_groups"]["names"])
        for group in de_genes:
            gene_dict[group] = de_genes[group].tolist()

    adata.uns[key_added] = gene_dict

    if return_dict:
        return gene_dict


def rank_genes_groups(
    adata,
    groupby,
    pool_doses=False,
    n_genes=50,
    rankby_abs=True,
    key_added="rank_genes_groups_cov",
    return_dict=False,
):

    """
    Function that generates a list of differentially expressed genes computed
    separately for each covariate category, and using the respective control
    cells as reference.

    Usage example:

    rank_genes_groups_by_cov(
        adata,
        groupby='cov_product_dose',
        covariate_key='cell_type',
        control_group='Vehicle_0'
    )

    Parameters
    ----------
    adata : AnnData
        AnnData dataset
    groupby : str
        Obs column that defines the groups, should be
        cartesian product of covariate_perturbation_cont_var,
        it is important that this format is followed.
    control_group : str
        String that defines the control group in the groupby obs
    covariate : str
        Obs column that defines the main covariate by which we
        want to separate DEG computation (eg. cell type, species, etc.)
    n_genes : int (default: 50)
        Number of DEGs to include in the lists
    rankby_abs : bool (default: True)
        If True, rank genes by absolute values of the score, thus including
        top downregulated genes in the top N genes. If False, the ranking will
        have only upregulated genes at the top.
    key_added : str (default: 'rank_genes_groups_cov')
        Key used when adding the dictionary to adata.uns
    return_dict : str (default: False)
        Signals whether to return the dictionary or not

    Returns
    -------
    Adds the DEG dictionary to adata.uns

    If return_dict is True returns:
    gene_dict : dict
        Dictionary where groups are stored as keys, and the list of DEGs
        are the corresponding values

    """

    covars_comb = []
    for i in range(len(adata)):
        cov = "_".join(adata.obs["cov_drug_dose_name"].values[i].split("_")[:-2])
        covars_comb.append(cov)
    adata.obs["covars_comb"] = covars_comb

    gene_dict = {}
    for cov_cat in np.unique(adata.obs["covars_comb"].values):
        adata_cov = adata[adata.obs["covars_comb"] == cov_cat]
        control_group_cov = (
            adata_cov[adata_cov.obs["control"] == 1].obs[groupby].values[0]
        )

        # compute DEGs
        sc.tl.rank_genes_groups(
            adata_cov,
            groupby=groupby,
            reference=control_group_cov,
            rankby_abs=rankby_abs,
            n_genes=n_genes,
        )

        # add entries to dictionary of gene sets
        de_genes = pd.DataFrame(adata_cov.uns["rank_genes_groups"]["names"])
        for group in de_genes:
            gene_dict[group] = de_genes[group].tolist()

    adata.uns[key_added] = gene_dict

    if return_dict:
        return gene_dict

# def evaluate_r2_(adata, pred_adata, condition_key, sampled=False):
#     r2_list = []
#     if issparse(adata.X): 
#         adata.X = adata.X.A
#     if issparse(pred_adata.X): 
#         pred_adata.X = pred_adata.X.A
#     for cond in pred_adata.obs[condition_key].unique():
#         adata_ = adata[adata.obs[condition_key] == cond]
#         pred_adata_ = pred_adata[pred_adata.obs[condition_key] == cond]
#         r2_mean = r2_score(adata_.X.mean(0), pred_adata_.X.mean(0))
#         if sampled:
#             r2_var = r2_score(adata_.X.var(0), pred_adata_.X.var(0))
#         else:
#             r2_var = r2_score(
#                 adata_.X.var(0), 
#                 pred_adata_.layers['variance'].var(0)
#             )
#         r2_list.append(
#             {
#                 'condition': cond,
#                 'r2_mean': r2_mean,
#                 'r2_var': r2_var,
#             }
#         )
#     r2_df = pd.DataFrame(r2_list).set_index('condition')
#     return r2_df

def evaluate_r2_(adata, pred_adata, condition_key, sampled=False, de_genes_dict=None):
    r2_list = []
    if issparse(adata.X): 
        adata.X = adata.X.A
    if issparse(pred_adata.X): 
        pred_adata.X = pred_adata.X.A
    for cond in pred_adata.obs[condition_key].unique():
        adata_ = adata[adata.obs[condition_key] == cond]
        pred_adata_ = pred_adata[pred_adata.obs[condition_key] == cond]
        r2_mean = r2_score(adata_.X.mean(0), pred_adata_.X.mean(0))
        if sampled:
            r2_var = r2_score(adata_.X.var(0), pred_adata_.X.var(0))
        else:
            r2_var = r2_score(
                adata_.X.var(0), 
                pred_adata_.layers['variance'].var(0)
            )
        r2_list.append(
            {
                'condition': cond,
                'r2_mean': r2_mean,
                'r2_var': r2_var,
            }
        )
        if de_genes_dict:
            de_genes = de_genes_dict[cond]
            sub_adata_ = adata_[:, de_genes]
            sub_pred_adata_ = pred_adata_[:, de_genes]
            r2_mean_deg = r2_score(sub_adata_.X.mean(0), sub_pred_adata_.X.mean(0))
            if sampled:
                r2_var_deg = r2_score(sub_adata_.X.var(0), sub_pred_adata_.X.var(0))
            else:
                r2_var_deg = r2_score(
                    sub_adata_.X.var(0), 
                    sub_pred_adata_.layers['variance'].var(0)
                )
            r2_list[-1]['r2_mean_deg'] = r2_mean_deg
            r2_list[-1]['r2_var_deg'] = r2_var_deg
    r2_df = pd.DataFrame(r2_list).set_index('condition')
    return r2_df
    
def evaluate_mmd(adata, pred_adata, condition_key, de_genes_dict=None):
    mmd_list = []
    for cond in pred_adata.obs[condition_key].unique():
        adata_ = adata[adata.obs[condition_key] == cond].copy()
        pred_adata_ = pred_adata[pred_adata.obs[condition_key] == cond].copy()
        if issparse(adata_.X): 
            adata_.X = adata_.X.A
        if issparse(pred_adata_.X): 
            pred_adata_.X = pred_adata_.X.A

        mmd = mmd_loss_calc(torch.Tensor(adata_.X), torch.Tensor(pred_adata_.X))
        mmd_list.append(
            {
                'condition': cond,
                'mmd': mmd.detach().cpu().numpy()
            }
        )
        if de_genes_dict:
            de_genes = de_genes_dict[cond]
            sub_adata_ = adata_[:, de_genes]
            sub_pred_adata_ = pred_adata_[:, de_genes]
            mmd_deg = mmd_loss_calc(torch.Tensor(sub_adata_.X), torch.Tensor(sub_pred_adata_.X))
            mmd_list[-1]['mmd_deg'] = mmd_deg.detach().cpu().numpy()
    mmd_df = pd.DataFrame(mmd_list).set_index('condition')
    return mmd_df

def evaluate_emd(adata, pred_adata, condition_key, de_genes_dict=None):
    emd_list = []
    for cond in pred_adata.obs[condition_key].unique():
        adata_ = adata[adata.obs[condition_key] == cond].copy()
        pred_adata_ = pred_adata[pred_adata.obs[condition_key] == cond].copy()
        if issparse(adata_.X): 
            adata_.X = adata_.X.A
        if issparse(pred_adata_.X): 
            pred_adata_.X = pred_adata_.X.A
        wd = []
        for i, _ in enumerate(adata_.var_names):
            wd.append(
                wasserstein_distance(torch.Tensor(adata_.X[:, i]), torch.Tensor(pred_adata_.X[:, i]))
            )
        emd_list.append(
            {
                'condition': cond,
                'emd': np.mean(wd)
            }
        )
        if de_genes_dict:
            de_genes = de_genes_dict[cond]
            sub_adata_ = adata_[:, de_genes]
            sub_pred_adata_ = pred_adata_[:, de_genes]
            wd_deg = []
            for i, _ in enumerate(sub_adata_.var_names):
                wd_deg.append(
                    wasserstein_distance(torch.Tensor(sub_adata_.X[:, i]), torch.Tensor(sub_pred_adata_.X[:, i]))
                )
            emd_list[-1]['emd_deg'] = np.mean(wd_deg)
    emd_df = pd.DataFrame(emd_list).set_index('condition')
    return emd_df

def pairwise_distance(x, y):
    x = x.view(x.shape[0], x.shape[1], 1)
    y = torch.transpose(y, 0, 1)
    output = torch.sum((x - y) ** 2, 1)
    output = torch.transpose(output, 0, 1)
    return output


def gaussian_kernel_matrix(x, y, alphas):
    """Computes multiscale-RBF kernel between x and y.
       Parameters
       ----------
       x: torch.Tensor
            Tensor with shape [batch_size, z_dim].
       y: torch.Tensor
            Tensor with shape [batch_size, z_dim].
       alphas: Tensor
       Returns
       -------
       Returns the computed multiscale-RBF kernel between x and y.
    """

    dist = pairwise_distance(x, y).contiguous()
    dist_ = dist.view(1, -1)

    alphas = alphas.view(alphas.shape[0], 1)
    beta = 1. / (2. * alphas)

    s = torch.matmul(beta, dist_)

    return torch.sum(torch.exp(-s), 0).view_as(dist)


def mmd_loss_calc(source_features, target_features):
    """Initializes Maximum Mean Discrepancy(MMD) between source_features and target_features.
       - Gretton, Arthur, et al. "A Kernel Two-Sample Test". 2012.
       Parameters
       ----------
       source_features: torch.Tensor
            Tensor with shape [batch_size, z_dim]
       target_features: torch.Tensor
            Tensor with shape [batch_size, z_dim]
       Returns
       -------
       Returns the computed MMD between x and y.
    """
    alphas = [
        1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100,
        1e3, 1e4, 1e5, 1e6
    ]
    alphas = torch.autograd.Variable(torch.FloatTensor(alphas)).to(device=source_features.device)

    cost = torch.mean(gaussian_kernel_matrix(source_features, source_features, alphas))
    cost += torch.mean(gaussian_kernel_matrix(target_features, target_features, alphas))
    cost -= 2 * torch.mean(gaussian_kernel_matrix(source_features, target_features, alphas))

    return cost

================================================
FILE: cpa/model.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from http.client import RemoteDisconnected
import json

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

class NBLoss(torch.nn.Module):
    def __init__(self):
        super(NBLoss, self).__init__()

    def forward(self, mu, y, theta, eps=1e-8):
        """Negative binomial negative log-likelihood. It assumes targets `y` with n
        rows and d columns, but estimates `yhat` with n rows and 2d columns.
        The columns 0:d of `yhat` contain estimated means, the columns d:2*d of
        `yhat` contain estimated variances. This module assumes that the
        estimated mean and inverse dispersion are positive---for numerical
        stability, it is recommended that the minimum estimated variance is
        greater than a small number (1e-3).
        Parameters
        ----------
        yhat: Tensor
                Torch Tensor of reeconstructed data.
        y: Tensor
                Torch Tensor of ground truth data.
        eps: Float
                numerical stability constant.
        """
        if theta.ndimension() == 1:
            # In this case, we reshape theta for broadcasting
            theta = theta.view(1, theta.size(0))
        log_theta_mu_eps = torch.log(theta + mu + eps)
        res = (
            theta * (torch.log(theta + eps) - log_theta_mu_eps)
            + y * (torch.log(mu + eps) - log_theta_mu_eps)
            + torch.lgamma(y + theta)
            - torch.lgamma(theta)
            - torch.lgamma(y + 1)
        )
        res = _nan2inf(res)
        return -torch.mean(res)
    
def _nan2inf(x):
    return torch.where(torch.isnan(x), torch.zeros_like(x) + np.inf, x)

class MLP(torch.nn.Module):
    """
    A multilayer perceptron with ReLU activations and optional BatchNorm.
    """

    def __init__(self, sizes, batch_norm=True, last_layer_act="linear"):
        super(MLP, self).__init__()
        layers = []
        for s in range(len(sizes) - 1):
            layers += [
                torch.nn.Linear(sizes[s], sizes[s + 1]),
                torch.nn.BatchNorm1d(sizes[s + 1])
                if batch_norm and s < len(sizes) - 2
                else None,
                torch.nn.ReLU(),
            ]

        layers = [l for l in layers if l is not None][:-1]
        self.activation = last_layer_act
        if self.activation == "linear":
            pass
        elif self.activation == "ReLU":
            self.relu = torch.nn.ReLU()
        else:
            raise ValueError("last_layer_act must be one of 'linear' or 'ReLU'")

        self.network = torch.nn.Sequential(*layers)

    def forward(self, x):
        if self.activation == "ReLU":
            x = self.network(x)
            dim = x.size(1) // 2
            return torch.cat((self.relu(x[:, :dim]), x[:, dim:]), dim=1)
        return self.network(x)


class GeneralizedSigmoid(torch.nn.Module):
    """
    Sigmoid, log-sigmoid or linear functions for encoding dose-response for
    drug perurbations.
    """

    def __init__(self, dim, device, nonlin="sigmoid"):
        """Sigmoid modeling of continuous variable.
        Params
        ------
        nonlin : str (default: logsigm)
            One of logsigm, sigm.
        """
        super(GeneralizedSigmoid, self).__init__()
        self.nonlin = nonlin
        self.beta = torch.nn.Parameter(
            torch.ones(1, dim, device=device), requires_grad=True
        )
        self.bias = torch.nn.Parameter(
            torch.zeros(1, dim, device=device), requires_grad=True
        )

    def forward(self, x):
        if self.nonlin == "logsigm":
            c0 = self.bias.sigmoid()
            return (torch.log1p(x) * self.beta + self.bias).sigmoid() - c0
        elif self.nonlin == "sigm":
            c0 = self.bias.sigmoid()
            return (x * self.beta + self.bias).sigmoid() - c0
        else:
            return x

    def one_drug(self, x, i):
        if self.nonlin == "logsigm":
            c0 = self.bias[0][i].sigmoid()
            return (torch.log1p(x) * self.beta[0][i] + self.bias[0][i]).sigmoid() - c0
        elif self.nonlin == "sigm":
            c0 = self.bias[0][i].sigmoid()
            return (x * self.beta[0][i] + self.bias[0][i]).sigmoid() - c0
        else:
            return x


class CPA(torch.nn.Module):
    """
    Our main module, the CPA autoencoder
    """

    def __init__(
        self,
        num_genes,
        num_drugs,
        num_covariates,
        device="cuda",
        seed=0,
        patience=5,
        loss_ae="gauss",
        doser_type="mlp",
        decoder_activation="linear",
        hparams="",
    ):
        super(CPA, self).__init__()
        # set generic attributes
        self.num_genes = num_genes
        self.num_drugs = num_drugs
        self.num_covariates = num_covariates
        self.device = device
        self.seed = seed
        self.loss_ae = loss_ae
        # early-stopping
        self.patience = patience
        self.best_score = -1e3
        self.patience_trials = 0

        # set hyperparameters
        self.set_hparams_(hparams)

        # set models
        self.encoder = MLP(
            [num_genes]
            + [self.hparams["autoencoder_width"]] * self.hparams["autoencoder_depth"]
            + [self.hparams["dim"]]
        )

        self.decoder = MLP(
            [self.hparams["dim"]]
            + [self.hparams["autoencoder_width"]] * self.hparams["autoencoder_depth"]
            + [num_genes * 2],
            last_layer_act=decoder_activation,
        )

        self.adversary_drugs = MLP(
            [self.hparams["dim"]]
            + [self.hparams["adversary_width"]] * self.hparams["adversary_depth"]
            + [num_drugs]
        )

        self.loss_adversary_drugs = torch.nn.BCEWithLogitsLoss()
        self.doser_type = doser_type
        if doser_type == "mlp":
            self.dosers = torch.nn.ModuleList()
            for _ in range(num_drugs):
                self.dosers.append(
                    MLP(
                        [1]
                        + [self.hparams["dosers_width"]] * self.hparams["dosers_depth"]
                        + [1],
                        batch_norm=False,
                    )
                )
        else:
            self.dosers = GeneralizedSigmoid(num_drugs, self.device, nonlin=doser_type)

        if self.num_covariates == [0]:
            pass
        else:
            assert 0 not in self.num_covariates
            self.adversary_covariates = []
            self.loss_adversary_covariates = []
            self.covariates_embeddings = (
                []
            )  # TODO: Continue with checking that dict assignment is possible via covaraites names and if dict are possible to use in optimisation
            for num_covariate in self.num_covariates:
                self.adversary_covariates.append(
                    MLP(
                        [self.hparams["dim"]]
                        + [self.hparams["adversary_width"]]
                        * self.hparams["adversary_depth"]
                        + [num_covariate]
                    )
                )
                self.loss_adversary_covariates.append(torch.nn.CrossEntropyLoss())
                self.covariates_embeddings.append(
                    torch.nn.Embedding(num_covariate, self.hparams["dim"])
                )
            self.covariates_embeddings = torch.nn.Sequential(
                *self.covariates_embeddings
            )

        self.drug_embeddings = torch.nn.Embedding(self.num_drugs, self.hparams["dim"])
        # losses
        if self.loss_ae == "nb":
            self.loss_autoencoder = NBLoss()
        elif self.loss_ae == 'gauss':
            self.loss_autoencoder = nn.GaussianNLLLoss()

        self.iteration = 0

        self.to(self.device)

        # optimizers
        has_drugs = self.num_drugs > 0
        has_covariates = self.num_covariates[0] > 0
        get_params = lambda model, cond: list(model.parameters()) if cond else []
        _parameters = (
            get_params(self.encoder, True)
            + get_params(self.decoder, True)
            + get_params(self.drug_embeddings, has_drugs)
        )
        for emb in self.covariates_embeddings:
            _parameters.extend(get_params(emb, has_covariates))

        self.optimizer_autoencoder = torch.optim.Adam(
            _parameters,
            lr=self.hparams["autoencoder_lr"],
            weight_decay=self.hparams["autoencoder_wd"],
        )

        _parameters = get_params(self.adversary_drugs, has_drugs)
        for adv in self.adversary_covariates:
            _parameters.extend(get_params(adv, has_covariates))

        self.optimizer_adversaries = torch.optim.Adam(
            _parameters,
            lr=self.hparams["adversary_lr"],
            weight_decay=self.hparams["adversary_wd"],
        )

        if has_drugs:
            self.optimizer_dosers = torch.optim.Adam(
                self.dosers.parameters(),
                lr=self.hparams["dosers_lr"],
                weight_decay=self.hparams["dosers_wd"],
            )

        # learning rate schedulers
        self.scheduler_autoencoder = torch.optim.lr_scheduler.StepLR(
            self.optimizer_autoencoder, step_size=self.hparams["step_size_lr"]
        )

        self.scheduler_adversary = torch.optim.lr_scheduler.StepLR(
            self.optimizer_adversaries, step_size=self.hparams["step_size_lr"]
        )

        if has_drugs:
            self.scheduler_dosers = torch.optim.lr_scheduler.StepLR(
                self.optimizer_dosers, step_size=self.hparams["step_size_lr"]
            )

        self.history = {"epoch": [], "stats_epoch": []}

    def set_hparams_(self, hparams):
        """
        Set hyper-parameters to default values or values fixed by user for those
        hyper-parameters specified in the JSON string `hparams`.
        """

        self.hparams = {
            "dim": 128,
            "dosers_width": 128,
            "dosers_depth": 2,
            "dosers_lr": 4e-3,
            "dosers_wd": 1e-7,
            "autoencoder_width": 128,
            "autoencoder_depth": 3,
            "adversary_width": 64,
            "adversary_depth": 2,
            "reg_adversary": 60,
            "penalty_adversary": 60,
            "autoencoder_lr": 3e-4,
            "adversary_lr": 3e-4,
            "autoencoder_wd": 4e-7,
            "adversary_wd": 4e-7,
            "adversary_steps": 3,
            "batch_size": 256,
            "step_size_lr": 45,
        }

        # the user may fix some hparams
        if hparams != "":
            if isinstance(hparams, str):
                self.hparams.update(json.loads(hparams))
            else:
                self.hparams.update(hparams)

        return self.hparams

    def move_inputs_(self, genes, drugs, covariates):
        """
        Move minibatch tensors to CPU/GPU.
        """
        if genes.device.type != self.device:
            genes = genes.to(self.device)
            if drugs is not None:
                drugs = drugs.to(self.device)
            if covariates is not None:
                covariates = [cov.to(self.device) for cov in covariates]
        return (genes, drugs, covariates)

    def compute_drug_embeddings_(self, drugs):
        """
        Compute sum of drug embeddings, each of them multiplied by its
        dose-response curve.
        """

        if self.doser_type == "mlp":
            doses = []
            for d in range(drugs.size(1)):
                this_drug = drugs[:, d].view(-1, 1)
                doses.append(self.dosers[d](this_drug).sigmoid() * this_drug.gt(0))
            return torch.cat(doses, 1) @ self.drug_embeddings.weight
        else:
            return self.dosers(drugs) @ self.drug_embeddings.weight

    def predict(
        self, 
        genes, 
        drugs, 
        covariates, 
        return_latent_basal=False,
        return_latent_treated=False,
    ):
        """
        Predict "what would have the gene expression `genes` been, had the
        cells in `genes` with cell types `cell_types` been treated with
        combination of drugs `drugs`.
        """

        genes, drugs, covariates = self.move_inputs_(genes, drugs, covariates)
        if self.loss_ae == 'nb':
            genes = torch.log1p(genes)

        latent_basal = self.encoder(genes)

        latent_treated = latent_basal

        if self.num_drugs > 0:
            latent_treated = latent_treated + self.compute_drug_embeddings_(drugs)
        if self.num_covariates[0] > 0:
            for i, emb in enumerate(self.covariates_embeddings):
                emb = emb.to(self.device)
                latent_treated = latent_treated + emb(
                    covariates[i].argmax(1)
                )  #argmax because OHE

        gene_reconstructions = self.decoder(latent_treated)
        if self.loss_ae == 'gauss':
            # convert variance estimates to a positive value in [1e-3, \infty)
            dim = gene_reconstructions.size(1) // 2
            gene_means = gene_reconstructions[:, :dim]
            gene_vars = F.softplus(gene_reconstructions[:, dim:]).add(1e-3)
            #gene_vars = gene_reconstructions[:, dim:].exp().add(1).log().add(1e-3)

        if self.loss_ae == 'nb':
            gene_means = F.softplus(gene_means).add(1e-3)
            #gene_reconstructions[:, :dim] = torch.clamp(gene_reconstructions[:, :dim], min=1e-4, max=1e4)
            #gene_reconstructions[:, dim:] = torch.clamp(gene_reconstructions[:, dim:], min=1e-4, max=1e4)
        gene_reconstructions = torch.cat([gene_means, gene_vars], dim=1)
                
        if return_latent_basal:
            if return_latent_treated:
                return gene_reconstructions, latent_basal, latent_treated
            else:
                return gene_reconstructions, latent_basal
        if return_latent_treated:
            return gene_reconstructions, latent_treated
        return gene_reconstructions

    def early_stopping(self, score):
        """
        Decays the learning rate, and possibly early-stops training.
        """
        self.scheduler_autoencoder.step()
        self.scheduler_adversary.step()
        self.scheduler_dosers.step()

        if score > self.best_score:
            self.best_score = score
            self.patience_trials = 0
        else:
            self.patience_trials += 1

        return self.patience_trials > self.patience

    def update(self, genes, drugs, covariates):
        """
        Update CPA's parameters given a minibatch of genes, drugs, and
        cell types.
        """
        genes, drugs, covariates = self.move_inputs_(genes, drugs, covariates)
        gene_reconstructions, latent_basal = self.predict(
            genes,
            drugs,
            covariates,
            return_latent_basal=True,
        )

        dim = gene_reconstructions.size(1) // 2
        gene_means = gene_reconstructions[:, :dim]
        gene_vars = gene_reconstructions[:, dim:]
        reconstruction_loss = self.loss_autoencoder(gene_means, genes, gene_vars)
        adversary_drugs_loss = torch.tensor([0.0], device=self.device)
        if self.num_drugs > 0:
            adversary_drugs_predictions = self.adversary_drugs(latent_basal)
            adversary_drugs_loss = self.loss_adversary_drugs(
                adversary_drugs_predictions, drugs.gt(0).float()
            )

        adversary_covariates_loss = torch.tensor(
            [0.0], device=self.device
        )
        if self.num_covariates[0] > 0:
            adversary_covariate_predictions = []
            for i, adv in enumerate(self.adversary_covariates):
                adv = adv.to(self.device)
                adversary_covariate_predictions.append(adv(latent_basal))
                adversary_covariates_loss += self.loss_adversary_covariates[i](
                    adversary_covariate_predictions[-1], covariates[i].argmax(1)
                )

        # two place-holders for when adversary is not executed
        adversary_drugs_penalty = torch.tensor([0.0], device=self.device)
        adversary_covariates_penalty = torch.tensor([0.0], device=self.device)

        if self.iteration % self.hparams["adversary_steps"]:

            def compute_gradients(output, input):
                grads = torch.autograd.grad(output, input, create_graph=True)
                grads = grads[0].pow(2).mean()
                return grads

            if self.num_drugs > 0:
                adversary_drugs_penalty = compute_gradients(
                    adversary_drugs_predictions.sum(), latent_basal
                )

            if self.num_covariates[0] > 0:
                adversary_covariates_penalty = torch.tensor([0.0], device=self.device)
                for pred in adversary_covariate_predictions:
                    adversary_covariates_penalty += compute_gradients(
                        pred.sum(), latent_basal
                    )  # TODO: Adding up tensor sum, is that right?

            self.optimizer_adversaries.zero_grad()
            (
                adversary_drugs_loss
                + self.hparams["penalty_adversary"] * adversary_drugs_penalty
                + adversary_covariates_loss
                + self.hparams["penalty_adversary"] * adversary_covariates_penalty
            ).backward()
            self.optimizer_adversaries.step()
        else:
            self.optimizer_autoencoder.zero_grad()
            if self.num_drugs > 0:
                self.optimizer_dosers.zero_grad()
            (
                reconstruction_loss
                - self.hparams["reg_adversary"] * adversary_drugs_loss
                - self.hparams["reg_adversary"] * adversary_covariates_loss
            ).backward()
            self.optimizer_autoencoder.step()
            if self.num_drugs > 0:
                self.optimizer_dosers.step()
        self.iteration += 1

        return {
            "loss_reconstruction": reconstruction_loss.item(),
            "loss_adv_drugs": adversary_drugs_loss.item(),
            "loss_adv_covariates": adversary_covariates_loss.item(),
            "penalty_adv_drugs": adversary_drugs_penalty.item(),
            "penalty_adv_covariates": adversary_covariates_penalty.item(),
        }

    @classmethod
    def defaults(self):
        """
        Returns the list of default hyper-parameters for CPA
        """

        return self.set_hparams_(self, "")


================================================
FILE: cpa/plotting.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from collections import defaultdict

import matplotlib.font_manager
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from adjustText import adjust_text
from sklearn.decomposition import KernelPCA
from sklearn.metrics import r2_score
from sklearn.metrics.pairwise import cosine_similarity

FONT_SIZE = 10
font = {"size": FONT_SIZE}

matplotlib.rc("font", **font)
matplotlib.rc("ytick", labelsize=FONT_SIZE)
matplotlib.rc("xtick", labelsize=FONT_SIZE)


class CPAVisuals:
    """
    A wrapper for automatic plotting CompPert latent embeddings and dose-response
    curve. Sets up prefix for all files and default dictionaries for atomic
    perturbations and cell types.
    """

    def __init__(
        self,
        cpa,
        fileprefix=None,
        perts_palette=None,
        covars_palette=None,
        plot_params={"fontsize": None},
    ):
        """
        Parameters
        ----------
        cpa : CompPertAPI
            Variable from API class.
        fileprefix : str, optional (default: None)
            Prefix (with path) to the filename to save all embeddings in a
            standartized manner. If None, embeddings are not saved to file.
        perts_palette : dict (default: None)
            Dictionary of colors for the embeddings of perturbations. Keys
            correspond to perturbations and values to their colors. If None,
            default dicitonary will be set up.
        covars_palette : dict (default: None)
            Dictionary of colors for the embeddings of covariates. Keys
            correspond to covariates and values to their colors. If None,
            default dicitonary will be set up.
        """

        self.fileprefix = fileprefix

        self.perturbation_key = cpa.perturbation_key
        self.dose_key = cpa.dose_key
        self.covariate_keys = cpa.covariate_keys
        self.measured_points = cpa.measured_points

        self.unique_perts = cpa.unique_perts
        self.unique_covars = cpa.unique_covars

        if perts_palette is None:
            self.perts_palette = dict(
                zip(self.unique_perts, get_palette(len(self.unique_perts)))
            )
        else:
            self.perts_palette = perts_palette

        if covars_palette is None:
            self.covars_palette = {}
            for cov in self.unique_covars:
                self.covars_palette[cov] = dict(
                    zip(
                        self.unique_covars[cov],
                        get_palette(len(self.unique_covars[cov]), palette_name="tab10"),
                    )
                )
        else:
            self.covars_palette = covars_palette

        if plot_params["fontsize"] is None:
            self.fontsize = FONT_SIZE
        else:
            self.fontsize = plot_params["fontsize"]

    def plot_latent_embeddings(
        self,
        emb,
        titlename="Example",
        kind="perturbations",
        palette=None,
        labels=None,
        dimred="KernelPCA",
        filename=None,
        show_text=True,
    ):
        """
        Parameters
        ----------
        emb : np.array
            Multi-dimensional embedding of perturbations or covariates.
        titlename : str, optional (default: 'Example')
            Title.
        kind : int, optional, optional (default: 'perturbations')
            Specify if this is embedding of perturbations, covariates or some
            other. If it is perturbations or covariates, it will use default
            saved dictionaries for colors.
        palette : dict, optional (default: None)
            If embedding of kind not perturbations or covariates, the user can
            specify color dictionary for the embedding.
        labels : list, optional (default: None)
            Labels for the embeddings.
        dimred : str, optional (default: 'KernelPCA')
            Dimensionality reduction method for plotting low dimensional
            representations. Options: 'KernelPCA', 'UMAPpre', 'UMAPcos', None.
            If None, uses first 2 dimensions of the embedding.
        filename : str (default: None)
            Name of the file to save the plot. If None, will automatically
            generate name from prefix file.
        """
        if filename is None:
            if self.fileprefix is None:
                filename = None
                file_name_similarity = None
            else:
                filename = f"{self.fileprefix}_emebdding.png"
                file_name_similarity = f"{self.fileprefix}_emebdding_similarity.png"
        else:
            file_name_similarity = filename.split(".")[0] + "_similarity.png"

        if labels is None:
            if kind == "perturbations":
                palette = self.perts_palette
                labels = self.unique_perts
            elif kind in self.unique_covars:
                palette = self.covars_palette[kind]
                labels = self.unique_covars[kind]

        if len(emb) < 2:
            print(f"Embedding contains only {len(emb)} vectors. Not enough to plot.")
        else:
            plot_embedding(
                fast_dimred(emb, method=dimred),
                labels,
                show_lines=True,
                show_text=show_text,
                col_dict=palette,
                title=titlename,
                file_name=filename,
                fontsize=self.fontsize,
            )

            plot_similarity(
                emb,
                labels,
                col_dict=palette,
                fontsize=self.fontsize,
                file_name=file_name_similarity,
            )

    def plot_contvar_response2D(
        self,
        df_response2D,
        df_ref=None,
        levels=15,
        figsize=(4, 4),
        xlims=(0, 1.03),
        ylims=(0, 1.03),
        palette="coolwarm",
        response_name="response",
        title_name=None,
        fontsize=None,
        postfix="",
        filename=None,
        alpha=0.4,
        sizes=(40, 160),
        logdose=False,
    ):

        """
        Parameters
        ----------
        df_response2D : pd.DataFrame
            Data frame with responses of combinations with columns=(dose1, dose2,
            response).
        levels: int, optional (default: 15)
            Number of levels for contour plot.
        response_name : str (default: 'response')
            Name of column in df_response to plot as response.
        alpha: float (default: 0.4)
            Transparency of the background contour.
        figsize: tuple (default: (4,4))
            Size of the figure in inches.
        palette : dict, optional (default: None)
            Colors dictionary for perturbations to plot.
        title_name : str, optional (default: None)
            Title for the plot.
        postfix : str, optional (defualt: '')
            Postfix to add to the output file name to save the model.
        filename : str, optional (defualt: None)
            Name of the file to save the plot.  If None, will automatically
            generate name from prefix file.
        logdose: bool (default: False)
            If True, dose values will be log10. 0 values will be mapped to
            minumum value -1,e.g.
            if smallest non-zero dose was 0.001, 0 will be mapped to -4.
        """
        sns.set_style("white")

        if (filename is None) and not (self.fileprefix is None):
            filename = f"{self.fileprefix}_{postfix}response2D.png"
        if fontsize is None:
            fontsize = self.fontsize

        x_name, y_name = df_response2D.columns[:2]

        x = df_response2D[x_name].values
        y = df_response2D[y_name].values

        if logdose:
            x = log10_with0(x)
            y = log10_with0(y)

        z = df_response2D[response_name].values

        n = int(np.sqrt(len(x)))

        X = x.reshape(n, n)
        Y = y.reshape(n, n)
        Z = z.reshape(n, n)

        fig, ax = plt.subplots(figsize=figsize)

        CS = ax.contourf(X, Y, Z, cmap=palette, levels=levels, alpha=alpha)
        CS = ax.contour(X, Y, Z, levels=15, cmap=palette)
        ax.clabel(CS, inline=1, fontsize=fontsize)
        ax.set(xlim=(0, 1), ylim=(0, 1))
        ax.axis("equal")
        ax.axis("square")
        ax.yaxis.set_tick_params(labelsize=fontsize)
        ax.xaxis.set_tick_params(labelsize=fontsize)
        ax.set_xlabel(x_name, fontsize=fontsize, fontweight="bold")
        ax.set_ylabel(y_name, fontsize=fontsize, fontweight="bold")
        ax.set_xlim(xlims)
        ax.set_ylim(ylims)

        # sns.despine(left=False, bottom=False, right=True)
        sns.despine()

        if not (df_ref is None):
            sns.scatterplot(
                x=x_name,
                y=y_name,
                hue="split",
                size="num_cells",
                sizes=sizes,
                alpha=1.0,
                palette={"train": "#000000", "training": "#000000", "ood": "#e41a1c"},
                data=df_ref,
                ax=ax,
            )
            ax.legend_.remove()

        ax.set_title(title_name, fontweight="bold", fontsize=fontsize)
        plt.tight_layout()

        if filename:
            save_to_file(fig, filename)

    def plot_contvar_response(
        self,
        df_response,
        response_name="response",
        var_name=None,
        df_ref=None,
        palette=None,
        title_name=None,
        postfix="",
        xlabelname=None,
        filename=None,
        logdose=False,
        fontsize=None,
        measured_points=None,
        bbox=(1.35, 1.0),
        figsize=(7.0, 4.0),
    ):
        """
        Parameters
        ----------
        df_response : pd.DataFrame
            Data frame of responses.
        response_name : str (default: 'response')
            Name of column in df_response to plot as response.
        var_name : str, optional  (default: None)
            Name of conditioning variable, e.g. could correspond to covariates.
        df_ref : pd.DataFrame, optional  (default: None)
            Reference values. Fields for plotting should correspond to
            df_response.
        palette : dict, optional (default: None)
            Colors dictionary for perturbations to plot.
        title_name : str, optional (default: None)
            Title for the plot.
        postfix : str, optional (defualt: '')
            Postfix to add to the output file name to save the model.
        filename : str, optional (defualt: None)
            Name of the file to save the plot.  If None, will automatically
            generate name from prefix file.
        logdose: bool (default: False)
            If True, dose values will be log10. 0 values will be mapped to
            minumum value -1,e.g.
            if smallest non-zero dose was 0.001, 0 will be mapped to -4.
        figsize: tuple (default: (7., 4.))
            Size of output figure
        """
        if (filename is None) and not (self.fileprefix is None):
            filename = f"{self.fileprefix}_{postfix}response.png"

        if fontsize is None:
            fontsize = self.fontsize

        if logdose:
            dose_name = f"log10-{self.dose_key}"
            df_response[dose_name] = log10_with0(df_response[self.dose_key].values)
            if not (df_ref is None):
                df_ref[dose_name] = log10_with0(df_ref[self.dose_key].values)
        else:
            dose_name = self.dose_key

        if var_name is None:
            if len(self.unique_covars) > 1:
                var_name = self.covars_key
            else:
                var_name = self.perturbation_key

        if palette is None:
            if var_name == self.perturbation_key:
                palette = self.perts_palette
            elif var_name in self.covariate_keys:
                palette = self.covars_palette[var_name]

        plot_dose_response(
            df_response,
            dose_name,
            var_name,
            xlabelname=xlabelname,
            df_ref=df_ref,
            response_name=response_name,
            title_name=title_name,
            use_ref_response=(not (df_ref is None)),
            col_dict=palette,
            plot_vertical=False,
            f1=figsize[0],
            f2=figsize[1],
            fname=filename,
            logscale=measured_points,
            measured_points=measured_points,
            bbox=bbox,
            fontsize=fontsize,
            figformat="png",
        )

    def plot_scatter(
        self,
        df,
        x_axis,
        y_axis,
        hue=None,
        size=None,
        style=None,
        figsize=(4.5, 4.5),
        title=None,
        palette=None,
        filename=None,
        alpha=0.75,
        sizes=(30, 90),
        text_dict=None,
        postfix="",
        fontsize=14,
    ):

        sns.set_style("white")

        if (filename is None) and not (self.fileprefix is None):
            filename = f"{self.fileprefix}_scatter{postfix}.png"

        if fontsize is None:
            fontsize = self.fontsize

        fig = plt.figure(figsize=figsize)
        ax = plt.gca()
        sns.scatterplot(
            x=x_axis,
            y=y_axis,
            hue=hue,
            style=style,
            size=size,
            sizes=sizes,
            alpha=alpha,
            palette=palette,
            data=df,
        )

        ax.legend_.remove()
        ax.set_xlabel(x_axis, fontsize=fontsize)
        ax.set_ylabel(y_axis, fontsize=fontsize)
        ax.xaxis.set_tick_params(labelsize=fontsize)
        ax.yaxis.set_tick_params(labelsize=fontsize)
        ax.set_title(title)
        if not (text_dict is None):
            texts = []
            for label in text_dict.keys():
                texts.append(
                    ax.text(
                        text_dict[label][0],
                        text_dict[label][1],
                        label,
                        fontsize=fontsize,
                    )
                )

            adjust_text(
                texts, arrowprops=dict(arrowstyle="-", color="black", lw=0.1), ax=ax
            )

        plt.tight_layout()

        if filename:
            save_to_file(fig, filename)


def log10_with0(x):
    mx = np.min(x[x > 0])
    x[x == 0] = mx / 10
    return np.log10(x)


def get_palette(n_colors, palette_name="Set1"):

    try:
        palette = sns.color_palette(palette_name)
    except:
        print("Palette not found. Using default palette tab10")
        palette = sns.color_palette()
    while len(palette) < n_colors:
        palette += palette

    return palette


def fast_dimred(emb, method="KernelPCA"):
    """
    Takes high dimensional embeddings and produces a 2-dimensional representation
    for plotting.
    emb: np.array
        Embeddings matrix.
    method: str (default: 'KernelPCA')
        Method for dimensionality reduction: KernelPCA, UMAPpre, UMAPcos, tSNE.
        If None return first 2 dimensions of the embedding vector.
    """
    if method is None:
        return emb[:, :2]
    elif method == "KernelPCA":
        similarity_matrix = cosine_similarity(emb)
        np.fill_diagonal(similarity_matrix, 1.0)
        X = KernelPCA(n_components=2, kernel="precomputed").fit_transform(
            similarity_matrix
        )
    else:
        raise NotImplementedError

    return X


def plot_dose_response(
    df,
    contvar_key,
    perturbation_key,
    df_ref=None,
    response_name="response",
    use_ref_response=False,
    palette=None,
    col_dict=None,
    fontsize=8,
    measured_points=None,
    interpolate=True,
    f1=7,
    f2=3.0,
    bbox=(1.35, 1.0),
    ref_name="origin",
    title_name="None",
    plot_vertical=True,
    fname=None,
    logscale=None,
    xlabelname=None,
    figformat="png",
):

    """Plotting decoding of the response with respect to dose.

    Params
    ------
    df : `DataFrame`
        Table with columns=[perturbation_key, contvar_key, response_name].
        The last column is always "response".
    contvar_key : str
        Name of the column in df for values to use for x axis.
    perturbation_key : str
        Name of the column in df for the perturbation or covariate to plot.
    response_name: str (default: response)
        Name of the column in df for values to use for y axis.
    df_ref : `DataFrame` (default: None)
        Table with the same columns as in df to plot ground_truth or another
        condition for comparison. Could
        also be used to just extract reference values for x-axis.
    use_ref_response : bool (default: False)
        A flag indicating if to use values for y axis from df_ref (True) or j
        ust to extract reference values for x-axis.
    col_dict : dictionary (default: None)
        Dictionary with colors for each value in perturbation_key.
    bbox : tuple (default: (1.35, 1.))
        Coordinates to adjust the legend.
    plot_vertical : boolean (default: False)
        Flag if to plot reference values for x axis from df_ref dataframe.
    f1 : float (default: 7.0))
        Width in inches for the plot.
    f2 : float (default: 3.0))
        Hight in inches for the plot.
    fname : str (default: None)
        Name of the file to export the plot. The name comes without format
        extension.
    format : str (default: png)
        Format for the file to export the plot.
    """
    sns.set_style("white")
    if use_ref_response and not (df_ref is None):
        df[ref_name] = "predictions"
        df_ref[ref_name] = "observations"
        if interpolate:
            df_plt = pd.concat([df, df_ref])
        else:
            df_plt = df
    else:
        df_plt = df
    df_plt = df_plt.reset_index()

    atomic_drugs = np.unique(df[perturbation_key].values)
    if palette is None:
        current_palette = get_palette(len(list(atomic_drugs)))

    if col_dict is None:
        col_dict = dict(zip(list(atomic_drugs), current_palette))

    fig = plt.figure(figsize=(f1, f2))
    ax = plt.gca()
    if use_ref_response:
        sns.lineplot(
            x=contvar_key,
            y=response_name,
            palette=col_dict,
            hue=perturbation_key,
            style=ref_name,
            dashes=[(1, 0), (2, 1)],
            legend="full",
            style_order=["predictions", "observations"],
            data=df_plt,
            ax=ax,
        )

        df_ref = df_ref.replace("training_treated", "train")
        sns.scatterplot(
            x=contvar_key,
            y=response_name,
            hue="split",
            size="num_cells",
            sizes=(10, 100),
            alpha=1.0,
            palette={"train": "#000000", "training": "#000000", "ood": "#e41a1c"},
            data=df_ref,
            ax=ax,
        )
        sns.despine()
        ax.legend_.remove()
    else:
        sns.lineplot(
            x=contvar_key,
            y=response_name,
            palette=col_dict,
            hue=perturbation_key,
            data=df_plt,
            ax=ax,
        )
        ax.legend(loc="upper right", bbox_to_anchor=bbox, fontsize=fontsize)
        sns.despine()
    if not (title_name is None):
        ax.set_title(title_name, fontsize=fontsize, fontweight="bold")
    ax.grid("off")

    if xlabelname is None:
        ax.set_xlabel(contvar_key, fontsize=fontsize)
    else:
        ax.set_xlabel(xlabelname, fontsize=fontsize)

    ax.set_ylabel(f"{response_name}", fontsize=fontsize)

    ax.xaxis.set_tick_params(labelsize=fontsize)
    ax.yaxis.set_tick_params(labelsize=fontsize)

    if not (logscale is None):
        ax.set_xticks(np.log10(logscale))
        ax.set_xticklabels(logscale, rotation=90)

    if not (df_ref is None):
        atomic_drugs = np.unique(df_ref[perturbation_key].values)
        for drug in atomic_drugs:
            x = df_ref[df_ref[perturbation_key] == drug][contvar_key].values
            m1 = np.min(df[df[perturbation_key] == drug][response_name].values)
            m2 = np.max(df[df[perturbation_key] == drug][response_name].values)

            if plot_vertical:
                for x_dot in x:
                    ax.plot(
                        [x_dot, x_dot],
                        [m1, m2],
                        ":",
                        color="black",
                        linewidth=0.5,
                        alpha=0.5,
                    )
    fig.tight_layout()
    if fname:
        plt.savefig(f"{fname}.{figformat}", format=figformat, dpi=600)

    return fig


def plot_uncertainty_comb_dose(
    cpa_api,
    cov,
    pert,
    N=11,
    metric="cosine",
    measured_points=None,
    cond_key="condition",
    figsize=(4, 4),
    vmin=None,
    vmax=None,
    sizes=(40, 160),
    df_ref=None,
    xlims=(0, 1.03),
    ylims=(0, 1.03),
    fixed_drugs="",
    fixed_doses="",
    title="",
    filename=None,
):
    """Plotting uncertainty for a single perturbation at a dose range for a
    particular covariate.

    Params
    ------
    cpa_api
        Api object for the model class.
    cov : dict
        Name of covariate.
    pert : str
        Name of the perturbation.
    N : int
        Number of dose values.
    metric: str (default: 'cosine')
        Metric to evaluate uncertainty.
    measured_points : dict (default: None)
        A dicitionary of dictionaries. Per each covariate a dictionary with
        observed doses per perturbation, e.g. {'covar1': {'pert1':
        [0.1, 0.5, 1.0], 'pert2': [0.3]}
    cond_key : str (default: 'condition')
        Name of the variable to use for plotting.
    filename : str (default: None)
        Full path to the file to export the plot. File extension should be
        included.

    Returns
        -------
        pd.DataFrame of uncertainty estimations.
    """

    cov_name = "_".join([cov[cov_key] for cov_key in cpa_api.covariate_keys])
    df_list = []
    for i in np.round(np.linspace(0, 1, N), decimals=2):
        for j in np.round(np.linspace(0, 1, N), decimals=2):
            df_list.append(
                {
                    "covariates": cov_name,
                    "condition": pert + fixed_drugs,
                    "dose_val": str(i) + "+" + str(j) + fixed_doses,
                }
            )
    df_pred = pd.DataFrame(df_list)
    uncert_cos = []
    uncert_eucl = []
    closest_cond_cos = []
    closest_cond_eucl = []
    for i in range(df_pred.shape[0]):
        (
            uncert_cos_,
            uncert_eucl_,
            closest_cond_cos_,
            closest_cond_eucl_,
        ) = cpa_api.compute_uncertainty(
            cov=cov, pert=df_pred.iloc[i]["condition"], dose=df_pred.iloc[i]["dose_val"]
        )
        uncert_cos.append(uncert_cos_)
        uncert_eucl.append(uncert_eucl_)
        closest_cond_cos.append(closest_cond_cos_)
        closest_cond_eucl.append(closest_cond_eucl_)

    df_pred["uncertainty_cosine"] = uncert_cos
    df_pred["uncertainty_eucl"] = uncert_eucl
    df_pred["closest_cond_cos"] = closest_cond_cos
    df_pred["closest_cond_eucl"] = closest_cond_eucl
    doses = df_pred.dose_val.apply(lambda x: x.split("+"))
    X = np.array(doses.apply(lambda x: x[0]).astype(float)).reshape(N, N)
    Y = np.array(doses.apply(lambda x: x[1]).astype(float)).reshape(N, N)
    Z = np.array(df_pred[f"uncertainty_{metric}"].values.astype(float)).reshape(N, N)

    fig, ax = plt.subplots(1, 1, figsize=figsize)
    CS = ax.contourf(X, Y, Z, cmap="coolwarm", levels=20, alpha=1, vmin=vmin, vmax=vmax)

    ax.set_xlabel(pert.split("+")[0], fontweight="bold")
    ax.set_ylabel(pert.split("+")[1], fontweight="bold")

    if not (df_ref is None):
        sns.scatterplot(
            x=pert.split("+")[0],
            y=pert.split("+")[1],
            hue="split",
            size="num_cells",
            sizes=sizes,
            alpha=1.0,
            palette={"train": "#000000", "training": "#000000", "ood": "#e41a1c"},
            data=df_ref,
            ax=ax,
        )
        ax.legend_.remove()

    if measured_points:
        ticks = measured_points[cov_name][pert]
        xticks = [float(x.split("+")[0]) for x in ticks]
        yticks = [float(x.split("+")[1]) for x in ticks]
        ax.set_xticks(xticks)
        ax.set_xticklabels(xticks, rotation=90)
        ax.set_yticks(yticks)
    fig.colorbar(CS)
    sns.despine()
    ax.axis("equal")
    ax.axis("square")
    ax.set_xlim(xlims)
    ax.set_ylim(ylims)
    ax.set_title(title, fontsize=10, fontweight='bold')
    plt.tight_layout()

    if filename:
        plt.savefig(filename, dpi=600)

    return df_pred


def plot_uncertainty_dose(
    cpa_api,
    cov,
    pert,
    N=11,
    metric="cosine",
    measured_points=None,
    cond_key="condition",
    log=False,
    min_dose=None,
    filename=None,
):
    """Plotting uncertainty for a single perturbation at a dose range for a
    particular covariate.

    Params
    ------
    cpa_api
        Api object for the model class.
    cov : str
        Name of covariate.
    pert : str
        Name of the perturbation.
    N : int
        Number of dose values.
    metric: str (default: 'cosine')
        Metric to evaluate uncertainty.
    measured_points : dict (default: None)
        A dicitionary of dictionaries. Per each covariate a dictionary with
        observed doses per perturbation, e.g. {'covar1': {'pert1':
        [0.1, 0.5, 1.0], 'pert2': [0.3]}
    cond_key : str (default: 'condition')
        Name of the variable to use for plotting.
    log : boolean (default: False)
        A flag if to plot on a log scale.
    min_dose : float (default: None)
        Minimum dose for the uncertainty estimate.
    filename : str (default: None)
        Full path to the file to export the plot. File extension should be included.

    Returns
        -------
        pd.DataFrame of uncertainty estimations.
    """

    df_list = []
    if log:
        if min_dose is None:
            min_dose = 1e-3
        N_val = np.round(np.logspace(np.log10(min_dose), np.log10(1), N), decimals=10)
    else:
        if min_dose is None:
            min_dose = 0
        N_val = np.round(np.linspace(min_dose, 1.0, N), decimals=3)

    cov_name = "_".join([cov[cov_key] for cov_key in cpa_api.covariate_keys])

    for i in N_val:
        df_list.append({"covariates": cov_name, "condition": pert, "dose_val": repr(i)})

    df_pred = pd.DataFrame(df_list)
    uncert_cos = []
    uncert_eucl = []
    closest_cond_cos = []
    closest_cond_eucl = []

    for i in range(df_pred.shape[0]):
        (
            uncert_cos_,
            uncert_eucl_,
            closest_cond_cos_,
            closest_cond_eucl_,
        ) = cpa_api.compute_uncertainty(
            cov=cov, pert=df_pred.iloc[i]["condition"], dose=df_pred.iloc[i]["dose_val"]
        )
        uncert_cos.append(uncert_cos_)
        uncert_eucl.append(uncert_eucl_)
        closest_cond_cos.append(closest_cond_cos_)
        closest_cond_eucl.append(closest_cond_eucl_)

    df_pred["uncertainty_cosine"] = uncert_cos
    df_pred["uncertainty_eucl"] = uncert_eucl
    df_pred["closest_cond_cos"] = closest_cond_cos
    df_pred["closest_cond_eucl"] = closest_cond_eucl

    x = df_pred.dose_val.values.astype(float)
    y = df_pred[f"uncertainty_{metric}"].values.astype(float)
    fig, ax = plt.subplots(1, 1)
    ax.plot(x, y)
    ax.set_xlabel(pert)
    ax.set_ylabel("Uncertainty")
    ax.set_title(cov_name)
    if log:
        ax.set_xscale("log")
    if measured_points:
        ticks = measured_points[cov_name][pert]
        ax.set_xticks(ticks)
        ax.set_xticklabels(ticks, rotation=90)
    else:
        plt.draw()
        ax.set_xticklabels(ax.get_xticklabels(), rotation=90)

    sns.despine()
    plt.tight_layout()

    if filename:
        plt.savefig(filename)

    return df_pred


def save_to_file(fig, file_name, file_format=None):
    if file_format is None:
        if file_name.split(".")[-1] in ["png", "pdf"]:
            file_format = file_name.split(".")[-1]
            savename = file_name
        else:
            file_format = "pdf"
            savename = f"{file_name}.{file_format}"
    else:
        savename = file_name

    fig.savefig(savename, format=file_format, dpi=600)
    print(f"Saved file to: {savename}")


def plot_embedding(
    emb,
    labels=None,
    col_dict=None,
    title=None,
    show_lines=False,
    show_text=False,
    show_legend=True,
    axis_equal=True,
    circle_size=40,
    circe_transparency=1.0,
    line_transparency=0.8,
    line_width=1.0,
    fontsize=9,
    fig_width=4,
    fig_height=4,
    file_name=None,
    file_format=None,
    labels_name=None,
    width_ratios=[7, 1],
    bbox=(1.3, 0.7),
):
    sns.set_style("white")

    # create data structure suitable for embedding
    df = pd.DataFrame(emb, columns=["dim1", "dim2"])
    if not (labels is None):
        if labels_name is None:
            labels_name = "labels"
        df[labels_name] = labels

    fig = plt.figure(figsize=(fig_width, fig_height))
    ax = plt.gca()

    sns.despine(left=False, bottom=False, right=True)

    if (col_dict is None) and not (labels is None):
        col_dict = get_colors(labels)

    sns.scatterplot(
        x="dim1",
        y="dim2",
        hue=labels_name,
        palette=col_dict,
        alpha=circe_transparency,
        edgecolor="none",
        s=circle_size,
        data=df,
        ax=ax,
    )

    try:
        ax.legend_.remove()
    except:
        pass

    if show_lines:
        for i in range(len(emb)):
            if col_dict is None:
                ax.plot(
                    [0, emb[i, 0]],
                    [0, emb[i, 1]],
                    alpha=line_transparency,
                    linewidth=line_width,
                    c=None,
                )
            else:
                ax.plot(
                    [0, emb[i, 0]],
                    [0, emb[i, 1]],
                    alpha=line_transparency,
                    linewidth=line_width,
                    c=col_dict[labels[i]],
                )

    if show_text and not (labels is None):
        texts = []
        labels = np.array(labels)
        unique_labels = np.unique(labels)
        for label in unique_labels:
            idx_label = np.where(labels == label)[0]
            texts.append(
                ax.text(
                    np.mean(emb[idx_label, 0]),
                    np.mean(emb[idx_label, 1]),
                    label,
                    #fontsize=fontsize,
                )
            )

        adjust_text(
            texts, arrowprops=dict(arrowstyle="-", color="black", lw=0.1), ax=ax
        )

    if axis_equal:
        ax.axis("equal")
        ax.axis("square")

    if title:
        ax.set_title(title, fontweight="bold")

    ax.set_xlabel("dim1"),# fontsize=fontsize)
    ax.set_ylabel("dim2"),# fontsize=fontsize)
    #ax.xaxis.set_tick_params(labelsize=fontsize)
    #ax.yaxis.set_tick_params(labelsize=fontsize)

    plt.tight_layout()

    if file_name:
        save_to_file(fig, file_name, file_format)

    return plt


def get_colors(labels, palette=None, palette_name=None):
    n_colors = len(labels)
    if palette is None:
        palette = get_palette(n_colors, palette_name)
    col_dict = dict(zip(labels, palette[:n_colors]))
    return col_dict


def plot_similarity(
    emb,
    labels=None,
    col_dict=None,
    fig_width=4,
    fig_height=4,
    cmap="coolwarm",
    fmt="png",
    fontsize=7,
    file_format=None,
    file_name=None,
):

    # first we take construct similarity matrix
    # add another similarity
    similarity_matrix = cosine_similarity(emb)

    df = pd.DataFrame(
        similarity_matrix,
        columns=labels,
        index=labels,
    )

    if col_dict is None:
        col_dict = get_colors(labels)

    network_colors = pd.Series(df.columns, index=df.columns).map(col_dict)

    sns_plot = sns.clustermap(
        df,
        cmap=cmap,
        center=0,
        row_colors=network_colors,
        col_colors=network_colors,
        mask=False,
        metric="euclidean",
        figsize=(fig_height, fig_width),
        vmin=-1,
        vmax=1,
        fmt=file_format,
    )

    sns_plot.ax_heatmap.xaxis.set_tick_params(labelsize=fontsize)
    sns_plot.ax_heatmap.yaxis.set_tick_params(labelsize=fontsize)
    sns_plot.ax_heatmap.axis("equal")
    sns_plot.cax.yaxis.set_tick_params(labelsize=fontsize)

    if file_name:
        save_to_file(sns_plot, file_name, file_format)


from scipy import sparse, stats
from sklearn.metrics import r2_score


def mean_plot(
    adata,
    pred,
    condition_key,
    exp_key,
    path_to_save="./reg_mean.pdf",
    gene_list=None,
    deg_list=None,
    show=False,
    title=None,
    verbose=False,
    x_coeff=0.30,
    y_coeff=0.8,
    fontsize=11,
    R2_type="R2",
    figsize=(3.5, 3.5),
    **kwargs,
):
    """
    Plots mean matching.

    # Parameters
    adata: `~anndata.AnnData`
        Contains real v
    pred: `~anndata.AnnData`
        Contains predicted values.
    condition_key: Str
        adata.obs key to look for x-axis and y-axis condition
    exp_key: Str
        Condition in adata.obs[condition_key] to be ploted
    path_to_save: basestring
        Path to save the plot.
    gene_list: list
        List of gene names to be plotted.
    deg_list: list
        List of DEGs to compute R2
    show: boolean
        if True plots the figure
    Verbose: boolean
        If true prints the value
    title: Str
        Title of the plot
    x_coeff: float
        Shifts R2 text horizontally by x_coeff
    y_coeff: float
        Shifts R2 text vertically by y_coeff
    show: bool
        if `True`: will show to the plot after saving it.
    fontsize: int
        Font size for R2 texts
    R2_type: Str
        How to compute R2 value, should be either Pearson R2 or R2 (sklearn)

    Returns:
    Calluated R2 values
    """

    r2_types = ["R2", "Pearson R2"]
    if R2_type not in r2_types:
        raise ValueError("R2 caclulation should be one of" + str(r2_types))
    if sparse.issparse(adata.X):
        adata.X = adata.X.A
    if sparse.issparse(pred.X):
        pred.X = pred.X.A
    diff_genes = deg_list
    real = adata[adata.obs[condition_key] == exp_key]
    pred = pred[pred.obs[condition_key] == exp_key]
    if diff_genes is not None:
        if hasattr(diff_genes, "tolist"):
            diff_genes = diff_genes.tolist()
        real_diff = adata[:, diff_genes][adata.obs[condition_key] == exp_key]
        pred_diff = pred[:, diff_genes][pred.obs[condition_key] == exp_key]
        x_diff = np.average(pred_diff.X, axis=0)
        y_diff = np.average(real_diff.X, axis=0)
        if R2_type == "R2":
            r2_diff = r2_score(y_diff, x_diff)
        if R2_type == "Pearson R2":
            m, b, pearson_r_diff, p_value_diff, std_err_diff = stats.linregress(
                y_diff, x_diff
            )
            r2_diff = pearson_r_diff ** 2
        if verbose:
            print(f"Top {len(diff_genes)} DEGs var: ", r2_diff)
    x = np.average(pred.X, axis=0)
    y = np.average(real.X, axis=0)
    if R2_type == "R2":
        r2 = r2_score(y, x)
    if R2_type == "Pearson R2":
        m, b, pearson_r, p_value, std_err = stats.linregress(y, x)
        r2 = pearson_r ** 2
    if verbose:
        print("All genes var: ", r2)
    df = pd.DataFrame({f"{exp_key}_true": x, f"{exp_key}_pred": y})

    plt.figure(figsize=figsize)
    ax = sns.regplot(x=f"{exp_key}_true", y=f"{exp_key}_pred", data=df)
    ax.tick_params(labelsize=fontsize)
    if "range" in kwargs:
        start, stop, step = kwargs.get("range")
        ax.set_xticks(np.arange(start, stop, step))
        ax.set_yticks(np.arange(start, stop, step))
    ax.set_xlabel("true", fontsize=fontsize)
    ax.set_ylabel("pred", fontsize=fontsize)
    if gene_list is not None:
        for i in gene_list:
            j = adata.var_names.tolist().index(i)
            x_bar = x[j]
            y_bar = y[j]
            plt.text(x_bar, y_bar, i, fontsize=fontsize, color="black")
            plt.plot(x_bar, y_bar, "o", color="red", markersize=5)
    if title is None:
        plt.title(f"", fontsize=fontsize, fontweight="bold")
    else:
        plt.title(title, fontsize=fontsize, fontweight="bold")
    ax.text(
        max(x) - max(x) * x_coeff,
        max(y) - y_coeff * max(y),
        r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r2:.2f}",
        fontsize=fontsize,
    )
    if diff_genes is not None:
        ax.text(
            max(x) - max(x) * x_coeff,
            max(y) - (y_coeff + 0.15) * max(y),
            r"$\mathrm{R^2_{\mathrm{\mathsf{DEGs}}}}$= " + f"{r2_diff:.2f}",
            fontsize=fontsize,
        )
    plt.savefig(f"{path_to_save}", bbox_inches="tight", dpi=100)
    if show:
        plt.show()
    plt.close()
    if diff_genes is not None:
        return r2, r2_diff
    else:
        return r2


def plot_r2_matrix(pred, adata, de_genes=None, **kwds):
    """Plots a pairwise R2 heatmap between predicted and control conditions.

    Params
    ------
    pred : `AnnData`
        Must have the field `cov_drug_dose_name`
    adata : `AnnData`
        Original gene expression data, with the field `cov_drug_dose_name`.
    de_genes : `dict`
        Dictionary of de_genes, where the keys
        match the categories in `cov_drug_dose_name`
    """
    r2s_mean = defaultdict(list)
    r2s_var = defaultdict(list)
    conditions = pred.obs["cov_drug_dose_name"].cat.categories
    for cond in conditions:
        if de_genes:
            degs = de_genes[cond]
            y_pred = pred[:, degs][pred.obs["cov_drug_dose_name"] == cond].X
            y_true_adata = adata[:, degs]
        else:
            y_pred = pred[pred.obs["cov_drug_dose_name"] == cond].X
            y_true_adata = adata

        # calculate r2 between pairwise
        for cond_real in conditions:
            y_true = y_true_adata[
                y_true_adata.obs["cov_drug_dose_name"] == cond_real
            ].X.toarray()
            r2s_mean[cond_real].append(
                r2_score(y_true.mean(axis=0), y_pred.mean(axis=0))
            )
            r2s_var[cond_real].append(r2_score(y_true.var(axis=0), y_pred.var(axis=0)))

    for r2_dict in [r2s_mean, r2s_var]:
        r2_df = pd.DataFrame.from_dict(r2_dict, orient="index")
        r2_df.columns = conditions

        plt.figure(figsize=(5, 5))
        p = sns.heatmap(
            data=r2_df,
            vmin=max(r2_df.min(0).min(), 0),
            cmap="Blues",
            cbar=False,
            annot=True,
            fmt=".2f",
            annot_kws={"fontsize": 5},
            **kwds,
        )
        plt.xticks(fontsize=6)
        plt.yticks(fontsize=6)
        plt.xlabel("y_true")
        plt.ylabel("y_pred")
        plt.show()


def arrange_history(history):

    print(history.keys())


class CPAHistory:
    """
    A wrapper for automatic plotting history of CPA model..
    """

    def __init__(self, cpa_api, fileprefix=None):
        """
        Parameters
        ----------
        cpa_api : dict
            cpa api instance
        fileprefix : str, optional (default: None)
            Prefix (with path) to the filename to save all embeddings in a
            standartized manner. If None, embeddings are not saved to file.
        """
        self.history = cpa_api.history
        self.time = self.history["elapsed_time_min"]
        self.losses_list = [
            "loss_reconstruction",
            "loss_adv_drugs",
            "loss_adv_covariates",
        ]
        self.penalties_list = ["penalty_adv_drugs", "penalty_adv_covariates"]

        subset_keys = ["epoch"] + self.losses_list + self.penalties_list

        self.losses = pd.DataFrame(
            dict((k, self.history[k]) for k in subset_keys if k in self.history)
        )

        self.header = ["mean", "mean_DE", "var", "var_DE"]
        self.eval_metrics = False
        if 'perturbation disentanglement' in list (self.history):               #check that metrics were evaluated
            self.eval_metrics = True
            self.metrics = pd.DataFrame(columns=["epoch", "split"] + self.header)
            for split in ["training", "test", "ood"]:
                df_split = pd.DataFrame(np.array(self.history[split]), columns=self.header)
                df_split["split"] = split
                df_split["epoch"] = self.history["stats_epoch"]
                self.metrics = pd.concat([self.metrics, df_split])
            self.covariate_names = list(cpa_api.datasets["training"].covariate_names)
            self.disent = pd.DataFrame(
                dict(
                    (k, self.history[k])
                    for k in 
                    ['perturbation disentanglement'] 
                    + [f'{cov} disentanglement' for cov in self.covariate_names]
                    if k in self.history
                )
            )
            self.disent["epoch"] = self.history["stats_epoch"]
        self.fileprefix = fileprefix

    def print_time(self):
        print(f"Computation time: {self.time:.0f} min")

    def plot_losses(self, filename=None):
        """
        Parameters
        ----------
        filename : str (default: None)
            Name of the file to save the plot. If None, will automatically
            generate name from prefix file.
        """
        if filename is None:
            if self.fileprefix is None:
                filename = None
            else:
                filename = f"{self.fileprefix}_history_losses.png"

        fig, ax = plt.subplots(1, 4, sharex=True, sharey=False, figsize=(12, 3))

        i = 0
        for i in range(4):
            if i < 3:
                ax[i].plot(
                    self.losses["epoch"].values, self.losses[self.losses_list[i]].values
                )
                ax[i].set_title(self.losses_list[i], fontweight="bold")
            else:
                ax[i].plot(
                    self.losses["epoch"].values, self.losses[self.penalties_list].values
                )
                ax[i].set_title("Penalties", fontweight="bold")
        sns.despine()
        plt.tight_layout()

        if filename:
            save_to_file(fig, filename)

    def plot_r2_metrics(self, epoch_min=0, filename=None):
        """
        Parameters
        ----------
        epoch_min : int (default: 0)
            Epoch from which to show metrics history plot. Done for readability.

        filename : str (default: None)
            Name of the file to save the plot. If None, will automatically
            generate name from prefix file.
        """

        assert self.eval_metrics == True, 'The evaluation metrics were not computed'

        if filename is None:
            if self.fileprefix is None:
                filename = None
            else:
                filename = f"{self.fileprefix}_history_metrics.png"

        df = self.metrics.melt(id_vars=["epoch", "split"])
        col_dict = dict(
            zip(["training", "test", "ood"], ["#377eb8", "#4daf4a", "#e41a1c"])
        )
        fig, axs = plt.subplots(2, 2, sharex=True, sharey=False, figsize=(7, 5))
        ax = plt.gca()
        i = 0
        for i1 in range(2):
            for i2 in range(2):
                sns.lineplot(
                    data=df[
                        (df["variable"] == self.header[i]) & (df["epoch"] > epoch_min)
                    ],
                    x="epoch",
                    y="value",
                    palette=col_dict,
                    hue="split",
                    ax=axs[i1, i2],
                )
                axs[i1, i2].set_title(self.header[i], fontweight="bold")
                i += 1
        sns.despine()
        plt.tight_layout()
        if filename:
            save_to_file(fig, filename)

    def plot_disentanglement_metrics(self, epoch_min=0, filename=None):
        """
        Parameters
        ----------
        epoch_min : int (default: 0)
            Epoch from which to show metrics history plot. Done for readability.

        filename : str (default: None)
            Name of the file to save the plot. If None, will automatically
            generate name from prefix file.
        """
        assert self.eval_metrics == True, 'The evaluation metrics were not computed'
        
        if filename is None:
            if self.fileprefix is None:
                filename = None
            else:
                filename = f"{self.fileprefix}_history_metrics.png"

        fig, axs = plt.subplots(
            1, 
            1+len(self.covariate_names), 
            sharex=True, 
            sharey=False, 
            figsize=(2 + 5*(len(self.covariate_names)), 3)
        )

        ax = plt.gca()
        sns.lineplot(
            data=self.disent[self.disent["epoch"] > epoch_min],
            x="epoch",
            y="perturbation disentanglement",
            legend=False,
            ax=axs[0],
        )
        axs[0].set_title("perturbation disent", fontweight="bold")

        for i, cov in enumerate(self.covariate_names):
            sns.lineplot(
                data=self.disent[self.disent['epoch'] > epoch_min],
                x="epoch",
                y=f"{cov} disentanglement",
                legend=False,
                ax=axs[1+i]
            )
            axs[1+i].set_title(f"{cov} disent", fontweight="bold")
        fig.tight_layout()
        sns.despine()

================================================
FILE: cpa/train.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import argparse
import json
import os
import time
from collections import defaultdict

import numpy as np
import torch
from cpa.data import load_dataset_splits
from cpa.model import CPA, MLP
from sklearn.metrics import r2_score
from torch.autograd import Variable
from torch.distributions import NegativeBinomial
from torch import nn


def pjson(s):
    """
    Prints a string in JSON format and flushes stdout
    """
    print(json.dumps(s), flush=True)

def _convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6):
    r"""NB parameterizations conversion
    Parameters
    ----------
    mu :
        mean of the NB distribution.
    theta :
        inverse overdispersion.
    eps :
        constant used for numerical log stability. (Default value = 1e-6)
    Returns
    -------
    type
        the number of failures until the experiment is stopped
        and the success probability.
    """
    assert (mu is None) == (
        theta is None
    ), "If using the mu/theta NB parameterization, both parameters must be specified"
    logits = (mu + eps).log() - (theta + eps).log()
    total_count = theta
    return total_count, logits

def evaluate_disentanglement(autoencoder, dataset):
    """
    Given a CPA model, this function measures the correlation between
    its latent space and 1) a dataset's drug vectors 2) a datasets covariate
    vectors.

    """
    with torch.no_grad():
        _, latent_basal = autoencoder.predict(
            dataset.genes,
            dataset.drugs,
            dataset.covariates,
            return_latent_basal=True,
        )
    
    mean = latent_basal.mean(dim=0, keepdim=True)
    stddev = latent_basal.std(0, unbiased=False, keepdim=True)
    normalized_basal = (latent_basal - mean) / stddev
    criterion = nn.CrossEntropyLoss()
    pert_scores, cov_scores = 0, []

    def compute_score(labels):
        if len(np.unique(labels)) > 1:
            unique_labels = set(labels)
            label_to_idx = {labels: idx for idx, labels in enumerate(unique_labels)}
            labels_tensor = torch.tensor(
                [label_to_idx[label] for label in labels], dtype=torch.long, device=autoencoder.device
            )
            assert normalized_basal.size(0) == len(labels_tensor)
            #might have to perform a train/test split here
            dataset = torch.utils.data.TensorDataset(normalized_basal, labels_tensor)
            data_loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)

            # 2 non-linear layers of size <input_dimension>
            # followed by a linear layer.
            disentanglement_classifier = MLP(
                [normalized_basal.size(1)]
                + [normalized_basal.size(1) for _ in range(2)]
                + [len(unique_labels)]
            ).to(autoencoder.device)
            optimizer = torch.optim.Adam(disentanglement_classifier.parameters(), lr=1e-2)

            for epoch in range(50):
                for X, y in data_loader:
                    pred = disentanglement_classifier(X)
                    loss = Variable(criterion(pred, y), requires_grad=True)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

            with torch.no_grad():
                pred = disentanglement_classifier(normalized_basal).argmax(dim=1)
                acc = torch.sum(pred == labels_tensor) / len(labels_tensor)
            return acc.item()
        else:
            return 0

    if dataset.perturbation_key is not None:
        pert_scores = compute_score(dataset.drugs_names)
    for cov in list(dataset.covariate_names):
        cov_scores = []
        if len(np.unique(dataset.covariate_names[cov])) == 0:
            cov_scores = [0]
            break
        else:
            cov_scores.append(compute_score(dataset.covariate_names[cov]))
        return [np.mean(pert_scores), *[np.mean(cov_score) for cov_score in cov_scores]]


def evaluate_r2(autoencoder, dataset, genes_control):
    """
    Measures different quality metrics about an CPA `autoencoder`, when
    tasked to translate some `genes_control` into each of the drug/covariates
    combinations described in `dataset`.

    Considered metrics are R2 score about means and variances for all genes, as
    well as R2 score about means and variances about differentially expressed
    (_de) genes.
    """

    mean_score, var_score, mean_score_de, var_score_de = [], [], [], []
    num, dim = genes_control.size(0), genes_control.size(1)

    total_cells = len(dataset)

    for pert_category in np.unique(dataset.pert_categories):
        # pert_category category contains: 'celltype_perturbation_dose' info
        de_idx = np.where(
            dataset.var_names.isin(np.array(dataset.de_genes[pert_category]))
        )[0]

        idx = np.where(dataset.pert_categories == pert_category)[0]

        if len(idx) > 30:
            emb_drugs = dataset.drugs[idx][0].view(1, -1).repeat(num, 1).clone()
            emb_covars = [
                covar[idx][0].view(1, -1).repeat(num, 1).clone()
                for covar in dataset.covariates
            ]

            genes_predict = (
                autoencoder.predict(genes_control, emb_drugs, emb_covars).detach().cpu()
            )

            mean_predict = genes_predict[:, :dim]
            var_predict = genes_predict[:, dim:]

            if autoencoder.loss_ae == 'nb':
                counts, logits = _convert_mean_disp_to_counts_logits(
                    torch.clamp(
                        torch.Tensor(mean_predict),
                        min=1e-4,
                        max=1e4,
                    ),
                    torch.clamp(
                        torch.Tensor(var_predict),
                        min=1e-4,
                        max=1e4,
                    )
                )
                dist = NegativeBinomial(
                    total_count=counts,
                    logits=logits
                )
                nb_sample = dist.sample().cpu().numpy()
                yp_m = nb_sample.mean(0)
                yp_v = nb_sample.var(0)
            else:
                # predicted means and variances
                yp_m = mean_predict.mean(0)
                yp_v = var_predict.mean(0)
            # estimate metrics only for reasonably-sized drug/cell-type combos

            y_true = dataset.genes[idx, :].numpy()

            # true means and variances
            yt_m = y_true.mean(axis=0)
            yt_v = y_true.var(axis=0)

            mean_score.append(r2_score(yt_m, yp_m))
            var_score.append(r2_score(yt_v, yp_v))

            mean_score_de.append(r2_score(yt_m[de_idx], yp_m[de_idx]))
            var_score_de.append(r2_score(yt_v[de_idx], yp_v[de_idx]))

    return [
        np.mean(s) if len(s) else -1
        for s in [mean_score, mean_score_de, var_score, var_score_de]
    ]


def evaluate(autoencoder, datasets):
    """
    Measure quality metrics using `evaluate()` on the training, test, and
    out-of-distribution (ood) splits.
    """

    autoencoder.eval()
    with torch.no_grad():
        stats_test = evaluate_r2(
            autoencoder, 
            datasets["test"].subset_condition(control=False), 
            datasets["test"].subset_condition(control=True).genes
        )

        disent_scores = evaluate_disentanglement(autoencoder, datasets["test"])
        stats_disent_pert = disent_scores[0]
        stats_disent_cov = disent_scores[1:]

        evaluation_stats = {
            "training": evaluate_r2(
                autoencoder,
                datasets["training"].subset_condition(control=False),
                datasets["training"].subset_condition(control=True).genes,
            ),
            "test": stats_test,
            "ood": evaluate_r2(
                autoencoder, datasets["ood"], datasets["test"].subset_condition(control=True).genes
            ),
            "perturbation disentanglement": stats_disent_pert,
            "optimal for perturbations": 1 / datasets["test"].num_drugs
            if datasets["test"].num_drugs > 0
            else None,
        }
        if len(stats_disent_cov) > 0:
            for i in range(len(stats_disent_cov)):
                evaluation_stats[
                    f"{list(datasets['test'].covariate_names)[i]} disentanglement"
                ] = stats_disent_cov[i]
                evaluation_stats[
                    f"optimal for {list(datasets['test'].covariate_names)[i]}"
                ] = 1 / datasets["test"].num_covariates[i]
    autoencoder.train()
    return evaluation_stats

def prepare_cpa(args, state_dict=None):
    """
    Instantiates autoencoder and dataset to run an experiment.
    """

    device = "cuda" if torch.cuda.is_available() else "cpu"

    datasets = load_dataset_splits(
        args["data"],
        args["perturbation_key"],
        args["dose_key"],
        args["covariate_keys"],
        args["split_key"],
        args["control"],
    )

    autoencoder = CPA(
        datasets["training"].num_genes,
        datasets["training"].num_drugs,
        datasets["training"].num_covariates,
        device=device,
        seed=args["seed"],
        loss_ae=args["loss_ae"],
        doser_type=args["doser_type"],
        patience=args["patience"],
        hparams=args["hparams"],
        decoder_activation=args["decoder_activation"],
    )
    if state_dict is not None:
        autoencoder.load_state_dict(state_dict)

    return autoencoder, datasets


def train_cpa(args, return_model=False):
    """
    Trains a CPA autoencoder
    """

    autoencoder, datasets = prepare_cpa(args)

    datasets.update(
        {
            "loader_tr": torch.utils.data.DataLoader(
                datasets["training"],
                batch_size=autoencoder.hparams["batch_size"],
                shuffle=True,
            )
        }
    )

    pjson({"training_args": args})
    pjson({"autoencoder_params": autoencoder.hparams})
    args["hparams"] = autoencoder.hparams

    start_time = time.time()
    for epoch in range(args["max_epochs"]):
        epoch_training_stats = defaultdict(float)

        for data in datasets["loader_tr"]:
            genes, drugs, covariates = data[0], data[1], data[2:]

            minibatch_training_stats = autoencoder.update(genes, drugs, covariates)

            for key, val in minibatch_training_stats.items():
                epoch_training_stats[key] += val

        for key, val in epoch_training_stats.items():
            epoch_training_stats[key] = val / len(datasets["loader_tr"])
            if not (key in autoencoder.history.keys()):
                autoencoder.history[key] = []
            autoencoder.history[key].append(epoch_training_stats[key])
        autoencoder.history["epoch"].append(epoch)

        ellapsed_minutes = (time.time() - start_time) / 60
        autoencoder.history["elapsed_time_min"] = ellapsed_minutes

        # decay learning rate if necessary
        # also check stopping condition: patience ran out OR
        # time ran out OR max epochs achieved
        stop = ellapsed_minutes > args["max_minutes"] or (
            epoch == args["max_epochs"] - 1
        )

        if (epoch % args["checkpoint_freq"]) == 0 or stop:
            evaluation_stats = evaluate(autoencoder, datasets)
            for key, val in evaluation_stats.items():
                if not (key in autoencoder.history.keys()):
                    autoencoder.history[key] = []
                autoencoder.history[key].append(val)
            autoencoder.history["stats_epoch"].append(epoch)

            pjson(
                {
                    "epoch": epoch,
                    "training_stats": epoch_training_stats,
                    "evaluation_stats": evaluation_stats,
                    "ellapsed_minutes": ellapsed_minutes,
                }
            )

            torch.save(
                (autoencoder.state_dict(), args, autoencoder.history),
                os.path.join(
                    args["save_dir"],
                    "model_seed={}_epoch={}.pt".format(args["seed"], epoch),
                ),
            )

            pjson(
                {
                    "model_saved": "model_seed={}_epoch={}.pt\n".format(
                        args["seed"], epoch
                    )
                }
            )
            stop = stop or autoencoder.early_stopping(np.mean(evaluation_stats["test"]))
            if stop:
                pjson({"early_stop": epoch})
                break

    if return_model:
        return autoencoder, datasets


def parse_arguments():
    """
    Read arguments if this script is called from a terminal.
    """

    parser = argparse.ArgumentParser(description="Drug combinations.")
    # dataset arguments
    parser.add_argument("--data", type=str, required=True)
    parser.add_argument("--perturbation_key", type=str, default="condition")
    parser.add_argument("--control", type=str, default=None)
    parser.add_argument("--dose_key", type=str, default="dose_val")
    parser.add_argument("--covariate_keys", nargs="*", type=str, default="cell_type")
    parser.add_argument("--split_key", type=str, default="split")
    parser.add_argument("--loss_ae", type=str, default="gauss")
    parser.add_argument("--doser_type", type=str, default="sigm")
    parser.add_argument("--decoder_activation", type=str, default="linear")

    # CPA arguments (see set_hparams_() in cpa.model.CPA)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--hparams", type=str, default="")

    # training arguments
    parser.add_argument("--max_epochs", type=int, default=2000)
    parser.add_argument("--max_minutes", type=int, default=300)
    parser.add_argument("--patience", type=int, default=20)
    parser.add_argument("--checkpoint_freq", type=int, default=20)

    # output folder
    parser.add_argument("--save_dir", type=str, required=True)
    # number of trials when executing cpa.sweep
    parser.add_argument("--sweep_seeds", type=int, default=200)
    return dict(vars(parser.parse_args()))


if __name__ == "__main__":
    train_cpa(parse_arguments())


================================================
FILE: datasets/.gitkeep
================================================


================================================
FILE: notebooks/demo.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A tour of the CPA model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# some standard packages to assist this tutorial\n",
    "import cpa\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import scanpy as sc\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training your model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Init your model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "adata = sc.read('../../cpa-reproducibility/datasets/GSM_new.h5ad')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING. Special characters ('_') were found in: 'dummy_cov'. They will be replaced with '-'. Be careful, it may lead to errors downstream.\n",
      "Creating 'cov_drug_dose_name' field.\n",
      "Ranking genes for DE genes.\n",
      "WARNING: Default of the method has been changed to 't-test' from 't-test_overestim_var'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Trying to set attribute `.obs` of view, copying.\n",
      "... storing 'cov_drug_dose_name' as categorical\n",
      "Trying to set attribute `.obs` of view, copying.\n",
      "... storing 'dummy_cov' as categorical\n",
      "Trying to set attribute `.obs` of view, copying.\n",
      "... storing 'covars_comb' as categorical\n"
     ]
    }
   ],
   "source": [
    "cpa_api = cpa.api.API(\n",
    "    adata, \n",
    "    perturbation_key='condition',\n",
    "    doser_type='logsigm',\n",
    "    split_key='split',\n",
    "    covariate_keys=[],\n",
    "    only_parameters=False,\n",
    "    hparams={}, \n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also load a pretrained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model from:\t../../cpa-reproducibility/notebooks/sciplex2_other.pt\n"
     ]
    }
   ],
   "source": [
    "cpa_api = cpa.api.API(\n",
    "    adata, \n",
    "    pretrained='../../cpa-reproducibility/notebooks/sciplex2_other.pt'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can print parameters of the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CPA(\n",
       "  (encoder): MLP(\n",
       "    (network): Sequential(\n",
       "      (0): Linear(in_features=4999, out_features=128, bias=True)\n",
       "      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): ReLU()\n",
       "      (3): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (5): ReLU()\n",
       "      (6): Linear(in_features=128, out_features=128, bias=True)\n",
       "    )\n",
       "  )\n",
       "  (decoder): MLP(\n",
       "    (network): Sequential(\n",
       "      (0): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): ReLU()\n",
       "      (3): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (5): ReLU()\n",
       "      (6): Linear(in_features=128, out_features=9998, bias=True)\n",
       "    )\n",
       "  )\n",
       "  (adversary_drugs): MLP(\n",
       "    (network): Sequential(\n",
       "      (0): Linear(in_features=128, out_features=64, bias=True)\n",
       "      (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): ReLU()\n",
       "      (3): Linear(in_features=64, out_features=64, bias=True)\n",
       "      (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (5): ReLU()\n",
       "      (6): Linear(in_features=64, out_features=5, bias=True)\n",
       "    )\n",
       "  )\n",
       "  (loss_adversary_drugs): BCEWithLogitsLoss()\n",
       "  (dosers): GeneralizedSigmoid()\n",
       "  (covariates_embeddings): Sequential(\n",
       "    (0): Embedding(1, 128)\n",
       "  )\n",
       "  (drug_embeddings): Embedding(5, 128)\n",
       "  (loss_autoencoder): GaussianNLLLoss()\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cpa_api.model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Start training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Results will be saved to the folder: /tmp/\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Rec: -1.6080, AdvPert: 0.49, AdvCov: 0.00: 100%|▉| 999/1000 [12:45<00:00,  1.30i"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved to: /tmp/None\n",
      "{'ellapsed_minutes': 12.652799173196156,\n",
      " 'epoch': 999,\n",
      " 'evaluation_stats': {'cell_type disentanglement': 0.0,\n",
      "                      'ood': [0.9116713597573521,\n",
      "                              0.8682491544314631,\n",
      "                              0.6297205495339959,\n",
      "                              -0.2600031534834341],\n",
      "                      'optimal for cell_type': 1.0,\n",
      "                      'optimal for perturbations': 0.2,\n",
      "                      'perturbation disentanglement': 0.2595809996128082,\n",
      "                      'test': [0.9242513761083746,\n",
      "                               0.7737687104344115,\n",
      "                               0.8351889804394074,\n",
      "                               0.0565223552642065],\n",
      "                      'training': [0.9295517385365168,\n",
      "                                   0.7780530537975899,\n",
      "                                   0.8627033078033616,\n",
      "                                   0.12276226755805468]},\n",
      " 'training_stats': defaultdict(<class 'float'>,\n",
      "                               {'loss_adv_covariates': 0.0,\n",
      "                                'loss_adv_drugs': 0.49258487340476775,\n",
      "                                'loss_reconstruction': -1.6079967220624287,\n",
      "                                'penalty_adv_covariates': 0.0,\n",
      "                                'penalty_adv_drugs': 2.152925470492543e-07})}\n",
      "Stop epoch: 999\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved to: /tmp/None\n"
     ]
    }
   ],
   "source": [
    "cpa_api.train(\n",
    "    max_epochs=1000, \n",
    "    run_eval=True, \n",
    "    checkpoint_freq=20,\n",
    "    filename=None, \n",
    "    max_minutes=2*60\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize training history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computation time: 19 min\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA1gAAADQCAYAAAAalMCAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAABD/UlEQVR4nO3deZxcVZ338c+3qruzQwgkEEIwgEFFZDMCiggIaMAl4uMCLoDjPJERVMbREYdnHB0fZ3h03BAFERlBRWQUJEpkEUFkRoSwJ4QlxAghgSQsISFLp7t/zx/3VKdSqe6u7qruqq7+vl+velXde8+599zqOn3vuWdTRGBmZmZmZmbVy9U7AWZmZmZmZs3CBSwzMzMzM7MacQHLzMzMzMysRlzAMjMzMzMzqxEXsMzMzMzMzGrEBSwzMzMzM7MaGVEFLEmRXjPqnRarnKSj099tWb3TYrUzHPOjpGUpze+q8X5PT/u9r5b7teHH+aKxFOXNW+udFrNaqPSeStIXU7gfDU3KmsuIKmDZwAzlBV/SrelYpxetXg58G7h0sI9vZmZW5CGy688vKo0wHAvJNnSKHkgUXmsk3SBpVh3TVO43ewfZb//G+qRqeGupdwJGOkmtEbGl3umohcE6l4hYApxd6/2aDSfN9L/CbDhIee5O4M56p8Wa0m+AvwBHAW8BXifplRGxqr7JykTE9cD19U7HcDWia7AkTZZ0iaQnJL0o6Q5Js4u2Hy/pbkkvSVor6R5J707bDpH0xxRvvaSFkv6uj+MVmhrcLulCSeuAc9O2d0q6M+3vr5K+LmlsUdzDJd0oaXU63h2F7ZIOkHR9egqyWtKvJb2iKG7hack5ku5N5zNf0k5p+06S/ivF3yTpL5K+n7ZF0Sn8Je3n6KKq419IukrSRuCDkn6U1n8xxZ9ReDJSlJ49JF2WznOTpMWSXpeaYByVgv1nYT/lqrOrPWdrPHXIjwekYzwvaYuklZIukNRWFObjkp5Mv7HPlsQ/Jv3GHihad1Rat7CPY++e8vNLkv4I7FWyvTvfSDpD0grgxh7ywja1vpJ2lPTz9F08IOnTafsLabsk/Vs6r82Snlb29HTn3tJs9THC8sVYSV+S9LCkjZKWS/rfaVurpM+nbS8pu278vaScpPHp/DokTSnaV2Hdbul7ujd9R1uUXX++VHTsstdnlTQRlDRV0m3Krj1b0nfwE0kT0/ay18y07W8k3Z/S9Zikf5LUkrbNUHZNez6d+yPF6bOm9MOI+CRwbFreCXi9pP0lXSdpVfp9/VLSnoVI2nptOEvSo5LWpd9gW9reZx4u1tNvVmWaCKqXe1X1ci85Eo3YApakHDAP+CiwBrgWeC1wnaQjUrD/BA4EfpleXcD+adv5wBvJqk5/Bjyf4lfiCODNwBXAUklvTcffK72vAT4NfDel9dXArcDxwGLg58AuQJukqcAfgLeSVefeC7wduFXbFya+ADwAbAJOSMcA+AfgPcBj6ZwXA29I275dFP8/0/LyonX/C9gH+DHwdF8nnjLi74FTUzp+TPbd7U7WBOOpFPSmdKw7yuyjFudsDaRO+XEy0J72dSnQCZxJ+o2km6Lvkv02bwQ+BEwvin8r8ATwGkmvSuvel94v7+PYV5Dl5yfInmB+rpewXwF+C/xPH/ssOD+l40XgbuCLJduPBT5Pdr4/BG4DXgNMqHD/NkRGYL74Adn/7CkpvfcA+6ZtXwH+jex3eiXZNfAbwOciYj1wDZAnuyYBvA0YB9wYEU8D08i+wyvJrjsTgC9IOrkkDdtcn8ukcQIwBvh1Su/zwAeB89L2stdMSR8jy287kV3rOtM5nZvC/l+ya9pdZN/Tk8BhvXxX1gRSHj+qaNUasv/JxwO3A38G3g3cIGlUSfQvkV0XWsh+gx9O63vNw2X0dZ9XSGuv96r0fi858kTEiHkBkV4zgEPT53XAuLT9m2ndFWn5GWA92Q/mFWQF0nza9ucU9m/ILmathW29HP/0FOdFYGLR+uvS+huBbwHfS8tdwFjggrR8bVGcfErPP6ZttxRtuzetm5uWl6Xlz6blL6Xl36Tl/5eWv5W+l/HF51L8vRWt+2Ja9zjQUrT+R2n9F9PyjEL8tPyetLwCGFsUrzW935q2n1607ei0bllarvqc/ar/q975McV7E1lh4xvAzYV8mLZdkpZ/mJYnkV20AnhXWveVwu89pedpsovZ7r0cc4+ic5+e1n09Ld+XlmcUhXlzT3khrevOM2T/Fzan5aPS9r9Pyy+k5RPS8u/S/qYAAnL1/k34NaLzxS5F535w0frW9PtcX/K7npOWV6Tl49PyrWn5v9LyyWk5B5wI/J/0Hd6Vtl+ctp9O+evz6cX7TesOJrsO/QdZwS6AR8v9DYvWLUrrfkF2rf1JWn46bf95Wv582v/oSv5Wfg2/F1vvT0pf88getgVZ379vpdeqtG52ye/rvWn5srR8QdExesvDR7P9daS3+7wfpeW+7lV7vZccaa+R3AdrRnp/MiJeSp8fTu8vS+8fA75G9o8a4FngLLInYJ8m+3FdwtZ//l8g+8fdl0UR8UKZtByfXgUC9mZr86Hu2pyI6ISsWUFatbgo3sPAQUXnUXBvei8ce3x6/xbZE9CPA58iuxD+XNKHI6Krj3O5MyI6etmeL1kunMuDEbGhsDL617dkRnqv5pytscxI70OWHyV9nuyJeKnJ6X1aen8EICKek/QssFtR2MuAfwLeT1aruivZRWxFz6favd+NEfFk+vxoL+H/u5dtsG0e2wUoNAMp5I+HSsLfSPZdfRi4Ja27i+yGdWUfx7KhNSO9j4R8Ubg2tEdE4f82EbElNfsbl1YVfteF72Fqavp0M9lT9yMlvZysMLUW+FUKdyEwt5fzKii9Pm9D0ilktVt97afUjPT+v0rW7yppPNnN7B7Al8m+/83Ad4DPYs3qN8ASsjx7N1l/pwvStlelV7GXlyyXvb+pIA8P1Iz03tO96rcY+L1k0xmxTQTJniAATNfWvk6FPjx/Te+/jYiZZDct7wF2JnsyB7AgIg4kq+4/muwp23mF9tR92NxDWj4ZESq8gH0iYiFZEyIoai6Q2p2rKO4ri/ZXeh4FhYJQlKx/LiJmkzV9OJDsSdsHyJpKQPZ0Asr/XkrPpXATsEN6379ke+FcXiNpTGFl0ffW2cuxCpal92rO2RrLsvQ+lPnx/en9C2RNLArN9JTeC81VXwEgaVI6ZreIeJSsluCVZLWk0HczqMJ+x0gqNK3at6fAEVGcxwr5a0JKU2tJ3DVktQkAM9N7cT6BrEB2FjCR7IJ9OfA64G/7SLcNvWXpfSTki8K1oU3SQYWVKa2r2frbL/yeC9/DyohoTzdwPyW7dlxK9kT9vyJiU8l5nU6WBy4sOa+C0mtaqcJ+LgFGFS0X76fcNXNZen9nyXV+78iaOC6NiCOAHcme/j8HfKbof4Q1nx9GxN9HxP+NiN9GVk20LG27uuR3MpWsiWmxnu5v+srD5fR2n1dQSFtP96p93UuOKCO5BmsB2QXgMOCPkhYBp5D9UL+XwtyrrDP5E2xtY/5Cev+1pDxZE7kdyf7RPsvWAkJ/XED2tO2rkt4AbAQOILto7QVcRHbzMyd1tH0UOBJ4PVkzg38CjpE0j+zp9cFkzUYqHVb2HEnvBB4kuzmbkdavTe9Pkj0tvUDSo2xtM15O4YnKaZI6yNoFF5tP1j53Jtn3+weyC+Y3yNr0Fp7of0rSAWTteEvV4pytsdQjPz6T3j9E9vTtXSXbryDr+3K6pNFkDwvK/c+8PKX7SLKmXNf0dqIRsVzSbWRNOG6UdBdbL4h9eRTYAEySdDlZrcGUon13SvoZcBrwM0m/Y/sn5m8ga8r7J7KbuMLF74UK02BDZyTlizWSriC7IbtZ0q/ICoaPRcTnJF0IfAa4QtL1wDtT1AuKdnMZ2c3kkUVpKD6vHYFPko3YdlJv6elF4fs5gayQdmKZMOWumReQ/c1+IukashvZWWTNv44GvqdsoKaHyb7PXcj+TusHmE4bnn5Kdn/zbkk3kBVq9iHrpzWTrYWc3vSVh8up5D6vr3vVvu4lR5Z6t1EcyhclbUzJbkwuJfthrSMbivXtReHPJ7swbST7gdxCahtOlgEWkz1VW5/iHtvH8U+npC130bZ3kTUBXEt2cbwTOLto++FkAz+sSce8g9SHiaxwcQPZhXMNWbXzq4riLkvHPTotn12cDuAdZBfxF8gGg3gU+ERR/Pen76grxduFkra5RWFHkf2DeJGsadJnCt97UZjpZBe+J9LxFgOvS9teA9wPbEnx3kP59sJVnbNf9X81QH7cj+wGdhNZp+J/pqgfVArzCbJmR8+RPQ0s/K7eVRRmElv7Pf1nhee+B1l+3kCWl/+9+NiU9F0sifthslqEVWQ3eH+iqN8i2U3kVek7fCB9NwE8k7bPTMdeRXYRXJH2M6revwm/Rny+GAv8K1nzw03pGP87bWsj6z9VeMjwMNn1JV+yj0LfqqWAitYfmb6LjWR9Xb6Zwv0qbT+dMteI0vVkg3v8Pu3nvvRdBKmPYwpT7popsoLpfelvURjM4EMpTmHbunR+DwLvr/fv0a/av8rll5LtB5ANovJMyruLyZrfjU/bS/9HfItt+0r1mocpf09V0X0evdyr0se95Eh7KX0pZmbWBCRNANZH+ude1B7/9og4stfIZmZmVrWR3ERwUEg6lKyJQ6k7I6Jcx1gzGyT1zI+SzmL7TsmQjfS0ZBAPfSzwfyT9lqzpxkfS+vMH8Zg2jIzQfGFmNmRcwKq9/chGTyl1GeVHHjKzwVPP/Pgetp3fpOBXZCNHDZYnyDrx/wNZE8D7ga9HxH/1GstGkpGYL8zMhoybCJqZmZmZmdXISB6m3ayhSZot6RFJSySdU2b7jpJ+Lel+SYskfaTSuGZmZmY2OIZlAWv27NmFEVT88muoX0MiDa38XbKhgPcDTpG0X0mwM4GHIpvn5mjg65LaKoy7Decpv+r4alrOV37V6dW0nKf8qtOr34ZlAWvNmjX1ToLZYDsUWBIRSyOiHbgSmFMSJoAJacLp8WRDJndUGHcbzlNmted8ZVZbzlM2XNSkgFVBUyZJOj9tf0DSIZXGNRuhprF10mXI5oOZVhLmAuBVZPMYPQh8KiK6KoxrZmZmZoOg6gJWhc2RTiCb3HImMJdsUstK45qNRCqzrrSa+q1kE1PuDhxENgP7DhXGRdJcSQskLVi9enV1qTUzMzMzoDY1WJU0R5oDXB6ZO4CJkqZWGNdsJFoOTC9a3oOspqrYR4CrU75aAvwFeGWFcYmIiyNiVkTMmjx5ck0Tb2ZmZjZS1aKAVUlzpJ7CVNyUyU/bbYS5C5gpaS9JbcDJwLySME+QTSqLpF2BVwBLK4xrZmZmZoOgFgWsSpoj9RSmoqZMUNnT9q6uYEtnV29pNRsWIqIDOAu4AVgMXBURiySdIemMFOzLwBskPQjcDHwuItb0FHfoz8LMzGzwdXYFHb7/swbSUoN9VNIcqacwbRXErdi5v1rI7xY/w13nHjfQXZg1jIiYD8wvWXdR0ecVwFsqjTsQjz2zjuO/eRvXfPwNHLznTtXuzszMrOY+9uMFrFy7ies+eWS9k2IG1KYGq5LmSPOAU9NogocDayNiZYVxK5YTRAxouHozK+OWR1YBcN0DK+ucEjMzM7PhoeoarIjokFRojpQHLi00ZUrbLyJ7kn4isATYQNY5v8e4A01LTqLL5Suzmik8r1C5xrxmZmZmtp1aNBGspClTAGdWGnegcsra4ZpZbcklLDMza2BuwGSNpCYTDTeKXE50OYeZ1Yxzk5mZNT4/BLTG0lwFLMlPMMxqqLuJYH2TYWZmZjZsNFkBy00EzWopcAnLzMwan+/+rJE0VwHLTQTNBoVcwjIzswblbsLWaJqrgOUmgmY15fxkZmZm1j9NVsCCTt8RmtWcnw6amVkj8zyo1kiaqoCVl5sImtXSlAmjAJg4prXOKTEzMyvPzwCt0TRVAUupiaCfYpjVxgF7TARg+qSx9U2ImZmZ2TDRVAWsXGrH5IEEzWqj0DTQNcPWDCTNlvSIpCWSzimzXZLOT9sfkHRIyfa8pHsl/WboUm1mZsNNUxWw8ulsfDNoVhuFZhfOUjbcScoD3wVOAPYDTpG0X0mwE4CZ6TUXuLBk+6eAxYOcVDMzG+aaqoCl7hos3w3a8FfB0/bPSrovvRZK6pQ0KW1bJunBtG3BwNOQvTtHWRM4FFgSEUsjoh24EphTEmYOcHlk7gAmSpoKIGkP4G3AJUOZaDPrmwdiskbTVAWs7iaCXXVOiFmVKnnaHhFfi4iDIuIg4PPAHyLiuaIgx6Tts6pIR+FYA92FWaOYBjxZtLw8ras0zLeAfwR6vcJImitpgaQFq1evrirBZmY2PDVVActNBK2JVPK0vdgpwM9qnQg3EbQmUu4Zd+kvu2wYSW8HVkXE3X0dJCIujohZETFr8uTJA0mnmQ2Ar1PWSJqqgJVzE0FrHpU8bQdA0lhgNvDLotUB3Cjpbklze4jX55P27hosNxK04W85ML1oeQ9gRYVhjgDeKWkZ2cOON0v6yeAl1cz6Qx6o3RpMUxWw5CaC1jwqedpe8A7gv0uaBx4REYeQNTE8U9KbtttZBU/aXYNlTeQuYKakvSS1AScD80rCzANOTaMJHg6sjYiVEfH5iNgjImakeL+PiA8NaerNzGzYqKqAJWmSpJskPZbedyoTZrqkWyQtlrRI0qeKtn1R0lNFHfVPrCY9OQ8pbc2jkqftBSdT0jwwIlak91XANWRNDvst190HayCxzRpHRHQAZwE3kI0EeFVELJJ0hqQzUrD5wFJgCfAD4ON1SayZ9ZtbWlgjaaky/jnAzRFxXhrl7BzgcyVhOoB/iIh7JE0A7pZ0U0Q8lLZ/MyL+o8p0AJDPuYmgNY3up+3AU2SFqA+UBpK0I3AU8KGideOAXESsS5/fAvzrQBLhebCsmUTEfLJCVPG6i4o+B3BmH/u4Fbh1EJJnZgPkUQSt0VTbRHAOcFn6fBnwrtIAqXnFPenzOrInh2X7klRLnmjYmkSFT9sBTgJujIiXitbtCtwu6X7gTuC6iLi+qvRUE9nMzMxsBKm2BmvXiFgJWUFK0pTeAkuaARwM/Llo9VmSTgUWkNV0Pd9D3LlkEz+y5557lt2/mwhaM+nraXta/hHwo5J1S4EDa5GG7qeCzlJmZmZmFemzBkvS79IkpqWv3oaMLref8WSjnJ0dES+m1RcC+wAHASuBr/cUv5IO+XmPImhWUzmPImhmZsOAb/2skfRZwIqI4yJi/zKva4Fnima5nwqsKrcPSa1khaufRsTVRft+JiI6I6KLrEPxgDrid5+Mmwia1dTWPlj1TYeZmdWPpNmSHpG0JPW5L90uSeen7Q9IOqSSuJI+kbYtkvTVgadvoDHNBke1fbDmAaelz6cB15YGUNYx6ofA4oj4Rsm2qUWLJwELq0lM982g7wbNaqIwt4ifDJqZjUyS8sB3yab92A84RdJ+JcFOAGam11yyFkq9xpV0DFlf/gMi4tVATQY8M2sE1RawzgOOl/QYcHxaRtLukgp9R44APkw2MWPpcOxflfSgpAeAY4C/ryYxHkXQrLYKDy3cRNDMbMQ6FFgSEUsjop1ssu3SbiJzgMsjcwcwMT1E7y3u3wHnRcRm6J5WZMB8lbJGUtUgFxHxLHBsmfUrgBPT59spP2kqEfHhao5fyk0EzWqru4DlPGVmNlJNA54sWl4OHFZBmGl9xN0XOFLSV4BNwGci4q7Sg1cyyJnK32aa1U21NVgNpXAz2OkSlllNbG0i6DxlZjZClSu9lF4UegrTW9wWYCfgcOCzwFXS9r2pKhnkzKzRVDtMe0MpNBH0zaBZbWxtImhmZiPUcmB60fIewIoKw7T1Enc5cHWa4PtOSV3ALsDqgSTS937WSJqqBstNBM1qq3saLOcpM7OR6i5gpqS9JLUBJ5MNclZsHnBqGk3wcGBtmie1t7i/At4MIGlfssLYmgGl0C0ErcE0VQ1Wzk0EzWqqex4sl7DMzEakiOiQdBZwA5AHLo2IRZLOSNsvAuaT9b1fAmwAPtJb3LTrS4FLJS0E2oHTwhcbaxJNVsDyKIJmteR5sMzMLCLmkxWiitddVPQ5gDMrjZvWtwMfqm1KzRpDUzYRdPnKrDa6B7moczrMzMx64+uUNZLmKmCls+l0CcusNrqHaXeeMjOzxuQuWNZomquA5SaCZjWV81XLzMzMrF+asoDlp+3WDCTNlvSIpCWSzimz/bOS7kuvhZI6JU2qJG4/0gD4oYWZmTU4X6asgTRlAcsd8m24k5QHvgucAOwHnCJpv+IwEfG1iDgoIg4CPg/8ISKeqyRuxenoPtbAzsPMzGywlZmf2KyumqyAlb17mHZrAocCSyJiaRpp6UpgTi/hTwF+NsC4PfJEw2ZmZmb901wFrJybM1nTmAY8WbS8PK3bjqSxwGzgl/2JK2mupAWSFqxevbpsItyv0czMhgNfpayRNFcBy8O0W/Mo196hp1/2O4D/jojn+hM3Ii6OiFkRMWvy5Mm9JsZ5yszMGpUbCFqjabICVvbuJoLWBJYD04uW9wBW9BD2ZLY2D+xv3F65WbuZmZlZ/zRXActNBK153AXMlLSXpDayQtS80kCSdgSOAq7tb9xKeGROayYVjMwpSeen7Q9IOiStny7pFkmLJS2S9KmhT72ZmQ0XVRWwJE2SdJOkx9L7Tj2EWybpwTSc9IL+xq+Umwhas4iIDuAs4AZgMXBVRCySdIakM4qCngTcGBEv9RV3IOkoVGC5UtiGuwpH1zwBmJlec4EL0/oO4B8i4lXA4cCZAx2Z08wGhx8EWiOptgbrHODmiJgJ3JyWe3JMGlJ61gDj98lNBK2ZRMT8iNg3IvaJiK+kdRdFxEVFYX4UESdXEncg5IcW1jwqGV1zDnB5ZO4AJkqaGhErI+IegIhYR/bgouygM2Y29Nyc3RpNtQWsOcBl6fNlwLuGOP42POKZWW11z4Pl8Zls+KtkdM0+w0iaARwM/LncQSoZndPMzJpbtQWsXSNiJUB6n9JDuABulHS3pLkDiN/PIaX7fR5mVkb3PFjOUzb8VTK6Zq9hJI0nmw7h7Ih4sdxB+jM6p5nVji9T1kha+gog6XfAbmU2nduP4xwRESskTQFukvRwRNzWj/hExMXAxQCzZs0qm49yqbjoGiyz2pAHubDmUcnomj2GkdRKVrj6aURcPYjpNLN+cgtBazR9FrAi4rietkl6ptA+XdJUYFUP+1iR3ldJuoasLfxtQEXxK5V3E0GzmpP8ZNCaQvfomsBTZKNrfqAkzDzgLElXAocBa9P1ScAPgcUR8Y2hTLSZmQ0/1TYRnAeclj6fxrZDRQMgaZykCYXPwFuAhZXG7w+5iaBZzQk3EbThr8KROecDS4ElwA+Aj6f1RwAfBt6cRsO9T9KJQ3sGZtYbX6eskfRZg9WH84CrJH0UeAJ4L4Ck3YFLIuJEYFfgmlT4aQGuiIjre4s/UIVRBLtcwjKrmZzkQS6sKUTEfLJCVPG64lE5AzizTLzbcSsks4YlDyNoDaaqAlZEPAscW2b9CuDE9HkpcGB/4g9U3hMNm9Wc5FphMzMzs0pV20SwoXgUQbPa29IZrHpxc72TYWZmZjYsNFUBS24iaDYofnnP8nonwczMrEduym6NpKkKWJ5o2MzMzGxkcQ8sazRNVcDa2gerzgkxMzMzM7MRqakKWN1NBF2DZWZmZjZi+NbPGklTFbDcRNDMzMxshHEbQWswTVXAyhcKWG4jaE1A0mxJj0haIumcHsIcnSY9XSTpD0Xrl0l6MG1bMHSpNjMzMxvZqp1ouKF4mHZrFpLywHeB44HlwF2S5kXEQ0VhJgLfA2ZHxBOSppTs5piIWDNUaTYzM6sXN16yRtJUNVhKZ+MmgtYEDgWWRMTSiGgHrgTmlIT5AHB1RDwBEBGrBisx++46frB2bWZmVhW5jaA1mKYqYOXdB8uaxzTgyaLl5WldsX2BnSTdKuluSacWbQvgxrR+brkDSJoraYGkBatXr+4xIQdNn8iuO4we2FmYmZmZjTBuImjWmMo9jiv9ZbcArwWOBcYAf5J0R0Q8ChwREStSs8GbJD0cEbdts7OIi4GLAWbNmtVjrsnn5KYXZmZmZhVqqhqswjDtnS5h2fC3HJhetLwHsKJMmOsj4qXU1+o24ECAiFiR3lcB15A1ORyQnJynzMzMzCrVVAWs7omGfTNow99dwExJe0lqA04G5pWEuRY4UlKLpLHAYcBiSeMkTQCQNA54C7BwoAnJSW52a2Y2gvU1qq0y56ftD0g6pB9xPyMpJO0y8PQNNKbZ4GiqJoJ5NxG0JhERHZLOAm4A8sClEbFI0hlp+0URsVjS9cADQBdwSUQslLQ3cI2y/NACXBER1w80LTmJjq6uak/JzMyGoUpGtQVOAGam12HAhcBhfcWVND1te2KozsdsKDRVASuXarA6fTNoTSAi5gPzS9ZdVLL8NeBrJeuWkpoK1kI+JzZ3+KmFmdkI1T2qLYCkwqi2xQWsOcDlERHAHZImSpoKzOgj7jeBfyRrkVGVcEsLayBN1UQQoCUnOp3JzGpm7cYtLF65rt7JMDOz+qhkVNuewvQYV9I7gaci4v7eDl7JiLduIWiNpqoClqRJkm6S9Fh636lMmFdIuq/o9aKks9O2L0p6qmjbidWkB7JarA63ETSrmQefWsvGLZ31ToaZmdVHJaPa9hSm7PrUb/hc4At9HTwiLo6IWRExa/LkyX0m1qwRVFuDdQ5wc0TMBG5Oy9uIiEci4qCIOIhsSOkNZKOaFXyzsD01iapKS04e5MLMzMysNiod1bZcmJ7W7wPsBdwvaVlaf4+k3QaaSN/5WSOptoA1B7gsfb4MeFcf4Y8FHo+Iv1Z53B7l5RosMzMzsxqpZFTbecCpaTTBw4G1EbGyp7gR8WBETImIGRExg6wgdkhEPD2QBHoUQWs01Rawdk0ZiPQ+pY/wJwM/K1l3VhrS89JyTQwLKmmDC5DPuwbLzMy2N5hDTZs1q4joAAqj2i4GriqMalsY2ZZsQKalwBLgB8DHe4s7xKdgNuT6HEVQ0u+AclW25/bnQOnJxTuBzxetvhD4MlnN7peBrwN/Uy5+RFwMXAwwa9asHktQrsEyGxwRgfyY0IapwRxq2qzZ9TWqbRo98MxK45YJM6P6VJo1jj4LWBFxXE/bJD0jaWpErEzDca7qZVcnAPdExDNF++7+LOkHwG8qS3bP8jlPimo2GCLcDMOGtcEcarpfPnXlvfgyZdWYNWMnTn39jHono6E4T1kjqXYerHnAacB56b23eQxOoaR5YKFwlhZPAhZWmR7yOdHR6VxmVmvOVTbMlRsu+rAKwvQ01HRpXCBrzg7MBdhzzz3LJuTB5Wudn6wqUyaMqncSGoo8ULs1mGoLWOcBV0n6KNks3O8FkLQ7cElEnJiWx5I1rfhYSfyvSjqI7N5tWZnt/Zb3PFhmNfWOA3fn1/evoCuCvC9iNnzVfKjpcgeppDn77z9zdI+JNDOz4a+qAlZEPEs2MmDp+hXAiUXLG4Cdy4T7cDXHLyefE53ug2VWM6/cbQK/vt/NL2zYq2ao6bYK4ppZHYXrha2BVDuKYMNxActscLhvow1zNR9qeigTb2Y9c/9gazTVNhFsOHm5gGVWSzlfuawJRESHpMJw0Xng0sJQ02n7RWQjnZ1INtT0BuAjvcWtw2mYmdkw0HwFLNdgWZOQNBv4NtkN3SURcV6ZMEcD3wJagTURcVSlcStPR/buGiwb7gZ7qGkzqx9foqyRuIBl1oAqmXdH0kTge8DsiHhC0pRK4/ZHrruANeDTMTMzMxsxmq4PVotHEbTm0D1nT0S0A4V5d4p9ALg6Ip4AiIhV/YhbsUITwXC+MjOzBuSW7NZomq6AlXMNljWHnubjKbYvsJOkWyXdLenUfsTtt/kPruw7kJmZmdkI13QFrBYXsKw5VDLvTgvwWuBtwFuBf5a0b4VxkTRX0gJJC1avXt1jQgo1WJ/75YOVpdzMzGyI+c7PGknTFbByEh0uYNnwV+mcPddHxEsRsQa4DTiwwrhExMURMSsiZk2ePLnHhLjphZmZNTZfqKyxNF0BqyUvulzAsuGvknl3rgWOlNQiaSxwGLC4wrgVc42wmZmZWeWabhRB12BZM6hkzp6IWCzpeuABoItsOPaFALWcs+cXdy+v8mzMzMwGl8dhskbSdAWslpw8X481hb7m7EnLXwO+VkncgWrv6KrFbszMzAaFm7Jbo2m6JoL5nOjodAHLrFacm8zMzMwq15QFLNdgmZmZmY0kvvezxtGUBSz3wTKrHU8wbGZmZla5Jixg5TyKoJmZmdkI4S5Y1miqKmBJeq+kRZK6JM3qJdxsSY9IWiLpnKL1kyTdJOmx9L5TNekByAvXYJnVkHOTmZmZWeWqrcFaCLybbILTsiTlge8CJwD7AadI2i9tPge4OSJmAjen5arkcznP22NWQ24haGZmjc7XKmskVRWwImJxRDzSR7BDgSURsTQi2oErgTlp2xzgsvT5MuBd1aQHIJ/zxKhmZmZmI4WHabdGMxR9sKYBTxYtL0/rAHaNiJUA6X1KTzuRNFfSAkkLVq9e3ePB8rkcnX6MYVYz4UaCZmZmZhXrs4Al6XeSFpZ5zekrbmEXZdb1+44tIi6OiFkRMWvy5Mk9hnMNllltnXviq+qdBDMzs175zs8aSUtfASLiuCqPsRyYXrS8B7AifX5G0tSIWClpKrCqymPR4j5YZjX1it12qHcSzMzMeiSPI2gNZiiaCN4FzJS0l6Q24GRgXto2DzgtfT4NuLbag+UkF7DMaijn65aZmZlZxaodpv0kScuB1wPXSbohrd9d0nyAiOgAzgJuABYDV0XEorSL84DjJT0GHJ+Wq9KSdwHLrJZy7j1sZmYNLtz/3hpIn00EexMR1wDXlFm/AjixaHk+ML9MuGeBY6tJQynXYJmZWTFJk4CfAzOAZcD7IuL5MuFmA98G8sAlEXFeWv814B1AO/A48JGIeGEo0m5mZsPPUDQRHFItOXkUQWsKPU3QXbT9aElrJd2XXl8o2rZM0oNp/YJq0pFzG0Eb/vqcc7GPORtvAvaPiAOAR4HPD0mqzawibmhhjaaqGqxGlMtlNVgRgZzjbJgqutk7nmygmLskzYuIh0qC/jEi3t7Dbo6JiDVVp6XaHZjV3xzg6PT5MuBW4HMlYbrnbASQVJiz8aGIuLEo3B3AewYzsWZmNrw1ZQ0WgFsJ2jDX2wTdQ8p9sKwJVDLnYm9zNhb7G+C3PR2o0jkbzay2fNtnjaTpClj5VMDq6Oqqc0rMqlLpzd7rJd0v6beSXl20PoAbJd0taW65A1R6I1jcQvCRp9f14xTMhs5QzNko6VygA/hpTzupdM5Gs+GkgibrknR+2v6ApEP6iivpa5IeTuGvkTRxwOkbaESzQdK0BSyXr2yYq2SC7nuAl0XEgcB3gF8VbTsiIg4h609ypqQ3bbezSm8Ei1LyxHMbKky+2dCKiOMiYv8yr2tJcy4C9DLnYm9zNiLpNODtwAfDw5XZCNJH/8SCE4CZ6TUXuLCCuO7baE2r+QpYcg2WNYVeb/YAIuLFiFifPs8HWiXtkpZXpPdVZCN9HjrQhLiJoDWBSuZc7HHOxjS64OeAd0aEnzLYSFNJk/U5wOWRuQOYmB5m9Bg3Im5MU/lA1rdxj2oS6cce1kiar4DlGixrDr1N0A2ApN2URnKRdChZfn5W0jhJE9L6ccBbgIUDTUhx8coP7m2YKjvnYj/mbLwAmADclEbmvGioT8Csjippst5TmKr7NlbSnN2DmlmjabpRBN0Hy5pBRHRIKtzs5YFLI2KRpDPS9ovIRjL7O0kdwEbg5IgISbsC16QLTgtwRURcP9C0FNdguXhlw1FPcy72Y87Glw9qAs0aWyVN1nsKU3Xfxoi4GLgYYNasWb4M2bDQdAWswpw9ngvLhrtyN3upYFX4fAHZk/XSeEuBA2uVDjcRNDMb0fpsst5LmLbe4hb1bTy22r6NbmFhjaTpmggWhmnv9DjtZrXh8pWZ2UjWZ5P1tHxqGk3wcGBtmhLBfRttRGq6AlZhkAsXsMxqo3iY9ivvfKJ+CTEzsyHXU/9ESWcUmq2TtbZYCiwBfgB8vLe4KY77NlrTaromgnnXYJnVVHHn4VseWc2SVet5+ZTxdUyRmZkNpQqarAdwZqVx03r3bbSm1Xw1WC5gmdVUrqSJ4JZODyBjZmaNxXd91khcwDKzXpUOcnHCt//IfU++UJ/EmJmZlfBYTNZomreA5dFkzAbN5X9aVu8kmJmZmTWkqgpYkt4raZGkLkmzeggzXdItkhansJ8q2vZFSU+lzo33STqx3D76o3serE4XsMxqwcO0m5lZw/NtnzWQage5WAi8G/h+L2E6gH+IiHskTQDulnRTRDyUtn8zIv6jynR0K4wi2OUaLLOaKO2DZWZm1kjk+USswVRVwIqIxbDtKGNlwqwEVqbP6yQtBqYBD/UYqQr5fKrBch8ss5roLX+bmZnV25tW/JC9eQp4a72TYgYMcR8sSTOAg4E/F60+S9IDki6VtFO1xyg0Z3px45Zqd2VmlK/B6qryAcbKtRu5+6/PV7UPMzMzgKkbHua1LK53Msy69VnAkvQ7SQvLvOb050CSxgO/BM6OiBfT6guBfYCDyGq5vt5L/LmSFkhasHr16h6P8+elzwLwhWsX9RjGzCpXrgbrV/et4O3f+SMxwKa4R33tVv7Xhf9TbdLMzMwIcsidsKyB9NlEMCKOq/YgklrJClc/jYiri/b9TFGYHwC/6SUdFwMXA8yaNavHXFQYnv3pFzdVm2wz68XCp15kQ3sn40a18NdnX2LKhNGMactXFLe9w3NpmZlZjQhyLmBZAxn0JoLKHn//EFgcEd8o2Ta1aPEkskEzqrLDmFYAWtwz32zQFaZDOOprtzL3xwsAWPjUWs695sGqmxGamZlVIsjhYQStkVQ7TPtJkpYDrweuk3RDWr+7pPkp2BHAh4E3lxmO/auSHpT0AHAM8PfVpAfgI0fMAOCjb9yr2l2Z1ZWk2ZIekbRE0jllth8taW1RvvpCpXFr5e5lz7NoxVoA/vjYGp5dv5m3f+d2fvrnJ3h89Xr+sual7rAvbGjnEz+7lyef2zBYyTEzs5FIcg2WNZRqRxG8BrimzPoVwInp8+1QfvzMiPhwNccvZ0xrnnxOeJR2G84k5YHvAscDy4G7JM0rmt6g4I8R8fYBxq3Ybz7xRt7+ndu3W/+RH921zfK///bh7s8f/uGdPP3iJpb+24nkcuL7ty3l1/ev4Nf3rxhoMrq9sKGdP//lOd766t2q3peZmQ1zcg2WNZZq58FqOJIY05pn45bOeifFrBqHAksiYimApCuBOVQ2vUE1ccuaNnFMReF+cffy7s+FfpAbtnTy3Pr27v6RxW566Bkmjm3ldTMm9Ss9Z11xL7cvWQPAhR88hBNeM3W7MOff/Bg3LHqamVPG8833H+Th5s3MmpY8yIU1lKYrYAGMdgHLhr9pwJNFy8uBw8qEe72k+4EVwGciYlGlcSXNBeYC7Lnnnr0mJldFn8ZP//w+bnzombLb/vflWb+tdx8yjX13ncCadZu55Pa/APDxo/fh/a+bzst2HtcdfsmqdSx86kX++tzWpod/99N7+P0/HMVeu4zrLkQ9sPwFvnHTowAsWvEi++2+A287YPdtCopPPLuBPzy6ipMO2YPxo1p4aXMHY9vyLoiZmQ0zmzuDtgjaO7poaxnSGYjMymrKAtb4UXnWbvA8WDaslbvLL308dw/wsohYn/o1/gqYWWHcikfmhOoGjempcFXs6nue2m7d9259nO/d+jg7jG7hxU0dvcZ/89f/AMC+u47n0WfWb7f93+Y/zL/Nf5jJE0axet1mxrbl2dCePYT552sXMaolx+Y0suG4tjxnH7cvmzuy7TuObeOQPSfy8injac3lWN/ewbI1L7F+cwcTx7Sx8/g28jmxpbOLjs7sa5w4tpW1G7ewob2TnLL5+YLsvVCTlxMUKvUKX2+Q/fGUwkUEWzoDpX3kc9qmJlDaGj6fEy2peXQQCNHe2QUELbkc+ZyQoCPtr1KdXUFrPkd7Z1f3PIOw9UdWWJWT6IrsuKX7nzSujXGjmvJyY2YN4K/PbeQAdfH7h59h9v7bt2gwG2pNecXbY6exrFi7sd7JMKvGcmB60fIeZLVU3YrmkyMi5kv6nqRdKonbX/k6jsrZV+GqWLnCVbHV6zYDdBeuCjYXDRv/UnsnX5nvCStr6fxTDuadB+5e72SYWZPqQt0Pm8waQVMWsKZMGMXV965h4VNredXUHcjJmc6GnbuAmZL2Ap4CTgY+UBxA0m7AMxERkg4lGxX0WeCFvuL2Vz0LWPU0YVQLF3zwENZu3MJfVmfNEttacuy1yzh2GNPC8y9t4YWN7Wzp6KIln8uapgQ8v6GdCaNb2WFMC11B95D1nV1BS16phqqru7anuFlLoQYqJ5GTaM1ntVJdAR1dXbTkckhbw2Xv0NHZRUdXkFe2z66A1nSsLR1dbO7oymrTcuoO05cIaMmL9o4uJBXVxkX39sJ7Z0qIyvy/PXj6xJr8PczMyhk/qpVcexcTXFNuDaIpf4kTRmenVRj17LTXv4wvzdm/nkky65eI6JB0FnADkAcujYhFks5I2y8C3gP8naQOYCNwckQEUDZuNenJ1+kBxTsP3J33v246M6eM56X2To75j1t7DHvtmUfw24VPs8v4Nv7vddvXQO21yzhu/vRRdHQFG7d0ct0DK/mnax4E4L4vHM+CZc/T1pLjgD12pK0lx6YtXUwa1zZYp2ZDSNIk4OfADGAZ8L6IeL5MuNnAt8nyzSURcV7J9s8AXwMmR8SaQU62mVXo1dMm0vkXKp7s3mywNWUBK5/btoPjZX/6K19856tdi2XDSkTMB+aXrLuo6PMFwAWVxq1GNYNcDNQrd5vAv7/7Ndv03Vl23tv48Z+W8c/XbltenDxhFAdOn8iBqabk8L133mZY+RP2340LP/RaANpyoq0lxymHTmf1us286+DdmTi2jeP223WbfY512aqZnAPcHBHnpXnhzgE+Vxygr+kNJE1P254Y0pSbWZ9yuRxdCro8R481iKYcamVsmScYa9a31yElZtYf3znlYJZ85QSWnfc2rj/7TWUHRvjw62fwsp3HbrPuhrPftM3y/tN25HsfPKR7+avvOWC7/UjiU8fN3GaUQmtac4DL0ufLgHeVCdM9vUFEtAOF6Q0Kvgn8I55sx6zhSDlEcO8TL9Q7KWZAkxawPnj4ntvN2/Mv8xbyszuf4E1fvYVf37+CC299HIDPX/0gZ15xTz2SaWbJ2cfNBGDvyeNoyff9b+n6T72JthTuq+85oGxTvhNfM5VTX/8yzjhqHyaMbq1tgm242TUiVgKk9yllwpSb3mAagKR3Ak9FxP19HUjSXEkLJC1YvXp19Sk3sz5t7uwiR5RtHm5WD03ZRHDqjmP473PezIxzruteN//Bp5n/4NMAfOJn9wIwbacx/OzOrLXHN97XyagWt901q4e/PXJv3jtresUTGo9py3PLZ4/mnr8+zzt6GZ3uX933csSQ9DtgtzKbzq10F2XWhaSxaR9vqWQn/Zn+wMxqQ7m8Jxq2htKUBayC+75wPJ+88j5ue7T8U8RPpoIWwCv+z/V855SDueXhVXzsqH3oiuBVU3cYqqSajWht+VzFhauCaRPH9DuONa+IOK6nbZKekTQ1IlZKmgqsKhOsp+kN9gH2Au5P/Xj3AO6RdGhEPF2zEzCzKoicC1jWQJq6gDVxbBuX/82hdHUFr/vK73j2pd77YRVqtq6+d+ukp996/0F8/aZH+Pvj9uXdh+wxqOk1G47OOublXHDLkorD7zd1Bx5amU3hNX5UC+s3d9Ca9wA0NqjmAacB56X3a8uEKTs1QhqBs7tJoaRlwCyPImjWOLIpzrv6Dmg2RJqyD1apXE7c/c/Hs+y8t/FfZ7wegI+9ae/u7aNbe/4azv75fTz53EY+fdX9zDjnOr7/h8e757QxG+k+dPiefOLYl/caZt5ZRzB90hjeeeDunH/KwVz3yTd2b/v08fuy7Ly3eYRPG2znAcdLeoxsJMDzACTtLmk+ZFMjAIXpDRYDV1U7vYGZDY221pbuNr5HfvX3dU2LGTR5DVY5r5sxifu/8BbGj27hH2e/kqdf3MTUHUZz57LnOPXSO2nv6CKfE1+esz9f+vUiNnds+0Tk33/7MP/+24c5cuYu7DN5PK/efQfe89o9fINoTe+nf3sYH7zkz9usy0vb9F289swj2Lilk5MvvoP9p+3Abz5xJAB//Mc3D2lazYpFxLPAsWXWrwBOLFruc3qDiJhR6/SZWXV223EMz6UarCef21jn1JiNwAIWwI5jt44oVujDcfjeO/PIl2cDdBeWPnDYnvzPkjWsXr+ZfSaP5+bFq/jNAyt4bNV6/vjYGv74WNZC5LO/eIAjZ+7C3x65Nxs2d3DoXpPYefyo7mNEBCvXbmLZmpcY1ZrjtS+bNFSnalYzuTIPEUrnxzpw+kTWbdoCwFnHzBySdJmZ2Qin3Daj1Kzf3MH4MtN8mA0V//qKlKuFesPLd+n+vP+0HflUGk76xkVP85M/P8HajVtYWlLggqzTfntnF6NactvVgs162U785G8PY3Tr1if/XV1Rl8lczapRKHR98/0H8vDKdQBMGN3KsvPeVlF8970yM7PqiVxRH6z9/+UGHv7y7G3us8yGUlUFLEnvBb4IvAo4NCIW9BBuGbAO6AQ6ImJWWj8J+DkwA1gGvC8inq8mTUPlLa/ejbe8euuIwHcte47v3rIEAbc8spr2ziyjlxauABb89XnecN7v2WV8G8e8cgodncEPb//LNmFOOXRP/uUd+/mfgzWMyRO2n2uq8EzgpIP3gIMr39erpu7A4pUvMmWH0TVKnZmZjVglNVgAtz+2huP227UuyTGrtgZrIfBu4PsVhD2mzKhL5wA3R8R5ks5Jy5+rMk118boZk/jRRw7dbn1EbFMz1tUVXHDLEv70+LP8aemzPPrM+rL7+9mdT3TP0fXGl+/Cy6eMZ979Kzhq38kcsudEjn3VruzuIaptCL18yoTt1pVrNliJX591BLc8sprjXlVuvlczM7N+kLabB+tTV97Lon+dXacE2UhXVQErIhZD+aZ1FZoDHJ0+XwbcyjAtYPWk9LvJ5cQnj53JJ4/Nmho+91I7E0a30JKqAl7c1MHo1hy/e2gV/3X3k9yx9FluX7KG25dkZdNr7n2Ka+59in++Nhvc6o0v34WZu45n9x3H8OHXv8w1XjaoLjl1Fn97+daK6oE2a23J5zjeTxbNzKwW2sYxQRu54H37cdZVDwHwUnsn3//D43zsqH3qnDgbiYaqD1YAN0oK4PtppnuAXSNiJUCaALLHx9mS5gJzAfbcc8/BTu+QmTRu22ZXO47JBuB42wFTedsBUwHo6OzijqXPkc+JfE585/eP8T+PP0tnV3DvE893F77+/beLeeurd+OVu+3A2LY8Y9ryjGnNl3xu4cVNW5g4tpW2fI5cTrTmcoxuzTGqJU9ri2jL58jn5JERbTvFzS3e+PJdmHvk3r2ENjOzZiBpNvBtIA9cEhHnlWxX2n4isAE4PSLu6S1uTbuJ7PpqAN6+yyrGnf46PvKju4Bs5GcXsKwe+ixgSfodsFuZTedGRLnJGss5IiJWpALUTZIejojb+pPQVCi7GGDWrFkjaiKqlnyON87cOtjGjz96WPfniOD5DVv4yR1/5Ye3/4WbF6/itwufrslxW/NixzFtjGrZWgAb1ZpjdEtWYFu3aQs7jxvFuFEttLWIHUa30prP5hQb05YnJzF+dAvjR+UZP6qVcaPyjB/VQk5idGueUS1ZAW9cKgDmJRfsivR1QSsK9zrgDuD9EfGLtG4ZZfo91tJ3P3hI9wMBMzNrTpLywHfJ5pBbDtwlaV5EPFQU7ARgZnodBlwIHNZH3Np1E9k5zcf45B0cc8RhXH/2kcz+1h8B+OZNj3L2cTN9b2FDqs8CVkQcV+1B0lwjRMQqSdcAhwK3Ac9Imppqr6YCq6o91kgjiUnj2rZpdtjR2cWGLZ1sau9kQ3pt3NLJxvZONrR3sKmji03tnYxpy9MVQUdnsGFLJ+0dXWzp7GJLRxfPvtRORLClK9i0pZPNHV1sTu/rNnWwduMWHlr5IjuOaWVcW57nNrTT1UX34B4DlROMa2thVGuOllyOiWNb6Ypg3KgWRrfkaWvJ0ZrPMaolR2te3cvjR7cwKp99/p/Hn6W1JceJ++9GV0BnBJ2dXTz3UjvjRrWw8/hRdEWQk9jQ3sGa9e3stctYWnI5NrR38L5Z0+v+j7jCC1oh3P8jmxy1VLl+j1UrjJCZ96iXZmYjwaHAkohYCiDpSrIuHsXXoznA5RERwB2SJqb7uhm9xK1dN5FJqTXFTV8A4JUT92T+WzfztZuW8MAt9/J3t+WYOnE0a9ZvZnNHF2+cOZnRLXnaOzoZk+45RrfkyeVEV1fQGUEOGN2aJ8geZre25GjL5dg6moYo3CoUVtXi1qGSXfR9HF+f+yuXb+E1R727Zvsb9CaCksYBuYhYlz6/BfjXtHkecBpwXnqvtEbMetGSz7FDPscOo4e2diEi6AroiuClzR0Isbmzk/WbOnhpcyfrNm9h9brNjG7NsykV6P6y5iVa8jlac6IzgvWbOmjv7OKlzZ0EwTMvbmJ0S571mzvo6Opiw8YsXntHJ1s6I/vc2dUdr9htj64e0Hm87YDdG2H+jEouaACfAH4JvG6oEvaxo/bmO79fQluqrTQzs6Y2DXiyaHk5WS1VX2Gm9RG3om4iFXURaRsHB34A7r+iu5C1H/Cfxb0w1qX3PLC0/G5s5FofY6BRCliSTgK+A0wGrpN0X0S8VdLuZE2aTgR2Ba5JNQItwBURcX3axXnAVZI+CjwBvLea9Fh9SSIvyCMmji38V2ulzOBzg6K9o4v1mzu6a9xacqIlnzU7bMnl6Ojq4sWNW2jJ5bqf/rywYQtdEUwc20Z7RxfjR7cwtjEGCunzgiZpGnAS8Ga2L2D11O+xOP6A+jV+4s0zOf0NM2hrcQHLzGwEKFcdUtpVo6cwlcTtVcVdRE66EI7/Emx6ETo3Q2c7RJlWNQFB9kBYwJbOLoJgQ3snAPlcjrxEe2cnWzoC5bLWNZu3dHU/yI2iVEQ/zqavsFHBVzOi+sgMISnHzBrur9pRBK8BrimzfgVZR0fSE/gDe4j/LHBsNWkwK2hryTGpZfu5mopN
Download .txt
gitextract_97iggw1v/

├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── cpa/
│   ├── __init__.py
│   ├── api.py
│   ├── data.py
│   ├── helper.py
│   ├── model.py
│   ├── plotting.py
│   └── train.py
├── datasets/
│   └── .gitkeep
├── notebooks/
│   └── demo.ipynb
├── preprocessing/
│   ├── GSM.ipynb
│   ├── Norman19.ipynb
│   ├── cross_species.ipynb
│   ├── lincs.ipynb
│   ├── pachter.ipynb
│   ├── sciplex3.ipynb
│   └── sciplex3_round_robin.ipynb
├── pretrained_models/
│   ├── .gitattributes
│   └── .gitkeep
├── requirements.txt
├── scripts/
│   ├── .gitkeep
│   ├── run_collect_results.sh
│   ├── run_one_epoch.sh
│   └── run_sweeps.sh
├── setup.py
└── tests/
    └── test.py
Download .txt
SYMBOL INDEX (103 symbols across 7 files)

FILE: cpa/api.py
  class API (line 23) | class API:
    method __init__ (line 28) | def __init__(
    method load_from_old (line 196) | def load_from_old(self, pretrained):
    method print_args (line 210) | def print_args(self):
    method load (line 213) | def load(self, pretrained):
    method train (line 226) | def train(
    method save (line 355) | def save(self, filename):
    method _init_pert_embeddings (line 366) | def _init_pert_embeddings(self):
    method get_drug_embeddings (line 378) | def get_drug_embeddings(self, dose=1.0, return_anndata=True):
    method _init_covars_embeddings (line 409) | def _init_covars_embeddings(self):
    method get_covars_embeddings_combined (line 438) | def get_covars_embeddings_combined(self, return_anndata=True):
    method get_covars_embeddings (line 456) | def get_covars_embeddings(self, covars_tgt, return_anndata=True):
    method _get_drug_encoding (line 477) | def _get_drug_encoding(self, drugs, doses=None):
    method mix_drugs (line 506) | def mix_drugs(self, drugs_list, doses_list=None, return_anndata=True):
    method latent_dose_response (line 550) | def latent_dose_response(
    method latent_dose_response2D (line 613) | def latent_dose_response2D(
    method compute_comb_emb (line 687) | def compute_comb_emb(self, thrh=30):
    method compute_uncertainty (line 746) | def compute_uncertainty(self, cov, pert, dose, thrh=30):
    method predict (line 803) | def predict(
    method get_latent (line 973) | def get_latent(
    method get_response (line 1073) | def get_response(
    method get_response_reference (line 1164) | def get_response_reference(self, perturbations=None):
    method get_response2D (line 1222) | def get_response2D(
    method evaluate_r2 (line 1332) | def evaluate_r2(self, dataset, genes_control, adata_random=None):
  function get_reference_from_combo (line 1432) | def get_reference_from_combo(perturbations_list, datasets, splits=["trai...
  function linear_interp (line 1466) | def linear_interp(y1, y2, x1, x2, x):
  function evaluate_r2_benchmark (line 1473) | def evaluate_r2_benchmark(cpa_api, datasets, pert_category, pert_categor...

FILE: cpa/data.py
  function ranks_to_df (line 18) | def ranks_to_df(data, key="rank_genes_groups"):
  function check_adata (line 42) | def check_adata(adata, special_fields):
  class Dataset (line 65) | class Dataset:
    method __init__ (line 66) | def __init__(
    method subset (line 292) | def subset(self, split, condition="all"):
    method __getitem__ (line 296) | def __getitem__(self, i):
    method __len__ (line 303) | def __len__(self):
  class SubDataset (line 307) | class SubDataset:
    method __init__ (line 312) | def __init__(self, dataset, indices):
    method __getitem__ (line 339) | def __getitem__(self, i):
    method subset_condition (line 346) | def subset_condition(self, control=True):
    method __len__ (line 350) | def __len__(self):
  function load_dataset_splits (line 354) | def load_dataset_splits(

FILE: cpa/helper.py
  function _convert_mean_disp_to_counts_logits (line 21) | def _convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6):
  function rank_genes_groups_by_cov (line 45) | def rank_genes_groups_by_cov(
  function rank_genes_groups (line 136) | def rank_genes_groups(
  function evaluate_r2_ (line 254) | def evaluate_r2_(adata, pred_adata, condition_key, sampled=False, de_gen...
  function evaluate_mmd (line 295) | def evaluate_mmd(adata, pred_adata, condition_key, de_genes_dict=None):
  function evaluate_emd (line 321) | def evaluate_emd(adata, pred_adata, condition_key, de_genes_dict=None):
  function pairwise_distance (line 354) | def pairwise_distance(x, y):
  function gaussian_kernel_matrix (line 362) | def gaussian_kernel_matrix(x, y, alphas):
  function mmd_loss_calc (line 387) | def mmd_loss_calc(source_features, target_features):

FILE: cpa/model.py
  class NBLoss (line 11) | class NBLoss(torch.nn.Module):
    method __init__ (line 12) | def __init__(self):
    method forward (line 15) | def forward(self, mu, y, theta, eps=1e-8):
  function _nan2inf (line 46) | def _nan2inf(x):
  class MLP (line 49) | class MLP(torch.nn.Module):
    method __init__ (line 54) | def __init__(self, sizes, batch_norm=True, last_layer_act="linear"):
    method forward (line 77) | def forward(self, x):
  class GeneralizedSigmoid (line 85) | class GeneralizedSigmoid(torch.nn.Module):
    method __init__ (line 91) | def __init__(self, dim, device, nonlin="sigmoid"):
    method forward (line 107) | def forward(self, x):
    method one_drug (line 117) | def one_drug(self, x, i):
  class CPA (line 128) | class CPA(torch.nn.Module):
    method __init__ (line 133) | def __init__(
    method set_hparams_ (line 286) | def set_hparams_(self, hparams):
    method move_inputs_ (line 322) | def move_inputs_(self, genes, drugs, covariates):
    method compute_drug_embeddings_ (line 334) | def compute_drug_embeddings_(self, drugs):
    method predict (line 349) | def predict(
    method early_stopping (line 403) | def early_stopping(self, score):
    method update (line 419) | def update(self, genes, drugs, covariates):
    method defaults (line 509) | def defaults(self):

FILE: cpa/plotting.py
  class CPAVisuals (line 23) | class CPAVisuals:
    method __init__ (line 30) | def __init__(
    method plot_latent_embeddings (line 90) | def plot_latent_embeddings(
    method plot_contvar_response2D (line 165) | def plot_contvar_response2D(
    method plot_contvar_response (line 274) | def plot_contvar_response(
    method plot_scatter (line 366) | def plot_scatter(
  function log10_with0 (line 435) | def log10_with0(x):
  function get_palette (line 441) | def get_palette(n_colors, palette_name="Set1"):
  function fast_dimred (line 454) | def fast_dimred(emb, method="KernelPCA"):
  function plot_dose_response (line 478) | def plot_dose_response(
  function plot_uncertainty_comb_dose (line 640) | def plot_uncertainty_comb_dose(
  function plot_uncertainty_dose (line 771) | def plot_uncertainty_dose(
  function save_to_file (line 882) | def save_to_file(fig, file_name, file_format=None):
  function plot_embedding (line 897) | def plot_embedding(
  function get_colors (line 1011) | def get_colors(labels, palette=None, palette_name=None):
  function plot_similarity (line 1019) | def plot_similarity(
  function mean_plot (line 1074) | def mean_plot(
  function plot_r2_matrix (line 1211) | def plot_r2_matrix(pred, adata, de_genes=None, **kwds):
  function arrange_history (line 1268) | def arrange_history(history):
  class CPAHistory (line 1273) | class CPAHistory:
    method __init__ (line 1278) | def __init__(self, cpa_api, fileprefix=None):
    method print_time (line 1326) | def print_time(self):
    method plot_losses (line 1329) | def plot_losses(self, filename=None):
    method plot_r2_metrics (line 1363) | def plot_r2_metrics(self, epoch_min=0, filename=None):
    method plot_disentanglement_metrics (line 1409) | def plot_disentanglement_metrics(self, epoch_min=0, filename=None):

FILE: cpa/train.py
  function pjson (line 19) | def pjson(s):
  function _convert_mean_disp_to_counts_logits (line 25) | def _convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6):
  function evaluate_disentanglement (line 48) | def evaluate_disentanglement(autoencoder, dataset):
  function evaluate_r2 (line 117) | def evaluate_r2(autoencoder, dataset, genes_control):
  function evaluate (line 199) | def evaluate(autoencoder, datasets):
  function prepare_cpa (line 243) | def prepare_cpa(args, state_dict=None):
  function train_cpa (line 277) | def train_cpa(args, return_model=False):
  function parse_arguments (line 368) | def parse_arguments():

FILE: tests/test.py
  function sim_adata (line 10) | def sim_adata():
Condensed preview — 30 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (3,055K chars).
[
  {
    "path": ".gitignore",
    "chars": 55,
    "preview": "__pycache__\n*.pyc\n*.egg-info/\n*.ipynb_checkpoints/\n*.pt"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 244,
    "preview": "# Code of Conduct\n\nFacebook has adopted a Code of Conduct that we expect project participants to adhere to.\nPlease read "
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 1248,
    "preview": "# Contributing to `CPA` \nWe want to make contributing to this project as easy and transparent as\npossible.\n\n## Pull Requ"
  },
  {
    "path": "LICENSE",
    "chars": 1090,
    "preview": "The MIT License\n\nCopyright (c) Facebook, Inc. and its affiliates.\n\nPermission is hereby granted, free of charge, to any "
  },
  {
    "path": "README.md",
    "chars": 3995,
    "preview": "# CPA - Compositional Perturbation Autoencoder\n\n# This code in not being maintained anymore, please use the new implemen"
  },
  {
    "path": "cpa/__init__.py",
    "chars": 132,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\nfrom cpa.api import API\nfrom cpa.plotting import"
  },
  {
    "path": "cpa/api.py",
    "chars": 56512,
    "preview": "import copy\nimport itertools\nimport os\nimport pprint\nimport time\nfrom collections import defaultdict\nfrom typing import "
  },
  {
    "path": "cpa/data.py",
    "chars": 13673,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\nimport warnings\n\nimport numpy as np\nimport torch"
  },
  {
    "path": "cpa/helper.py",
    "chars": 13590,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\nimport warnings\n\nimport numpy as np\nimport panda"
  },
  {
    "path": "cpa/model.py",
    "chars": 18503,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\nfrom http.client import RemoteDisconnected\nimpor"
  },
  {
    "path": "cpa/plotting.py",
    "chars": 45462,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\nfrom collections import defaultdict\n\nimport matp"
  },
  {
    "path": "cpa/train.py",
    "chars": 14255,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\nimport argparse\nimport json\nimport os\nimport tim"
  },
  {
    "path": "datasets/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "notebooks/demo.ipynb",
    "chars": 726674,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# A tour of the CPA model\"\n   ]\n  }"
  },
  {
    "path": "preprocessing/GSM.ipynb",
    "chars": 1922268,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": "
  },
  {
    "path": "preprocessing/Norman19.ipynb",
    "chars": 52957,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\":"
  },
  {
    "path": "preprocessing/lincs.ipynb",
    "chars": 27413,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\":"
  },
  {
    "path": "preprocessing/pachter.ipynb",
    "chars": 48280,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\":"
  },
  {
    "path": "preprocessing/sciplex3.ipynb",
    "chars": 39478,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"satellite-immigration\",\n   \"metadata\": {},\n"
  },
  {
    "path": "preprocessing/sciplex3_round_robin.ipynb",
    "chars": 23541,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\":"
  },
  {
    "path": "pretrained_models/.gitattributes",
    "chars": 69,
    "preview": "Norman2010_prep_new_deg_collect/ filter=lfs diff=lfs merge=lfs -text\n"
  },
  {
    "path": "pretrained_models/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "requirements.txt",
    "chars": 91,
    "preview": "adjustText\nargparse\nmatplotlib\nnumpy\npandas\nscanpy \nscipy \nseaborn \nsklearn\nsubmitit\ntorch\n"
  },
  {
    "path": "scripts/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "scripts/run_collect_results.sh",
    "chars": 2146,
    "preview": "#!/bin/bash\n\npython -m cpa.collect_results     --save_dir /checkpoint/$USER/sweep_GSM_new_logsigm\npython -m cpa.collect_"
  },
  {
    "path": "scripts/run_one_epoch.sh",
    "chars": 849,
    "preview": "#!/bin/bash\n\n# change the path vairable to your path to the datasets folder\npath='../cpa_binaries'\n\npython -m cpa.train "
  },
  {
    "path": "scripts/run_sweeps.sh",
    "chars": 3602,
    "preview": "#!/bin/bash\n\n# rm -rf /checkpoint/$USER/sweep_GSM_2k_hvg\n# rm -rf /checkpoint/$USER/sweep_GSM_4k_hvg\n# rm -rf /checkpoin"
  },
  {
    "path": "setup.py",
    "chars": 542,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\nfrom distutils.core import setup\n\next_modules = "
  },
  {
    "path": "tests/test.py",
    "chars": 885,
    "preview": "import sys\n\nsys.path.append(\"../\")\nimport cpa\nimport scanpy as sc\nimport scvi\nfrom cpa.helper import rank_genes_groups_b"
  }
]

// ... and 1 more files (download for full content)

About this extraction

This page contains the full source code of the facebookresearch/CPA GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 30 files (36.5 MB), approximately 755.7k tokens, and a symbol index with 103 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!