Showing preview only (3,023K chars total). Download the full file or copy to clipboard to get everything.
Repository: facebookresearch/CPA
Branch: main
Commit: 50e283a7a1b3
Files: 30
Total size: 36.5 MB
Directory structure:
gitextract_97iggw1v/
├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── cpa/
│ ├── __init__.py
│ ├── api.py
│ ├── data.py
│ ├── helper.py
│ ├── model.py
│ ├── plotting.py
│ └── train.py
├── datasets/
│ └── .gitkeep
├── notebooks/
│ └── demo.ipynb
├── preprocessing/
│ ├── GSM.ipynb
│ ├── Norman19.ipynb
│ ├── cross_species.ipynb
│ ├── lincs.ipynb
│ ├── pachter.ipynb
│ ├── sciplex3.ipynb
│ └── sciplex3_round_robin.ipynb
├── pretrained_models/
│ ├── .gitattributes
│ └── .gitkeep
├── requirements.txt
├── scripts/
│ ├── .gitkeep
│ ├── run_collect_results.sh
│ ├── run_one_epoch.sh
│ └── run_sweeps.sh
├── setup.py
└── tests/
└── test.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__
*.pyc
*.egg-info/
*.ipynb_checkpoints/
*.pt
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Code of Conduct
Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
Please read the [full text](https://code.fb.com/codeofconduct/)
so that you can understand what actions will and will not be tolerated.
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to `CPA`
We want to make contributing to this project as easy and transparent as
possible.
## Pull Requests
We actively welcome your pull requests.
1. Fork the repo and create your branch from `master`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.
Complete your CLA here: <https://code.facebook.com/cla>
## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.
## License
By contributing to `CPA`, you agree that your contributions
will be licensed under the LICENSE file in the root directory of this source
tree.
================================================
FILE: LICENSE
================================================
The MIT License
Copyright (c) Facebook, Inc. and its affiliates.
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
================================================
FILE: README.md
================================================
# CPA - Compositional Perturbation Autoencoder
# This code in not being maintained anymore, please use the new implementation [here](https://github.com/theislab/cpa).
## What is CPA?

`CPA` is a framework to learn effects of perturbations at the single-cell level. CPA encodes and learns phenotypic drug response across different cell types, doses and drug combinations. CPA allows:
* Out-of-distribution predicitons of unseen drug combinations at various doses and among different cell types.
* Learn interpretable drug and cell type latent spaces.
* Estimate dose response curve for each perturbation and their combinations.
* Access the uncertainty of the estimations of the model.
## Package Structure
The repository is centered around the `cpa` module:
* [`cpa.train`](cpa/train.py) contains scripts to train the model.
* [`cpa.api`](cpa/api.py) contains user friendly scripts to interact with the model via scanpy.
* [`cpa.plotting`](cpa/plotting.py) contains scripts to plotting functions.
* [`cpa.model`](cpa/model.py) contains modules of cpa model.
* [`cpa.data`](cpa/data.py) contains data loader, which transforms anndata structure to a class compatible with cpa model.
Additional files and folders:
* [`datasets`](datasets/) contains both versions of the data: raw and pre-processed.
* [`preprocessing`](preprocessing/) contains notebooks to reproduce the datasets pre-processing from raw data.
## Usage
- As a first step, download the contents of `datasets/` and `pretrained_models/` from [this tarball](https://dl.fbaipublicfiles.com/dlp/cpa_binaries.tar).
To learn how to use this repository, check
[`./notebooks/demo.ipynb`](notebooks/demo.ipynb), and the following scripts:
* Note that hyperparameters in the `demo.ipynb` are set as default but might not work work for new datasets.
## Examples and Reproducibility
you can find more example and hyperparamters tuning scripts and also reproducbility notebooks for the plots in the paper in the [`reproducibility`](https://github.com/theislab/cpa-reproducibility) repo.
## Curation of your own data to train CPA
* To prepare your data to train CPA, you need to add specific fields to adata object and perfrom data split. Examples on how to add
necessary fields for multiple datasets used in the paper can be found in [`preprocessing/`](/https://github.com/facebookresearch/CPA/tree/master/preprocessing) folder.
## Training a model
There are two ways to train a cpa model:
* Using the command line, e.g.: `python -m cpa.train --data datasets/GSM_new.h5ad --save_dir /tmp --max_epochs 1 --doser_type sigm`
* From jupyter notebook: example in [`./notebooks/demo.ipynb`](notebooks/demo.ipynb)
## Documentation
Currently you can access the documentation via `help` function in IPython. For example:
```python
from cpa.api import API
help(API)
from cpa.plotting import CPAVisuals
help(CPAVisuals)
```
A separate page with the documentation is coming soon.
## Support and contribute
If you have a question or noticed a problem, you can post an [`issue`](https://github.com/facebookresearch/CPA/issues/new).
## Reference
Please cite the following publication if you find CPA useful in your research.
```
@article{lotfollahi2023predicting,
title={Predicting cellular responses to complex perturbations in high-throughput screens},
author={Lotfollahi, Mohammad and Klimovskaia Susmelj, Anna and De Donno, Carlo and Hetzel, Leon and Ji, Yuge and Ibarra, Ignacio L and Srivatsan, Sanjay R and Naghipourfar, Mohsen and Daza, Riza M and Martin, Beth and others},
journal={Molecular Systems Biology},
pages={e11517},
year={2023}
}
```
The paper titled **Predicting cellular responses to complex perturbations in high-throughput screens** can be found [here](https://www.biorxiv.org/content/10.1101/2021.04.14.439903v2](https://www.embopress.org/doi/full/10.15252/msb.202211517).
## License
This source code is released under the MIT license, included [here](LICENSE).
================================================
FILE: cpa/__init__.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from cpa.api import API
from cpa.plotting import CPAVisuals
================================================
FILE: cpa/api.py
================================================
import copy
import itertools
import os
import pprint
import time
from collections import defaultdict
from typing import Optional, Union, Tuple
import numpy as np
import pandas as pd
import scanpy as sc
import torch
from torch.distributions import (
NegativeBinomial,
Normal
)
from cpa.train import evaluate, prepare_cpa
from cpa.helper import _convert_mean_disp_to_counts_logits
from sklearn.metrics import r2_score
from sklearn.metrics.pairwise import cosine_distances, euclidean_distances
from tqdm import tqdm
class API:
"""
API for CPA model to make it compatible with scanpy.
"""
def __init__(
self,
data,
perturbation_key="condition",
covariate_keys=["cell_type"],
split_key="split",
dose_key="dose_val",
control=None,
doser_type="mlp",
decoder_activation="linear",
loss_ae="gauss",
patience=200,
seed=0,
pretrained=None,
device="cuda",
save_dir="/tmp/", # directory to save the model
hparams={},
only_parameters=False,
):
"""
Parameters
----------
data : str or `AnnData`
AnndData object or a full path to the file in the .h5ad format.
covariate_keys : list (default: ['cell_type'])
List of names in the .obs of AnnData that should be used as
covariates.
split_key : str (default: 'split')
Name of the column in .obs of AnnData to use for splitting the
dataset into train, test and validation.
perturbation_key : str (default: 'condition')
Name of the column in .obs of AnnData to use for perturbation
variable.
dose_key : str (default: 'dose_val')
Name of the column in .obs of AnnData to use for continious
covariate.
doser_type : str (default: 'mlp')
Type of the nonlinearity in the latent space for the continious
covariate encoding: sigm, logsigm, mlp.
decoder_activation : str (default: 'linear')
Last layer of the decoder.
loss_ae : str (default: 'gauss')
Loss (currently only gaussian loss is supported).
patience : int (default: 200)
Patience for early stopping.
seed : int (default: 0)
Random seed.
pretrained : str (default: None)
Full path to the pretrained model.
only_parameters : bool (default: False)
Whether to load only arguments or also weights from pretrained model.
save_dir : str (default: '/tmp/')
Folder to save the model.
device : str (default: 'cpu')
Device for model computations. If None, will try to use CUDA if
available.
hparams : dict (default: {})
Parameters for the architecture of the CPA model.
control: str
Obs columns with booleans that identify control. If it is not provided
the model will look for them in adata.obs["control"]
"""
args = locals()
del args["self"]
if not (pretrained is None):
state, self.used_args, self.history = torch.load(
pretrained, map_location=torch.device(device)
)
self.args = self.used_args
self.args["data"] = data
self.args["covariate_keys"] = covariate_keys
self.args["device"] = device
self.args["control"] = control
if only_parameters:
state = None
print(f"Loaded ARGS of the model from:\t{pretrained}")
else:
print(f"Loaded pretrained model from:\t{pretrained}")
else:
state = None
self.args = args
self.model, self.datasets = prepare_cpa(self.args, state_dict=state)
if not (pretrained is None) and (not only_parameters):
self.model.history = self.history
self.args["save_dir"] = save_dir
self.args["hparams"] = self.model.hparams
if not (save_dir is None):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
dataset = self.datasets["training"]
self.perturbation_key = dataset.perturbation_key
self.dose_key = dataset.dose_key
self.covariate_keys = covariate_keys # very important, specifies the order of
# covariates during training
self.min_dose = dataset.drugs[dataset.drugs > 0].min().item()
self.max_dose = dataset.drugs[dataset.drugs > 0].max().item()
self.var_names = dataset.var_names
self.unique_perts = list(dataset.perts_dict.keys())
self.unique_covars = {}
for cov in dataset.covars_dict:
self.unique_covars[cov] = list(dataset.covars_dict[cov].keys())
self.num_drugs = dataset.num_drugs
self.perts_dict = dataset.perts_dict
self.covars_dict = dataset.covars_dict
self.drug_ohe = torch.Tensor(list(dataset.perts_dict.values()))
self.covars_ohe = {}
for cov in dataset.covars_dict:
self.covars_ohe[cov] = torch.LongTensor(
list(dataset.covars_dict[cov].values())
)
self.emb_covars = {}
for cov in dataset.covars_dict:
self.emb_covars[cov] = None
self.emb_perts = None
self.seen_covars_perts = None
self.comb_emb = None
self.control_cat = None
self.seen_covars_perts = {}
for k in self.datasets.keys():
self.seen_covars_perts[k] = np.unique(self.datasets[k].pert_categories)
self.measured_points = {}
self.num_measured_points = {}
for k in self.datasets.keys():
self.measured_points[k] = {}
self.num_measured_points[k] = {}
for pert in np.unique(self.datasets[k].pert_categories):
num_points = len(np.where(self.datasets[k].pert_categories == pert)[0])
self.num_measured_points[k][pert] = num_points
*cov_list, drug, dose = pert.split("_")
cov = "_".join(cov_list)
if not ("+" in dose):
dose = float(dose)
if cov in self.measured_points[k].keys():
if drug in self.measured_points[k][cov].keys():
self.measured_points[k][cov][drug].append(dose)
else:
self.measured_points[k][cov][drug] = [dose]
else:
self.measured_points[k][cov] = {drug: [dose]}
self.measured_points["all"] = copy.deepcopy(self.measured_points["training"])
for cov in self.measured_points["ood"].keys():
for pert in self.measured_points["ood"][cov].keys():
if pert in self.measured_points["training"][cov].keys():
self.measured_points["all"][cov][pert] = (
self.measured_points["training"][cov][pert].copy()
+ self.measured_points["ood"][cov][pert].copy()
)
else:
self.measured_points["all"][cov][pert] = self.measured_points[
"ood"
][cov][pert].copy()
def load_from_old(self, pretrained):
"""
Parameters
----------
pretrained : str
Full path to the pretrained model.
"""
print(f"Loaded pretrained model from:\t{pretrained}")
state, self.used_args, self.history = torch.load(
pretrained, map_location=torch.device(self.args["device"])
)
self.model.load_state_dict(state_dict)
self.model.history = self.history
def print_args(self):
pprint.pprint(self.args)
def load(self, pretrained):
"""
Parameters
----------
pretrained : str
Full path to the pretrained model.
""" # TODO fix compatibility
print(f"Loaded pretrained model from:\t{pretrained}")
state, self.used_args, self.history = torch.load(
pretrained, map_location=torch.device(self.args["device"])
)
self.model.load_state_dict(state_dict)
def train(
self,
max_epochs=1,
checkpoint_freq=20,
run_eval=False,
max_minutes=60,
filename="model.pt",
batch_size=None,
save_dir=None,
seed=0,
):
"""
Parameters
----------
max_epochs : int (default: 1)
Maximum number epochs for training.
checkpoint_freq : int (default: 20)
Checkoint frequencty to save intermediate results.
run_eval : bool (default: False)
Whether or not to run disentanglement and R2 evaluation during training.
max_minutes : int (default: 60)
Maximum computation time in minutes.
filename : str (default: 'model.pt')
Name of the file without the directoty path to save the model.
Name should be with .pt extension.
batch_size : int, optional (default: None)
Batch size for training. If None, uses default batch size specified
in hparams.
save_dir : str, optional (default: None)
Full path to the folder to save the model. If None, will use from
the path specified during init.
seed : int (default: None)
Random seed. If None, uses default random seed specified during init.
"""
args = locals()
del args["self"]
if batch_size is None:
batch_size = self.model.hparams["batch_size"]
args["batch_size"] = batch_size
self.args["batch_size"] = batch_size
if save_dir is None:
save_dir = self.args["save_dir"]
print("Results will be saved to the folder:", save_dir)
self.datasets.update(
{
"loader_tr": torch.utils.data.DataLoader(
self.datasets["training"], batch_size=batch_size, shuffle=True
)
}
)
self.model.train()
start_time = time.time()
pbar = tqdm(range(max_epochs), ncols=80)
try:
for epoch in pbar:
epoch_training_stats = defaultdict(float)
for data in self.datasets["loader_tr"]:
genes, drugs, covariates = data[0], data[1], data[2:]
minibatch_training_stats = self.model.update(
genes, drugs, covariates
)
for key, val in minibatch_training_stats.items():
epoch_training_stats[key] += val
for key, val in epoch_training_stats.items():
epoch_training_stats[key] = val / len(self.datasets["loader_tr"])
if not (key in self.model.history.keys()):
self.model.history[key] = []
self.model.history[key].append(epoch_training_stats[key])
self.model.history["epoch"].append(epoch)
ellapsed_minutes = (time.time() - start_time) / 60
self.model.history["elapsed_time_min"] = ellapsed_minutes
# decay learning rate if necessary
# also check stopping condition: patience ran out OR
# time ran out OR max epochs achieved
stop = ellapsed_minutes > max_minutes or (epoch == max_epochs - 1)
pbar.set_description(
f"Rec: {epoch_training_stats['loss_reconstruction']:.4f}, "
+ f"AdvPert: {epoch_training_stats['loss_adv_drugs']:.2f}, "
+ f"AdvCov: {epoch_training_stats['loss_adv_covariates']:.2f}"
)
if (epoch % checkpoint_freq) == 0 or stop:
if run_eval == True:
evaluation_stats = evaluate(self.model, self.datasets)
for key, val in evaluation_stats.items():
if not (key in self.model.history.keys()):
self.model.history[key] = []
self.model.history[key].append(val)
self.model.history["stats_epoch"].append(epoch)
stop = stop or self.model.early_stopping(
np.mean(evaluation_stats["test"])
)
else:
stop = stop or self.model.early_stopping(
np.mean(epoch_training_stats["test"])
)
evaluation_stats = None
if stop:
self.save(f"{save_dir}{filename}")
pprint.pprint(
{
"epoch": epoch,
"training_stats": epoch_training_stats,
"evaluation_stats": evaluation_stats,
"ellapsed_minutes": ellapsed_minutes,
}
)
print(f"Stop epoch: {epoch}")
break
except KeyboardInterrupt:
self.save(f"{save_dir}{filename}")
self.save(f"{save_dir}{filename}")
def save(self, filename):
"""
Parameters
----------
filename : str
Full path to save pretrained model.
"""
torch.save((self.model.state_dict(), self.args, self.model.history), filename)
self.history = self.model.history
print(f"Model saved to: {filename}")
def _init_pert_embeddings(self):
dose = 1.0
self.emb_perts = (
self.model.compute_drug_embeddings_(
dose * self.drug_ohe.to(self.model.device)
)
.cpu()
.clone()
.detach()
.numpy()
)
def get_drug_embeddings(self, dose=1.0, return_anndata=True):
"""
Parameters
----------
dose : int (default: 1.0)
Dose at which to evaluate latent embedding vector.
return_anndata : bool, optional (default: True)
Return embedding wrapped into anndata object.
Returns
-------
If return_anndata is True, returns anndata object. Otherwise, doesn't
return anything. Always saves embeddding in self.emb_perts.
"""
self._init_pert_embeddings()
emb_perts = (
self.model.compute_drug_embeddings_(
dose * self.drug_ohe.to(self.model.device)
)
.cpu()
.clone()
.detach()
.numpy()
)
if return_anndata:
adata = sc.AnnData(emb_perts)
adata.obs[self.perturbation_key] = self.unique_perts
return adata
def _init_covars_embeddings(self):
combo_list = []
for covars_key in self.covariate_keys:
combo_list.append(self.unique_covars[covars_key])
if self.emb_covars[covars_key] is None:
i_cov = self.covariate_keys.index(covars_key)
self.emb_covars[covars_key] = dict(
zip(
self.unique_covars[covars_key],
self.model.covariates_embeddings[i_cov](
self.covars_ohe[covars_key].to(self.model.device).argmax(1)
)
.cpu()
.clone()
.detach()
.numpy(),
)
)
self.emb_covars_combined = {}
for combo in list(itertools.product(*combo_list)):
combo_name = "_".join(combo)
for i, cov in enumerate(combo):
covars_key = self.covariate_keys[i]
if i == 0:
emb = self.emb_covars[covars_key][cov]
else:
emb += self.emb_covars[covars_key][cov]
self.emb_covars_combined[combo_name] = emb
def get_covars_embeddings_combined(self, return_anndata=True):
"""
Parameters
----------
return_anndata : bool, optional (default: True)
Return embedding wrapped into anndata object.
Returns
-------
If return_anndata is True, returns anndata object. Otherwise, doesn't
return anything. Always saves embeddding in self.emb_covars.
"""
self._init_covars_embeddings()
if return_anndata:
adata = sc.AnnData(np.array(list(self.emb_covars_combined.values())))
adata.obs["covars"] = self.emb_covars_combined.keys()
return adata
def get_covars_embeddings(self, covars_tgt, return_anndata=True):
"""
Parameters
----------
covars_tgt : str
Name of covariate for which to return AnnData
return_anndata : bool, optional (default: True)
Return embedding wrapped into anndata object.
Returns
-------
If return_anndata is True, returns anndata object. Otherwise, doesn't
return anything. Always saves embeddding in self.emb_covars.
"""
self._init_covars_embeddings()
if return_anndata:
adata = sc.AnnData(np.array(list(self.emb_covars[covars_tgt].values())))
adata.obs[covars_tgt] = self.emb_covars[covars_tgt].keys()
return adata
def _get_drug_encoding(self, drugs, doses=None):
"""
Parameters
----------
drugs : str
Drugs combination as a string, where individual drugs are separated
with a plus.
doses : str, optional (default: None)
Doses corresponding to the drugs combination as a string. Individual
drugs are separated with a plus.
Returns
-------
One hot encodding for a mixture of drugs.
"""
drug_mix = np.zeros([1, self.num_drugs])
atomic_drugs = drugs.split("+")
doses = str(doses)
if doses is None:
doses_list = [1.0] * len(atomic_drugs)
else:
doses_list = [float(d) for d in str(doses).split("+")]
for j, drug in enumerate(atomic_drugs):
drug_mix += doses_list[j] * self.perts_dict[drug]
return drug_mix
def mix_drugs(self, drugs_list, doses_list=None, return_anndata=True):
"""
Gets a list of drugs combinations to mix, e.g. ['A+B', 'B+C'] and
corresponding doses.
Parameters
----------
drugs_list : list
List of drug combinations, where each drug combination is a string.
Individual drugs in the combination are separated with a plus.
doses_list : str, optional (default: None)
List of corresponding doses, where each dose combination is a string.
Individual doses in the combination are separated with a plus.
return_anndata : bool, optional (default: True)
Return embedding wrapped into anndata object.
Returns
-------
If return_anndata is True, returns anndata structure of the combinations,
otherwise returns a np.array of corresponding embeddings.
"""
drug_mix = np.zeros([len(drugs_list), self.num_drugs])
for i, drug_combo in enumerate(drugs_list):
drug_mix[i] = self._get_drug_encoding(drug_combo, doses=doses_list[i])
emb = (
self.model.compute_drug_embeddings_(
torch.Tensor(drug_mix).to(self.model.device)
)
.cpu()
.clone()
.detach()
.numpy()
)
if return_anndata:
adata = sc.AnnData(emb)
adata.obs[self.perturbation_key] = drugs_list
adata.obs[self.dose_key] = doses_list
return adata
else:
return emb
def latent_dose_response(
self, perturbations=None, dose=None, contvar_min=0, contvar_max=1, n_points=100
):
"""
Parameters
----------
perturbations : list
List containing two names for which to return complete pairwise
dose-response.
doses : np.array (default: None)
Doses values. If None, default values will be generated on a grid:
n_points in range [contvar_min, contvar_max].
contvar_min : float (default: 0)
Minimum dose value to generate for default option.
contvar_max : float (default: 0)
Maximum dose value to generate for default option.
n_points : int (default: 100)
Number of dose points to generate for default option.
Returns
-------
pd.DataFrame
"""
# dosers work only for atomic drugs. TODO add drug combinations
self.model.eval()
if perturbations is None:
perturbations = self.unique_perts
if dose is None:
dose = np.linspace(contvar_min, contvar_max, n_points)
n_points = len(dose)
df = pd.DataFrame(columns=[self.perturbation_key, self.dose_key, "response"])
for drug in perturbations:
d = np.where(self.perts_dict[drug] == 1)[0][0]
this_drug = torch.Tensor(dose).to(self.model.device).view(-1, 1)
if self.model.doser_type == "mlp":
response = (
(self.model.dosers[d](this_drug).sigmoid() * this_drug.gt(0))
.cpu()
.clone()
.detach()
.numpy()
.reshape(-1)
)
else:
response = (
self.model.dosers.one_drug(this_drug.view(-1), d)
.cpu()
.clone()
.detach()
.numpy()
.reshape(-1)
)
df_drug = pd.DataFrame(
list(zip([drug] * n_points, dose, list(response))),
columns=[self.perturbation_key, self.dose_key, "response"],
)
df = pd.concat([df, df_drug])
return df
def latent_dose_response2D(
self,
perturbations,
dose=None,
contvar_min=0,
contvar_max=1,
n_points=100,
):
"""
Parameters
----------
perturbations : list, optional (default: None)
List of atomic drugs for which to return latent dose response.
Currently drug combinations are not supported.
doses : np.array (default: None)
Doses values. If None, default values will be generated on a grid:
n_points in range [contvar_min, contvar_max].
contvar_min : float (default: 0)
Minimum dose value to generate for default option.
contvar_max : float (default: 0)
Maximum dose value to generate for default option.
n_points : int (default: 100)
Number of dose points to generate for default option.
Returns
-------
pd.DataFrame
"""
# dosers work only for atomic drugs. TODO add drug combinations
assert len(perturbations) == 2, "You should provide a list of 2 perturbations."
self.model.eval()
if dose is None:
dose = np.linspace(contvar_min, contvar_max, n_points)
n_points = len(dose)
df = pd.DataFrame(columns=perturbations + ["response"])
response = {}
for drug in perturbations:
d = np.where(self.perts_dict[drug] == 1)[0][0]
this_drug = torch.Tensor(dose).to(self.model.device).view(-1, 1)
if self.model.doser_type == "mlp":
response[drug] = (
(self.model.dosers[d](this_drug).sigmoid() * this_drug.gt(0))
.cpu()
.clone()
.detach()
.numpy()
.reshape(-1)
)
else:
response[drug] = (
self.model.dosers.one_drug(this_drug.view(-1), d)
.cpu()
.clone()
.detach()
.numpy()
.reshape(-1)
)
l = 0
for i in range(len(dose)):
for j in range(len(dose)):
df.loc[l] = [
dose[i],
dose[j],
response[perturbations[0]][i] + response[perturbations[1]][j],
]
l += 1
return df
def compute_comb_emb(self, thrh=30):
"""
Generates an AnnData object containing all the latent vectors of the
cov+dose*pert combinations seen during training.
Called in api.compute_uncertainty(), stores the AnnData in self.comb_emb.
Parameters
----------
Returns
-------
"""
if self.seen_covars_perts["training"] is None:
raise ValueError("Need to run parse_training_conditions() first!")
emb_covars = self.get_covars_embeddings_combined(return_anndata=True)
# Generate adata with all cov+pert latent vect combinations
tmp_ad_list = []
for cov_pert in self.seen_covars_perts["training"]:
if self.num_measured_points["training"][cov_pert] > thrh:
*cov_list, pert_loop, dose_loop = cov_pert.split("_")
cov_loop = "_".join(cov_list)
emb_perts_loop = []
if "+" in pert_loop:
pert_loop_list = pert_loop.split("+")
dose_loop_list = dose_loop.split("+")
for _dose in pd.Series(dose_loop_list).unique():
tmp_ad = self.get_drug_embeddings(dose=float(_dose))
tmp_ad.obs["pert_dose"] = tmp_ad.obs.condition + "_" + _dose
emb_perts_loop.append(tmp_ad)
emb_perts_loop = emb_perts_loop[0].concatenate(emb_perts_loop[1:])
X = emb_covars.X[
emb_covars.obs.covars == cov_loop
] + np.expand_dims(
emb_perts_loop.X[
emb_perts_loop.obs.pert_dose.isin(
[
pert_loop_list[i] + "_" + dose_loop_list[i]
for i in range(len(pert_loop_list))
]
)
].sum(axis=0),
axis=0,
)
if X.shape[0] > 1:
raise ValueError("Error with comb computation")
else:
emb_perts = self.get_drug_embeddings(dose=float(dose_loop))
X = (
emb_covars.X[emb_covars.obs.covars == cov_loop]
+ emb_perts.X[emb_perts.obs.condition == pert_loop]
)
tmp_ad = sc.AnnData(X=X)
tmp_ad.obs["cov_pert"] = "_".join([cov_loop, pert_loop, dose_loop])
tmp_ad_list.append(tmp_ad)
self.comb_emb = tmp_ad_list[0].concatenate(tmp_ad_list[1:])
def compute_uncertainty(self, cov, pert, dose, thrh=30):
"""
Compute uncertainties for the queried covariate+perturbation combination.
The distance from the closest condition in the training set is used as a
proxy for uncertainty.
Parameters
----------
cov: dict
Provide a value for each covariate (eg. cell_type) as a dictionaty
for the queried uncertainty (e.g. cov_dict={'cell_type': 'A549'}).
pert: string
Perturbation for the queried uncertainty. In case of combinations the
format has to be 'pertA+pertB'
dose: string
String which contains the dose of the perturbation queried. In case
of combinations the format has to be 'doseA+doseB'
Returns
-------
min_cos_dist: float
Minimum cosine distance with the training set.
min_eucl_dist: float
Minimum euclidean distance with the training set.
closest_cond_cos: string
Closest training condition wrt cosine distances.
closest_cond_eucl: string
Closest training condition wrt euclidean distances.
"""
if self.comb_emb is None:
self.compute_comb_emb(thrh=30)
drug_ohe = torch.Tensor(self._get_drug_encoding(pert, doses=dose)).to(
self.model.device
)
pert = drug_ohe.expand([1, self.drug_ohe.shape[1]])
drug_emb = self.model.compute_drug_embeddings_(pert).detach().cpu().numpy()
cond_emb = drug_emb
for cov_key in cov:
cond_emb += self.emb_covars[cov_key][cov[cov_key]]
cos_dist = cosine_distances(cond_emb, self.comb_emb.X)[0]
min_cos_dist = np.min(cos_dist)
cos_idx = np.argmin(cos_dist)
closest_cond_cos = self.comb_emb.obs.cov_pert[cos_idx]
eucl_dist = euclidean_distances(cond_emb, self.comb_emb.X)[0]
min_eucl_dist = np.min(eucl_dist)
eucl_idx = np.argmin(eucl_dist)
closest_cond_eucl = self.comb_emb.obs.cov_pert[eucl_idx]
return min_cos_dist, min_eucl_dist, closest_cond_cos, closest_cond_eucl
def predict(
self,
genes,
cov,
pert,
dose,
uncertainty=True,
return_anndata=True,
sample=False,
n_samples=1,
):
"""Predict values of control 'genes' conditions specified in df.
Parameters
----------
genes : np.array
Control cells.
cov: dict of lists
Provide a value for each covariate (eg. cell_type) as a dictionaty
for the queried uncertainty (e.g. cov_dict={'cell_type': 'A549'}).
pert: list
Perturbation for the queried uncertainty. In case of combinations the
format has to be 'pertA+pertB'
dose: list
String which contains the dose of the perturbation queried. In case
of combinations the format has to be 'doseA+doseB'
uncertainty: bool (default: True)
Compute uncertainties for the generated cells.
return_anndata : bool, optional (default: True)
Return embedding wrapped into anndata object.
sample : bool (default: False)
If sample is True, returns samples from gausssian distribution with
mean and variance estimated by the model. Otherwise, returns just
means and variances estimated by the model.
n_samples : int (default: 10)
Number of samples to sample if sampling is True.
Returns
-------
If return_anndata is True, returns anndata structure. Otherwise, returns
np.arrays for gene_means, gene_vars and a data frame for the corresponding
conditions df_obs.
"""
assert len(dose) == len(pert), "Check the length of pert, dose"
for cov_key in cov:
assert len(cov[cov_key]) == len(pert), "Check the length of covariates"
df = pd.concat(
[
pd.DataFrame({self.perturbation_key: pert, self.dose_key: dose}),
pd.DataFrame(cov),
],
axis=1,
)
self.model.eval()
num = genes.shape[0]
dim = genes.shape[1]
genes = torch.Tensor(genes).to(self.model.device)
gene_means_list = []
gene_vars_list = []
df_list = []
for i in range(len(df)):
comb_name = pert[i]
dose_name = dose[i]
covar_name = {}
for cov_key in cov:
covar_name[cov_key] = cov[cov_key][i]
drug_ohe = torch.Tensor(
self._get_drug_encoding(comb_name, doses=dose_name)
).to(self.model.device)
drugs = drug_ohe.expand([num, self.drug_ohe.shape[1]])
covars = []
for cov_key in self.covariate_keys:
covar_ohe = torch.Tensor(
self.covars_dict[cov_key][covar_name[cov_key]]
).to(self.model.device)
covars.append(covar_ohe.expand([num, covar_ohe.shape[0]]).clone())
gene_reconstructions = (
self.model.predict(genes, drugs, covars).cpu().clone().detach().numpy()
)
if sample:
df_list.append(
pd.DataFrame(
[df.loc[i].values] * num * n_samples, columns=df.columns
)
)
if self.args['loss_ae'] == 'gauss':
dist = Normal(
torch.Tensor(gene_reconstructions[:, :dim]),
torch.Tensor(gene_reconstructions[:, dim:]),
)
elif self.args['loss_ae'] == 'nb':
counts, logits = _convert_mean_disp_to_counts_logits(
torch.clamp(
torch.Tensor(gene_reconstructions[:, :dim]),
min=1e-8,
max=1e8,
),
torch.clamp(
torch.Tensor(gene_reconstructions[:, dim:]),
min=1e-8,
max=1e8,
)
)
dist = NegativeBinomial(
total_count=counts,
logits=logits
)
sampled_gexp = (
dist.sample(torch.Size([n_samples]))
.cpu()
.detach()
.numpy()
.reshape(-1, dim)
)
sampled_gexp[sampled_gexp < 0] = 0 #set negative values to 0, since gexp can't be negative
gene_means_list.append(sampled_gexp)
else:
df_list.append(
pd.DataFrame([df.loc[i].values] * num, columns=df.columns)
)
gene_means_list.append(gene_reconstructions[:, :dim])
if uncertainty:
(
cos_dist,
eucl_dist,
closest_cond_cos,
closest_cond_eucl,
) = self.compute_uncertainty(
cov=covar_name, pert=comb_name, dose=dose_name
)
df_list[-1] = df_list[-1].assign(
uncertainty_cosine=cos_dist,
uncertainty_euclidean=eucl_dist,
closest_cond_cosine=closest_cond_cos,
closest_cond_euclidean=closest_cond_eucl,
)
gene_vars_list.append(gene_reconstructions[:, dim:])
gene_means = np.concatenate(gene_means_list)
gene_vars = np.concatenate(gene_vars_list)
df_obs = pd.concat(df_list)
del df_list, gene_means_list, gene_vars_list
if return_anndata:
adata = sc.AnnData(gene_means)
adata.var_names = self.var_names
adata.obs = df_obs
if not sample:
adata.layers["variance"] = gene_vars
adata.obs.index = adata.obs.index.astype(str) # type fix
del gene_means, gene_vars, df_obs
return adata
else:
return gene_means, gene_vars, df_obs
def get_latent(
self,
genes,
cov,
pert,
dose,
return_anndata=True,
):
"""Get latent values of control 'genes' with conditions specified in df.
Parameters
----------
genes : np.array
Control cells.
cov: dict of lists
Provide a value for each covariate (eg. cell_type) as a dictionaty
for the queried uncertainty (e.g. cov_dict={'cell_type': 'A549'}).
pert: list
Perturbation for the queried uncertainty. In case of combinations the
format has to be 'pertA+pertB'
dose: list
String which contains the dose of the perturbation queried. In case
of combinations the format has to be 'doseA+doseB'
return_anndata : bool, optional (default: True)
Return embedding wrapped into anndata object.
Returns
-------
If return_anndata is True, returns anndata structure. Otherwise, returns
np.arrays for latent and a data frame for the corresponding
conditions df_obs.
"""
assert len(dose) == len(pert), "Check the length of pert, dose"
for cov_key in cov:
assert len(cov[cov_key]) == len(pert), "Check the length of covariates"
df = pd.concat(
[
pd.DataFrame({self.perturbation_key: pert, self.dose_key: dose}),
pd.DataFrame(cov),
],
axis=1,
)
self.model.eval()
num = genes.shape[0]
genes = torch.Tensor(genes).to(self.model.device)
latent_list = []
df_list = []
for i in range(len(df)):
comb_name = pert[i]
dose_name = dose[i]
covar_name = {}
for cov_key in cov:
covar_name[cov_key] = cov[cov_key][i]
drug_ohe = torch.Tensor(
self._get_drug_encoding(comb_name, doses=dose_name)
).to(self.model.device)
drugs = drug_ohe.expand([num, self.drug_ohe.shape[1]])
covars = []
for cov_key in self.covariate_keys:
covar_ohe = torch.Tensor(
self.covars_dict[cov_key][covar_name[cov_key]]
).to(self.model.device)
covars.append(covar_ohe.expand([num, covar_ohe.shape[0]]).clone())
_, latent_treated = self.model.predict(
genes,
drugs,
covars,
return_latent_treated=True,
)
latent_treated = latent_treated.cpu().clone().detach().numpy()
df_list.append(
pd.DataFrame([df.loc[i].values] * num, columns=df.columns)
)
latent_list.append(latent_treated)
latent = np.concatenate(latent_list)
df_obs = pd.concat(df_list)
del df_list
if return_anndata:
adata = sc.AnnData(latent)
adata.obs = df_obs
adata.obs.index = adata.obs.index.astype(str) # type fix
return adata
else:
return latent, df_obs
def get_response(
self,
genes_control=None,
doses=None,
contvar_min=None,
contvar_max=None,
n_points=10,
ncells_max=100,
perturbations=None,
control_name="test",
):
"""Decoded dose response data frame.
Parameters
----------
genes_control : np.array (deafult: None)
Genes for which to predict values. If None, take from 'test_control'
split in datasets.
doses : np.array (default: None)
Doses values. If None, default values will be generated on a grid:
n_points in range [contvar_min, contvar_max].
contvar_min : float (default: 0)
Minimum dose value to generate for default option.
contvar_max : float (default: 0)
Maximum dose value to generate for default option.
n_points : int (default: 100)
Number of dose points to generate for default option.
perturbations : list (default: None)
List of perturbations for dose response
Returns
-------
pd.DataFrame
of decoded response values of genes and average response.
"""
if genes_control is None:
genes_control = self.datasets["test"].subset_condition(control=True).genes
if contvar_min is None:
contvar_min = 0
if contvar_max is None:
contvar_max = self.max_dose
self.model.eval()
if doses is None:
doses = np.linspace(contvar_min, contvar_max, n_points)
if perturbations is None:
perturbations = self.unique_perts
response = pd.DataFrame(
columns=self.covariate_keys
+ [self.perturbation_key, self.dose_key, "response"]
+ list(self.var_names)
)
if ncells_max < len(genes_control):
ncells_max = min(ncells_max, len(genes_control))
idx = torch.LongTensor(
np.random.choice(range(len(genes_control)), ncells_max, replace=False)
)
genes_control = genes_control[idx]
j = 0
for covar_combo in self.emb_covars_combined:
cov_dict = {}
for i, cov_val in enumerate(covar_combo.split("_")):
cov_dict[self.covariate_keys[i]] = [cov_val]
print(cov_dict)
for _, drug in enumerate(perturbations):
if not (drug in self.datasets[control_name].subset_condition(control=True).ctrl_name):
for dose in doses:
# TODO handle covars
gene_means, _, _ = self.predict(
genes_control,
cov=cov_dict,
pert=[drug],
dose=[dose],
return_anndata=False,
)
predicted_data = np.mean(gene_means, axis=0).reshape(-1)
response.loc[j] = (
covar_combo.split("_")
+ [drug, dose, np.linalg.norm(predicted_data)]
+ list(predicted_data)
)
j += 1
return response
def get_response_reference(self, perturbations=None):
"""Computes reference values of the response.
Parameters
----------
dataset : CompPertDataset
The file location of the spreadsheet
perturbations : list (default: None)
List of perturbations for dose response
Returns
-------
pd.DataFrame
of decoded response values of genes and average response.
"""
if perturbations is None:
perturbations = self.unique_perts
reference_response_curve = pd.DataFrame(
columns=self.covariate_keys
+ [self.perturbation_key, self.dose_key, "split", "num_cells", "response"]
+ list(self.var_names)
)
dataset_ctr = self.datasets["training"].subset_condition(control=True)
i = 0
for split in ["training", "ood"]:
if split == 'ood':
dataset = self.datasets[split]
else:
dataset = self.datasets["training"].subset_condition(control=False)
for pert in self.seen_covars_perts[split]:
*covars, drug, dose_val = pert.split("_")
if drug in perturbations:
if not ("+" in dose_val):
dose = float(dose_val)
else:
dose = dose_val
idx = np.where((dataset.pert_categories == pert))[0]
if len(idx):
y_true = dataset.genes[idx, :].numpy().mean(axis=0)
reference_response_curve.loc[i] = (
covars
+ [drug, dose, split, len(idx), np.linalg.norm(y_true)]
+ list(y_true)
)
i += 1
reference_response_curve = reference_response_curve.replace(
"training_treated", "train"
)
return reference_response_curve
def get_response2D(
self,
perturbations,
covar,
genes_control=None,
doses=None,
contvar_min=None,
contvar_max=None,
n_points=10,
ncells_max=100,
#fixed_drugs="",
#fixed_doses="",
):
"""Decoded dose response data frame.
Parameters
----------
perturbations : list
List of length 2 of perturbations for dose response.
covar : dict
Name of a covariate for which to compute dose-response.
genes_control : np.array (deafult: None)
Genes for which to predict values. If None, take from 'test_control'
split in datasets.
doses : np.array (default: None)
Doses values. If None, default values will be generated on a grid:
n_points in range [contvar_min, contvar_max].
contvar_min : float (default: 0)
Minimum dose value to generate for default option.
contvar_max : float (default: 0)
Maximum dose value to generate for default option.
n_points : int (default: 100)
Number of dose points to generate for default option.
Returns
-------
pd.DataFrame
of decoded response values of genes and average response.
"""
assert len(perturbations) == 2, "You should provide a list of 2 perturbations."
if contvar_min is None:
contvar_min = self.min_dose
if contvar_max is None:
contvar_max = self.max_dose
self.model.eval()
# doses = torch.Tensor(np.linspace(contvar_min, contvar_max, n_points))
if doses is None:
doses = np.linspace(contvar_min, contvar_max, n_points)
# genes_control = dataset.genes[dataset.indices['control']]
if genes_control is None:
genes_control = self.datasets["test"].subset_condition(control=True).genes
ncells_max = min(ncells_max, len(genes_control))
idx = torch.LongTensor(np.random.choice(range(len(genes_control)), ncells_max))
genes_control = genes_control[idx]
response = pd.DataFrame(
columns=perturbations+["response"]+list(self.var_names)
)
drug = perturbations[0] + "+" + perturbations[1]
dose_vals = [f"{d[0]}+{d[1]}" for d in itertools.product(*[doses, doses])]
dose_comb = [list(d) for d in itertools.product(*[doses, doses])]
i = 0
if not (drug in self.datasets['training'].subset_condition(control=True).ctrl_name):
for dose in dose_vals:
gene_means, _, _ = self.predict(
genes_control,
cov=covar,
pert=[drug],# + fixed_drugs],
dose=[dose],# + fixed_doses],
return_anndata=False,
)
predicted_data = np.mean(gene_means, axis=0).reshape(-1)
response.loc[i] = (
dose_comb[i]
+ [np.linalg.norm(predicted_data)]
+ list(predicted_data)
)
i += 1
# i = 0
# if not (drug in ["Vehicle", "EGF", "unst", "control", "ctrl"]):
# for dose in dose_vals:
# gene_means, _, _ = self.predict(
# genes_control,
# cov=covar,
# pert=[drug + fixed_drugs],
# dose=[dose + fixed_doses],
# return_anndata=False,
# )
# predicted_data = np.mean(gene_means, axis=0).reshape(-1)
# response.loc[i] = (
# dose_comb[i]
# + [np.linalg.norm(predicted_data)]
# + list(predicted_data)
# )
# i += 1
return response
def evaluate_r2(self, dataset, genes_control, adata_random=None):
"""
Measures different quality metrics about an CPA `autoencoder`, when
tasked to translate some `genes_control` into each of the drug/cell_type
combinations described in `dataset`.
Considered metrics are R2 score about means and variances for all genes, as
well as R2 score about means and variances about differentially expressed
(_de) genes.
"""
self.model.eval()
scores = pd.DataFrame(
columns=self.covariate_keys
+ [
self.perturbation_key,
self.dose_key,
"R2_mean",
"R2_mean_DE",
"R2_var",
"R2_var_DE",
"model",
"num_cells",
]
)
num, dim = genes_control.size(0), genes_control.size(1)
total_cells = len(dataset)
icond = 0
for pert_category in np.unique(dataset.pert_categories):
# pert_category category contains: 'celltype_perturbation_dose' info
de_idx = np.where(
dataset.var_names.isin(np.array(dataset.de_genes[pert_category]))
)[0]
idx = np.where(dataset.pert_categories == pert_category)[0]
*covars, pert, dose = pert_category.split("_")
cov_dict = {}
for i, cov_key in enumerate(self.covariate_keys):
cov_dict[cov_key] = [covars[i]]
if len(idx) > 0:
mean_predict, var_predict, _ = self.predict(
genes_control,
cov=cov_dict,
pert=[pert],
dose=[dose],
return_anndata=False,
sample=False,
)
# estimate metrics only for reasonably-sized drug/cell-type combos
y_true = dataset.genes[idx, :].numpy()
# true means and variances
yt_m = y_true.mean(axis=0)
yt_v = y_true.var(axis=0)
# predicted means and variances
yp_m = mean_predict.mean(0)
yp_v = var_predict.mean(0)
#yp_v = np.var(mean_predict, axis=0)
mean_score = r2_score(yt_m, yp_m)
var_score = r2_score(yt_v, yp_v)
mean_score_de = r2_score(yt_m[de_idx], yp_m[de_idx])
var_score_de = r2_score(yt_v[de_idx], yp_v[de_idx])
scores.loc[icond] = pert_category.split("_") + [
mean_score,
mean_score_de,
var_score,
var_score_de,
"cpa",
len(idx),
]
icond += 1
if adata_random is not None:
yp_m_bl = np.mean(adata_random, axis=0)
yp_v_bl = np.var(adata_random, axis=0)
mean_score_bl = r2_score(yt_m, yp_m_bl)
var_score_bl = r2_score(yt_v, yp_v_bl)
mean_score_de_bl = r2_score(yt_m[de_idx], yp_m_bl[de_idx])
var_score_de_bl = r2_score(yt_v[de_idx], yp_v_bl[de_idx])
scores.loc[icond] = pert_category.split("_") + [
mean_score_bl,
mean_score_de_bl,
var_score_bl,
var_score_de_bl,
"baseline",
len(idx),
]
icond += 1
return scores
def get_reference_from_combo(perturbations_list, datasets, splits=["training", "ood"]):
"""
A simple function that produces a pd.DataFrame of individual
drugs-doses combinations used among the splits (for a fixed covariate).
"""
df_list = []
for split_name in splits:
full_dataset = datasets[split_name]
ref = {"num_cells": []}
for pp in perturbations_list:
ref[pp] = []
ndrugs = len(perturbations_list)
for pert_cat in np.unique(full_dataset.pert_categories):
_, pert, dose = pert_cat.split("_")
pert_list = pert.split("+")
if set(pert_list) == set(perturbations_list):
dose_list = dose.split("+")
ncells = len(
full_dataset.pert_categories[
full_dataset.pert_categories == pert_cat
]
)
for j in range(ndrugs):
ref[pert_list[j]].append(float(dose_list[j]))
ref["num_cells"].append(ncells)
print(pert, dose, ncells)
df = pd.DataFrame.from_dict(ref)
df["split"] = split_name
df_list.append(df)
return pd.concat(df_list)
def linear_interp(y1, y2, x1, x2, x):
a = (y1 - y2) / (x1 - x2)
b = y1 - a * x1
y = a * x + b
return y
def evaluate_r2_benchmark(cpa_api, datasets, pert_category, pert_category_list):
scores = pd.DataFrame(
columns=[
cpa_api.covars_key,
cpa_api.perturbation_key,
cpa_api.dose_key,
"R2_mean",
"R2_mean_DE",
"R2_var",
"R2_var_DE",
"num_cells",
"benchmark",
"method",
]
)
de_idx = np.where(
datasets["ood"].var_names.isin(
np.array(datasets["ood"].de_genes[pert_category])
)
)[0]
idx = np.where(datasets["ood"].pert_categories == pert_category)[0]
y_true = datasets["ood"].genes[idx, :].numpy()
# true means and variances
yt_m = y_true.mean(axis=0)
yt_v = y_true.var(axis=0)
icond = 0
if len(idx) > 0:
for pert_category_predict in pert_category_list:
if "+" in pert_category_predict:
pert1, pert2 = pert_category_predict.split("+")
idx_pred1 = np.where(datasets["training"].pert_categories == pert1)[0]
idx_pred2 = np.where(datasets["training"].pert_categories == pert2)[0]
y_pred1 = datasets["training"].genes[idx_pred1, :].numpy()
y_pred2 = datasets["training"].genes[idx_pred2, :].numpy()
x1 = float(pert1.split("_")[2])
x2 = float(pert2.split("_")[2])
x = float(pert_category.split("_")[2])
yp_m1 = y_pred1.mean(axis=0)
yp_m2 = y_pred2.mean(axis=0)
yp_v1 = y_pred1.var(axis=0)
yp_v2 = y_pred2.var(axis=0)
yp_m = linear_interp(yp_m1, yp_m2, x1, x2, x)
yp_v = linear_interp(yp_v1, yp_v2, x1, x2, x)
# yp_m = (y_pred1.mean(axis=0) + y_pred2.mean(axis=0))/2
# yp_v = (y_pred1.var(axis=0) + y_pred2.var(axis=0))/2
else:
idx_pred = np.where(
datasets["training"].pert_categories == pert_category_predict
)[0]
print(pert_category_predict, len(idx_pred))
y_pred = datasets["training"].genes[idx_pred, :].numpy()
# predicted means and variances
yp_m = y_pred.mean(axis=0)
yp_v = y_pred.var(axis=0)
mean_score = r2_score(yt_m, yp_m)
var_score = r2_score(yt_v, yp_v)
mean_score_de = r2_score(yt_m[de_idx], yp_m[de_idx])
var_score_de = r2_score(yt_v[de_idx], yp_v[de_idx])
scores.loc[icond] = pert_category.split("_") + [
mean_score,
mean_score_de,
var_score,
var_score_de,
len(idx),
pert_category_predict,
"benchmark",
]
icond += 1
return scores
================================================
FILE: cpa/data.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import warnings
import numpy as np
import torch
warnings.simplefilter(action="ignore", category=FutureWarning)
from typing import Union
import pandas as pd
import scanpy as sc
import scipy
from cpa.helper import rank_genes_groups
from sklearn.preprocessing import OneHotEncoder
def ranks_to_df(data, key="rank_genes_groups"):
"""Converts an `sc.tl.rank_genes_groups` result into a MultiIndex dataframe.
You can access various levels of the MultiIndex with `df.loc[[category]]`.
Params
------
data : `AnnData`
key : str (default: 'rank_genes_groups')
Field in `.uns` of data where `sc.tl.rank_genes_groups` result is
stored.
"""
d = data.uns[key]
dfs = []
for k in d.keys():
if k == "params":
continue
series = pd.DataFrame.from_records(d[k]).unstack()
series.name = k
dfs.append(series)
return pd.concat(dfs, axis=1)
def check_adata(adata, special_fields):
replaced = False
for sf in special_fields:
if sf in adata.obs:
flag = 0
for el in adata.obs[sf].values:
if "_" in str(el):
flag += 1
if flag:
print(
f"WARNING. Special characters ('_') were found in: '{sf}'.",
"They will be replaced with '-'.",
"Be careful, it may lead to errors downstream.",
)
adata.obs[sf] = [s.replace("_", "-") for s in adata.obs[sf].values]
replaced = True
return adata, replaced
indx = lambda a, i: a[i] if a is not None else None
class Dataset:
def __init__(
self,
data,
perturbation_key=None,
dose_key=None,
covariate_keys=None,
split_key="split",
control=None,
):
if type(data) == str:
data = sc.read(data)
#Assert that keys are present in the adata object
assert perturbation_key in data.obs.columns, f"Perturbation {perturbation_key} is missing in the provided adata"
for key in covariate_keys:
assert key in data.obs.columns, f"Covariate {key} is missing in the provided adata"
assert dose_key in data.obs.columns, f"Dose {dose_key} is missing in the provided adata"
assert split_key in data.obs.columns, f"Split {split_key} is missing in the provided adata"
assert not (split_key is None), "split_key can not be None"
#If covariate keys is empty list create dummy covariate
if len(covariate_keys) == 0:
print("Adding a dummy covariate...")
data.obs['dummy_cov'] = 'dummy_cov'
covariate_keys = ['dummy_cov']
self.perturbation_key = perturbation_key
self.dose_key = dose_key
if scipy.sparse.issparse(data.X):
self.genes = torch.Tensor(data.X.A)
else:
self.genes = torch.Tensor(data.X)
self.var_names = data.var_names
if isinstance(covariate_keys, str):
covariate_keys = [covariate_keys]
self.covariate_keys = covariate_keys
data, replaced = check_adata(
data, [perturbation_key, dose_key] + covariate_keys
)
for cov in covariate_keys:
if not (cov in data.obs):
data.obs[cov] = "unknown"
if split_key in data.obs:
pass
else:
print("Performing automatic train-test split with 0.25 ratio.")
from sklearn.model_selection import train_test_split
data.obs[split_key] = "train"
idx = list(range(len(data)))
idx_train, idx_test = train_test_split(
data.obs_names, test_size=0.25, random_state=42
)
data.obs[split_key].loc[idx_train] = "train"
data.obs[split_key].loc[idx_test] = "test"
if "control" in data.obs:
self.ctrl = data.obs["control"].values
else:
print(f"Assigning control values for {control}")
assert_msg = "Please provide a name for control condition."
assert not (control is None), assert_msg
data.obs["control"] = 0
if dose_key in data.obs:
pert, dose = control.split("_")
data.obs.loc[
(data.obs[perturbation_key] == pert) & (data.obs[dose_key] == dose),
"control",
] = 1
else:
pert = control
data.obs.loc[(data.obs[perturbation_key] == pert), "control"] = 1
self.ctrl = data.obs["control"].values
assert_msg = "Cells to assign as control not found! Please check the name of control variable."
assert sum(self.ctrl), assert_msg
print(f"Assigned {sum(self.ctrl)} control cells")
if perturbation_key is not None:
if dose_key is None:
raise ValueError(
f"A 'dose_key' is required when provided a 'perturbation_key'({perturbation_key})."
)
if not (dose_key in data.obs):
print(
f"Creating a default entrance for dose_key {dose_key}:",
"1.0 per perturbation",
)
dose_val = []
for i in range(len(data)):
pert = data.obs[perturbation_key].values[i].split("+")
dose_val.append("+".join(["1.0"] * len(pert)))
data.obs[dose_key] = dose_val
if not ("cov_drug_dose_name" in data.obs) or replaced:
print("Creating 'cov_drug_dose_name' field.")
cov_drug_dose_name = []
for i in range(len(data)):
comb_name = ""
for cov_key in self.covariate_keys:
comb_name += f"{data.obs[cov_key].values[i]}_"
comb_name += f"{data.obs[perturbation_key].values[i]}_{data.obs[dose_key].values[i]}"
cov_drug_dose_name.append(comb_name)
data.obs["cov_drug_dose_name"] = cov_drug_dose_name
if not ("rank_genes_groups_cov" in data.uns) or replaced:
print("Ranking genes for DE genes.")
rank_genes_groups(data, groupby="cov_drug_dose_name")
self.pert_categories = np.array(data.obs["cov_drug_dose_name"].values)
self.de_genes = data.uns["rank_genes_groups_cov"]
self.drugs_names = np.array(data.obs[perturbation_key].values)
self.dose_names = np.array(data.obs[dose_key].values)
# get unique drugs
drugs_names_unique = set()
for d in self.drugs_names:
[drugs_names_unique.add(i) for i in d.split("+")]
self.drugs_names_unique = np.array(list(drugs_names_unique))
# save encoder for a comparison with Mo's model
# later we need to remove this part
encoder_drug = OneHotEncoder(sparse=False)
encoder_drug.fit(self.drugs_names_unique.reshape(-1, 1))
# Store as attribute for molecular featurisation
self.encoder_drug = encoder_drug
self.perts_dict = dict(
zip(
self.drugs_names_unique,
encoder_drug.transform(self.drugs_names_unique.reshape(-1, 1)),
)
)
# get drug combinations
drugs = []
for i, comb in enumerate(self.drugs_names):
drugs_combos = encoder_drug.transform(
np.array(comb.split("+")).reshape(-1, 1)
)
dose_combos = str(data.obs[dose_key].values[i]).split("+")
for j, d in enumerate(dose_combos):
if j == 0:
drug_ohe = float(d) * drugs_combos[j]
else:
drug_ohe += float(d) * drugs_combos[j]
drugs.append(drug_ohe)
self.drugs = torch.Tensor(drugs)
atomic_ohe = encoder_drug.transform(self.drugs_names_unique.reshape(-1, 1))
self.drug_dict = {}
for idrug, drug in enumerate(self.drugs_names_unique):
i = np.where(atomic_ohe[idrug] == 1)[0][0]
self.drug_dict[i] = drug
else:
self.pert_categories = None
self.de_genes = None
self.drugs_names = None
self.dose_names = None
self.drugs_names_unique = None
self.perts_dict = None
self.drug_dict = None
self.drugs = None
if isinstance(covariate_keys, list) and covariate_keys:
if not len(covariate_keys) == len(set(covariate_keys)):
raise ValueError(f"Duplicate keys were given in: {covariate_keys}")
self.covariate_names = {}
self.covariate_names_unique = {}
self.covars_dict = {}
self.covariates = []
for cov in covariate_keys:
self.covariate_names[cov] = np.array(data.obs[cov].values)
self.covariate_names_unique[cov] = np.unique(self.covariate_names[cov])
names = self.covariate_names_unique[cov]
encoder_cov = OneHotEncoder(sparse=False)
encoder_cov.fit(names.reshape(-1, 1))
self.covars_dict[cov] = dict(
zip(list(names), encoder_cov.transform(names.reshape(-1, 1)))
)
names = self.covariate_names[cov]
self.covariates.append(
torch.Tensor(encoder_cov.transform(names.reshape(-1, 1))).float()
)
else:
self.covariate_names = None
self.covariate_names_unique = None
self.covars_dict = None
self.covariates = None
if perturbation_key is not None:
self.ctrl_name = list(
np.unique(data[data.obs["control"] == 1].obs[self.perturbation_key])
)
else:
self.ctrl_name = None
if self.covariates is not None:
self.num_covariates = [
len(names) for names in self.covariate_names_unique.values()
]
else:
self.num_covariates = [0]
self.num_genes = self.genes.shape[1]
self.num_drugs = len(self.drugs_names_unique) if self.drugs is not None else 0
self.is_control = data.obs["control"].values.astype(bool)
self.indices = {
"all": list(range(len(self.genes))),
"control": np.where(data.obs["control"] == 1)[0].tolist(),
"treated": np.where(data.obs["control"] != 1)[0].tolist(),
"train": np.where(data.obs[split_key] == "train")[0].tolist(),
"test": np.where(data.obs[split_key] == "test")[0].tolist(),
"ood": np.where(data.obs[split_key] == "ood")[0].tolist(),
}
def subset(self, split, condition="all"):
idx = list(set(self.indices[split]) & set(self.indices[condition]))
return SubDataset(self, idx)
def __getitem__(self, i):
return (
self.genes[i],
indx(self.drugs, i),
*[indx(cov, i) for cov in self.covariates],
)
def __len__(self):
return len(self.genes)
class SubDataset:
"""
Subsets a `Dataset` by selecting the examples given by `indices`.
"""
def __init__(self, dataset, indices):
self.perturbation_key = dataset.perturbation_key
self.dose_key = dataset.dose_key
self.covariate_keys = dataset.covariate_keys
self.perts_dict = dataset.perts_dict
self.covars_dict = dataset.covars_dict
self.genes = dataset.genes[indices]
self.drugs = indx(dataset.drugs, indices)
self.covariates = [indx(cov, indices) for cov in dataset.covariates]
self.drugs_names = indx(dataset.drugs_names, indices)
self.pert_categories = indx(dataset.pert_categories, indices)
self.covariate_names = {}
for cov in self.covariate_keys:
self.covariate_names[cov] = indx(dataset.covariate_names[cov], indices)
self.var_names = dataset.var_names
self.de_genes = dataset.de_genes
self.ctrl_name = indx(dataset.ctrl_name, 0)
self.num_covariates = dataset.num_covariates
self.num_genes = dataset.num_genes
self.num_drugs = dataset.num_drugs
self.is_control = dataset.is_control[indices]
def __getitem__(self, i):
return (
self.genes[i],
indx(self.drugs, i),
*[indx(cov, i) for cov in self.covariates],
)
def subset_condition(self, control=True):
idx = np.where(self.is_control == control)[0].tolist()
return SubDataset(self, idx)
def __len__(self):
return len(self.genes)
def load_dataset_splits(
data: str,
perturbation_key: Union[str, None],
dose_key: Union[str, None],
covariate_keys: Union[list, str, None],
split_key: str,
control: Union[str, None],
return_dataset: bool = False,
):
dataset = Dataset(
data, perturbation_key, dose_key, covariate_keys, split_key, control
)
splits = {
"training": dataset.subset("train", "all"),
"test": dataset.subset("test", "all"),
"ood": dataset.subset("ood", "all"),
}
if return_dataset:
return splits, dataset
else:
return splits
================================================
FILE: cpa/helper.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import warnings
import numpy as np
import pandas as pd
import scanpy as sc
from sklearn.metrics import r2_score
from scipy.sparse import issparse
from scipy.stats import wasserstein_distance
import torch
warnings.filterwarnings("ignore")
import sys
if not sys.warnoptions:
warnings.simplefilter("ignore")
warnings.simplefilter(action="ignore", category=FutureWarning)
def _convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6):
r"""NB parameterizations conversion
Parameters
----------
mu :
mean of the NB distribution.
theta :
inverse overdispersion.
eps :
constant used for numerical log stability. (Default value = 1e-6)
Returns
-------
type
the number of failures until the experiment is stopped
and the success probability.
"""
assert (mu is None) == (
theta is None
), "If using the mu/theta NB parameterization, both parameters must be specified"
logits = (mu + eps).log() - (theta + eps).log()
total_count = theta
return total_count, logits
def rank_genes_groups_by_cov(
adata,
groupby,
control_group,
covariate,
pool_doses=False,
n_genes=50,
rankby_abs=True,
key_added="rank_genes_groups_cov",
return_dict=False,
):
"""
Function that generates a list of differentially expressed genes computed
separately for each covariate category, and using the respective control
cells as reference.
Usage example:
rank_genes_groups_by_cov(
adata,
groupby='cov_product_dose',
covariate_key='cell_type',
control_group='Vehicle_0'
)
Parameters
----------
adata : AnnData
AnnData dataset
groupby : str
Obs column that defines the groups, should be
cartesian product of covariate_perturbation_cont_var,
it is important that this format is followed.
control_group : str
String that defines the control group in the groupby obs
covariate : str
Obs column that defines the main covariate by which we
want to separate DEG computation (eg. cell type, species, etc.)
n_genes : int (default: 50)
Number of DEGs to include in the lists
rankby_abs : bool (default: True)
If True, rank genes by absolute values of the score, thus including
top downregulated genes in the top N genes. If False, the ranking will
have only upregulated genes at the top.
key_added : str (default: 'rank_genes_groups_cov')
Key used when adding the dictionary to adata.uns
return_dict : str (default: False)
Signals whether to return the dictionary or not
Returns
-------
Adds the DEG dictionary to adata.uns
If return_dict is True returns:
gene_dict : dict
Dictionary where groups are stored as keys, and the list of DEGs
are the corresponding values
"""
gene_dict = {}
cov_categories = adata.obs[covariate].unique()
for cov_cat in cov_categories:
print(cov_cat)
# name of the control group in the groupby obs column
control_group_cov = "_".join([cov_cat, control_group])
# subset adata to cells belonging to a covariate category
adata_cov = adata[adata.obs[covariate] == cov_cat]
# compute DEGs
sc.tl.rank_genes_groups(
adata_cov,
groupby=groupby,
reference=control_group_cov,
rankby_abs=rankby_abs,
n_genes=n_genes,
)
# add entries to dictionary of gene sets
de_genes = pd.DataFrame(adata_cov.uns["rank_genes_groups"]["names"])
for group in de_genes:
gene_dict[group] = de_genes[group].tolist()
adata.uns[key_added] = gene_dict
if return_dict:
return gene_dict
def rank_genes_groups(
adata,
groupby,
pool_doses=False,
n_genes=50,
rankby_abs=True,
key_added="rank_genes_groups_cov",
return_dict=False,
):
"""
Function that generates a list of differentially expressed genes computed
separately for each covariate category, and using the respective control
cells as reference.
Usage example:
rank_genes_groups_by_cov(
adata,
groupby='cov_product_dose',
covariate_key='cell_type',
control_group='Vehicle_0'
)
Parameters
----------
adata : AnnData
AnnData dataset
groupby : str
Obs column that defines the groups, should be
cartesian product of covariate_perturbation_cont_var,
it is important that this format is followed.
control_group : str
String that defines the control group in the groupby obs
covariate : str
Obs column that defines the main covariate by which we
want to separate DEG computation (eg. cell type, species, etc.)
n_genes : int (default: 50)
Number of DEGs to include in the lists
rankby_abs : bool (default: True)
If True, rank genes by absolute values of the score, thus including
top downregulated genes in the top N genes. If False, the ranking will
have only upregulated genes at the top.
key_added : str (default: 'rank_genes_groups_cov')
Key used when adding the dictionary to adata.uns
return_dict : str (default: False)
Signals whether to return the dictionary or not
Returns
-------
Adds the DEG dictionary to adata.uns
If return_dict is True returns:
gene_dict : dict
Dictionary where groups are stored as keys, and the list of DEGs
are the corresponding values
"""
covars_comb = []
for i in range(len(adata)):
cov = "_".join(adata.obs["cov_drug_dose_name"].values[i].split("_")[:-2])
covars_comb.append(cov)
adata.obs["covars_comb"] = covars_comb
gene_dict = {}
for cov_cat in np.unique(adata.obs["covars_comb"].values):
adata_cov = adata[adata.obs["covars_comb"] == cov_cat]
control_group_cov = (
adata_cov[adata_cov.obs["control"] == 1].obs[groupby].values[0]
)
# compute DEGs
sc.tl.rank_genes_groups(
adata_cov,
groupby=groupby,
reference=control_group_cov,
rankby_abs=rankby_abs,
n_genes=n_genes,
)
# add entries to dictionary of gene sets
de_genes = pd.DataFrame(adata_cov.uns["rank_genes_groups"]["names"])
for group in de_genes:
gene_dict[group] = de_genes[group].tolist()
adata.uns[key_added] = gene_dict
if return_dict:
return gene_dict
# def evaluate_r2_(adata, pred_adata, condition_key, sampled=False):
# r2_list = []
# if issparse(adata.X):
# adata.X = adata.X.A
# if issparse(pred_adata.X):
# pred_adata.X = pred_adata.X.A
# for cond in pred_adata.obs[condition_key].unique():
# adata_ = adata[adata.obs[condition_key] == cond]
# pred_adata_ = pred_adata[pred_adata.obs[condition_key] == cond]
# r2_mean = r2_score(adata_.X.mean(0), pred_adata_.X.mean(0))
# if sampled:
# r2_var = r2_score(adata_.X.var(0), pred_adata_.X.var(0))
# else:
# r2_var = r2_score(
# adata_.X.var(0),
# pred_adata_.layers['variance'].var(0)
# )
# r2_list.append(
# {
# 'condition': cond,
# 'r2_mean': r2_mean,
# 'r2_var': r2_var,
# }
# )
# r2_df = pd.DataFrame(r2_list).set_index('condition')
# return r2_df
def evaluate_r2_(adata, pred_adata, condition_key, sampled=False, de_genes_dict=None):
r2_list = []
if issparse(adata.X):
adata.X = adata.X.A
if issparse(pred_adata.X):
pred_adata.X = pred_adata.X.A
for cond in pred_adata.obs[condition_key].unique():
adata_ = adata[adata.obs[condition_key] == cond]
pred_adata_ = pred_adata[pred_adata.obs[condition_key] == cond]
r2_mean = r2_score(adata_.X.mean(0), pred_adata_.X.mean(0))
if sampled:
r2_var = r2_score(adata_.X.var(0), pred_adata_.X.var(0))
else:
r2_var = r2_score(
adata_.X.var(0),
pred_adata_.layers['variance'].var(0)
)
r2_list.append(
{
'condition': cond,
'r2_mean': r2_mean,
'r2_var': r2_var,
}
)
if de_genes_dict:
de_genes = de_genes_dict[cond]
sub_adata_ = adata_[:, de_genes]
sub_pred_adata_ = pred_adata_[:, de_genes]
r2_mean_deg = r2_score(sub_adata_.X.mean(0), sub_pred_adata_.X.mean(0))
if sampled:
r2_var_deg = r2_score(sub_adata_.X.var(0), sub_pred_adata_.X.var(0))
else:
r2_var_deg = r2_score(
sub_adata_.X.var(0),
sub_pred_adata_.layers['variance'].var(0)
)
r2_list[-1]['r2_mean_deg'] = r2_mean_deg
r2_list[-1]['r2_var_deg'] = r2_var_deg
r2_df = pd.DataFrame(r2_list).set_index('condition')
return r2_df
def evaluate_mmd(adata, pred_adata, condition_key, de_genes_dict=None):
mmd_list = []
for cond in pred_adata.obs[condition_key].unique():
adata_ = adata[adata.obs[condition_key] == cond].copy()
pred_adata_ = pred_adata[pred_adata.obs[condition_key] == cond].copy()
if issparse(adata_.X):
adata_.X = adata_.X.A
if issparse(pred_adata_.X):
pred_adata_.X = pred_adata_.X.A
mmd = mmd_loss_calc(torch.Tensor(adata_.X), torch.Tensor(pred_adata_.X))
mmd_list.append(
{
'condition': cond,
'mmd': mmd.detach().cpu().numpy()
}
)
if de_genes_dict:
de_genes = de_genes_dict[cond]
sub_adata_ = adata_[:, de_genes]
sub_pred_adata_ = pred_adata_[:, de_genes]
mmd_deg = mmd_loss_calc(torch.Tensor(sub_adata_.X), torch.Tensor(sub_pred_adata_.X))
mmd_list[-1]['mmd_deg'] = mmd_deg.detach().cpu().numpy()
mmd_df = pd.DataFrame(mmd_list).set_index('condition')
return mmd_df
def evaluate_emd(adata, pred_adata, condition_key, de_genes_dict=None):
emd_list = []
for cond in pred_adata.obs[condition_key].unique():
adata_ = adata[adata.obs[condition_key] == cond].copy()
pred_adata_ = pred_adata[pred_adata.obs[condition_key] == cond].copy()
if issparse(adata_.X):
adata_.X = adata_.X.A
if issparse(pred_adata_.X):
pred_adata_.X = pred_adata_.X.A
wd = []
for i, _ in enumerate(adata_.var_names):
wd.append(
wasserstein_distance(torch.Tensor(adata_.X[:, i]), torch.Tensor(pred_adata_.X[:, i]))
)
emd_list.append(
{
'condition': cond,
'emd': np.mean(wd)
}
)
if de_genes_dict:
de_genes = de_genes_dict[cond]
sub_adata_ = adata_[:, de_genes]
sub_pred_adata_ = pred_adata_[:, de_genes]
wd_deg = []
for i, _ in enumerate(sub_adata_.var_names):
wd_deg.append(
wasserstein_distance(torch.Tensor(sub_adata_.X[:, i]), torch.Tensor(sub_pred_adata_.X[:, i]))
)
emd_list[-1]['emd_deg'] = np.mean(wd_deg)
emd_df = pd.DataFrame(emd_list).set_index('condition')
return emd_df
def pairwise_distance(x, y):
x = x.view(x.shape[0], x.shape[1], 1)
y = torch.transpose(y, 0, 1)
output = torch.sum((x - y) ** 2, 1)
output = torch.transpose(output, 0, 1)
return output
def gaussian_kernel_matrix(x, y, alphas):
"""Computes multiscale-RBF kernel between x and y.
Parameters
----------
x: torch.Tensor
Tensor with shape [batch_size, z_dim].
y: torch.Tensor
Tensor with shape [batch_size, z_dim].
alphas: Tensor
Returns
-------
Returns the computed multiscale-RBF kernel between x and y.
"""
dist = pairwise_distance(x, y).contiguous()
dist_ = dist.view(1, -1)
alphas = alphas.view(alphas.shape[0], 1)
beta = 1. / (2. * alphas)
s = torch.matmul(beta, dist_)
return torch.sum(torch.exp(-s), 0).view_as(dist)
def mmd_loss_calc(source_features, target_features):
"""Initializes Maximum Mean Discrepancy(MMD) between source_features and target_features.
- Gretton, Arthur, et al. "A Kernel Two-Sample Test". 2012.
Parameters
----------
source_features: torch.Tensor
Tensor with shape [batch_size, z_dim]
target_features: torch.Tensor
Tensor with shape [batch_size, z_dim]
Returns
-------
Returns the computed MMD between x and y.
"""
alphas = [
1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100,
1e3, 1e4, 1e5, 1e6
]
alphas = torch.autograd.Variable(torch.FloatTensor(alphas)).to(device=source_features.device)
cost = torch.mean(gaussian_kernel_matrix(source_features, source_features, alphas))
cost += torch.mean(gaussian_kernel_matrix(target_features, target_features, alphas))
cost -= 2 * torch.mean(gaussian_kernel_matrix(source_features, target_features, alphas))
return cost
================================================
FILE: cpa/model.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from http.client import RemoteDisconnected
import json
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
class NBLoss(torch.nn.Module):
def __init__(self):
super(NBLoss, self).__init__()
def forward(self, mu, y, theta, eps=1e-8):
"""Negative binomial negative log-likelihood. It assumes targets `y` with n
rows and d columns, but estimates `yhat` with n rows and 2d columns.
The columns 0:d of `yhat` contain estimated means, the columns d:2*d of
`yhat` contain estimated variances. This module assumes that the
estimated mean and inverse dispersion are positive---for numerical
stability, it is recommended that the minimum estimated variance is
greater than a small number (1e-3).
Parameters
----------
yhat: Tensor
Torch Tensor of reeconstructed data.
y: Tensor
Torch Tensor of ground truth data.
eps: Float
numerical stability constant.
"""
if theta.ndimension() == 1:
# In this case, we reshape theta for broadcasting
theta = theta.view(1, theta.size(0))
log_theta_mu_eps = torch.log(theta + mu + eps)
res = (
theta * (torch.log(theta + eps) - log_theta_mu_eps)
+ y * (torch.log(mu + eps) - log_theta_mu_eps)
+ torch.lgamma(y + theta)
- torch.lgamma(theta)
- torch.lgamma(y + 1)
)
res = _nan2inf(res)
return -torch.mean(res)
def _nan2inf(x):
return torch.where(torch.isnan(x), torch.zeros_like(x) + np.inf, x)
class MLP(torch.nn.Module):
"""
A multilayer perceptron with ReLU activations and optional BatchNorm.
"""
def __init__(self, sizes, batch_norm=True, last_layer_act="linear"):
super(MLP, self).__init__()
layers = []
for s in range(len(sizes) - 1):
layers += [
torch.nn.Linear(sizes[s], sizes[s + 1]),
torch.nn.BatchNorm1d(sizes[s + 1])
if batch_norm and s < len(sizes) - 2
else None,
torch.nn.ReLU(),
]
layers = [l for l in layers if l is not None][:-1]
self.activation = last_layer_act
if self.activation == "linear":
pass
elif self.activation == "ReLU":
self.relu = torch.nn.ReLU()
else:
raise ValueError("last_layer_act must be one of 'linear' or 'ReLU'")
self.network = torch.nn.Sequential(*layers)
def forward(self, x):
if self.activation == "ReLU":
x = self.network(x)
dim = x.size(1) // 2
return torch.cat((self.relu(x[:, :dim]), x[:, dim:]), dim=1)
return self.network(x)
class GeneralizedSigmoid(torch.nn.Module):
"""
Sigmoid, log-sigmoid or linear functions for encoding dose-response for
drug perurbations.
"""
def __init__(self, dim, device, nonlin="sigmoid"):
"""Sigmoid modeling of continuous variable.
Params
------
nonlin : str (default: logsigm)
One of logsigm, sigm.
"""
super(GeneralizedSigmoid, self).__init__()
self.nonlin = nonlin
self.beta = torch.nn.Parameter(
torch.ones(1, dim, device=device), requires_grad=True
)
self.bias = torch.nn.Parameter(
torch.zeros(1, dim, device=device), requires_grad=True
)
def forward(self, x):
if self.nonlin == "logsigm":
c0 = self.bias.sigmoid()
return (torch.log1p(x) * self.beta + self.bias).sigmoid() - c0
elif self.nonlin == "sigm":
c0 = self.bias.sigmoid()
return (x * self.beta + self.bias).sigmoid() - c0
else:
return x
def one_drug(self, x, i):
if self.nonlin == "logsigm":
c0 = self.bias[0][i].sigmoid()
return (torch.log1p(x) * self.beta[0][i] + self.bias[0][i]).sigmoid() - c0
elif self.nonlin == "sigm":
c0 = self.bias[0][i].sigmoid()
return (x * self.beta[0][i] + self.bias[0][i]).sigmoid() - c0
else:
return x
class CPA(torch.nn.Module):
"""
Our main module, the CPA autoencoder
"""
def __init__(
self,
num_genes,
num_drugs,
num_covariates,
device="cuda",
seed=0,
patience=5,
loss_ae="gauss",
doser_type="mlp",
decoder_activation="linear",
hparams="",
):
super(CPA, self).__init__()
# set generic attributes
self.num_genes = num_genes
self.num_drugs = num_drugs
self.num_covariates = num_covariates
self.device = device
self.seed = seed
self.loss_ae = loss_ae
# early-stopping
self.patience = patience
self.best_score = -1e3
self.patience_trials = 0
# set hyperparameters
self.set_hparams_(hparams)
# set models
self.encoder = MLP(
[num_genes]
+ [self.hparams["autoencoder_width"]] * self.hparams["autoencoder_depth"]
+ [self.hparams["dim"]]
)
self.decoder = MLP(
[self.hparams["dim"]]
+ [self.hparams["autoencoder_width"]] * self.hparams["autoencoder_depth"]
+ [num_genes * 2],
last_layer_act=decoder_activation,
)
self.adversary_drugs = MLP(
[self.hparams["dim"]]
+ [self.hparams["adversary_width"]] * self.hparams["adversary_depth"]
+ [num_drugs]
)
self.loss_adversary_drugs = torch.nn.BCEWithLogitsLoss()
self.doser_type = doser_type
if doser_type == "mlp":
self.dosers = torch.nn.ModuleList()
for _ in range(num_drugs):
self.dosers.append(
MLP(
[1]
+ [self.hparams["dosers_width"]] * self.hparams["dosers_depth"]
+ [1],
batch_norm=False,
)
)
else:
self.dosers = GeneralizedSigmoid(num_drugs, self.device, nonlin=doser_type)
if self.num_covariates == [0]:
pass
else:
assert 0 not in self.num_covariates
self.adversary_covariates = []
self.loss_adversary_covariates = []
self.covariates_embeddings = (
[]
) # TODO: Continue with checking that dict assignment is possible via covaraites names and if dict are possible to use in optimisation
for num_covariate in self.num_covariates:
self.adversary_covariates.append(
MLP(
[self.hparams["dim"]]
+ [self.hparams["adversary_width"]]
* self.hparams["adversary_depth"]
+ [num_covariate]
)
)
self.loss_adversary_covariates.append(torch.nn.CrossEntropyLoss())
self.covariates_embeddings.append(
torch.nn.Embedding(num_covariate, self.hparams["dim"])
)
self.covariates_embeddings = torch.nn.Sequential(
*self.covariates_embeddings
)
self.drug_embeddings = torch.nn.Embedding(self.num_drugs, self.hparams["dim"])
# losses
if self.loss_ae == "nb":
self.loss_autoencoder = NBLoss()
elif self.loss_ae == 'gauss':
self.loss_autoencoder = nn.GaussianNLLLoss()
self.iteration = 0
self.to(self.device)
# optimizers
has_drugs = self.num_drugs > 0
has_covariates = self.num_covariates[0] > 0
get_params = lambda model, cond: list(model.parameters()) if cond else []
_parameters = (
get_params(self.encoder, True)
+ get_params(self.decoder, True)
+ get_params(self.drug_embeddings, has_drugs)
)
for emb in self.covariates_embeddings:
_parameters.extend(get_params(emb, has_covariates))
self.optimizer_autoencoder = torch.optim.Adam(
_parameters,
lr=self.hparams["autoencoder_lr"],
weight_decay=self.hparams["autoencoder_wd"],
)
_parameters = get_params(self.adversary_drugs, has_drugs)
for adv in self.adversary_covariates:
_parameters.extend(get_params(adv, has_covariates))
self.optimizer_adversaries = torch.optim.Adam(
_parameters,
lr=self.hparams["adversary_lr"],
weight_decay=self.hparams["adversary_wd"],
)
if has_drugs:
self.optimizer_dosers = torch.optim.Adam(
self.dosers.parameters(),
lr=self.hparams["dosers_lr"],
weight_decay=self.hparams["dosers_wd"],
)
# learning rate schedulers
self.scheduler_autoencoder = torch.optim.lr_scheduler.StepLR(
self.optimizer_autoencoder, step_size=self.hparams["step_size_lr"]
)
self.scheduler_adversary = torch.optim.lr_scheduler.StepLR(
self.optimizer_adversaries, step_size=self.hparams["step_size_lr"]
)
if has_drugs:
self.scheduler_dosers = torch.optim.lr_scheduler.StepLR(
self.optimizer_dosers, step_size=self.hparams["step_size_lr"]
)
self.history = {"epoch": [], "stats_epoch": []}
def set_hparams_(self, hparams):
"""
Set hyper-parameters to default values or values fixed by user for those
hyper-parameters specified in the JSON string `hparams`.
"""
self.hparams = {
"dim": 128,
"dosers_width": 128,
"dosers_depth": 2,
"dosers_lr": 4e-3,
"dosers_wd": 1e-7,
"autoencoder_width": 128,
"autoencoder_depth": 3,
"adversary_width": 64,
"adversary_depth": 2,
"reg_adversary": 60,
"penalty_adversary": 60,
"autoencoder_lr": 3e-4,
"adversary_lr": 3e-4,
"autoencoder_wd": 4e-7,
"adversary_wd": 4e-7,
"adversary_steps": 3,
"batch_size": 256,
"step_size_lr": 45,
}
# the user may fix some hparams
if hparams != "":
if isinstance(hparams, str):
self.hparams.update(json.loads(hparams))
else:
self.hparams.update(hparams)
return self.hparams
def move_inputs_(self, genes, drugs, covariates):
"""
Move minibatch tensors to CPU/GPU.
"""
if genes.device.type != self.device:
genes = genes.to(self.device)
if drugs is not None:
drugs = drugs.to(self.device)
if covariates is not None:
covariates = [cov.to(self.device) for cov in covariates]
return (genes, drugs, covariates)
def compute_drug_embeddings_(self, drugs):
"""
Compute sum of drug embeddings, each of them multiplied by its
dose-response curve.
"""
if self.doser_type == "mlp":
doses = []
for d in range(drugs.size(1)):
this_drug = drugs[:, d].view(-1, 1)
doses.append(self.dosers[d](this_drug).sigmoid() * this_drug.gt(0))
return torch.cat(doses, 1) @ self.drug_embeddings.weight
else:
return self.dosers(drugs) @ self.drug_embeddings.weight
def predict(
self,
genes,
drugs,
covariates,
return_latent_basal=False,
return_latent_treated=False,
):
"""
Predict "what would have the gene expression `genes` been, had the
cells in `genes` with cell types `cell_types` been treated with
combination of drugs `drugs`.
"""
genes, drugs, covariates = self.move_inputs_(genes, drugs, covariates)
if self.loss_ae == 'nb':
genes = torch.log1p(genes)
latent_basal = self.encoder(genes)
latent_treated = latent_basal
if self.num_drugs > 0:
latent_treated = latent_treated + self.compute_drug_embeddings_(drugs)
if self.num_covariates[0] > 0:
for i, emb in enumerate(self.covariates_embeddings):
emb = emb.to(self.device)
latent_treated = latent_treated + emb(
covariates[i].argmax(1)
) #argmax because OHE
gene_reconstructions = self.decoder(latent_treated)
if self.loss_ae == 'gauss':
# convert variance estimates to a positive value in [1e-3, \infty)
dim = gene_reconstructions.size(1) // 2
gene_means = gene_reconstructions[:, :dim]
gene_vars = F.softplus(gene_reconstructions[:, dim:]).add(1e-3)
#gene_vars = gene_reconstructions[:, dim:].exp().add(1).log().add(1e-3)
if self.loss_ae == 'nb':
gene_means = F.softplus(gene_means).add(1e-3)
#gene_reconstructions[:, :dim] = torch.clamp(gene_reconstructions[:, :dim], min=1e-4, max=1e4)
#gene_reconstructions[:, dim:] = torch.clamp(gene_reconstructions[:, dim:], min=1e-4, max=1e4)
gene_reconstructions = torch.cat([gene_means, gene_vars], dim=1)
if return_latent_basal:
if return_latent_treated:
return gene_reconstructions, latent_basal, latent_treated
else:
return gene_reconstructions, latent_basal
if return_latent_treated:
return gene_reconstructions, latent_treated
return gene_reconstructions
def early_stopping(self, score):
"""
Decays the learning rate, and possibly early-stops training.
"""
self.scheduler_autoencoder.step()
self.scheduler_adversary.step()
self.scheduler_dosers.step()
if score > self.best_score:
self.best_score = score
self.patience_trials = 0
else:
self.patience_trials += 1
return self.patience_trials > self.patience
def update(self, genes, drugs, covariates):
"""
Update CPA's parameters given a minibatch of genes, drugs, and
cell types.
"""
genes, drugs, covariates = self.move_inputs_(genes, drugs, covariates)
gene_reconstructions, latent_basal = self.predict(
genes,
drugs,
covariates,
return_latent_basal=True,
)
dim = gene_reconstructions.size(1) // 2
gene_means = gene_reconstructions[:, :dim]
gene_vars = gene_reconstructions[:, dim:]
reconstruction_loss = self.loss_autoencoder(gene_means, genes, gene_vars)
adversary_drugs_loss = torch.tensor([0.0], device=self.device)
if self.num_drugs > 0:
adversary_drugs_predictions = self.adversary_drugs(latent_basal)
adversary_drugs_loss = self.loss_adversary_drugs(
adversary_drugs_predictions, drugs.gt(0).float()
)
adversary_covariates_loss = torch.tensor(
[0.0], device=self.device
)
if self.num_covariates[0] > 0:
adversary_covariate_predictions = []
for i, adv in enumerate(self.adversary_covariates):
adv = adv.to(self.device)
adversary_covariate_predictions.append(adv(latent_basal))
adversary_covariates_loss += self.loss_adversary_covariates[i](
adversary_covariate_predictions[-1], covariates[i].argmax(1)
)
# two place-holders for when adversary is not executed
adversary_drugs_penalty = torch.tensor([0.0], device=self.device)
adversary_covariates_penalty = torch.tensor([0.0], device=self.device)
if self.iteration % self.hparams["adversary_steps"]:
def compute_gradients(output, input):
grads = torch.autograd.grad(output, input, create_graph=True)
grads = grads[0].pow(2).mean()
return grads
if self.num_drugs > 0:
adversary_drugs_penalty = compute_gradients(
adversary_drugs_predictions.sum(), latent_basal
)
if self.num_covariates[0] > 0:
adversary_covariates_penalty = torch.tensor([0.0], device=self.device)
for pred in adversary_covariate_predictions:
adversary_covariates_penalty += compute_gradients(
pred.sum(), latent_basal
) # TODO: Adding up tensor sum, is that right?
self.optimizer_adversaries.zero_grad()
(
adversary_drugs_loss
+ self.hparams["penalty_adversary"] * adversary_drugs_penalty
+ adversary_covariates_loss
+ self.hparams["penalty_adversary"] * adversary_covariates_penalty
).backward()
self.optimizer_adversaries.step()
else:
self.optimizer_autoencoder.zero_grad()
if self.num_drugs > 0:
self.optimizer_dosers.zero_grad()
(
reconstruction_loss
- self.hparams["reg_adversary"] * adversary_drugs_loss
- self.hparams["reg_adversary"] * adversary_covariates_loss
).backward()
self.optimizer_autoencoder.step()
if self.num_drugs > 0:
self.optimizer_dosers.step()
self.iteration += 1
return {
"loss_reconstruction": reconstruction_loss.item(),
"loss_adv_drugs": adversary_drugs_loss.item(),
"loss_adv_covariates": adversary_covariates_loss.item(),
"penalty_adv_drugs": adversary_drugs_penalty.item(),
"penalty_adv_covariates": adversary_covariates_penalty.item(),
}
@classmethod
def defaults(self):
"""
Returns the list of default hyper-parameters for CPA
"""
return self.set_hparams_(self, "")
================================================
FILE: cpa/plotting.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from collections import defaultdict
import matplotlib.font_manager
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from adjustText import adjust_text
from sklearn.decomposition import KernelPCA
from sklearn.metrics import r2_score
from sklearn.metrics.pairwise import cosine_similarity
FONT_SIZE = 10
font = {"size": FONT_SIZE}
matplotlib.rc("font", **font)
matplotlib.rc("ytick", labelsize=FONT_SIZE)
matplotlib.rc("xtick", labelsize=FONT_SIZE)
class CPAVisuals:
"""
A wrapper for automatic plotting CompPert latent embeddings and dose-response
curve. Sets up prefix for all files and default dictionaries for atomic
perturbations and cell types.
"""
def __init__(
self,
cpa,
fileprefix=None,
perts_palette=None,
covars_palette=None,
plot_params={"fontsize": None},
):
"""
Parameters
----------
cpa : CompPertAPI
Variable from API class.
fileprefix : str, optional (default: None)
Prefix (with path) to the filename to save all embeddings in a
standartized manner. If None, embeddings are not saved to file.
perts_palette : dict (default: None)
Dictionary of colors for the embeddings of perturbations. Keys
correspond to perturbations and values to their colors. If None,
default dicitonary will be set up.
covars_palette : dict (default: None)
Dictionary of colors for the embeddings of covariates. Keys
correspond to covariates and values to their colors. If None,
default dicitonary will be set up.
"""
self.fileprefix = fileprefix
self.perturbation_key = cpa.perturbation_key
self.dose_key = cpa.dose_key
self.covariate_keys = cpa.covariate_keys
self.measured_points = cpa.measured_points
self.unique_perts = cpa.unique_perts
self.unique_covars = cpa.unique_covars
if perts_palette is None:
self.perts_palette = dict(
zip(self.unique_perts, get_palette(len(self.unique_perts)))
)
else:
self.perts_palette = perts_palette
if covars_palette is None:
self.covars_palette = {}
for cov in self.unique_covars:
self.covars_palette[cov] = dict(
zip(
self.unique_covars[cov],
get_palette(len(self.unique_covars[cov]), palette_name="tab10"),
)
)
else:
self.covars_palette = covars_palette
if plot_params["fontsize"] is None:
self.fontsize = FONT_SIZE
else:
self.fontsize = plot_params["fontsize"]
def plot_latent_embeddings(
self,
emb,
titlename="Example",
kind="perturbations",
palette=None,
labels=None,
dimred="KernelPCA",
filename=None,
show_text=True,
):
"""
Parameters
----------
emb : np.array
Multi-dimensional embedding of perturbations or covariates.
titlename : str, optional (default: 'Example')
Title.
kind : int, optional, optional (default: 'perturbations')
Specify if this is embedding of perturbations, covariates or some
other. If it is perturbations or covariates, it will use default
saved dictionaries for colors.
palette : dict, optional (default: None)
If embedding of kind not perturbations or covariates, the user can
specify color dictionary for the embedding.
labels : list, optional (default: None)
Labels for the embeddings.
dimred : str, optional (default: 'KernelPCA')
Dimensionality reduction method for plotting low dimensional
representations. Options: 'KernelPCA', 'UMAPpre', 'UMAPcos', None.
If None, uses first 2 dimensions of the embedding.
filename : str (default: None)
Name of the file to save the plot. If None, will automatically
generate name from prefix file.
"""
if filename is None:
if self.fileprefix is None:
filename = None
file_name_similarity = None
else:
filename = f"{self.fileprefix}_emebdding.png"
file_name_similarity = f"{self.fileprefix}_emebdding_similarity.png"
else:
file_name_similarity = filename.split(".")[0] + "_similarity.png"
if labels is None:
if kind == "perturbations":
palette = self.perts_palette
labels = self.unique_perts
elif kind in self.unique_covars:
palette = self.covars_palette[kind]
labels = self.unique_covars[kind]
if len(emb) < 2:
print(f"Embedding contains only {len(emb)} vectors. Not enough to plot.")
else:
plot_embedding(
fast_dimred(emb, method=dimred),
labels,
show_lines=True,
show_text=show_text,
col_dict=palette,
title=titlename,
file_name=filename,
fontsize=self.fontsize,
)
plot_similarity(
emb,
labels,
col_dict=palette,
fontsize=self.fontsize,
file_name=file_name_similarity,
)
def plot_contvar_response2D(
self,
df_response2D,
df_ref=None,
levels=15,
figsize=(4, 4),
xlims=(0, 1.03),
ylims=(0, 1.03),
palette="coolwarm",
response_name="response",
title_name=None,
fontsize=None,
postfix="",
filename=None,
alpha=0.4,
sizes=(40, 160),
logdose=False,
):
"""
Parameters
----------
df_response2D : pd.DataFrame
Data frame with responses of combinations with columns=(dose1, dose2,
response).
levels: int, optional (default: 15)
Number of levels for contour plot.
response_name : str (default: 'response')
Name of column in df_response to plot as response.
alpha: float (default: 0.4)
Transparency of the background contour.
figsize: tuple (default: (4,4))
Size of the figure in inches.
palette : dict, optional (default: None)
Colors dictionary for perturbations to plot.
title_name : str, optional (default: None)
Title for the plot.
postfix : str, optional (defualt: '')
Postfix to add to the output file name to save the model.
filename : str, optional (defualt: None)
Name of the file to save the plot. If None, will automatically
generate name from prefix file.
logdose: bool (default: False)
If True, dose values will be log10. 0 values will be mapped to
minumum value -1,e.g.
if smallest non-zero dose was 0.001, 0 will be mapped to -4.
"""
sns.set_style("white")
if (filename is None) and not (self.fileprefix is None):
filename = f"{self.fileprefix}_{postfix}response2D.png"
if fontsize is None:
fontsize = self.fontsize
x_name, y_name = df_response2D.columns[:2]
x = df_response2D[x_name].values
y = df_response2D[y_name].values
if logdose:
x = log10_with0(x)
y = log10_with0(y)
z = df_response2D[response_name].values
n = int(np.sqrt(len(x)))
X = x.reshape(n, n)
Y = y.reshape(n, n)
Z = z.reshape(n, n)
fig, ax = plt.subplots(figsize=figsize)
CS = ax.contourf(X, Y, Z, cmap=palette, levels=levels, alpha=alpha)
CS = ax.contour(X, Y, Z, levels=15, cmap=palette)
ax.clabel(CS, inline=1, fontsize=fontsize)
ax.set(xlim=(0, 1), ylim=(0, 1))
ax.axis("equal")
ax.axis("square")
ax.yaxis.set_tick_params(labelsize=fontsize)
ax.xaxis.set_tick_params(labelsize=fontsize)
ax.set_xlabel(x_name, fontsize=fontsize, fontweight="bold")
ax.set_ylabel(y_name, fontsize=fontsize, fontweight="bold")
ax.set_xlim(xlims)
ax.set_ylim(ylims)
# sns.despine(left=False, bottom=False, right=True)
sns.despine()
if not (df_ref is None):
sns.scatterplot(
x=x_name,
y=y_name,
hue="split",
size="num_cells",
sizes=sizes,
alpha=1.0,
palette={"train": "#000000", "training": "#000000", "ood": "#e41a1c"},
data=df_ref,
ax=ax,
)
ax.legend_.remove()
ax.set_title(title_name, fontweight="bold", fontsize=fontsize)
plt.tight_layout()
if filename:
save_to_file(fig, filename)
def plot_contvar_response(
self,
df_response,
response_name="response",
var_name=None,
df_ref=None,
palette=None,
title_name=None,
postfix="",
xlabelname=None,
filename=None,
logdose=False,
fontsize=None,
measured_points=None,
bbox=(1.35, 1.0),
figsize=(7.0, 4.0),
):
"""
Parameters
----------
df_response : pd.DataFrame
Data frame of responses.
response_name : str (default: 'response')
Name of column in df_response to plot as response.
var_name : str, optional (default: None)
Name of conditioning variable, e.g. could correspond to covariates.
df_ref : pd.DataFrame, optional (default: None)
Reference values. Fields for plotting should correspond to
df_response.
palette : dict, optional (default: None)
Colors dictionary for perturbations to plot.
title_name : str, optional (default: None)
Title for the plot.
postfix : str, optional (defualt: '')
Postfix to add to the output file name to save the model.
filename : str, optional (defualt: None)
Name of the file to save the plot. If None, will automatically
generate name from prefix file.
logdose: bool (default: False)
If True, dose values will be log10. 0 values will be mapped to
minumum value -1,e.g.
if smallest non-zero dose was 0.001, 0 will be mapped to -4.
figsize: tuple (default: (7., 4.))
Size of output figure
"""
if (filename is None) and not (self.fileprefix is None):
filename = f"{self.fileprefix}_{postfix}response.png"
if fontsize is None:
fontsize = self.fontsize
if logdose:
dose_name = f"log10-{self.dose_key}"
df_response[dose_name] = log10_with0(df_response[self.dose_key].values)
if not (df_ref is None):
df_ref[dose_name] = log10_with0(df_ref[self.dose_key].values)
else:
dose_name = self.dose_key
if var_name is None:
if len(self.unique_covars) > 1:
var_name = self.covars_key
else:
var_name = self.perturbation_key
if palette is None:
if var_name == self.perturbation_key:
palette = self.perts_palette
elif var_name in self.covariate_keys:
palette = self.covars_palette[var_name]
plot_dose_response(
df_response,
dose_name,
var_name,
xlabelname=xlabelname,
df_ref=df_ref,
response_name=response_name,
title_name=title_name,
use_ref_response=(not (df_ref is None)),
col_dict=palette,
plot_vertical=False,
f1=figsize[0],
f2=figsize[1],
fname=filename,
logscale=measured_points,
measured_points=measured_points,
bbox=bbox,
fontsize=fontsize,
figformat="png",
)
def plot_scatter(
self,
df,
x_axis,
y_axis,
hue=None,
size=None,
style=None,
figsize=(4.5, 4.5),
title=None,
palette=None,
filename=None,
alpha=0.75,
sizes=(30, 90),
text_dict=None,
postfix="",
fontsize=14,
):
sns.set_style("white")
if (filename is None) and not (self.fileprefix is None):
filename = f"{self.fileprefix}_scatter{postfix}.png"
if fontsize is None:
fontsize = self.fontsize
fig = plt.figure(figsize=figsize)
ax = plt.gca()
sns.scatterplot(
x=x_axis,
y=y_axis,
hue=hue,
style=style,
size=size,
sizes=sizes,
alpha=alpha,
palette=palette,
data=df,
)
ax.legend_.remove()
ax.set_xlabel(x_axis, fontsize=fontsize)
ax.set_ylabel(y_axis, fontsize=fontsize)
ax.xaxis.set_tick_params(labelsize=fontsize)
ax.yaxis.set_tick_params(labelsize=fontsize)
ax.set_title(title)
if not (text_dict is None):
texts = []
for label in text_dict.keys():
texts.append(
ax.text(
text_dict[label][0],
text_dict[label][1],
label,
fontsize=fontsize,
)
)
adjust_text(
texts, arrowprops=dict(arrowstyle="-", color="black", lw=0.1), ax=ax
)
plt.tight_layout()
if filename:
save_to_file(fig, filename)
def log10_with0(x):
mx = np.min(x[x > 0])
x[x == 0] = mx / 10
return np.log10(x)
def get_palette(n_colors, palette_name="Set1"):
try:
palette = sns.color_palette(palette_name)
except:
print("Palette not found. Using default palette tab10")
palette = sns.color_palette()
while len(palette) < n_colors:
palette += palette
return palette
def fast_dimred(emb, method="KernelPCA"):
"""
Takes high dimensional embeddings and produces a 2-dimensional representation
for plotting.
emb: np.array
Embeddings matrix.
method: str (default: 'KernelPCA')
Method for dimensionality reduction: KernelPCA, UMAPpre, UMAPcos, tSNE.
If None return first 2 dimensions of the embedding vector.
"""
if method is None:
return emb[:, :2]
elif method == "KernelPCA":
similarity_matrix = cosine_similarity(emb)
np.fill_diagonal(similarity_matrix, 1.0)
X = KernelPCA(n_components=2, kernel="precomputed").fit_transform(
similarity_matrix
)
else:
raise NotImplementedError
return X
def plot_dose_response(
df,
contvar_key,
perturbation_key,
df_ref=None,
response_name="response",
use_ref_response=False,
palette=None,
col_dict=None,
fontsize=8,
measured_points=None,
interpolate=True,
f1=7,
f2=3.0,
bbox=(1.35, 1.0),
ref_name="origin",
title_name="None",
plot_vertical=True,
fname=None,
logscale=None,
xlabelname=None,
figformat="png",
):
"""Plotting decoding of the response with respect to dose.
Params
------
df : `DataFrame`
Table with columns=[perturbation_key, contvar_key, response_name].
The last column is always "response".
contvar_key : str
Name of the column in df for values to use for x axis.
perturbation_key : str
Name of the column in df for the perturbation or covariate to plot.
response_name: str (default: response)
Name of the column in df for values to use for y axis.
df_ref : `DataFrame` (default: None)
Table with the same columns as in df to plot ground_truth or another
condition for comparison. Could
also be used to just extract reference values for x-axis.
use_ref_response : bool (default: False)
A flag indicating if to use values for y axis from df_ref (True) or j
ust to extract reference values for x-axis.
col_dict : dictionary (default: None)
Dictionary with colors for each value in perturbation_key.
bbox : tuple (default: (1.35, 1.))
Coordinates to adjust the legend.
plot_vertical : boolean (default: False)
Flag if to plot reference values for x axis from df_ref dataframe.
f1 : float (default: 7.0))
Width in inches for the plot.
f2 : float (default: 3.0))
Hight in inches for the plot.
fname : str (default: None)
Name of the file to export the plot. The name comes without format
extension.
format : str (default: png)
Format for the file to export the plot.
"""
sns.set_style("white")
if use_ref_response and not (df_ref is None):
df[ref_name] = "predictions"
df_ref[ref_name] = "observations"
if interpolate:
df_plt = pd.concat([df, df_ref])
else:
df_plt = df
else:
df_plt = df
df_plt = df_plt.reset_index()
atomic_drugs = np.unique(df[perturbation_key].values)
if palette is None:
current_palette = get_palette(len(list(atomic_drugs)))
if col_dict is None:
col_dict = dict(zip(list(atomic_drugs), current_palette))
fig = plt.figure(figsize=(f1, f2))
ax = plt.gca()
if use_ref_response:
sns.lineplot(
x=contvar_key,
y=response_name,
palette=col_dict,
hue=perturbation_key,
style=ref_name,
dashes=[(1, 0), (2, 1)],
legend="full",
style_order=["predictions", "observations"],
data=df_plt,
ax=ax,
)
df_ref = df_ref.replace("training_treated", "train")
sns.scatterplot(
x=contvar_key,
y=response_name,
hue="split",
size="num_cells",
sizes=(10, 100),
alpha=1.0,
palette={"train": "#000000", "training": "#000000", "ood": "#e41a1c"},
data=df_ref,
ax=ax,
)
sns.despine()
ax.legend_.remove()
else:
sns.lineplot(
x=contvar_key,
y=response_name,
palette=col_dict,
hue=perturbation_key,
data=df_plt,
ax=ax,
)
ax.legend(loc="upper right", bbox_to_anchor=bbox, fontsize=fontsize)
sns.despine()
if not (title_name is None):
ax.set_title(title_name, fontsize=fontsize, fontweight="bold")
ax.grid("off")
if xlabelname is None:
ax.set_xlabel(contvar_key, fontsize=fontsize)
else:
ax.set_xlabel(xlabelname, fontsize=fontsize)
ax.set_ylabel(f"{response_name}", fontsize=fontsize)
ax.xaxis.set_tick_params(labelsize=fontsize)
ax.yaxis.set_tick_params(labelsize=fontsize)
if not (logscale is None):
ax.set_xticks(np.log10(logscale))
ax.set_xticklabels(logscale, rotation=90)
if not (df_ref is None):
atomic_drugs = np.unique(df_ref[perturbation_key].values)
for drug in atomic_drugs:
x = df_ref[df_ref[perturbation_key] == drug][contvar_key].values
m1 = np.min(df[df[perturbation_key] == drug][response_name].values)
m2 = np.max(df[df[perturbation_key] == drug][response_name].values)
if plot_vertical:
for x_dot in x:
ax.plot(
[x_dot, x_dot],
[m1, m2],
":",
color="black",
linewidth=0.5,
alpha=0.5,
)
fig.tight_layout()
if fname:
plt.savefig(f"{fname}.{figformat}", format=figformat, dpi=600)
return fig
def plot_uncertainty_comb_dose(
cpa_api,
cov,
pert,
N=11,
metric="cosine",
measured_points=None,
cond_key="condition",
figsize=(4, 4),
vmin=None,
vmax=None,
sizes=(40, 160),
df_ref=None,
xlims=(0, 1.03),
ylims=(0, 1.03),
fixed_drugs="",
fixed_doses="",
title="",
filename=None,
):
"""Plotting uncertainty for a single perturbation at a dose range for a
particular covariate.
Params
------
cpa_api
Api object for the model class.
cov : dict
Name of covariate.
pert : str
Name of the perturbation.
N : int
Number of dose values.
metric: str (default: 'cosine')
Metric to evaluate uncertainty.
measured_points : dict (default: None)
A dicitionary of dictionaries. Per each covariate a dictionary with
observed doses per perturbation, e.g. {'covar1': {'pert1':
[0.1, 0.5, 1.0], 'pert2': [0.3]}
cond_key : str (default: 'condition')
Name of the variable to use for plotting.
filename : str (default: None)
Full path to the file to export the plot. File extension should be
included.
Returns
-------
pd.DataFrame of uncertainty estimations.
"""
cov_name = "_".join([cov[cov_key] for cov_key in cpa_api.covariate_keys])
df_list = []
for i in np.round(np.linspace(0, 1, N), decimals=2):
for j in np.round(np.linspace(0, 1, N), decimals=2):
df_list.append(
{
"covariates": cov_name,
"condition": pert + fixed_drugs,
"dose_val": str(i) + "+" + str(j) + fixed_doses,
}
)
df_pred = pd.DataFrame(df_list)
uncert_cos = []
uncert_eucl = []
closest_cond_cos = []
closest_cond_eucl = []
for i in range(df_pred.shape[0]):
(
uncert_cos_,
uncert_eucl_,
closest_cond_cos_,
closest_cond_eucl_,
) = cpa_api.compute_uncertainty(
cov=cov, pert=df_pred.iloc[i]["condition"], dose=df_pred.iloc[i]["dose_val"]
)
uncert_cos.append(uncert_cos_)
uncert_eucl.append(uncert_eucl_)
closest_cond_cos.append(closest_cond_cos_)
closest_cond_eucl.append(closest_cond_eucl_)
df_pred["uncertainty_cosine"] = uncert_cos
df_pred["uncertainty_eucl"] = uncert_eucl
df_pred["closest_cond_cos"] = closest_cond_cos
df_pred["closest_cond_eucl"] = closest_cond_eucl
doses = df_pred.dose_val.apply(lambda x: x.split("+"))
X = np.array(doses.apply(lambda x: x[0]).astype(float)).reshape(N, N)
Y = np.array(doses.apply(lambda x: x[1]).astype(float)).reshape(N, N)
Z = np.array(df_pred[f"uncertainty_{metric}"].values.astype(float)).reshape(N, N)
fig, ax = plt.subplots(1, 1, figsize=figsize)
CS = ax.contourf(X, Y, Z, cmap="coolwarm", levels=20, alpha=1, vmin=vmin, vmax=vmax)
ax.set_xlabel(pert.split("+")[0], fontweight="bold")
ax.set_ylabel(pert.split("+")[1], fontweight="bold")
if not (df_ref is None):
sns.scatterplot(
x=pert.split("+")[0],
y=pert.split("+")[1],
hue="split",
size="num_cells",
sizes=sizes,
alpha=1.0,
palette={"train": "#000000", "training": "#000000", "ood": "#e41a1c"},
data=df_ref,
ax=ax,
)
ax.legend_.remove()
if measured_points:
ticks = measured_points[cov_name][pert]
xticks = [float(x.split("+")[0]) for x in ticks]
yticks = [float(x.split("+")[1]) for x in ticks]
ax.set_xticks(xticks)
ax.set_xticklabels(xticks, rotation=90)
ax.set_yticks(yticks)
fig.colorbar(CS)
sns.despine()
ax.axis("equal")
ax.axis("square")
ax.set_xlim(xlims)
ax.set_ylim(ylims)
ax.set_title(title, fontsize=10, fontweight='bold')
plt.tight_layout()
if filename:
plt.savefig(filename, dpi=600)
return df_pred
def plot_uncertainty_dose(
cpa_api,
cov,
pert,
N=11,
metric="cosine",
measured_points=None,
cond_key="condition",
log=False,
min_dose=None,
filename=None,
):
"""Plotting uncertainty for a single perturbation at a dose range for a
particular covariate.
Params
------
cpa_api
Api object for the model class.
cov : str
Name of covariate.
pert : str
Name of the perturbation.
N : int
Number of dose values.
metric: str (default: 'cosine')
Metric to evaluate uncertainty.
measured_points : dict (default: None)
A dicitionary of dictionaries. Per each covariate a dictionary with
observed doses per perturbation, e.g. {'covar1': {'pert1':
[0.1, 0.5, 1.0], 'pert2': [0.3]}
cond_key : str (default: 'condition')
Name of the variable to use for plotting.
log : boolean (default: False)
A flag if to plot on a log scale.
min_dose : float (default: None)
Minimum dose for the uncertainty estimate.
filename : str (default: None)
Full path to the file to export the plot. File extension should be included.
Returns
-------
pd.DataFrame of uncertainty estimations.
"""
df_list = []
if log:
if min_dose is None:
min_dose = 1e-3
N_val = np.round(np.logspace(np.log10(min_dose), np.log10(1), N), decimals=10)
else:
if min_dose is None:
min_dose = 0
N_val = np.round(np.linspace(min_dose, 1.0, N), decimals=3)
cov_name = "_".join([cov[cov_key] for cov_key in cpa_api.covariate_keys])
for i in N_val:
df_list.append({"covariates": cov_name, "condition": pert, "dose_val": repr(i)})
df_pred = pd.DataFrame(df_list)
uncert_cos = []
uncert_eucl = []
closest_cond_cos = []
closest_cond_eucl = []
for i in range(df_pred.shape[0]):
(
uncert_cos_,
uncert_eucl_,
closest_cond_cos_,
closest_cond_eucl_,
) = cpa_api.compute_uncertainty(
cov=cov, pert=df_pred.iloc[i]["condition"], dose=df_pred.iloc[i]["dose_val"]
)
uncert_cos.append(uncert_cos_)
uncert_eucl.append(uncert_eucl_)
closest_cond_cos.append(closest_cond_cos_)
closest_cond_eucl.append(closest_cond_eucl_)
df_pred["uncertainty_cosine"] = uncert_cos
df_pred["uncertainty_eucl"] = uncert_eucl
df_pred["closest_cond_cos"] = closest_cond_cos
df_pred["closest_cond_eucl"] = closest_cond_eucl
x = df_pred.dose_val.values.astype(float)
y = df_pred[f"uncertainty_{metric}"].values.astype(float)
fig, ax = plt.subplots(1, 1)
ax.plot(x, y)
ax.set_xlabel(pert)
ax.set_ylabel("Uncertainty")
ax.set_title(cov_name)
if log:
ax.set_xscale("log")
if measured_points:
ticks = measured_points[cov_name][pert]
ax.set_xticks(ticks)
ax.set_xticklabels(ticks, rotation=90)
else:
plt.draw()
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
sns.despine()
plt.tight_layout()
if filename:
plt.savefig(filename)
return df_pred
def save_to_file(fig, file_name, file_format=None):
if file_format is None:
if file_name.split(".")[-1] in ["png", "pdf"]:
file_format = file_name.split(".")[-1]
savename = file_name
else:
file_format = "pdf"
savename = f"{file_name}.{file_format}"
else:
savename = file_name
fig.savefig(savename, format=file_format, dpi=600)
print(f"Saved file to: {savename}")
def plot_embedding(
emb,
labels=None,
col_dict=None,
title=None,
show_lines=False,
show_text=False,
show_legend=True,
axis_equal=True,
circle_size=40,
circe_transparency=1.0,
line_transparency=0.8,
line_width=1.0,
fontsize=9,
fig_width=4,
fig_height=4,
file_name=None,
file_format=None,
labels_name=None,
width_ratios=[7, 1],
bbox=(1.3, 0.7),
):
sns.set_style("white")
# create data structure suitable for embedding
df = pd.DataFrame(emb, columns=["dim1", "dim2"])
if not (labels is None):
if labels_name is None:
labels_name = "labels"
df[labels_name] = labels
fig = plt.figure(figsize=(fig_width, fig_height))
ax = plt.gca()
sns.despine(left=False, bottom=False, right=True)
if (col_dict is None) and not (labels is None):
col_dict = get_colors(labels)
sns.scatterplot(
x="dim1",
y="dim2",
hue=labels_name,
palette=col_dict,
alpha=circe_transparency,
edgecolor="none",
s=circle_size,
data=df,
ax=ax,
)
try:
ax.legend_.remove()
except:
pass
if show_lines:
for i in range(len(emb)):
if col_dict is None:
ax.plot(
[0, emb[i, 0]],
[0, emb[i, 1]],
alpha=line_transparency,
linewidth=line_width,
c=None,
)
else:
ax.plot(
[0, emb[i, 0]],
[0, emb[i, 1]],
alpha=line_transparency,
linewidth=line_width,
c=col_dict[labels[i]],
)
if show_text and not (labels is None):
texts = []
labels = np.array(labels)
unique_labels = np.unique(labels)
for label in unique_labels:
idx_label = np.where(labels == label)[0]
texts.append(
ax.text(
np.mean(emb[idx_label, 0]),
np.mean(emb[idx_label, 1]),
label,
#fontsize=fontsize,
)
)
adjust_text(
texts, arrowprops=dict(arrowstyle="-", color="black", lw=0.1), ax=ax
)
if axis_equal:
ax.axis("equal")
ax.axis("square")
if title:
ax.set_title(title, fontweight="bold")
ax.set_xlabel("dim1"),# fontsize=fontsize)
ax.set_ylabel("dim2"),# fontsize=fontsize)
#ax.xaxis.set_tick_params(labelsize=fontsize)
#ax.yaxis.set_tick_params(labelsize=fontsize)
plt.tight_layout()
if file_name:
save_to_file(fig, file_name, file_format)
return plt
def get_colors(labels, palette=None, palette_name=None):
n_colors = len(labels)
if palette is None:
palette = get_palette(n_colors, palette_name)
col_dict = dict(zip(labels, palette[:n_colors]))
return col_dict
def plot_similarity(
emb,
labels=None,
col_dict=None,
fig_width=4,
fig_height=4,
cmap="coolwarm",
fmt="png",
fontsize=7,
file_format=None,
file_name=None,
):
# first we take construct similarity matrix
# add another similarity
similarity_matrix = cosine_similarity(emb)
df = pd.DataFrame(
similarity_matrix,
columns=labels,
index=labels,
)
if col_dict is None:
col_dict = get_colors(labels)
network_colors = pd.Series(df.columns, index=df.columns).map(col_dict)
sns_plot = sns.clustermap(
df,
cmap=cmap,
center=0,
row_colors=network_colors,
col_colors=network_colors,
mask=False,
metric="euclidean",
figsize=(fig_height, fig_width),
vmin=-1,
vmax=1,
fmt=file_format,
)
sns_plot.ax_heatmap.xaxis.set_tick_params(labelsize=fontsize)
sns_plot.ax_heatmap.yaxis.set_tick_params(labelsize=fontsize)
sns_plot.ax_heatmap.axis("equal")
sns_plot.cax.yaxis.set_tick_params(labelsize=fontsize)
if file_name:
save_to_file(sns_plot, file_name, file_format)
from scipy import sparse, stats
from sklearn.metrics import r2_score
def mean_plot(
adata,
pred,
condition_key,
exp_key,
path_to_save="./reg_mean.pdf",
gene_list=None,
deg_list=None,
show=False,
title=None,
verbose=False,
x_coeff=0.30,
y_coeff=0.8,
fontsize=11,
R2_type="R2",
figsize=(3.5, 3.5),
**kwargs,
):
"""
Plots mean matching.
# Parameters
adata: `~anndata.AnnData`
Contains real v
pred: `~anndata.AnnData`
Contains predicted values.
condition_key: Str
adata.obs key to look for x-axis and y-axis condition
exp_key: Str
Condition in adata.obs[condition_key] to be ploted
path_to_save: basestring
Path to save the plot.
gene_list: list
List of gene names to be plotted.
deg_list: list
List of DEGs to compute R2
show: boolean
if True plots the figure
Verbose: boolean
If true prints the value
title: Str
Title of the plot
x_coeff: float
Shifts R2 text horizontally by x_coeff
y_coeff: float
Shifts R2 text vertically by y_coeff
show: bool
if `True`: will show to the plot after saving it.
fontsize: int
Font size for R2 texts
R2_type: Str
How to compute R2 value, should be either Pearson R2 or R2 (sklearn)
Returns:
Calluated R2 values
"""
r2_types = ["R2", "Pearson R2"]
if R2_type not in r2_types:
raise ValueError("R2 caclulation should be one of" + str(r2_types))
if sparse.issparse(adata.X):
adata.X = adata.X.A
if sparse.issparse(pred.X):
pred.X = pred.X.A
diff_genes = deg_list
real = adata[adata.obs[condition_key] == exp_key]
pred = pred[pred.obs[condition_key] == exp_key]
if diff_genes is not None:
if hasattr(diff_genes, "tolist"):
diff_genes = diff_genes.tolist()
real_diff = adata[:, diff_genes][adata.obs[condition_key] == exp_key]
pred_diff = pred[:, diff_genes][pred.obs[condition_key] == exp_key]
x_diff = np.average(pred_diff.X, axis=0)
y_diff = np.average(real_diff.X, axis=0)
if R2_type == "R2":
r2_diff = r2_score(y_diff, x_diff)
if R2_type == "Pearson R2":
m, b, pearson_r_diff, p_value_diff, std_err_diff = stats.linregress(
y_diff, x_diff
)
r2_diff = pearson_r_diff ** 2
if verbose:
print(f"Top {len(diff_genes)} DEGs var: ", r2_diff)
x = np.average(pred.X, axis=0)
y = np.average(real.X, axis=0)
if R2_type == "R2":
r2 = r2_score(y, x)
if R2_type == "Pearson R2":
m, b, pearson_r, p_value, std_err = stats.linregress(y, x)
r2 = pearson_r ** 2
if verbose:
print("All genes var: ", r2)
df = pd.DataFrame({f"{exp_key}_true": x, f"{exp_key}_pred": y})
plt.figure(figsize=figsize)
ax = sns.regplot(x=f"{exp_key}_true", y=f"{exp_key}_pred", data=df)
ax.tick_params(labelsize=fontsize)
if "range" in kwargs:
start, stop, step = kwargs.get("range")
ax.set_xticks(np.arange(start, stop, step))
ax.set_yticks(np.arange(start, stop, step))
ax.set_xlabel("true", fontsize=fontsize)
ax.set_ylabel("pred", fontsize=fontsize)
if gene_list is not None:
for i in gene_list:
j = adata.var_names.tolist().index(i)
x_bar = x[j]
y_bar = y[j]
plt.text(x_bar, y_bar, i, fontsize=fontsize, color="black")
plt.plot(x_bar, y_bar, "o", color="red", markersize=5)
if title is None:
plt.title(f"", fontsize=fontsize, fontweight="bold")
else:
plt.title(title, fontsize=fontsize, fontweight="bold")
ax.text(
max(x) - max(x) * x_coeff,
max(y) - y_coeff * max(y),
r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r2:.2f}",
fontsize=fontsize,
)
if diff_genes is not None:
ax.text(
max(x) - max(x) * x_coeff,
max(y) - (y_coeff + 0.15) * max(y),
r"$\mathrm{R^2_{\mathrm{\mathsf{DEGs}}}}$= " + f"{r2_diff:.2f}",
fontsize=fontsize,
)
plt.savefig(f"{path_to_save}", bbox_inches="tight", dpi=100)
if show:
plt.show()
plt.close()
if diff_genes is not None:
return r2, r2_diff
else:
return r2
def plot_r2_matrix(pred, adata, de_genes=None, **kwds):
"""Plots a pairwise R2 heatmap between predicted and control conditions.
Params
------
pred : `AnnData`
Must have the field `cov_drug_dose_name`
adata : `AnnData`
Original gene expression data, with the field `cov_drug_dose_name`.
de_genes : `dict`
Dictionary of de_genes, where the keys
match the categories in `cov_drug_dose_name`
"""
r2s_mean = defaultdict(list)
r2s_var = defaultdict(list)
conditions = pred.obs["cov_drug_dose_name"].cat.categories
for cond in conditions:
if de_genes:
degs = de_genes[cond]
y_pred = pred[:, degs][pred.obs["cov_drug_dose_name"] == cond].X
y_true_adata = adata[:, degs]
else:
y_pred = pred[pred.obs["cov_drug_dose_name"] == cond].X
y_true_adata = adata
# calculate r2 between pairwise
for cond_real in conditions:
y_true = y_true_adata[
y_true_adata.obs["cov_drug_dose_name"] == cond_real
].X.toarray()
r2s_mean[cond_real].append(
r2_score(y_true.mean(axis=0), y_pred.mean(axis=0))
)
r2s_var[cond_real].append(r2_score(y_true.var(axis=0), y_pred.var(axis=0)))
for r2_dict in [r2s_mean, r2s_var]:
r2_df = pd.DataFrame.from_dict(r2_dict, orient="index")
r2_df.columns = conditions
plt.figure(figsize=(5, 5))
p = sns.heatmap(
data=r2_df,
vmin=max(r2_df.min(0).min(), 0),
cmap="Blues",
cbar=False,
annot=True,
fmt=".2f",
annot_kws={"fontsize": 5},
**kwds,
)
plt.xticks(fontsize=6)
plt.yticks(fontsize=6)
plt.xlabel("y_true")
plt.ylabel("y_pred")
plt.show()
def arrange_history(history):
print(history.keys())
class CPAHistory:
"""
A wrapper for automatic plotting history of CPA model..
"""
def __init__(self, cpa_api, fileprefix=None):
"""
Parameters
----------
cpa_api : dict
cpa api instance
fileprefix : str, optional (default: None)
Prefix (with path) to the filename to save all embeddings in a
standartized manner. If None, embeddings are not saved to file.
"""
self.history = cpa_api.history
self.time = self.history["elapsed_time_min"]
self.losses_list = [
"loss_reconstruction",
"loss_adv_drugs",
"loss_adv_covariates",
]
self.penalties_list = ["penalty_adv_drugs", "penalty_adv_covariates"]
subset_keys = ["epoch"] + self.losses_list + self.penalties_list
self.losses = pd.DataFrame(
dict((k, self.history[k]) for k in subset_keys if k in self.history)
)
self.header = ["mean", "mean_DE", "var", "var_DE"]
self.eval_metrics = False
if 'perturbation disentanglement' in list (self.history): #check that metrics were evaluated
self.eval_metrics = True
self.metrics = pd.DataFrame(columns=["epoch", "split"] + self.header)
for split in ["training", "test", "ood"]:
df_split = pd.DataFrame(np.array(self.history[split]), columns=self.header)
df_split["split"] = split
df_split["epoch"] = self.history["stats_epoch"]
self.metrics = pd.concat([self.metrics, df_split])
self.covariate_names = list(cpa_api.datasets["training"].covariate_names)
self.disent = pd.DataFrame(
dict(
(k, self.history[k])
for k in
['perturbation disentanglement']
+ [f'{cov} disentanglement' for cov in self.covariate_names]
if k in self.history
)
)
self.disent["epoch"] = self.history["stats_epoch"]
self.fileprefix = fileprefix
def print_time(self):
print(f"Computation time: {self.time:.0f} min")
def plot_losses(self, filename=None):
"""
Parameters
----------
filename : str (default: None)
Name of the file to save the plot. If None, will automatically
generate name from prefix file.
"""
if filename is None:
if self.fileprefix is None:
filename = None
else:
filename = f"{self.fileprefix}_history_losses.png"
fig, ax = plt.subplots(1, 4, sharex=True, sharey=False, figsize=(12, 3))
i = 0
for i in range(4):
if i < 3:
ax[i].plot(
self.losses["epoch"].values, self.losses[self.losses_list[i]].values
)
ax[i].set_title(self.losses_list[i], fontweight="bold")
else:
ax[i].plot(
self.losses["epoch"].values, self.losses[self.penalties_list].values
)
ax[i].set_title("Penalties", fontweight="bold")
sns.despine()
plt.tight_layout()
if filename:
save_to_file(fig, filename)
def plot_r2_metrics(self, epoch_min=0, filename=None):
"""
Parameters
----------
epoch_min : int (default: 0)
Epoch from which to show metrics history plot. Done for readability.
filename : str (default: None)
Name of the file to save the plot. If None, will automatically
generate name from prefix file.
"""
assert self.eval_metrics == True, 'The evaluation metrics were not computed'
if filename is None:
if self.fileprefix is None:
filename = None
else:
filename = f"{self.fileprefix}_history_metrics.png"
df = self.metrics.melt(id_vars=["epoch", "split"])
col_dict = dict(
zip(["training", "test", "ood"], ["#377eb8", "#4daf4a", "#e41a1c"])
)
fig, axs = plt.subplots(2, 2, sharex=True, sharey=False, figsize=(7, 5))
ax = plt.gca()
i = 0
for i1 in range(2):
for i2 in range(2):
sns.lineplot(
data=df[
(df["variable"] == self.header[i]) & (df["epoch"] > epoch_min)
],
x="epoch",
y="value",
palette=col_dict,
hue="split",
ax=axs[i1, i2],
)
axs[i1, i2].set_title(self.header[i], fontweight="bold")
i += 1
sns.despine()
plt.tight_layout()
if filename:
save_to_file(fig, filename)
def plot_disentanglement_metrics(self, epoch_min=0, filename=None):
"""
Parameters
----------
epoch_min : int (default: 0)
Epoch from which to show metrics history plot. Done for readability.
filename : str (default: None)
Name of the file to save the plot. If None, will automatically
generate name from prefix file.
"""
assert self.eval_metrics == True, 'The evaluation metrics were not computed'
if filename is None:
if self.fileprefix is None:
filename = None
else:
filename = f"{self.fileprefix}_history_metrics.png"
fig, axs = plt.subplots(
1,
1+len(self.covariate_names),
sharex=True,
sharey=False,
figsize=(2 + 5*(len(self.covariate_names)), 3)
)
ax = plt.gca()
sns.lineplot(
data=self.disent[self.disent["epoch"] > epoch_min],
x="epoch",
y="perturbation disentanglement",
legend=False,
ax=axs[0],
)
axs[0].set_title("perturbation disent", fontweight="bold")
for i, cov in enumerate(self.covariate_names):
sns.lineplot(
data=self.disent[self.disent['epoch'] > epoch_min],
x="epoch",
y=f"{cov} disentanglement",
legend=False,
ax=axs[1+i]
)
axs[1+i].set_title(f"{cov} disent", fontweight="bold")
fig.tight_layout()
sns.despine()
================================================
FILE: cpa/train.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import argparse
import json
import os
import time
from collections import defaultdict
import numpy as np
import torch
from cpa.data import load_dataset_splits
from cpa.model import CPA, MLP
from sklearn.metrics import r2_score
from torch.autograd import Variable
from torch.distributions import NegativeBinomial
from torch import nn
def pjson(s):
"""
Prints a string in JSON format and flushes stdout
"""
print(json.dumps(s), flush=True)
def _convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6):
r"""NB parameterizations conversion
Parameters
----------
mu :
mean of the NB distribution.
theta :
inverse overdispersion.
eps :
constant used for numerical log stability. (Default value = 1e-6)
Returns
-------
type
the number of failures until the experiment is stopped
and the success probability.
"""
assert (mu is None) == (
theta is None
), "If using the mu/theta NB parameterization, both parameters must be specified"
logits = (mu + eps).log() - (theta + eps).log()
total_count = theta
return total_count, logits
def evaluate_disentanglement(autoencoder, dataset):
"""
Given a CPA model, this function measures the correlation between
its latent space and 1) a dataset's drug vectors 2) a datasets covariate
vectors.
"""
with torch.no_grad():
_, latent_basal = autoencoder.predict(
dataset.genes,
dataset.drugs,
dataset.covariates,
return_latent_basal=True,
)
mean = latent_basal.mean(dim=0, keepdim=True)
stddev = latent_basal.std(0, unbiased=False, keepdim=True)
normalized_basal = (latent_basal - mean) / stddev
criterion = nn.CrossEntropyLoss()
pert_scores, cov_scores = 0, []
def compute_score(labels):
if len(np.unique(labels)) > 1:
unique_labels = set(labels)
label_to_idx = {labels: idx for idx, labels in enumerate(unique_labels)}
labels_tensor = torch.tensor(
[label_to_idx[label] for label in labels], dtype=torch.long, device=autoencoder.device
)
assert normalized_basal.size(0) == len(labels_tensor)
#might have to perform a train/test split here
dataset = torch.utils.data.TensorDataset(normalized_basal, labels_tensor)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)
# 2 non-linear layers of size <input_dimension>
# followed by a linear layer.
disentanglement_classifier = MLP(
[normalized_basal.size(1)]
+ [normalized_basal.size(1) for _ in range(2)]
+ [len(unique_labels)]
).to(autoencoder.device)
optimizer = torch.optim.Adam(disentanglement_classifier.parameters(), lr=1e-2)
for epoch in range(50):
for X, y in data_loader:
pred = disentanglement_classifier(X)
loss = Variable(criterion(pred, y), requires_grad=True)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
pred = disentanglement_classifier(normalized_basal).argmax(dim=1)
acc = torch.sum(pred == labels_tensor) / len(labels_tensor)
return acc.item()
else:
return 0
if dataset.perturbation_key is not None:
pert_scores = compute_score(dataset.drugs_names)
for cov in list(dataset.covariate_names):
cov_scores = []
if len(np.unique(dataset.covariate_names[cov])) == 0:
cov_scores = [0]
break
else:
cov_scores.append(compute_score(dataset.covariate_names[cov]))
return [np.mean(pert_scores), *[np.mean(cov_score) for cov_score in cov_scores]]
def evaluate_r2(autoencoder, dataset, genes_control):
"""
Measures different quality metrics about an CPA `autoencoder`, when
tasked to translate some `genes_control` into each of the drug/covariates
combinations described in `dataset`.
Considered metrics are R2 score about means and variances for all genes, as
well as R2 score about means and variances about differentially expressed
(_de) genes.
"""
mean_score, var_score, mean_score_de, var_score_de = [], [], [], []
num, dim = genes_control.size(0), genes_control.size(1)
total_cells = len(dataset)
for pert_category in np.unique(dataset.pert_categories):
# pert_category category contains: 'celltype_perturbation_dose' info
de_idx = np.where(
dataset.var_names.isin(np.array(dataset.de_genes[pert_category]))
)[0]
idx = np.where(dataset.pert_categories == pert_category)[0]
if len(idx) > 30:
emb_drugs = dataset.drugs[idx][0].view(1, -1).repeat(num, 1).clone()
emb_covars = [
covar[idx][0].view(1, -1).repeat(num, 1).clone()
for covar in dataset.covariates
]
genes_predict = (
autoencoder.predict(genes_control, emb_drugs, emb_covars).detach().cpu()
)
mean_predict = genes_predict[:, :dim]
var_predict = genes_predict[:, dim:]
if autoencoder.loss_ae == 'nb':
counts, logits = _convert_mean_disp_to_counts_logits(
torch.clamp(
torch.Tensor(mean_predict),
min=1e-4,
max=1e4,
),
torch.clamp(
torch.Tensor(var_predict),
min=1e-4,
max=1e4,
)
)
dist = NegativeBinomial(
total_count=counts,
logits=logits
)
nb_sample = dist.sample().cpu().numpy()
yp_m = nb_sample.mean(0)
yp_v = nb_sample.var(0)
else:
# predicted means and variances
yp_m = mean_predict.mean(0)
yp_v = var_predict.mean(0)
# estimate metrics only for reasonably-sized drug/cell-type combos
y_true = dataset.genes[idx, :].numpy()
# true means and variances
yt_m = y_true.mean(axis=0)
yt_v = y_true.var(axis=0)
mean_score.append(r2_score(yt_m, yp_m))
var_score.append(r2_score(yt_v, yp_v))
mean_score_de.append(r2_score(yt_m[de_idx], yp_m[de_idx]))
var_score_de.append(r2_score(yt_v[de_idx], yp_v[de_idx]))
return [
np.mean(s) if len(s) else -1
for s in [mean_score, mean_score_de, var_score, var_score_de]
]
def evaluate(autoencoder, datasets):
"""
Measure quality metrics using `evaluate()` on the training, test, and
out-of-distribution (ood) splits.
"""
autoencoder.eval()
with torch.no_grad():
stats_test = evaluate_r2(
autoencoder,
datasets["test"].subset_condition(control=False),
datasets["test"].subset_condition(control=True).genes
)
disent_scores = evaluate_disentanglement(autoencoder, datasets["test"])
stats_disent_pert = disent_scores[0]
stats_disent_cov = disent_scores[1:]
evaluation_stats = {
"training": evaluate_r2(
autoencoder,
datasets["training"].subset_condition(control=False),
datasets["training"].subset_condition(control=True).genes,
),
"test": stats_test,
"ood": evaluate_r2(
autoencoder, datasets["ood"], datasets["test"].subset_condition(control=True).genes
),
"perturbation disentanglement": stats_disent_pert,
"optimal for perturbations": 1 / datasets["test"].num_drugs
if datasets["test"].num_drugs > 0
else None,
}
if len(stats_disent_cov) > 0:
for i in range(len(stats_disent_cov)):
evaluation_stats[
f"{list(datasets['test'].covariate_names)[i]} disentanglement"
] = stats_disent_cov[i]
evaluation_stats[
f"optimal for {list(datasets['test'].covariate_names)[i]}"
] = 1 / datasets["test"].num_covariates[i]
autoencoder.train()
return evaluation_stats
def prepare_cpa(args, state_dict=None):
"""
Instantiates autoencoder and dataset to run an experiment.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
datasets = load_dataset_splits(
args["data"],
args["perturbation_key"],
args["dose_key"],
args["covariate_keys"],
args["split_key"],
args["control"],
)
autoencoder = CPA(
datasets["training"].num_genes,
datasets["training"].num_drugs,
datasets["training"].num_covariates,
device=device,
seed=args["seed"],
loss_ae=args["loss_ae"],
doser_type=args["doser_type"],
patience=args["patience"],
hparams=args["hparams"],
decoder_activation=args["decoder_activation"],
)
if state_dict is not None:
autoencoder.load_state_dict(state_dict)
return autoencoder, datasets
def train_cpa(args, return_model=False):
"""
Trains a CPA autoencoder
"""
autoencoder, datasets = prepare_cpa(args)
datasets.update(
{
"loader_tr": torch.utils.data.DataLoader(
datasets["training"],
batch_size=autoencoder.hparams["batch_size"],
shuffle=True,
)
}
)
pjson({"training_args": args})
pjson({"autoencoder_params": autoencoder.hparams})
args["hparams"] = autoencoder.hparams
start_time = time.time()
for epoch in range(args["max_epochs"]):
epoch_training_stats = defaultdict(float)
for data in datasets["loader_tr"]:
genes, drugs, covariates = data[0], data[1], data[2:]
minibatch_training_stats = autoencoder.update(genes, drugs, covariates)
for key, val in minibatch_training_stats.items():
epoch_training_stats[key] += val
for key, val in epoch_training_stats.items():
epoch_training_stats[key] = val / len(datasets["loader_tr"])
if not (key in autoencoder.history.keys()):
autoencoder.history[key] = []
autoencoder.history[key].append(epoch_training_stats[key])
autoencoder.history["epoch"].append(epoch)
ellapsed_minutes = (time.time() - start_time) / 60
autoencoder.history["elapsed_time_min"] = ellapsed_minutes
# decay learning rate if necessary
# also check stopping condition: patience ran out OR
# time ran out OR max epochs achieved
stop = ellapsed_minutes > args["max_minutes"] or (
epoch == args["max_epochs"] - 1
)
if (epoch % args["checkpoint_freq"]) == 0 or stop:
evaluation_stats = evaluate(autoencoder, datasets)
for key, val in evaluation_stats.items():
if not (key in autoencoder.history.keys()):
autoencoder.history[key] = []
autoencoder.history[key].append(val)
autoencoder.history["stats_epoch"].append(epoch)
pjson(
{
"epoch": epoch,
"training_stats": epoch_training_stats,
"evaluation_stats": evaluation_stats,
"ellapsed_minutes": ellapsed_minutes,
}
)
torch.save(
(autoencoder.state_dict(), args, autoencoder.history),
os.path.join(
args["save_dir"],
"model_seed={}_epoch={}.pt".format(args["seed"], epoch),
),
)
pjson(
{
"model_saved": "model_seed={}_epoch={}.pt\n".format(
args["seed"], epoch
)
}
)
stop = stop or autoencoder.early_stopping(np.mean(evaluation_stats["test"]))
if stop:
pjson({"early_stop": epoch})
break
if return_model:
return autoencoder, datasets
def parse_arguments():
"""
Read arguments if this script is called from a terminal.
"""
parser = argparse.ArgumentParser(description="Drug combinations.")
# dataset arguments
parser.add_argument("--data", type=str, required=True)
parser.add_argument("--perturbation_key", type=str, default="condition")
parser.add_argument("--control", type=str, default=None)
parser.add_argument("--dose_key", type=str, default="dose_val")
parser.add_argument("--covariate_keys", nargs="*", type=str, default="cell_type")
parser.add_argument("--split_key", type=str, default="split")
parser.add_argument("--loss_ae", type=str, default="gauss")
parser.add_argument("--doser_type", type=str, default="sigm")
parser.add_argument("--decoder_activation", type=str, default="linear")
# CPA arguments (see set_hparams_() in cpa.model.CPA)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hparams", type=str, default="")
# training arguments
parser.add_argument("--max_epochs", type=int, default=2000)
parser.add_argument("--max_minutes", type=int, default=300)
parser.add_argument("--patience", type=int, default=20)
parser.add_argument("--checkpoint_freq", type=int, default=20)
# output folder
parser.add_argument("--save_dir", type=str, required=True)
# number of trials when executing cpa.sweep
parser.add_argument("--sweep_seeds", type=int, default=200)
return dict(vars(parser.parse_args()))
if __name__ == "__main__":
train_cpa(parse_arguments())
================================================
FILE: datasets/.gitkeep
================================================
================================================
FILE: notebooks/demo.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# A tour of the CPA model"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# some standard packages to assist this tutorial\n",
"import cpa\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import scanpy as sc\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training your model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Init your model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"adata = sc.read('../../cpa-reproducibility/datasets/GSM_new.h5ad')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING. Special characters ('_') were found in: 'dummy_cov'. They will be replaced with '-'. Be careful, it may lead to errors downstream.\n",
"Creating 'cov_drug_dose_name' field.\n",
"Ranking genes for DE genes.\n",
"WARNING: Default of the method has been changed to 't-test' from 't-test_overestim_var'\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Trying to set attribute `.obs` of view, copying.\n",
"... storing 'cov_drug_dose_name' as categorical\n",
"Trying to set attribute `.obs` of view, copying.\n",
"... storing 'dummy_cov' as categorical\n",
"Trying to set attribute `.obs` of view, copying.\n",
"... storing 'covars_comb' as categorical\n"
]
}
],
"source": [
"cpa_api = cpa.api.API(\n",
" adata, \n",
" perturbation_key='condition',\n",
" doser_type='logsigm',\n",
" split_key='split',\n",
" covariate_keys=[],\n",
" only_parameters=False,\n",
" hparams={}, \n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also load a pretrained model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded pretrained model from:\t../../cpa-reproducibility/notebooks/sciplex2_other.pt\n"
]
}
],
"source": [
"cpa_api = cpa.api.API(\n",
" adata, \n",
" pretrained='../../cpa-reproducibility/notebooks/sciplex2_other.pt'\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can print parameters of the model:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"CPA(\n",
" (encoder): MLP(\n",
" (network): Sequential(\n",
" (0): Linear(in_features=4999, out_features=128, bias=True)\n",
" (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ReLU()\n",
" (3): Linear(in_features=128, out_features=128, bias=True)\n",
" (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (5): ReLU()\n",
" (6): Linear(in_features=128, out_features=128, bias=True)\n",
" )\n",
" )\n",
" (decoder): MLP(\n",
" (network): Sequential(\n",
" (0): Linear(in_features=128, out_features=128, bias=True)\n",
" (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ReLU()\n",
" (3): Linear(in_features=128, out_features=128, bias=True)\n",
" (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (5): ReLU()\n",
" (6): Linear(in_features=128, out_features=9998, bias=True)\n",
" )\n",
" )\n",
" (adversary_drugs): MLP(\n",
" (network): Sequential(\n",
" (0): Linear(in_features=128, out_features=64, bias=True)\n",
" (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ReLU()\n",
" (3): Linear(in_features=64, out_features=64, bias=True)\n",
" (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (5): ReLU()\n",
" (6): Linear(in_features=64, out_features=5, bias=True)\n",
" )\n",
" )\n",
" (loss_adversary_drugs): BCEWithLogitsLoss()\n",
" (dosers): GeneralizedSigmoid()\n",
" (covariates_embeddings): Sequential(\n",
" (0): Embedding(1, 128)\n",
" )\n",
" (drug_embeddings): Embedding(5, 128)\n",
" (loss_autoencoder): GaussianNLLLoss()\n",
")"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cpa_api.model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Start training"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Results will be saved to the folder: /tmp/\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Rec: -1.6080, AdvPert: 0.49, AdvCov: 0.00: 100%|▉| 999/1000 [12:45<00:00, 1.30i"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model saved to: /tmp/None\n",
"{'ellapsed_minutes': 12.652799173196156,\n",
" 'epoch': 999,\n",
" 'evaluation_stats': {'cell_type disentanglement': 0.0,\n",
" 'ood': [0.9116713597573521,\n",
" 0.8682491544314631,\n",
" 0.6297205495339959,\n",
" -0.2600031534834341],\n",
" 'optimal for cell_type': 1.0,\n",
" 'optimal for perturbations': 0.2,\n",
" 'perturbation disentanglement': 0.2595809996128082,\n",
" 'test': [0.9242513761083746,\n",
" 0.7737687104344115,\n",
" 0.8351889804394074,\n",
" 0.0565223552642065],\n",
" 'training': [0.9295517385365168,\n",
" 0.7780530537975899,\n",
" 0.8627033078033616,\n",
" 0.12276226755805468]},\n",
" 'training_stats': defaultdict(<class 'float'>,\n",
" {'loss_adv_covariates': 0.0,\n",
" 'loss_adv_drugs': 0.49258487340476775,\n",
" 'loss_reconstruction': -1.6079967220624287,\n",
" 'penalty_adv_covariates': 0.0,\n",
" 'penalty_adv_drugs': 2.152925470492543e-07})}\n",
"Stop epoch: 999\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model saved to: /tmp/None\n"
]
}
],
"source": [
"cpa_api.train(\n",
" max_epochs=1000, \n",
" run_eval=True, \n",
" checkpoint_freq=20,\n",
" filename=None, \n",
" max_minutes=2*60\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Visualize training history"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Computation time: 19 min\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1gAAADQCAYAAAAalMCAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAABD/UlEQVR4nO3deZxcVZ338c+3qruzQwgkEEIwgEFFZDMCiggIaMAl4uMCLoDjPJERVMbREYdnHB0fZ3h03BAFERlBRWQUJEpkEUFkRoSwJ4QlxAghgSQsISFLp7t/zx/3VKdSqe6u7qruqq7+vl+velXde8+599zqOn3vuWdTRGBmZmZmZmbVy9U7AWZmZmZmZs3CBSwzMzMzM7MacQHLzMzMzMysRlzAMjMzMzMzqxEXsMzMzMzMzGrEBSwzMzMzM7MaGVEFLEmRXjPqnRarnKSj099tWb3TYrUzHPOjpGUpze+q8X5PT/u9r5b7teHH+aKxFOXNW+udFrNaqPSeStIXU7gfDU3KmsuIKmDZwAzlBV/SrelYpxetXg58G7h0sI9vZmZW5CGy688vKo0wHAvJNnSKHkgUXmsk3SBpVh3TVO43ewfZb//G+qRqeGupdwJGOkmtEbGl3umohcE6l4hYApxd6/2aDSfN9L/CbDhIee5O4M56p8Wa0m+AvwBHAW8BXifplRGxqr7JykTE9cD19U7HcDWia7AkTZZ0iaQnJL0o6Q5Js4u2Hy/pbkkvSVor6R5J707bDpH0xxRvvaSFkv6uj+MVmhrcLulCSeuAc9O2d0q6M+3vr5K+LmlsUdzDJd0oaXU63h2F7ZIOkHR9egqyWtKvJb2iKG7hack5ku5N5zNf0k5p+06S/ivF3yTpL5K+n7ZF0Sn8Je3n6KKq419IukrSRuCDkn6U1n8xxZ9ReDJSlJ49JF2WznOTpMWSXpeaYByVgv1nYT/lqrOrPWdrPHXIjwekYzwvaYuklZIukNRWFObjkp5Mv7HPlsQ/Jv3GHihad1Rat7CPY++e8vNLkv4I7FWyvTvfSDpD0grgxh7ywja1vpJ2lPTz9F08IOnTafsLabsk/Vs6r82Snlb29HTn3tJs9THC8sVYSV+S9LCkjZKWS/rfaVurpM+nbS8pu278vaScpPHp/DokTSnaV2Hdbul7ujd9R1uUXX++VHTsstdnlTQRlDRV0m3Krj1b0nfwE0kT0/ay18y07W8k3Z/S9Zikf5LUkrbNUHZNez6d+yPF6bOm9MOI+CRwbFreCXi9pP0lXSdpVfp9/VLSnoVI2nptOEvSo5LWpd9gW9reZx4u1tNvVmWaCKqXe1X1ci85Eo3YApakHDAP+CiwBrgWeC1wnaQjUrD/BA4EfpleXcD+adv5wBvJqk5/Bjyf4lfiCODNwBXAUklvTcffK72vAT4NfDel9dXArcDxwGLg58AuQJukqcAfgLeSVefeC7wduFXbFya+ADwAbAJOSMcA+AfgPcBj6ZwXA29I275dFP8/0/LyonX/C9gH+DHwdF8nnjLi74FTUzp+TPbd7U7WBOOpFPSmdKw7yuyjFudsDaRO+XEy0J72dSnQCZxJ+o2km6Lvkv02bwQ+BEwvin8r8ATwGkmvSuvel94v7+PYV5Dl5yfInmB+rpewXwF+C/xPH/ssOD+l40XgbuCLJduPBT5Pdr4/BG4DXgNMqHD/NkRGYL74Adn/7CkpvfcA+6ZtXwH+jex3eiXZNfAbwOciYj1wDZAnuyYBvA0YB9wYEU8D08i+wyvJrjsTgC9IOrkkDdtcn8ukcQIwBvh1Su/zwAeB89L2stdMSR8jy287kV3rOtM5nZvC/l+ya9pdZN/Tk8BhvXxX1gRSHj+qaNUasv/JxwO3A38G3g3cIGlUSfQvkV0XWsh+gx9O63vNw2X0dZ9XSGuv96r0fi858kTEiHkBkV4zgEPT53XAuLT9m2ndFWn5GWA92Q/mFWQF0nza9ucU9m/ILmathW29HP/0FOdFYGLR+uvS+huBbwHfS8tdwFjggrR8bVGcfErPP6ZttxRtuzetm5uWl6Xlz6blL6Xl36Tl/5eWv5W+l/HF51L8vRWt+2Ja9zjQUrT+R2n9F9PyjEL8tPyetLwCGFsUrzW935q2n1607ei0bllarvqc/ar/q975McV7E1lh4xvAzYV8mLZdkpZ/mJYnkV20AnhXWveVwu89pedpsovZ7r0cc4+ic5+e1n09Ld+XlmcUhXlzT3khrevOM2T/Fzan5aPS9r9Pyy+k5RPS8u/S/qYAAnL1/k34NaLzxS5F535w0frW9PtcX/K7npOWV6Tl49PyrWn5v9LyyWk5B5wI/J/0Hd6Vtl+ctp9O+evz6cX7TesOJrsO/QdZwS6AR8v9DYvWLUrrfkF2rf1JWn46bf95Wv582v/oSv5Wfg2/F1vvT0pf88getgVZ379vpdeqtG52ye/rvWn5srR8QdExesvDR7P9daS3+7wfpeW+7lV7vZccaa+R3AdrRnp/MiJeSp8fTu8vS+8fA75G9o8a4FngLLInYJ8m+3FdwtZ//l8g+8fdl0UR8UKZtByfXgUC9mZr86Hu2pyI6ISsWUFatbgo3sPAQUXnUXBvei8ce3x6/xbZE9CPA58iuxD+XNKHI6Krj3O5MyI6etmeL1kunMuDEbGhsDL617dkRnqv5pytscxI70OWHyV9nuyJeKnJ6X1aen8EICKek/QssFtR2MuAfwLeT1aruivZRWxFz6favd+NEfFk+vxoL+H/u5dtsG0e2wUoNAMp5I+HSsLfSPZdfRi4Ja27i+yGdWUfx7KhNSO9j4R8Ubg2tEdE4f82EbElNfsbl1YVfteF72Fqavp0M9lT9yMlvZysMLUW+FUKdyEwt5fzKii9Pm9D0ilktVt97afUjPT+v0rW7yppPNnN7B7Al8m+/83Ad4DPYs3qN8ASsjx7N1l/pwvStlelV7GXlyyXvb+pIA8P1Iz03tO96rcY+L1k0xmxTQTJniAATNfWvk6FPjx/Te+/jYiZZDct7wF2JnsyB7AgIg4kq+4/muwp23mF9tR92NxDWj4ZESq8gH0iYiFZEyIoai6Q2p2rKO4ri/ZXeh4FhYJQlKx/LiJmkzV9OJDsSdsHyJpKQPZ0Asr/XkrPpXATsEN6379ke+FcXiNpTGFl0ffW2cuxCpal92rO2RrLsvQ+lPnx/en9C2RNLArN9JTeC81VXwEgaVI6ZreIeJSsluCVZLWk0HczqMJ+x0gqNK3at6fAEVGcxwr5a0JKU2tJ3DVktQkAM9N7cT6BrEB2FjCR7IJ9OfA64G/7SLcNvWXpfSTki8K1oU3SQYWVKa2r2frbL/yeC9/DyohoTzdwPyW7dlxK9kT9vyJiU8l5nU6WBy4sOa+C0mtaqcJ+LgFGFS0X76fcNXNZen9nyXV+78iaOC6NiCOAHcme/j8HfKbof4Q1nx9GxN9HxP+NiN9GVk20LG27uuR3MpWsiWmxnu5v+srD5fR2n1dQSFtP96p93UuOKCO5BmsB2QXgMOCPkhYBp5D9UL+XwtyrrDP5E2xtY/5Cev+1pDxZE7kdyf7RPsvWAkJ/XED2tO2rkt4AbAQOILto7QVcRHbzMyd1tH0UOBJ4PVkzg38CjpE0j+zp9cFkzUYqHVb2HEnvBB4kuzmbkdavTe9Pkj0tvUDSo2xtM15O4YnKaZI6yNoFF5tP1j53Jtn3+weyC+Y3yNr0Fp7of0rSAWTteEvV4pytsdQjPz6T3j9E9vTtXSXbryDr+3K6pNFkDwvK/c+8PKX7SLKmXNf0dqIRsVzSbWRNOG6UdBdbL4h9eRTYAEySdDlZrcGUon13SvoZcBrwM0m/Y/sn5m8ga8r7J7KbuMLF74UK02BDZyTlizWSriC7IbtZ0q/ICoaPRcTnJF0IfAa4QtL1wDtT1AuKdnMZ2c3kkUVpKD6vHYFPko3YdlJv6elF4fs5gayQdmKZMOWumReQ/c1+IukashvZWWTNv44GvqdsoKaHyb7PXcj+TusHmE4bnn5Kdn/zbkk3kBVq9iHrpzWTrYWc3vSVh8up5D6vr3vVvu4lR5Z6t1EcyhclbUzJbkwuJfthrSMbivXtReHPJ7swbST7gdxCahtOlgEWkz1VW5/iHtvH8U+npC130bZ3kTUBXEt2cbwTOLto++FkAz+sSce8g9SHiaxwcQPZhXMNWbXzq4riLkvHPTotn12cDuAdZBfxF8gGg3gU+ERR/Pen76grxduFkra5RWFHkf2DeJGsadJnCt97UZjpZBe+J9LxFgOvS9teA9wPbEnx3kP59sJVnbNf9X81QH7cj+wGdhNZp+J/pqgfVArzCbJmR8+RPQ0s/K7eVRRmElv7Pf1nhee+B1l+3kCWl/+9+NiU9F0sifthslqEVWQ3eH+iqN8i2U3kVek7fCB9NwE8k7bPTMdeRXYRXJH2M6revwm/Rny+GAv8K1nzw03pGP87bWsj6z9VeMjwMNn1JV+yj0LfqqWAitYfmb6LjWR9Xb6Zwv0qbT+dMteI0vVkg3v8Pu3nvvRdBKmPYwpT7popsoLpfelvURjM4EMpTmHbunR+DwLvr/fv0a/av8rll5LtB5ANovJMyruLyZrfjU/bS/9HfItt+0r1mocpf09V0X0evdyr0se95Eh7KX0pZmbWBCRNANZH+ude1B7/9og4stfIZmZmVrWR3ERwUEg6lKyJQ6k7I6Jcx1gzGyT1zI+SzmL7TsmQjfS0ZBAPfSzwfyT9lqzpxkfS+vMH8Zg2jIzQfGFmNmRcwKq9/chGTyl1GeVHHjKzwVPP/Pgetp3fpOBXZCNHDZYnyDrx/wNZE8D7ga9HxH/1GstGkpGYL8zMhoybCJqZmZmZmdXISB6m3ayhSZot6RFJSySdU2b7jpJ+Lel+SYskfaTSuGZmZmY2OIZlAWv27NmFEVT88muoX0MiDa38XbKhgPcDTpG0X0mwM4GHIpvn5mjg65LaKoy7Decpv+r4alrOV37V6dW0nKf8qtOr34ZlAWvNmjX1ToLZYDsUWBIRSyOiHbgSmFMSJoAJacLp8WRDJndUGHcbzlNmted8ZVZbzlM2XNSkgFVBUyZJOj9tf0DSIZXGNRuhprF10mXI5oOZVhLmAuBVZPMYPQh8KiK6KoxrZmZmZoOg6gJWhc2RTiCb3HImMJdsUstK45qNRCqzrrSa+q1kE1PuDhxENgP7DhXGRdJcSQskLVi9enV1qTUzMzMzoDY1WJU0R5oDXB6ZO4CJkqZWGNdsJFoOTC9a3oOspqrYR4CrU75aAvwFeGWFcYmIiyNiVkTMmjx5ck0Tb2ZmZjZS1aKAVUlzpJ7CVNyUyU/bbYS5C5gpaS9JbcDJwLySME+QTSqLpF2BVwBLK4xrZmZmZoOgFgWsSpoj9RSmoqZMUNnT9q6uYEtnV29pNRsWIqIDOAu4AVgMXBURiySdIemMFOzLwBskPQjcDHwuItb0FHfoz8LMzGzwdXYFHb7/swbSUoN9VNIcqacwbRXErdi5v1rI7xY/w13nHjfQXZg1jIiYD8wvWXdR0ecVwFsqjTsQjz2zjuO/eRvXfPwNHLznTtXuzszMrOY+9uMFrFy7ies+eWS9k2IG1KYGq5LmSPOAU9NogocDayNiZYVxK5YTRAxouHozK+OWR1YBcN0DK+ucEjMzM7PhoeoarIjokFRojpQHLi00ZUrbLyJ7kn4isATYQNY5v8e4A01LTqLL5Suzmik8r1C5xrxmZmZmtp1aNBGspClTAGdWGnegcsra4ZpZbcklLDMza2BuwGSNpCYTDTeKXE50OYeZ1Yxzk5mZNT4/BLTG0lwFLMlPMMxqqLuJYH2TYWZmZjZsNFkBy00EzWopcAnLzMwan+/+rJE0VwHLTQTNBoVcwjIzswblbsLWaJqrgOUmgmY15fxkZmZm1j9NVsCCTt8RmtWcnw6amVkj8zyo1kiaqoCVl5sImtXSlAmjAJg4prXOKTEzMyvPzwCt0TRVAUupiaCfYpjVxgF7TARg+qSx9U2ImZmZ2TDRVAWsXGrH5IEEzWqj0DTQNcPWDCTNlvSIpCWSzimzXZLOT9sfkHRIyfa8pHsl/WboUm1mZsNNUxWw8ulsfDNoVhuFZhfOUjbcScoD3wVOAPYDTpG0X0mwE4CZ6TUXuLBk+6eAxYOcVDMzG+aaqoCl7hos3w3a8FfB0/bPSrovvRZK6pQ0KW1bJunBtG3BwNOQvTtHWRM4FFgSEUsjoh24EphTEmYOcHlk7gAmSpoKIGkP4G3AJUOZaDPrmwdiskbTVAWs7iaCXXVOiFmVKnnaHhFfi4iDIuIg4PPAHyLiuaIgx6Tts6pIR+FYA92FWaOYBjxZtLw8ras0zLeAfwR6vcJImitpgaQFq1evrirBZmY2PDVVActNBK2JVPK0vdgpwM9qnQg3EbQmUu4Zd+kvu2wYSW8HVkXE3X0dJCIujohZETFr8uTJA0mnmQ2Ar1PWSJqqgJVzE0FrHpU8bQdA0lhgNvDLotUB3Cjpbklze4jX55P27hosNxK04W85ML1oeQ9gRYVhjgDeKWkZ2cOON0v6yeAl1cz6Qx6o3RpMUxWw5CaC1jwqedpe8A7gv0uaBx4REYeQNTE8U9KbtttZBU/aXYNlTeQuYKakvSS1AScD80rCzANOTaMJHg6sjYiVEfH5iNgjImakeL+PiA8NaerNzGzYqKqAJWmSpJskPZbedyoTZrqkWyQtlrRI0qeKtn1R0lNFHfVPrCY9OQ8pbc2jkqftBSdT0jwwIlak91XANWRNDvst190HayCxzRpHRHQAZwE3kI0EeFVELJJ0hqQzUrD5wFJgCfAD4ON1SayZ9ZtbWlgjaaky/jnAzRFxXhrl7BzgcyVhOoB/iIh7JE0A7pZ0U0Q8lLZ/MyL+o8p0AJDPuYmgNY3up+3AU2SFqA+UBpK0I3AU8KGideOAXESsS5/fAvzrQBLhebCsmUTEfLJCVPG6i4o+B3BmH/u4Fbh1EJJnZgPkUQSt0VTbRHAOcFn6fBnwrtIAqXnFPenzOrInh2X7klRLnmjYmkSFT9sBTgJujIiXitbtCtwu6X7gTuC6iLi+qvRUE9nMzMxsBKm2BmvXiFgJWUFK0pTeAkuaARwM/Llo9VmSTgUWkNV0Pd9D3LlkEz+y5557lt2/mwhaM+nraXta/hHwo5J1S4EDa5GG7qeCzlJmZmZmFemzBkvS79IkpqWv3oaMLref8WSjnJ0dES+m1RcC+wAHASuBr/cUv5IO+XmPImhWUzmPImhmZsOAb/2skfRZwIqI4yJi/zKva4Fnima5nwqsKrcPSa1khaufRsTVRft+JiI6I6KLrEPxgDrid5+Mmwia1dTWPlj1TYeZmdWPpNmSHpG0JPW5L90uSeen7Q9IOqSSuJI+kbYtkvTVgadvoDHNBke1fbDmAaelz6cB15YGUNYx6ofA4oj4Rsm2qUWLJwELq0lM982g7wbNaqIwt4ifDJqZjUyS8sB3yab92A84RdJ+JcFOAGam11yyFkq9xpV0DFlf/gMi4tVATQY8M2sE1RawzgOOl/QYcHxaRtLukgp9R44APkw2MWPpcOxflfSgpAeAY4C/ryYxHkXQrLYKDy3cRNDMbMQ6FFgSEUsjop1ssu3SbiJzgMsjcwcwMT1E7y3u3wHnRcRm6J5WZMB8lbJGUtUgFxHxLHBsmfUrgBPT59spP2kqEfHhao5fyk0EzWqru4DlPGVmNlJNA54sWl4OHFZBmGl9xN0XOFLSV4BNwGci4q7Sg1cyyJnK32aa1U21NVgNpXAz2OkSlllNbG0i6DxlZjZClSu9lF4UegrTW9wWYCfgcOCzwFXS9r2pKhnkzKzRVDtMe0MpNBH0zaBZbWxtImhmZiPUcmB60fIewIoKw7T1Enc5cHWa4PtOSV3ALsDqgSTS937WSJqqBstNBM1qq3saLOcpM7OR6i5gpqS9JLUBJ5MNclZsHnBqGk3wcGBtmie1t7i/At4MIGlfssLYmgGl0C0ErcE0VQ1Wzk0EzWqqex4sl7DMzEakiOiQdBZwA5AHLo2IRZLOSNsvAuaT9b1fAmwAPtJb3LTrS4FLJS0E2oHTwhcbaxJNVsDyKIJmteR5sMzMLCLmkxWiitddVPQ5gDMrjZvWtwMfqm1KzRpDUzYRdPnKrDa6B7moczrMzMx64+uUNZLmKmCls+l0CcusNrqHaXeeMjOzxuQuWNZomquA5SaCZjWV81XLzMzMrF+asoDlp+3WDCTNlvSIpCWSzimz/bOS7kuvhZI6JU2qJG4/0gD4oYWZmTU4X6asgTRlAcsd8m24k5QHvgucAOwHnCJpv+IwEfG1iDgoIg4CPg/8ISKeqyRuxenoPtbAzsPMzGywlZmf2KyumqyAlb17mHZrAocCSyJiaRpp6UpgTi/hTwF+NsC4PfJEw2ZmZmb901wFrJybM1nTmAY8WbS8PK3bjqSxwGzgl/2JK2mupAWSFqxevbpsItyv0czMhgNfpayRNFcBy8O0W/Mo196hp1/2O4D/jojn+hM3Ii6OiFkRMWvy5Mm9JsZ5yszMGpUbCFqjabICVvbuJoLWBJYD04uW9wBW9BD2ZLY2D+xv3F65WbuZmZlZ/zRXActNBK153AXMlLSXpDayQtS80kCSdgSOAq7tb9xKeGROayYVjMwpSeen7Q9IOiStny7pFkmLJS2S9KmhT72ZmQ0XVRWwJE2SdJOkx9L7Tj2EWybpwTSc9IL+xq+Umwhas4iIDuAs4AZgMXBVRCySdIakM4qCngTcGBEv9RV3IOkoVGC5UtiGuwpH1zwBmJlec4EL0/oO4B8i4lXA4cCZAx2Z08wGhx8EWiOptgbrHODmiJgJ3JyWe3JMGlJ61gDj98lNBK2ZRMT8iNg3IvaJiK+kdRdFxEVFYX4UESdXEncg5IcW1jwqGV1zDnB5ZO4AJkqaGhErI+IegIhYR/bgouygM2Y29Nyc3RpNtQWsOcBl6fNlwLuGOP42POKZWW11z4Pl8Zls+KtkdM0+w0iaARwM/LncQSoZndPMzJpbtQWsXSNiJUB6n9JDuABulHS3pLkDiN/PIaX7fR5mVkb3PFjOUzb8VTK6Zq9hJI0nmw7h7Ih4sdxB+jM6p5nVji9T1kha+gog6XfAbmU2nduP4xwRESskTQFukvRwRNzWj/hExMXAxQCzZs0qm49yqbjoGiyz2pAHubDmUcnomj2GkdRKVrj6aURcPYjpNLN+cgtBazR9FrAi4rietkl6ptA+XdJUYFUP+1iR3ldJuoasLfxtQEXxK5V3E0GzmpP8ZNCaQvfomsBTZKNrfqAkzDzgLElXAocBa9P1ScAPgcUR8Y2hTLSZmQ0/1TYRnAeclj6fxrZDRQMgaZykCYXPwFuAhZXG7w+5iaBZzQk3EbThr8KROecDS4ElwA+Aj6f1RwAfBt6cRsO9T9KJQ3sGZtYbX6eskfRZg9WH84CrJH0UeAJ4L4Ck3YFLIuJEYFfgmlT4aQGuiIjre4s/UIVRBLtcwjKrmZzkQS6sKUTEfLJCVPG64lE5AzizTLzbcSsks4YlDyNoDaaqAlZEPAscW2b9CuDE9HkpcGB/4g9U3hMNm9Wc5FphMzMzs0pV20SwoXgUQbPa29IZrHpxc72TYWZmZjYsNFUBS24iaDYofnnP8nonwczMrEduym6NpKkKWJ5o2MzMzGxkcQ8sazRNVcDa2gerzgkxMzMzM7MRqakKWN1NBF2DZWZmZjZi+NbPGklTFbDcRNDMzMxshHEbQWswTVXAyhcKWG4jaE1A0mxJj0haIumcHsIcnSY9XSTpD0Xrl0l6MG1bMHSpNjMzMxvZqp1ouKF4mHZrFpLywHeB44HlwF2S5kXEQ0VhJgLfA2ZHxBOSppTs5piIWDNUaTYzM6sXN16yRtJUNVhKZ+MmgtYEDgWWRMTSiGgHrgTmlIT5AHB1RDwBEBGrBisx++46frB2bWZmVhW5jaA1mKYqYOXdB8uaxzTgyaLl5WldsX2BnSTdKuluSacWbQvgxrR+brkDSJoraYGkBatXr+4xIQdNn8iuO4we2FmYmZmZjTBuImjWmMo9jiv9ZbcArwWOBcYAf5J0R0Q8ChwREStSs8GbJD0cEbdts7OIi4GLAWbNmtVjrsnn5KYXZmZmZhVqqhqswjDtnS5h2fC3HJhetLwHsKJMmOsj4qXU1+o24ECAiFiR3lcB15A1ORyQnJynzMzMzCrVVAWs7omGfTNow99dwExJe0lqA04G5pWEuRY4UlKLpLHAYcBiSeMkTQCQNA54C7BwoAnJSW52a2Y2gvU1qq0y56ftD0g6pB9xPyMpJO0y8PQNNKbZ4GiqJoJ5NxG0JhERHZLOAm4A8sClEbFI0hlp+0URsVjS9cADQBdwSUQslLQ3cI2y/NACXBER1w80LTmJjq6uak/JzMyGoUpGtQVOAGam12HAhcBhfcWVND1te2KozsdsKDRVASuXarA6fTNoTSAi5gPzS9ZdVLL8NeBrJeuWkpoK1kI+JzZ3+KmFmdkI1T2qLYCkwqi2xQWsOcDlERHAHZImSpoKzOgj7jeBfyRrkVGVcEsLayBN1UQQoCUnOp3JzGpm7cYtLF65rt7JMDOz+qhkVNuewvQYV9I7gaci4v7eDl7JiLduIWiNpqoClqRJkm6S9Fh636lMmFdIuq/o9aKks9O2L0p6qmjbidWkB7JarA63ETSrmQefWsvGLZ31ToaZmdVHJaPa9hSm7PrUb/hc4At9HTwiLo6IWRExa/LkyX0m1qwRVFuDdQ5wc0TMBG5Oy9uIiEci4qCIOIhsSOkNZKOaFXyzsD01iapKS04e5MLMzMysNiod1bZcmJ7W7wPsBdwvaVlaf4+k3QaaSN/5WSOptoA1B7gsfb4MeFcf4Y8FHo+Iv1Z53B7l5RosMzMzsxqpZFTbecCpaTTBw4G1EbGyp7gR8WBETImIGRExg6wgdkhEPD2QBHoUQWs01Rawdk0ZiPQ+pY/wJwM/K1l3VhrS89JyTQwLKmmDC5DPuwbLzMy2N5hDTZs1q4joAAqj2i4GriqMalsY2ZZsQKalwBLgB8DHe4s7xKdgNuT6HEVQ0u+AclW25/bnQOnJxTuBzxetvhD4MlnN7peBrwN/Uy5+RFwMXAwwa9asHktQrsEyGxwRgfyY0IapwRxq2qzZ9TWqbRo98MxK45YJM6P6VJo1jj4LWBFxXE/bJD0jaWpErEzDca7qZVcnAPdExDNF++7+LOkHwG8qS3bP8jlPimo2GCLcDMOGtcEcarpfPnXlvfgyZdWYNWMnTn39jHono6E4T1kjqXYerHnAacB56b23eQxOoaR5YKFwlhZPAhZWmR7yOdHR6VxmVmvOVTbMlRsu+rAKwvQ01HRpXCBrzg7MBdhzzz3LJuTB5Wudn6wqUyaMqncSGoo8ULs1mGoLWOcBV0n6KNks3O8FkLQ7cElEnJiWx5I1rfhYSfyvSjqI7N5tWZnt/Zb3PFhmNfWOA3fn1/evoCuCvC9iNnzVfKjpcgeppDn77z9zdI+JNDOz4a+qAlZEPEs2MmDp+hXAiUXLG4Cdy4T7cDXHLyefE53ug2VWM6/cbQK/vt/NL2zYq2ao6bYK4ppZHYXrha2BVDuKYMNxActscLhvow1zNR9qeigTb2Y9c/9gazTVNhFsOHm5gGVWSzlfuawJRESHpMJw0Xng0sJQ02n7RWQjnZ1INtT0BuAjvcWtw2mYmdkw0HwFLNdgWZOQNBv4NtkN3SURcV6ZMEcD3wJagTURcVSlcStPR/buGiwb7gZ7qGkzqx9foqyRuIBl1oAqmXdH0kTge8DsiHhC0pRK4/ZHrruANeDTMTMzMxsxmq4PVotHEbTm0D1nT0S0A4V5d4p9ALg6Ip4AiIhV/YhbsUITwXC+MjOzBuSW7NZomq6AlXMNljWHnubjKbYvsJOkWyXdLenUfsTtt/kPruw7kJmZmdkI13QFrBYXsKw5VDLvTgvwWuBtwFuBf5a0b4VxkTRX0gJJC1avXt1jQgo1WJ/75YOVpdzMzGyI+c7PGknTFbByEh0uYNnwV+mcPddHxEsRsQa4DTiwwrhExMURMSsiZk2ePLnHhLjphZmZNTZfqKyxNF0BqyUvulzAsuGvknl3rgWOlNQiaSxwGLC4wrgVc42wmZmZWeWabhRB12BZM6hkzp6IWCzpeuABoItsOPaFALWcs+cXdy+v8mzMzMwGl8dhskbSdAWslpw8X481hb7m7EnLXwO+VkncgWrv6KrFbszMzAaFm7Jbo2m6JoL5nOjodAHLrFacm8zMzMwq15QFLNdgmZmZmY0kvvezxtGUBSz3wTKrHU8wbGZmZla5Jixg5TyKoJmZmdkI4S5Y1miqKmBJeq+kRZK6JM3qJdxsSY9IWiLpnKL1kyTdJOmx9L5TNekByAvXYJnVkHOTmZmZWeWqrcFaCLybbILTsiTlge8CJwD7AadI2i9tPge4OSJmAjen5arkcznP22NWQ24haGZmjc7XKmskVRWwImJxRDzSR7BDgSURsTQi2oErgTlp2xzgsvT5MuBd1aQHIJ/zxKhmZmZmI4WHabdGMxR9sKYBTxYtL0/rAHaNiJUA6X1KTzuRNFfSAkkLVq9e3ePB8rkcnX6MYVYz4UaCZmZmZhXrs4Al6XeSFpZ5zekrbmEXZdb1+44tIi6OiFkRMWvy5Mk9hnMNllltnXviq+qdBDMzs175zs8aSUtfASLiuCqPsRyYXrS8B7AifX5G0tSIWClpKrCqymPR4j5YZjX1it12qHcSzMzMeiSPI2gNZiiaCN4FzJS0l6Q24GRgXto2DzgtfT4NuLbag+UkF7DMaijn65aZmZlZxaodpv0kScuB1wPXSbohrd9d0nyAiOgAzgJuABYDV0XEorSL84DjJT0GHJ+Wq9KSdwHLrJZy7j1sZmYNLtz/3hpIn00EexMR1wDXlFm/AjixaHk+ML9MuGeBY6tJQynXYJmZWTFJk4CfAzOAZcD7IuL5MuFmA98G8sAlEXFeWv814B1AO/A48JGIeGEo0m5mZsPPUDQRHFItOXkUQWsKPU3QXbT9aElrJd2XXl8o2rZM0oNp/YJq0pFzG0Eb/vqcc7GPORtvAvaPiAOAR4HPD0mqzawibmhhjaaqGqxGlMtlNVgRgZzjbJgqutk7nmygmLskzYuIh0qC/jEi3t7Dbo6JiDVVp6XaHZjV3xzg6PT5MuBW4HMlYbrnbASQVJiz8aGIuLEo3B3AewYzsWZmNrw1ZQ0WgFsJ2jDX2wTdQ8p9sKwJVDLnYm9zNhb7G+C3PR2o0jkbzay2fNtnjaTpClj5VMDq6Oqqc0rMqlLpzd7rJd0v6beSXl20PoAbJd0taW65A1R6I1jcQvCRp9f14xTMhs5QzNko6VygA/hpTzupdM5Gs+GkgibrknR+2v6ApEP6iivpa5IeTuGvkTRxwOkbaESzQdK0BSyXr2yYq2SC7nuAl0XEgcB3gF8VbTsiIg4h609ypqQ3bbezSm8Ei1LyxHMbKky+2dCKiOMiYv8yr2tJcy4C9DLnYm9zNiLpNODtwAfDw5XZCNJH/8SCE4CZ6TUXuLCCuO7baE2r+QpYcg2WNYVeb/YAIuLFiFifPs8HWiXtkpZXpPdVZCN9HjrQhLiJoDWBSuZc7HHOxjS64OeAd0aEnzLYSFNJk/U5wOWRuQOYmB5m9Bg3Im5MU/lA1rdxj2oS6cce1kiar4DlGixrDr1N0A2ApN2URnKRdChZfn5W0jhJE9L6ccBbgIUDTUhx8coP7m2YKjvnYj/mbLwAmADclEbmvGioT8Csjippst5TmKr7NlbSnN2DmlmjabpRBN0Hy5pBRHRIKtzs5YFLI2KRpDPS9ovIRjL7O0kdwEbg5IgISbsC16QLTgtwRURcP9C0FNdguXhlw1FPcy72Y87Glw9qAs0aWyVN1nsKU3Xfxoi4GLgYYNasWb4M2bDQdAWswpw9ngvLhrtyN3upYFX4fAHZk/XSeEuBA2uVDjcRNDMb0fpsst5LmLbe4hb1bTy22r6NbmFhjaTpmggWhmnv9DjtZrXh8pWZ2UjWZ5P1tHxqGk3wcGBtmhLBfRttRGq6AlZhkAsXsMxqo3iY9ivvfKJ+CTEzsyHXU/9ESWcUmq2TtbZYCiwBfgB8vLe4KY77NlrTaromgnnXYJnVVHHn4VseWc2SVet5+ZTxdUyRmZkNpQqarAdwZqVx03r3bbSm1Xw1WC5gmdVUrqSJ4JZODyBjZmaNxXd91khcwDKzXpUOcnHCt//IfU++UJ/EmJmZlfBYTNZomreA5dFkzAbN5X9aVu8kmJmZmTWkqgpYkt4raZGkLkmzeggzXdItkhansJ8q2vZFSU+lzo33STqx3D76o3serE4XsMxqwcO0m5lZw/NtnzWQage5WAi8G/h+L2E6gH+IiHskTQDulnRTRDyUtn8zIv6jynR0K4wi2OUaLLOaKO2DZWZm1kjk+USswVRVwIqIxbDtKGNlwqwEVqbP6yQtBqYBD/UYqQr5fKrBch8ss5roLX+bmZnV25tW/JC9eQp4a72TYgYMcR8sSTOAg4E/F60+S9IDki6VtFO1xyg0Z3px45Zqd2VmlK/B6qryAcbKtRu5+6/PV7UPMzMzgKkbHua1LK53Msy69VnAkvQ7SQvLvOb050CSxgO/BM6OiBfT6guBfYCDyGq5vt5L/LmSFkhasHr16h6P8+elzwLwhWsX9RjGzCpXrgbrV/et4O3f+SMxwKa4R33tVv7Xhf9TbdLMzMwIcsidsKyB9NlEMCKOq/YgklrJClc/jYiri/b9TFGYHwC/6SUdFwMXA8yaNavHXFQYnv3pFzdVm2wz68XCp15kQ3sn40a18NdnX2LKhNGMactXFLe9w3NpmZlZjQhyLmBZAxn0JoLKHn//EFgcEd8o2Ta1aPEkskEzqrLDmFYAWtwz32zQFaZDOOprtzL3xwsAWPjUWs695sGqmxGamZlVIsjhYQStkVQ7TPtJkpYDrweuk3RDWr+7pPkp2BHAh4E3lxmO/auSHpT0AHAM8PfVpAfgI0fMAOCjb9yr2l2Z1ZWk2ZIekbRE0jllth8taW1RvvpCpXFr5e5lz7NoxVoA/vjYGp5dv5m3f+d2fvrnJ3h89Xr+sual7rAvbGjnEz+7lyef2zBYyTEzs5FIcg2WNZRqRxG8BrimzPoVwInp8+1QfvzMiPhwNccvZ0xrnnxOeJR2G84k5YHvAscDy4G7JM0rmt6g4I8R8fYBxq3Ybz7xRt7+ndu3W/+RH921zfK///bh7s8f/uGdPP3iJpb+24nkcuL7ty3l1/ev4Nf3rxhoMrq9sKGdP//lOd766t2q3peZmQ1zcg2WNZZq58FqOJIY05pn45bOeifFrBqHAksiYimApCuBOVQ2vUE1ccuaNnFMReF+cffy7s+FfpAbtnTy3Pr27v6RxW566Bkmjm3ldTMm9Ss9Z11xL7cvWQPAhR88hBNeM3W7MOff/Bg3LHqamVPG8833H+Th5s3MmpY8yIU1lKYrYAGMdgHLhr9pwJNFy8uBw8qEe72k+4EVwGciYlGlcSXNBeYC7Lnnnr0mJldFn8ZP//w+bnzombLb/vflWb+tdx8yjX13ncCadZu55Pa/APDxo/fh/a+bzst2HtcdfsmqdSx86kX++tzWpod/99N7+P0/HMVeu4zrLkQ9sPwFvnHTowAsWvEi++2+A287YPdtCopPPLuBPzy6ipMO2YPxo1p4aXMHY9vyLoiZmQ0zmzuDtgjaO7poaxnSGYjMymrKAtb4UXnWbvA8WDaslbvLL308dw/wsohYn/o1/gqYWWHcikfmhOoGjempcFXs6nue2m7d9259nO/d+jg7jG7hxU0dvcZ/89f/AMC+u47n0WfWb7f93+Y/zL/Nf5jJE0axet1mxrbl2dCePYT552sXMaolx+Y0suG4tjxnH7cvmzuy7TuObeOQPSfy8injac3lWN/ewbI1L7F+cwcTx7Sx8/g28jmxpbOLjs7sa5w4tpW1G7ewob2TnLL5+YLsvVCTlxMUKvUKX2+Q/fGUwkUEWzoDpX3kc9qmJlDaGj6fEy2peXQQCNHe2QUELbkc+ZyQoCPtr1KdXUFrPkd7Z1f3PIOw9UdWWJWT6IrsuKX7nzSujXGjmvJyY2YN4K/PbeQAdfH7h59h9v7bt2gwG2pNecXbY6exrFi7sd7JMKvGcmB60fIeZLVU3YrmkyMi5kv6nqRdKonbX/k6jsrZV+GqWLnCVbHV6zYDdBeuCjYXDRv/UnsnX5nvCStr6fxTDuadB+5e72SYWZPqQt0Pm8waQVMWsKZMGMXV965h4VNredXUHcjJmc6GnbuAmZL2Ap4CTgY+UBxA0m7AMxERkg4lGxX0WeCFvuL2Vz0LWPU0YVQLF3zwENZu3MJfVmfNEttacuy1yzh2GNPC8y9t4YWN7Wzp6KIln8uapgQ8v6GdCaNb2WFMC11B95D1nV1BS16phqqru7anuFlLoQYqJ5GTaM1ntVJdAR1dXbTkckhbw2Xv0NHZRUdXkFe2z66A1nSsLR1dbO7oymrTcuoO05cIaMmL9o4uJBXVxkX39sJ7Z0qIyvy/PXj6xJr8PczMyhk/qpVcexcTXFNuDaIpf4kTRmenVRj17LTXv4wvzdm/nkky65eI6JB0FnADkAcujYhFks5I2y8C3gP8naQOYCNwckQEUDZuNenJ1+kBxTsP3J33v246M6eM56X2To75j1t7DHvtmUfw24VPs8v4Nv7vddvXQO21yzhu/vRRdHQFG7d0ct0DK/mnax4E4L4vHM+CZc/T1pLjgD12pK0lx6YtXUwa1zZYp2ZDSNIk4OfADGAZ8L6IeL5MuNnAt8nyzSURcV7J9s8AXwMmR8SaQU62mVXo1dMm0vkXKp7s3mywNWUBK5/btoPjZX/6K19856tdi2XDSkTMB+aXrLuo6PMFwAWVxq1GNYNcDNQrd5vAv7/7Ndv03Vl23tv48Z+W8c/XbltenDxhFAdOn8iBqabk8L133mZY+RP2340LP/RaANpyoq0lxymHTmf1us286+DdmTi2jeP223WbfY512aqZnAPcHBHnpXnhzgE+Vxygr+kNJE1P254Y0pSbWZ9yuRxdCro8R481iKYcamVsmScYa9a31yElZtYf3znlYJZ85QSWnfc2rj/7TWUHRvjw62fwsp3HbrPuhrPftM3y/tN25HsfPKR7+avvOWC7/UjiU8fN3GaUQmtac4DL0ufLgHeVCdM9vUFEtAOF6Q0Kvgn8I55sx6zhSDlEcO8TL9Q7KWZAkxawPnj4ntvN2/Mv8xbyszuf4E1fvYVf37+CC299HIDPX/0gZ15xTz2SaWbJ2cfNBGDvyeNoyff9b+n6T72JthTuq+85oGxTvhNfM5VTX/8yzjhqHyaMbq1tgm242TUiVgKk9yllwpSb3mAagKR3Ak9FxP19HUjSXEkLJC1YvXp19Sk3sz5t7uwiR5RtHm5WD03ZRHDqjmP473PezIxzruteN//Bp5n/4NMAfOJn9wIwbacx/OzOrLXHN97XyagWt901q4e/PXJv3jtresUTGo9py3PLZ4/mnr8+zzt6GZ3uX933csSQ9DtgtzKbzq10F2XWhaSxaR9vqWQn/Zn+wMxqQ7m8Jxq2htKUBayC+75wPJ+88j5ue7T8U8RPpoIWwCv+z/V855SDueXhVXzsqH3oiuBVU3cYqqSajWht+VzFhauCaRPH9DuONa+IOK6nbZKekTQ1IlZKmgqsKhOsp+kN9gH2Au5P/Xj3AO6RdGhEPF2zEzCzKoicC1jWQJq6gDVxbBuX/82hdHUFr/vK73j2pd77YRVqtq6+d+ukp996/0F8/aZH+Pvj9uXdh+wxqOk1G47OOublXHDLkorD7zd1Bx5amU3hNX5UC+s3d9Ca9wA0NqjmAacB56X3a8uEKTs1QhqBs7tJoaRlwCyPImjWOLIpzrv6Dmg2RJqyD1apXE7c/c/Hs+y8t/FfZ7wegI+9ae/u7aNbe/4azv75fTz53EY+fdX9zDjnOr7/h8e757QxG+k+dPiefOLYl/caZt5ZRzB90hjeeeDunH/KwVz3yTd2b/v08fuy7Ly3eYRPG2znAcdLeoxsJMDzACTtLmk+ZFMjAIXpDRYDV1U7vYGZDY221pbuNr5HfvX3dU2LGTR5DVY5r5sxifu/8BbGj27hH2e/kqdf3MTUHUZz57LnOPXSO2nv6CKfE1+esz9f+vUiNnds+0Tk33/7MP/+24c5cuYu7DN5PK/efQfe89o9fINoTe+nf3sYH7zkz9usy0vb9F289swj2Lilk5MvvoP9p+3Abz5xJAB//Mc3D2lazYpFxLPAsWXWrwBOLFruc3qDiJhR6/SZWXV223EMz6UarCef21jn1JiNwAIWwI5jt44oVujDcfjeO/PIl2cDdBeWPnDYnvzPkjWsXr+ZfSaP5+bFq/jNAyt4bNV6/vjYGv74WNZC5LO/eIAjZ+7C3x65Nxs2d3DoXpPYefyo7mNEBCvXbmLZmpcY1ZrjtS+bNFSnalYzuTIPEUrnxzpw+kTWbdoCwFnHzBySdJmZ2Qin3Daj1Kzf3MH4MtN8mA0V//qKlKuFesPLd+n+vP+0HflUGk76xkVP85M/P8HajVtYWlLggqzTfntnF6NactvVgs162U785G8PY3Tr1if/XV1Rl8lczapRKHR98/0H8vDKdQBMGN3KsvPeVlF8970yM7PqiVxRH6z9/+UGHv7y7G3us8yGUlUFLEnvBb4IvAo4NCIW9BBuGbAO6AQ6ImJWWj8J+DkwA1gGvC8inq8mTUPlLa/ejbe8euuIwHcte47v3rIEAbc8spr2ziyjlxauABb89XnecN7v2WV8G8e8cgodncEPb//LNmFOOXRP/uUd+/mfgzWMyRO2n2uq8EzgpIP3gIMr39erpu7A4pUvMmWH0TVKnZmZjVglNVgAtz+2huP227UuyTGrtgZrIfBu4PsVhD2mzKhL5wA3R8R5ks5Jy5+rMk118boZk/jRRw7dbn1EbFMz1tUVXHDLEv70+LP8aemzPPrM+rL7+9mdT3TP0fXGl+/Cy6eMZ979Kzhq38kcsudEjn3VruzuIaptCL18yoTt1pVrNliJX591BLc8sprjXlVuvlczM7N+kLabB+tTV97Lon+dXacE2UhXVQErIhZD+aZ1FZoDHJ0+XwbcyjAtYPWk9LvJ5cQnj53JJ4/Nmho+91I7E0a30JKqAl7c1MHo1hy/e2gV/3X3k9yx9FluX7KG25dkZdNr7n2Ka+59in++Nhvc6o0v34WZu45n9x3H8OHXv8w1XjaoLjl1Fn97+daK6oE2a23J5zjeTxbNzKwW2sYxQRu54H37cdZVDwHwUnsn3//D43zsqH3qnDgbiYaqD1YAN0oK4PtppnuAXSNiJUCaALLHx9mS5gJzAfbcc8/BTu+QmTRu22ZXO47JBuB42wFTedsBUwHo6OzijqXPkc+JfE585/eP8T+PP0tnV3DvE893F77+/beLeeurd+OVu+3A2LY8Y9ryjGnNl3xu4cVNW5g4tpW2fI5cTrTmcoxuzTGqJU9ri2jL58jn5JERbTvFzS3e+PJdmHvk3r2ENjOzZiBpNvBtIA9cEhHnlWxX2n4isAE4PSLu6S1uTbuJ7PpqAN6+yyrGnf46PvKju4Bs5GcXsKwe+ixgSfodsFuZTedGRLnJGss5IiJWpALUTZIejojb+pPQVCi7GGDWrFkjaiKqlnyON87cOtjGjz96WPfniOD5DVv4yR1/5Ye3/4WbF6/itwufrslxW/NixzFtjGrZWgAb1ZpjdEtWYFu3aQs7jxvFuFEttLWIHUa30prP5hQb05YnJzF+dAvjR+UZP6qVcaPyjB/VQk5idGueUS1ZAW9cKgDmJRfsivR1QSsK9zrgDuD9EfGLtG4ZZfo91tJ3P3hI9wMBMzNrTpLywHfJ5pBbDtwlaV5EPFQU7ARgZnodBlwIHNZH3Np1E9k5zcf45B0cc8RhXH/2kcz+1h8B+OZNj3L2cTN9b2FDqs8CVkQcV+1B0lwjRMQqSdcAhwK3Ac9Imppqr6YCq6o91kgjiUnj2rZpdtjR2cWGLZ1sau9kQ3pt3NLJxvZONrR3sKmji03tnYxpy9MVQUdnsGFLJ+0dXWzp7GJLRxfPvtRORLClK9i0pZPNHV1sTu/rNnWwduMWHlr5IjuOaWVcW57nNrTT1UX34B4DlROMa2thVGuOllyOiWNb6Ypg3KgWRrfkaWvJ0ZrPMaolR2te3cvjR7cwKp99/p/Hn6W1JceJ++9GV0BnBJ2dXTz3UjvjRrWw8/hRdEWQk9jQ3sGa9e3stctYWnI5NrR38L5Z0+v+j7jCC1oh3P8jmxy1VLl+j1UrjJCZ96iXZmYjwaHAkohYCiDpSrIuHsXXoznA5RERwB2SJqb7uhm9xK1dN5FJqTXFTV8A4JUT92T+WzfztZuW8MAt9/J3t+WYOnE0a9ZvZnNHF2+cOZnRLXnaOzoZk+45RrfkyeVEV1fQGUEOGN2aJ8geZre25GjL5dg6moYo3CoUVtXi1qGSXfR9HF+f+yuXb+E1R727Zvsb9CaCksYBuYhYlz6/BfjXtHkecBpwXnqvtEbMetGSz7FDPscOo4e2diEi6AroiuClzR0Isbmzk/WbOnhpcyfrNm9h9brNjG7NsykV6P6y5iVa8jlac6IzgvWbOmjv7OKlzZ0EwTMvbmJ0S571mzvo6Opiw8YsXntHJ1s6I/vc2dUdr9htj64e0Hm87YDdG2H+jEouaACfAH4JvG6oEvaxo/bmO79fQluqrTQzs6Y2DXiyaHk5WS1VX2Gm9RG3om4iFXURaRsHB34A7r+iu5C1H/Cfxb0w1qX3PLC0/G5s5FofY6BRCliSTgK+A0wGrpN0X0S8VdLuZE2aTgR2Ba5JNQItwBURcX3axXnAVZI+CjwBvLea9Fh9SSIvyCMmji38V2ulzOBzg6K9o4v1mzu6a9xacqIlnzU7bMnl6Ojq4sWNW2jJ5bqf/rywYQtdEUwc20Z7RxfjR7cwtjEGCunzgiZpGnAS8Ga2L2D11O+xOP6A+jV+4s0zOf0NM2hrcQHLzGwEKFcdUtpVo6cwlcTtVcVdRE66EI7/Emx6ETo3Q2c7RJlWNQFB9kBYwJbOLoJgQ3snAPlcjrxEe2cnWzoC5bLWNZu3dHU/yI2iVEQ/zqavsFHBVzOi+sgMISnHzBrur9pRBK8BrimzfgVZR0fSE/gDe4j/LHBsNWkwK2hryTGpZfu5mopN
gitextract_97iggw1v/
├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── cpa/
│ ├── __init__.py
│ ├── api.py
│ ├── data.py
│ ├── helper.py
│ ├── model.py
│ ├── plotting.py
│ └── train.py
├── datasets/
│ └── .gitkeep
├── notebooks/
│ └── demo.ipynb
├── preprocessing/
│ ├── GSM.ipynb
│ ├── Norman19.ipynb
│ ├── cross_species.ipynb
│ ├── lincs.ipynb
│ ├── pachter.ipynb
│ ├── sciplex3.ipynb
│ └── sciplex3_round_robin.ipynb
├── pretrained_models/
│ ├── .gitattributes
│ └── .gitkeep
├── requirements.txt
├── scripts/
│ ├── .gitkeep
│ ├── run_collect_results.sh
│ ├── run_one_epoch.sh
│ └── run_sweeps.sh
├── setup.py
└── tests/
└── test.py
SYMBOL INDEX (103 symbols across 7 files)
FILE: cpa/api.py
class API (line 23) | class API:
method __init__ (line 28) | def __init__(
method load_from_old (line 196) | def load_from_old(self, pretrained):
method print_args (line 210) | def print_args(self):
method load (line 213) | def load(self, pretrained):
method train (line 226) | def train(
method save (line 355) | def save(self, filename):
method _init_pert_embeddings (line 366) | def _init_pert_embeddings(self):
method get_drug_embeddings (line 378) | def get_drug_embeddings(self, dose=1.0, return_anndata=True):
method _init_covars_embeddings (line 409) | def _init_covars_embeddings(self):
method get_covars_embeddings_combined (line 438) | def get_covars_embeddings_combined(self, return_anndata=True):
method get_covars_embeddings (line 456) | def get_covars_embeddings(self, covars_tgt, return_anndata=True):
method _get_drug_encoding (line 477) | def _get_drug_encoding(self, drugs, doses=None):
method mix_drugs (line 506) | def mix_drugs(self, drugs_list, doses_list=None, return_anndata=True):
method latent_dose_response (line 550) | def latent_dose_response(
method latent_dose_response2D (line 613) | def latent_dose_response2D(
method compute_comb_emb (line 687) | def compute_comb_emb(self, thrh=30):
method compute_uncertainty (line 746) | def compute_uncertainty(self, cov, pert, dose, thrh=30):
method predict (line 803) | def predict(
method get_latent (line 973) | def get_latent(
method get_response (line 1073) | def get_response(
method get_response_reference (line 1164) | def get_response_reference(self, perturbations=None):
method get_response2D (line 1222) | def get_response2D(
method evaluate_r2 (line 1332) | def evaluate_r2(self, dataset, genes_control, adata_random=None):
function get_reference_from_combo (line 1432) | def get_reference_from_combo(perturbations_list, datasets, splits=["trai...
function linear_interp (line 1466) | def linear_interp(y1, y2, x1, x2, x):
function evaluate_r2_benchmark (line 1473) | def evaluate_r2_benchmark(cpa_api, datasets, pert_category, pert_categor...
FILE: cpa/data.py
function ranks_to_df (line 18) | def ranks_to_df(data, key="rank_genes_groups"):
function check_adata (line 42) | def check_adata(adata, special_fields):
class Dataset (line 65) | class Dataset:
method __init__ (line 66) | def __init__(
method subset (line 292) | def subset(self, split, condition="all"):
method __getitem__ (line 296) | def __getitem__(self, i):
method __len__ (line 303) | def __len__(self):
class SubDataset (line 307) | class SubDataset:
method __init__ (line 312) | def __init__(self, dataset, indices):
method __getitem__ (line 339) | def __getitem__(self, i):
method subset_condition (line 346) | def subset_condition(self, control=True):
method __len__ (line 350) | def __len__(self):
function load_dataset_splits (line 354) | def load_dataset_splits(
FILE: cpa/helper.py
function _convert_mean_disp_to_counts_logits (line 21) | def _convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6):
function rank_genes_groups_by_cov (line 45) | def rank_genes_groups_by_cov(
function rank_genes_groups (line 136) | def rank_genes_groups(
function evaluate_r2_ (line 254) | def evaluate_r2_(adata, pred_adata, condition_key, sampled=False, de_gen...
function evaluate_mmd (line 295) | def evaluate_mmd(adata, pred_adata, condition_key, de_genes_dict=None):
function evaluate_emd (line 321) | def evaluate_emd(adata, pred_adata, condition_key, de_genes_dict=None):
function pairwise_distance (line 354) | def pairwise_distance(x, y):
function gaussian_kernel_matrix (line 362) | def gaussian_kernel_matrix(x, y, alphas):
function mmd_loss_calc (line 387) | def mmd_loss_calc(source_features, target_features):
FILE: cpa/model.py
class NBLoss (line 11) | class NBLoss(torch.nn.Module):
method __init__ (line 12) | def __init__(self):
method forward (line 15) | def forward(self, mu, y, theta, eps=1e-8):
function _nan2inf (line 46) | def _nan2inf(x):
class MLP (line 49) | class MLP(torch.nn.Module):
method __init__ (line 54) | def __init__(self, sizes, batch_norm=True, last_layer_act="linear"):
method forward (line 77) | def forward(self, x):
class GeneralizedSigmoid (line 85) | class GeneralizedSigmoid(torch.nn.Module):
method __init__ (line 91) | def __init__(self, dim, device, nonlin="sigmoid"):
method forward (line 107) | def forward(self, x):
method one_drug (line 117) | def one_drug(self, x, i):
class CPA (line 128) | class CPA(torch.nn.Module):
method __init__ (line 133) | def __init__(
method set_hparams_ (line 286) | def set_hparams_(self, hparams):
method move_inputs_ (line 322) | def move_inputs_(self, genes, drugs, covariates):
method compute_drug_embeddings_ (line 334) | def compute_drug_embeddings_(self, drugs):
method predict (line 349) | def predict(
method early_stopping (line 403) | def early_stopping(self, score):
method update (line 419) | def update(self, genes, drugs, covariates):
method defaults (line 509) | def defaults(self):
FILE: cpa/plotting.py
class CPAVisuals (line 23) | class CPAVisuals:
method __init__ (line 30) | def __init__(
method plot_latent_embeddings (line 90) | def plot_latent_embeddings(
method plot_contvar_response2D (line 165) | def plot_contvar_response2D(
method plot_contvar_response (line 274) | def plot_contvar_response(
method plot_scatter (line 366) | def plot_scatter(
function log10_with0 (line 435) | def log10_with0(x):
function get_palette (line 441) | def get_palette(n_colors, palette_name="Set1"):
function fast_dimred (line 454) | def fast_dimred(emb, method="KernelPCA"):
function plot_dose_response (line 478) | def plot_dose_response(
function plot_uncertainty_comb_dose (line 640) | def plot_uncertainty_comb_dose(
function plot_uncertainty_dose (line 771) | def plot_uncertainty_dose(
function save_to_file (line 882) | def save_to_file(fig, file_name, file_format=None):
function plot_embedding (line 897) | def plot_embedding(
function get_colors (line 1011) | def get_colors(labels, palette=None, palette_name=None):
function plot_similarity (line 1019) | def plot_similarity(
function mean_plot (line 1074) | def mean_plot(
function plot_r2_matrix (line 1211) | def plot_r2_matrix(pred, adata, de_genes=None, **kwds):
function arrange_history (line 1268) | def arrange_history(history):
class CPAHistory (line 1273) | class CPAHistory:
method __init__ (line 1278) | def __init__(self, cpa_api, fileprefix=None):
method print_time (line 1326) | def print_time(self):
method plot_losses (line 1329) | def plot_losses(self, filename=None):
method plot_r2_metrics (line 1363) | def plot_r2_metrics(self, epoch_min=0, filename=None):
method plot_disentanglement_metrics (line 1409) | def plot_disentanglement_metrics(self, epoch_min=0, filename=None):
FILE: cpa/train.py
function pjson (line 19) | def pjson(s):
function _convert_mean_disp_to_counts_logits (line 25) | def _convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6):
function evaluate_disentanglement (line 48) | def evaluate_disentanglement(autoencoder, dataset):
function evaluate_r2 (line 117) | def evaluate_r2(autoencoder, dataset, genes_control):
function evaluate (line 199) | def evaluate(autoencoder, datasets):
function prepare_cpa (line 243) | def prepare_cpa(args, state_dict=None):
function train_cpa (line 277) | def train_cpa(args, return_model=False):
function parse_arguments (line 368) | def parse_arguments():
FILE: tests/test.py
function sim_adata (line 10) | def sim_adata():
Condensed preview — 30 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (3,055K chars).
[
{
"path": ".gitignore",
"chars": 55,
"preview": "__pycache__\n*.pyc\n*.egg-info/\n*.ipynb_checkpoints/\n*.pt"
},
{
"path": "CODE_OF_CONDUCT.md",
"chars": 244,
"preview": "# Code of Conduct\n\nFacebook has adopted a Code of Conduct that we expect project participants to adhere to.\nPlease read "
},
{
"path": "CONTRIBUTING.md",
"chars": 1248,
"preview": "# Contributing to `CPA` \nWe want to make contributing to this project as easy and transparent as\npossible.\n\n## Pull Requ"
},
{
"path": "LICENSE",
"chars": 1090,
"preview": "The MIT License\n\nCopyright (c) Facebook, Inc. and its affiliates.\n\nPermission is hereby granted, free of charge, to any "
},
{
"path": "README.md",
"chars": 3995,
"preview": "# CPA - Compositional Perturbation Autoencoder\n\n# This code in not being maintained anymore, please use the new implemen"
},
{
"path": "cpa/__init__.py",
"chars": 132,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\nfrom cpa.api import API\nfrom cpa.plotting import"
},
{
"path": "cpa/api.py",
"chars": 56512,
"preview": "import copy\nimport itertools\nimport os\nimport pprint\nimport time\nfrom collections import defaultdict\nfrom typing import "
},
{
"path": "cpa/data.py",
"chars": 13673,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\nimport warnings\n\nimport numpy as np\nimport torch"
},
{
"path": "cpa/helper.py",
"chars": 13590,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\nimport warnings\n\nimport numpy as np\nimport panda"
},
{
"path": "cpa/model.py",
"chars": 18503,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\nfrom http.client import RemoteDisconnected\nimpor"
},
{
"path": "cpa/plotting.py",
"chars": 45462,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\nfrom collections import defaultdict\n\nimport matp"
},
{
"path": "cpa/train.py",
"chars": 14255,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\nimport argparse\nimport json\nimport os\nimport tim"
},
{
"path": "datasets/.gitkeep",
"chars": 0,
"preview": ""
},
{
"path": "notebooks/demo.ipynb",
"chars": 726674,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"# A tour of the CPA model\"\n ]\n }"
},
{
"path": "preprocessing/GSM.ipynb",
"chars": 1922268,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": null,\n \"metadata\": {},\n \"outputs\": [],\n \"source\": "
},
{
"path": "preprocessing/Norman19.ipynb",
"chars": 52957,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"metadata\": {},\n \"outputs\": [\n {\n \"name\":"
},
{
"path": "preprocessing/lincs.ipynb",
"chars": 27413,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"metadata\": {},\n \"outputs\": [\n {\n \"name\":"
},
{
"path": "preprocessing/pachter.ipynb",
"chars": 48280,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"metadata\": {},\n \"outputs\": [\n {\n \"name\":"
},
{
"path": "preprocessing/sciplex3.ipynb",
"chars": 39478,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"id\": \"satellite-immigration\",\n \"metadata\": {},\n"
},
{
"path": "preprocessing/sciplex3_round_robin.ipynb",
"chars": 23541,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"metadata\": {},\n \"outputs\": [\n {\n \"name\":"
},
{
"path": "pretrained_models/.gitattributes",
"chars": 69,
"preview": "Norman2010_prep_new_deg_collect/ filter=lfs diff=lfs merge=lfs -text\n"
},
{
"path": "pretrained_models/.gitkeep",
"chars": 0,
"preview": ""
},
{
"path": "requirements.txt",
"chars": 91,
"preview": "adjustText\nargparse\nmatplotlib\nnumpy\npandas\nscanpy \nscipy \nseaborn \nsklearn\nsubmitit\ntorch\n"
},
{
"path": "scripts/.gitkeep",
"chars": 0,
"preview": ""
},
{
"path": "scripts/run_collect_results.sh",
"chars": 2146,
"preview": "#!/bin/bash\n\npython -m cpa.collect_results --save_dir /checkpoint/$USER/sweep_GSM_new_logsigm\npython -m cpa.collect_"
},
{
"path": "scripts/run_one_epoch.sh",
"chars": 849,
"preview": "#!/bin/bash\n\n# change the path vairable to your path to the datasets folder\npath='../cpa_binaries'\n\npython -m cpa.train "
},
{
"path": "scripts/run_sweeps.sh",
"chars": 3602,
"preview": "#!/bin/bash\n\n# rm -rf /checkpoint/$USER/sweep_GSM_2k_hvg\n# rm -rf /checkpoint/$USER/sweep_GSM_4k_hvg\n# rm -rf /checkpoin"
},
{
"path": "setup.py",
"chars": 542,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\nfrom distutils.core import setup\n\next_modules = "
},
{
"path": "tests/test.py",
"chars": 885,
"preview": "import sys\n\nsys.path.append(\"../\")\nimport cpa\nimport scanpy as sc\nimport scvi\nfrom cpa.helper import rank_genes_groups_b"
}
]
// ... and 1 more files (download for full content)
About this extraction
This page contains the full source code of the facebookresearch/CPA GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 30 files (36.5 MB), approximately 755.7k tokens, and a symbol index with 103 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.