" """ A generic description appended to model uploads that are automatically uploaded to zenodo via Zenodo API call in medigan""" ZENODO_GENERIC_MODEL_DESCRIPTION = ( f"
Usage:
This GAN is used as part of the medigan library. " f"This GANs metadata is therefore stored in and retrieved from medigan's " f"config file. medigan " f"is an open-source Python library on Github that allows developers and " f"researchers to easily add synthetic imaging data into their model training pipelines. medigan is documented " f"here and can be used via pip install:
" f"pip install medigan To run this model in medigan, use the following commands.
" f" from medigan import Generators "
f" generators = Generators() "
f" generators.generate(model_id='YOUR_MODEL_ID',num_samples=10)" ) """ The REST API to interact with Zenodo """ ZENODO_API_URL = "https://zenodo.org/api/deposit/depositions" # "https://sandbox.zenodo.org/api/deposit/depositions" """ The HEADER for Zenodo REST API requests""" ZENODO_HEADERS = {"Content-Type": "application/json"} """ The title of the Github Issue when adding a model to medigan""" GITHUB_TITLE = "Model Integration Request for medigan" """ The repository of the Github Issue when adding a model to medigan""" GITHUB_REPO = "RichardObi/medigan" # "RichardObi/medigan-models" """ The assignee of the Github Issue when adding a model to medigan""" GITHUB_ASSIGNEE = "RichardObi" ================================================ FILE: src/medigan/contribute_model/__init__.py ================================================ ================================================ FILE: src/medigan/contribute_model/base_model_uploader.py ================================================ # -*- coding: utf-8 -*- # ! /usr/bin/env python """Base Model uploader class that uploads models to medigan associated data storage services. """ from __future__ import absolute_import class BaseModelUploader: """`BaseModelUploader` class: Uploads a user's model and metadata to third party storage to allow its inclusion into the medigan library. Parameters ---------- model_id: str The generative model's unique id metadata: dict The model's corresponding metadata Attributes ---------- model_id: str The generative model's unique id metadata: dict The model's corresponding metadata """ def __init__( self, model_id: str, metadata: dict, ): self.model_id = model_id self.metadata = metadata def push(self): raise NotImplementedError def __repr__(self): return f"BaseModelUploader(model_id={self.model_id}, metadata={self.metadata})" def __len__(self): raise NotImplementedError def __getitem__(self, idx: int): raise NotImplementedError ================================================ FILE: src/medigan/contribute_model/github_model_uploader.py ================================================ # -*- coding: utf-8 -*- # ! /usr/bin/env python """Github Model uploader class that uploads the metadata of a new model to the medigan github repository. """ from __future__ import absolute_import import json import logging from github import Github from ..constants import ( CONFIG_FILE_KEY_EXECUTION, CONFIG_FILE_KEY_PACKAGE_LINK, GITHUB_ASSIGNEE, GITHUB_REPO, GITHUB_TITLE, ) from ..utils import Utils from .base_model_uploader import BaseModelUploader class GithubModelUploader(BaseModelUploader): """`GithubModelUploader` class: Pushes the metadata of a user's model to the medigan repo, where it creates a dedicated github issue. Parameters ---------- model_id: str The generative model's unique id access_token: str a personal access token linked to your github user account, used as means of authentication Attributes ---------- model_id: str The generative model's unique id access_token: str a personal access token linked to your github user account, used as means of authentication """ def __init__( self, model_id: str, access_token: str, ): self.model_id = model_id self.access_token = access_token def push( self, metadata: dict, package_link: str = None, creator_name: str = "n.a.", creator_affiliation: str = "n.a.", model_description: str = "n.a.", ): """Upload the model's metadata inside a github issue to the medigan github repository. To add your model to medigan, your metadata will be reviewed on Github and added to medigan's official model metadata The medigan repository issues page: https://github.com/RichardObi/medigan/issues Get your Github access token here: https://github.com/settings/tokens Parameters ---------- metadata: dict The model's corresponding metadata package_link: a package link creator_name: str the creator name that will appear on the corresponding github issue creator_affiliation: str the creator affiliation that will appear on the corresponding github issue model_description: list the model_description that will appear on the corresponding github issue Returns ------- str Returns the url pointing to the corresponding issue on github """ # Check if the package_link is already in the metadata. If not, add it to metadata. metadata = self.add_package_link_to_metadata( metadata=metadata, package_link=package_link ) # First use pyGithub to create a Github instance based on san access token g = Github(self.access_token) repo = g.get_repo(GITHUB_REPO) logging.debug(f"Repo: {repo}") # Create metadata for github issue title = f"{GITHUB_TITLE}: {self.model_id}" line_break = "\n" body = ( f"{line_break + '**Creator:** ' if creator_name != '' else ''}{creator_name}" f"{line_break + '**Affiliation:** ' if creator_affiliation != '' else ''}{creator_affiliation}" f"{line_break + '**Description:** ' if model_description != '' else ''}{model_description}" f"{line_break} **Stored in:** {package_link}" f"{line_break} ### Model Metadata: {line_break} ```json{json.dumps(metadata, indent=3)}" ) # As a logged-in github user, let's now push to medigan repo using pyGithub github_issue = repo.create_issue( title=title, body=body, assignee=GITHUB_ASSIGNEE ) logging.info( f"{self.model_id}: Successfully created github issue in '{GITHUB_REPO}': '{github_issue.html_url}'" ) return github_issue.html_url def add_package_link_to_metadata( self, metadata: dict, package_link: str = None, is_update_forced: bool = False ) -> dict: """Update `package_link` in the model's metadata if current `package_link` does not containing a valid url. Parameters ---------- metadata: dict The model's corresponding metadata package_link: str the new package link to used to replace the old one is_update_forced: bool flag to update metadata even though metadata already contains a valid url in its `package_link` Returns ------- dict Returns the updated metadata dict with replaced `package_link` if replacement was applicable. """ current_pl = None try: # Get the package link from the metadata object current_pl = metadata[self.model_id][CONFIG_FILE_KEY_EXECUTION][ CONFIG_FILE_KEY_PACKAGE_LINK ] except Exception as e: logging.debug( f"{self.model_id}: Package Link could not be located in metadata for key {self.model_id}.{CONFIG_FILE_KEY_EXECUTION}.{CONFIG_FILE_KEY_PACKAGE_LINK}: {e}" ) # Check if the package link in the metadata contains a valid URL if ( current_pl is not None and not is_update_forced and ( (Utils.is_url_valid(current_pl) and current_pl.startswith("http")) or current_pl.startswith("models/") ) ): # If there is already a valid (non-local) url to a zip file, # we assume that this URL is validly pointing to the model. # Note: The package link can start with models/ indicating that the model is hosted directly in medigan # instead of Zenodo, see model 00007 for an example. pass else: # We update the metadata with the retrieved package_link. Note: Also, in case the metadata points to a path on a # user's machine, we avoid publishing that path to github issue by making this update. try: metadata[self.model_id][CONFIG_FILE_KEY_EXECUTION][ CONFIG_FILE_KEY_PACKAGE_LINK ] = package_link except Exception as e: logging.warning( f"{self.model_id}: Package Link could not be update in metadata for key {self.model_id}.{CONFIG_FILE_KEY_EXECUTION}.{CONFIG_FILE_KEY_PACKAGE_LINK}: {e}" ) logging.debug( f"{self.model_id}: Before creating github issue, updated package link from '{current_pl}' to '{package_link}'" ) return metadata def __repr__(self): return ( f"GithubModelUploader(model_id={self.model_id}, metadata={self.metadata})" ) def __len__(self): raise NotImplementedError def __getitem__(self, idx: int): raise NotImplementedError ================================================ FILE: src/medigan/contribute_model/model_contributor.py ================================================ # -*- coding: utf-8 -*- # ! /usr/bin/env python """Model contributor class that tests models, creates metadata entries, uploads and contributes them to medigan. """ from __future__ import absolute_import import importlib import logging import sys from pathlib import Path from ..constants import ( CONFIG_FILE_KEY_DEPENDENCIES, CONFIG_FILE_KEY_EXECUTION, CONFIG_FILE_KEY_GENERATE, CONFIG_FILE_KEY_GENERATE_NAME, CONFIG_FILE_KEY_MODEL_EXTENSION, CONFIG_FILE_KEY_MODEL_NAME, CONFIG_FILE_KEY_PACKAGE_LINK, CONFIG_FILE_KEY_PACKAGE_NAME, CONFIG_TEMPLATE_FILE_NAME_AND_EXTENSION, CONFIG_TEMPLATE_FILE_URL, INIT_PY_FILE, TEMPLATE_FOLDER, ) from ..utils import Utils from .github_model_uploader import GithubModelUploader from .zenodo_model_uploader import ZenodoModelUploader class ModelContributor: """`ModelContributor` class: Contributes a user's local model to the public medigan library Parameters ---------- model_id: str The generative model's unique id init_py_path: str The path to the local model's `__init__.py` file needed for importing and running this model. Attributes ---------- model_id: str The generative model's unique id init_py_path: str The path to the local model's __init__.py file needed for importing and running this model. package_path: str Path as string to the generative model's python package package_name: str Name of the model's python package i.e. the name of the model's zip file and unzipped package folder metadata_file_path: str Path as string to the generative model's metadata file e.g. default is relative path to package root. zenodo_model_uploader: str An instance of the `ZenodoModelUploader` class github_model_uploader: str An instance of the `GithubModelUploader` class. """ def __init__( self, model_id: str, init_py_path: str, ): self.validate_model_id(model_id) self.model_id = model_id self.init_py_path = init_py_path self.validate_init_py_path(init_py_path) self.package_path = self.init_py_path.replace(INIT_PY_FILE, "") self.package_name = Path(self.package_path).name self.metadata_file_path = "" # Default is relative path to package root. self.validate_local_model_import() self.zenodo_model_uploader = None self.github_model_uploader = None ############################ VALIDATION ############################ def validate_model_id( self, model_id: str, max_chars: int = 30, min_chars: int = 13 ) -> bool: """Asserts if the `model_id` is in the correct format and has a valid length Parameters ---------- model_id: str The generative model's unique id max_chars: int the maximum of chars allowed in the model_id min_chars: int the minimum of chars allowed in the model_id Returns ------- bool Returns flag indicating whether the `model_id` is correctly formatted. """ num_chars = len(model_id) assert ( num_chars <= max_chars ), f"The model_id {model_id} is too large ({num_chars}). Please reduce to a maximum of {max_chars} characters. Format Convention: '00001_GANTYPE_MODALITY'" assert ( num_chars >= min_chars ), f"The model_id {model_id} is too small ({num_chars}). Please reduce to a minimum of {min_chars} characters. Format Convention: '00001_GANTYPE_MODALITY'" for i in range(5): assert model_id[ i ].isdigit(), f"Your model_id's ({model_id}) character '{model_id[i]}' at position {i} is not a digit. The first 5 characters should be digits as in '00001_GANTYPE_MODALITY'. Please adjust." logging.info( f"The provided model_id is valid and will now be used to refer to the contributed model in medigan: {model_id}" ) return True def validate_init_py_path(self, init_py_path) -> bool: """Asserts whether the `init_py_path` exists and points to a valid `__init__.py` correct file. Parameters ---------- init_py_path: str The path to the local model's __init__.py file needed for importing and running this model. """ assert ( Path(init_py_path).exists() and Path(init_py_path).is_file() ), f"{self.model_id}: The path to your model's __init__.py function does not exist or does not point to a file. Please revise path {init_py_path}. Note: You can find an __init__.py example in https://github.com/RichardObi/medigan/tree/main/templates" assert Utils.is_file_in( folder_path=self.init_py_path.replace(f"/{INIT_PY_FILE}", ""), filename=INIT_PY_FILE, ), f"{self.model_id}: No __init__.py was found in your path {init_py_path}. Please revise. Note: You can find an __init__.py example in /templates in https://github.com/RichardObi/medigan" logging.info( f"The provided path to your model's __init__.py function was valid and points to a __init__.py file: {init_py_path}" ) return True def validate_and_update_model_weights_path(self) -> dict: """Check if the model files can be found in the `package_path` or based on the `path_to_metadata`. Ideally, the user provided `package_path` and the `path_to_metadata` should both point to the same model package containing weights, config, license, etc. Here we check both of these paths to find the model weights. Returns ------- dict Returns the metadata after updating the path to the model's checkpoint's weights """ metadata_dir_path = Path(self.metadata_file_path).parent potential_weight_paths: list = [] execution_metadata = self.metadata[self.model_id][CONFIG_FILE_KEY_EXECUTION] # package_path + package_path + file + extension try: potential_weight_paths.append( Path( self.package_path + f"/{execution_metadata[CONFIG_FILE_KEY_MODEL_NAME]}{execution_metadata[CONFIG_FILE_KEY_MODEL_EXTENSION]}" ) ) except KeyError as e: raise e # metadata_dir + package_path + file + extension try: potential_weight_paths.append( Path( str(metadata_dir_path) + f"/{execution_metadata[CONFIG_FILE_KEY_MODEL_NAME]}{execution_metadata[CONFIG_FILE_KEY_MODEL_EXTENSION]}" ) ) except KeyError as e: raise e # metadata_dir + package_path + file + extension try: potential_weight_paths.append( Path( str(metadata_dir_path) + "/" + self.package_path + f"/{execution_metadata[CONFIG_FILE_KEY_MODEL_NAME]}{execution_metadata[CONFIG_FILE_KEY_MODEL_EXTENSION]}" ) ) except KeyError as e: raise e for potential_weight_path in potential_weight_paths: if potential_weight_path.is_file(): # Checking if there is a weights/checkpoint (model name + extension) file in the package /metadata path self.package_path = str( Path(potential_weight_path).parent.resolve(strict=False) ) # strict=False, as models might be not on user's disc. self.metadata[self.model_id][CONFIG_FILE_KEY_EXECUTION][ CONFIG_FILE_KEY_PACKAGE_LINK ] = self.package_path logging.info( f"The model weights path is valid and was added to the metadata of your model: {self.package_path}" ) return self.metadata raise FileNotFoundError( f"{self.model_id}: Error validating metadata. There was no valid model weights file found. Please revise. Tested paths: '{potential_weight_paths}'" ) def validate_local_model_import(self): """Check if the model package in the `package_path` can be imported as python library using importlib.""" # Validation: Import module as python library to check if generate function is inside the # path_to_script_w_generate_function python file and no errors occur. try: sys.path.insert(1, str(self.package_path).replace(self.package_name, "")) importlib.import_module(name=self.package_name) logging.info( f"Model import test successful: The model was successfully imported using importlib: {self.package_name}" ) except Exception as e: raise Exception( f"{self.model_id}: Error while testing importlib model import. Is your {INIT_PY_FILE} erroneous? " f"Please revise if the provided path ({self.init_py_path}) is valid and accessible and try again. " f"Exception: {e}" ) from e ############################ UPLOAD ############################ def push_to_zenodo( self, access_token: str, creator_name: str, creator_affiliation: str, model_description: str = "", ): """Upload the model files as zip archive to a public Zenodo repository where the model will be persistently stored. Get your Zenodo access token here: https://zenodo.org/account/settings/applications/tokens/new/ (Enable scopes `deposit:actions` and `deposit:write`) Parameters ---------- access_token: str a personal access token in Zenodo linked to a user account for authentication creator_name: str the creator name that will appear on the corresponding Zenodo model upload homepage creator_affiliation: str the creator affiliation that will appear on the corresponding Zenodo model upload homepage model_description: list the model_description that will appear on the corresponding Zenodo model upload homepage Returns ------- str Returns the url pointing to the corresponding Zenodo model upload homepage """ if self.zenodo_model_uploader is None: self.zenodo_model_uploader = ZenodoModelUploader( model_id=self.model_id, access_token=access_token ) # Update in case previous access token gave an error self.zenodo_model_uploader.access_token = access_token return self.zenodo_model_uploader.push( metadata=self.metadata, package_path=self.package_path, package_name=self.package_name, creator_name=creator_name, creator_affiliation=creator_affiliation, model_description=model_description, ) def push_to_github( self, access_token: str, package_link: str = None, creator_name: str = "", creator_affiliation: str = "", model_description: str = "", ): """Upload the model's metadata inside a github issue to the medigan github repository. To add your model to medigan, your metadata will be reviewed on Github and added to medigan's official model metadata The medigan repository issues page: https://github.com/RichardObi/medigan/issues Get your Github access token here: https://github.com/settings/tokens Parameters ---------- access_token: str a personal access token linked to your github user account, used as means of authentication package_link: a package link creator_name: str the creator name that will appear on the corresponding github issue creator_affiliation: str the creator affiliation that will appear on the corresponding github issue model_description: list the model_description that will appear on the corresponding github issue Returns ------- str Returns the url pointing to the corresponding issue on github """ if self.github_model_uploader is None: self.github_model_uploader = GithubModelUploader( model_id=self.model_id, access_token=access_token ) # Update in case previous access token gave an error self.github_model_uploader.access_token = access_token return self.github_model_uploader.push( metadata=self.metadata, package_link=package_link, creator_name=creator_name, creator_affiliation=creator_affiliation, model_description=model_description, ) ############################ METADATA ############################ def load_metadata_template(self) -> dict: """Loads and parses (json to dict) a default medigan metadata template. Returns ------- dict Returns the metadata template as dict """ path_to_metadata_template = Path( f"{TEMPLATE_FOLDER}/{CONFIG_TEMPLATE_FILE_NAME_AND_EXTENSION}" ) Utils.mkdirs(TEMPLATE_FOLDER) Utils.is_file_located_or_downloaded( download_link=CONFIG_TEMPLATE_FILE_URL, path_as_string=path_to_metadata_template, ) metadata_template = Utils.read_in_json(path_as_string=path_to_metadata_template) if self.model_id is not None: # Replacing the placeholder id of template with model_id metadata_template[self.model_id] = metadata_template[ list(metadata_template)[0] ] del metadata_template[list(metadata_template)[0]] return metadata_template def add_metadata_from_file(self, metadata_file_path) -> dict: """Read and parse the metadata of a local model, identified by `model_id`, from a metadata file in json format. Parameters ---------- model_id: str The generative model's unique id metadata_file_path: str the path pointing to the metadata file Returns ------- dict Returns a dict containing the contents of parsed metadata json file. """ if Path(metadata_file_path).is_file(): self.metadata = Utils.read_in_json(path_as_string=metadata_file_path) self.metadata_file_path = metadata_file_path elif Path(metadata_file_path + "/metadata.json").is_file(): self.metadata = Utils.read_in_json( path_as_string=metadata_file_path + "/metadata.json" ) self.metadata_file_path = metadata_file_path + "/metadata.json" else: raise FileNotFoundError( f"{self.model_id}: No metadata json file was found in the path you provided ({metadata_file_path}). " f"If you do not have a metadata file, create one using the add_metadata_from_input() function." ) self.validate_and_update_model_weights_path() return self.metadata def add_metadata_from_input( self, model_weights_name: str = None, model_weights_extension: str = None, generate_method_name: str = None, dependencies: list = [], fill_more_fields_interactively: bool = True, output_path: str = "config", ): """Create a metadata dict for a local model, identified by `model_id`, given the necessary minimum metadata contents. Parameters ---------- model_id: str The generative model's unique id model_weights_name: str the name of the checkpoint file containing the model's weights model_weights_extension: str the extension (e.g. .pt) of the checkpoint file containing the model's weights generate_method_name: str the name of the sample generation method inside the models __init__.py file dependencies: list the list of dependencies that need to be installed via pip to run the model. fill_more_fields_interactively: bool flag indicating whether a user will be interactively asked via command line for further input to fill out missing metadata content output_path: str the path where the created metadata json file will be stored. Returns ------- dict Returns a dict containing the contents of the metadata json file. """ # Get the metadata template to guide data structure and formatting of metadata. self.metadata_template = self.load_metadata_template() # Generate metadata with variables provided as parameters metadata = self.metadata_template[self.model_id][CONFIG_FILE_KEY_EXECUTION] metadata.update({CONFIG_FILE_KEY_PACKAGE_LINK: self.package_path}) metadata.update({CONFIG_FILE_KEY_PACKAGE_NAME: self.package_name}) metadata.update({CONFIG_FILE_KEY_MODEL_NAME: model_weights_name}) metadata.update({CONFIG_FILE_KEY_MODEL_EXTENSION: model_weights_extension}) metadata.update({CONFIG_FILE_KEY_DEPENDENCIES: dependencies}) metadata[CONFIG_FILE_KEY_GENERATE][ CONFIG_FILE_KEY_GENERATE_NAME ] = generate_method_name metadata_final = self.metadata_template metadata_final[self.model_id].update({CONFIG_FILE_KEY_EXECUTION: metadata}) Utils.store_dict_as( dictionary=metadata_final, extension=".json", output_path=output_path, filename=self.model_id, ) logging.info( f"{self.model_id}: Your model's metadata was stored in {output_path}." ) if fill_more_fields_interactively: # Add more information to the metadata dict via user prompts metadata_final = self._recursively_fill_metadata(metadata=metadata_final) # Store again as additional fields should have now been filled Utils.store_dict_as( dictionary=metadata_final, extension=".json", output_path=output_path, filename=self.model_id, ) logging.info( f"{self.model_id}: Your model's metadata was updated. Find it in {output_path}/{self.model_id}.json" ) self.metadata = metadata_final self.validate_and_update_model_weights_path() return self.metadata def is_value_for_key_already_set( self, key: str, metadata: dict, nested_key ) -> bool: """Check if the value of a `key` in a `metadata` dictionary is already set and e.g. not an empty string, dict or list. Parameters ---------- key: str The key in the currently traversed part of the model's metadata dictionary metadata: dict The currently traversed part of the model's metadata dictionary nested_key: str the `nested_key` indicates which subpart of the model's metadata we are currently traversing Returns ------- bool Flag indicating whether a value exists for the `key` in the dict """ if ( metadata.get(key) is None or metadata.get(key) == "" or (isinstance(metadata.get(key), list) and not metadata.get(key)) or isinstance(metadata.get(key), dict) ): # Note: If metadata.get(key) is referencing a dict, we always want to go inside the dict and add values. return False else: logging.debug( f"{self.model_id}: Key value pair ({key}:{metadata.get(key)}) already exists in metadata for key " f"'{nested_key}'. Not prompting user to insert value for this key." ) return True def _recursively_fill_metadata( self, metadata_template: dict = None, metadata: dict = {}, nested_key: str = "" ) -> dict: """Filling a model metadata template with values retrieved via user input prompts and by traversing nested dicts and list recursively. Parameters ---------- metadata_template: dict The template containing all keys expected in a model's metadata dictionary. metadata: dict The currently traversed part of the model's metadata dictionary nested_key: str the `nested_key` indicates which subpart of the model's metadata we are currently traversing Returns ------- dict The final fully filled metadata dictionary. """ if metadata_template is None: metadata_template = self.metadata_template # Prompt user for optional metadata input retrieved_nested_key = nested_key for key in metadata_template: # nested_key to know where we are inside the metadata dict. nested_key = ( key if retrieved_nested_key == "" else f"{retrieved_nested_key}.{key}" ) if not self.is_value_for_key_already_set( key=key, metadata=metadata, nested_key=nested_key ): value_template = metadata_template.get(key) if value_template is None: input_value = input( f"{self.model_id}: Please enter value of type float or int for your model for key '{nested_key}': " ) try: value_assigned = float(input_value.replace(",", ".")) except ValueError: value_assigned = ( int(input_value) if input_value.isdigit() else None ) elif isinstance(value_template, list): input_value = input( f"{self.model_id}: Please enter a comma-separated list of values for your model for key: '{nested_key}': " ) value_assigned = ( [value.strip() for value in input_value.split(",")] if input_value != "" else [] ) elif isinstance(value_template, str): value_assigned = str( input( f"{self.model_id}: Please enter value of type string for your model for key '{nested_key}': " ) ) elif isinstance(value_template, dict): if len(value_template) == 0: # If dict is empty, no recursion. Instead, we ask the user directly for input. iterations = int( input( f"{self.model_id}: How many key-value pairs do you want to nest below key '{nested_key}' " f"in your model's metadata. Type a number: " ) or "0" ) nested_metadata: dict = {} for i in range(iterations): nested_key_input = str( input(f"{self.model_id}: Enter key {i + 1}: ") ) nested_value_input = input( f"{self.model_id}: For key{i + 1}={nested_key_input}, enter value: " ) nested_metadata.update( {nested_key_input: nested_value_input} ) value_assigned = nested_metadata else: # From metadata, get the nested dict below the key. If metadata has no nested dict, get the # template's nested dict instead, which is stored in value_template temp_metadata = ( metadata.get(key) if metadata.get(key) is not None else value_template ) # Filling nested dicts via recursion. value_assigned is of type dict in this case. value_assigned = self._recursively_fill_metadata( metadata_template=value_template, nested_key=nested_key, metadata=temp_metadata, ) logging.debug( f"{self.model_id}: You provided this key-value pair: {key}={value_assigned}" ) metadata.update({key: value_assigned}) return metadata def __repr__(self): return f"ModelContributor(model_id={self.model_id}, metadata={self.metadata})" def __len__(self): raise NotImplementedError def __getitem__(self, idx: int): raise NotImplementedError ================================================ FILE: src/medigan/contribute_model/zenodo_model_uploader.py ================================================ # -*- coding: utf-8 -*- # ! /usr/bin/env python """Zenodo Model uploader class that uploads models to medigan associated data storage on Zenodo. """ from __future__ import absolute_import import json import logging import shutil from pathlib import Path import requests from ..constants import ( CONFIG_FILE_KEY_DESCRIPTION, CONFIG_FILE_KEY_SELECTION, CONFIG_FILE_KEY_TAGS, ZENODO_API_URL, ZENODO_GENERIC_MODEL_DESCRIPTION, ZENODO_HEADERS, ZENODO_LINE_BREAK, ) from .base_model_uploader import BaseModelUploader class ZenodoModelUploader(BaseModelUploader): """`ZenodoModelUploader` class: Uploads a user's model via API to Zenodo, here it is permanently stored with DOI. Parameters ---------- model_id: str The generative model's unique id access_token: str a personal access token in Zenodo linked to a user account for authentication Attributes ---------- model_id: str The generative model's unique id access_token: str a personal access token in Zenodo linked to a user account for authentication """ def __init__( self, model_id, access_token, ): self.model_id = model_id self.params = {"access_token": access_token} ############################ UPLOAD ############################ def create_upload_description( self, metadata: dict, model_description: str = "" ) -> str: """Create a string containing the textual description that will accompany the upload model files. The string contains tags and a text retrieved from the description subsection of the model metadata. Parameters ---------- metadata: dict The model's corresponding metadata model_description: str the model_description that will appear on the corresponding Zenodo model upload homepage Returns ------- str Returns the textual description of the model upload """ try: tags = f"{ZENODO_LINE_BREAK}
Tags:
{metadata[self.model_id][CONFIG_FILE_KEY_SELECTION][CONFIG_FILE_KEY_TAGS]}" except: tags = "" try: description_from_config = f"Description from model config:
: {json.dumps(metadata[self.model_id][CONFIG_FILE_KEY_DESCRIPTION])}" except: description_from_config = "" return f"{model_description}Model ID:
{self.model_id}. {ZENODO_LINE_BREAK}Uploaded via:
API {tags} {ZENODO_LINE_BREAK} {ZENODO_GENERIC_MODEL_DESCRIPTION.replace('YOUR_MODEL_ID', self.model_id)} {description_from_config} {ZENODO_LINE_BREAK}" def create_upload_json_data( self, creator_name: str, creator_affiliation: str, description: str = "" ) -> dict: """Create some descriptive data in dict format to be uploaded and stored alongside the model files. Parameters ---------- creator_name: str the creator name that will appear on the corresponding Zenodo model upload homepage creator_affiliation: str the creator affiliation that will appear on the corresponding Zenodo model upload homepage description: str the model_description that will appear on the corresponding Zenodo model upload homepage Returns ------- dict Returns the descriptive data in dictionary structure describing the model upload. """ return { "metadata": { "title": f"MEDIGAN MODEL UPLOAD: {self.model_id}", "upload_type": "software", "description": description, "creators": [ { "name": f"{creator_name}", "affiliation": f"{creator_affiliation}", } ], } } def locate_or_create_model_zip_file( self, package_path: str, package_name: str ) -> (str, str): """If not possible to locate, create a zipped python package of the model. Parameters ---------- package_path: str Path as string to the generative model's python package containing an `__init__.py` file package_name: str Name of the model's python package i.e. the name of the model's zip file and unzipped package folder Returns ------- tuple Returns a tuple containing two strings: The `filename` and the `file_path` of and to the zipped python package """ # Check if zip file already exists if not (Path(package_path).is_file() and package_path.endswith(".zip")): # Create a zip archive containing the model package and store that zip file inside the # folder of the model package package_parent_path = str(Path(package_path).parent) logging.info( f"Archiving the model package as zip archive: base_name={package_parent_path+ '/' + package_name}, root_dir={package_path + '/'} " ) filename = shutil.make_archive( base_name=package_parent_path + "/" + package_name, format="zip", root_dir=package_path, ) file_path = filename filename = Path(file_path).name else: filename = Path(package_path).name file_path = package_path logging.info( f"Model was successfully archived as zip archive: filename={filename}, file_path={file_path} " ) return filename, file_path def empty_upload(self) -> dict: """Upload an empty placeholder entry to Zenodo as is required to retrieve a `deposition_id` and `bucket_url`. deposition_id and bucket_url aare needed for file upload and publishing in the subsequent upload steps. Returns ------- dict Returns the response retrieved via the Zenodo API """ r = requests.post( ZENODO_API_URL, params=self.params, json={}, headers=ZENODO_HEADERS, ) if not r.status_code == 201: raise Exception( f"{self.model_id}: Error ({r.status_code}!=201) during Zenodo ('{ZENODO_API_URL}') upload (step 1: creating empty upload template): {r.json()}." ) return r def upload(self, file_path: str, filename: str, bucket_url: str) -> dict: """Upload a file to Zenodo entry of the uploaded model files. Parameters ---------- file_path: str The path of the file that is uploaded to Zenodo filename: str The name of the file that is uploaded to Zenodo bucket_url: str The bucket url used in the PUT request to upload the data file. Returns ------- dict Returns the response retrieved via the Zenodo API """ with open(file_path, "rb") as fp: r = requests.put( "%s/%s" % (bucket_url, filename), data=fp, params=self.params, ) if not r.status_code == 200 and not r.status_code == 201: raise Exception( f"{self.model_id}: Error ({r.status_code}!= any of (200, 201) ) during Zenodo ('{bucket_url}') upload (step 2: uploading model as zip file): {r.json()}" ) return r def upload_descriptive_data(self, deposition_id: str, data: dict) -> dict: """Upload textual descriptive data to be associated and added to the Zenodo entry of the uploaded model files. Parameters ---------- deposition_id: str The deposition id assigned by Zenodo to the uploaded model file data: dict The descriptive information that will to be uploaded to Zenodo and associated with the desposition_id Returns ------- dict Returns the response retrieved via the Zenodo API """ deposition_url = f"{ZENODO_API_URL}/{deposition_id}" r = requests.put( deposition_url, params=self.params, data=json.dumps(data), headers=ZENODO_HEADERS, ) if not r.status_code == 200 and not r.status_code == 201: raise Exception( f"{self.model_id}: Error ({r.status_code}!= any of (200, 201) ) during Zenodo ('{deposition_url}') upload (step 3: updating metadata): {r.json()}" ) return r def publish(self, deposition_id: str) -> dict: """Publish a zenodo upload. This makes the upload official, as it will then be publicly accessible and persistently stored on Zenodo with associated DOI. Parameters ---------- deposition_id: str The deposition id assigned by Zenodo to the uploaded model file Returns ------- dict Returns the response retrieved via the Zenodo API """ # Get explicit user approval to publish on Zenodo. Published files cannot be deleted. is_user_sure = str( input( f"You are about to publish model {self.model_id} with Zenodo-ID {deposition_id} permanently on {ZENODO_API_URL.replace('/api/deposit/depositions','')}. To proceed, type 'Yes': " ) ) publish_url = f"{ZENODO_API_URL}/{deposition_id}/actions/publish" if is_user_sure == "Yes": r = requests.post( publish_url, params=self.params, ) else: raise Exception( f"{self.model_id}: Error during Zenodo ('{publish_url}') upload (step 4: publishing uploaded model) due to user opt-out: You typed '{is_user_sure}' instead of 'Yes'. Model was not published. Try again. Your Zenodo deposition ID (if retrieved): '{deposition_id}'." ) if not r.status_code == 202: raise Exception( f"{self.model_id}: Error ({r.status_code}!=202) during Zenodo ('{publish_url}') upload (step 4: publishing uploaded model): {r.json()}" ) logging.info( f"{self.model_id}: Successfully pushed model to Zenodo with DOI '{r.json()['doi']}': '{r.json()['links']['record_html']}" ) logging.debug( f"{self.model_id}: Full Zenodo API response after successful publishing of model: {r.json()}" ) return r def push( self, metadata: dict, package_path: str, package_name: str, creator_name: str, creator_affiliation: str, model_description: str = "", ): """Upload the model files as zip archive to a public Zenodo repository where the model will be persistently stored. Get your Zenodo access token here: https://zenodo.org/account/settings/applications/tokens/new/ (Enable scopes `deposit:actions` and `deposit:write`) Parameters ---------- metadata: dict The model's corresponding metadata package_path: dict The path to the packaged model files package_name: dict The name of the packaged model files creator_name: str the creator name that will appear on the corresponding Zenodo model upload homepage creator_affiliation: str the creator affiliation that will appear on the corresponding Zenodo model upload homepage model_description: list the model_description that will appear on the corresponding Zenodo model upload homepage Returns ------- str Returns the url pointing to the corresponding Zenodo model upload homepage """ # Check if zip file exists, else create new one for upload. filename, file_path = self.locate_or_create_model_zip_file( package_path=package_path, package_name=package_name ) # create empty upload to Zenodo to get deposition_id and bucket_url response = self.empty_upload() logging.debug(f"API Response after creating empty upload template: {response}") # Get the deposition id from the response deposition_id = response.json()["id"] # Using bucket as defined by Zenodo API for zip file model upload bucket_url = response.json()["links"]["bucket"] logging.info( f"Starting Zenodo upload of model with deposition_id {deposition_id} to {bucket_url}" ) response = self.upload( file_path=file_path, filename=filename, bucket_url=bucket_url, ) logging.debug( f"API Response after uploading model to '{bucket_url}': {response}" ) # get the model description i.e. model type, metadata info, etc. description = self.create_upload_description( metadata=metadata, model_description=model_description ) # get the data that includes description, but also creator information data = self.create_upload_json_data( description=description, creator_name=creator_name, creator_affiliation=creator_affiliation, ) # upload the model zip file and its descriptive data response = self.upload_descriptive_data(deposition_id=deposition_id, data=data) logging.debug( f"API Response after uploading descriptive model data: {response}" ) # publish to Zenodo. Model will get DOI after this step and become part of Zenodo's permanent record. response = self.publish(deposition_id=deposition_id) logging.debug( f"API Response after publishing the deposition {deposition_id} on Zenodo: {response}" ) return response.json()["links"]["record_html"] # zenodo_record_url def __repr__(self): return f"ZenodoModelUploader(model_id={self.model_id}, zenodo_url={ZENODO_API_URL})" def __len__(self): raise NotImplementedError def __getitem__(self, idx: int): raise NotImplementedError ================================================ FILE: src/medigan/exceptions.py ================================================ # -*- coding: utf-8 -*- #! /usr/bin/env python """Custom exceptions to handle module specific error and facilitate bug fixes and debugging.""" # TODO Add custom exceptions for improved exception handling. # See requests.exceptions for reference. ================================================ FILE: src/medigan/execute_model/__init__.py ================================================ ================================================ FILE: src/medigan/execute_model/install_model_dependencies.py ================================================ # -*- coding: utf-8 -*- # ! /usr/bin/env python """ Functionality for automated installation of a model's python package dependencies. """ import argparse import subprocess import sys try: # if called as script (__main__) or from inside medigan from ..config_manager import ConfigManager from ..constants import CONFIG_FILE_KEY_DEPENDENCIES, CONFIG_FILE_KEY_EXECUTION except: # if called from outside medigan from medigan.config_manager import ConfigManager from medigan.constants import ( CONFIG_FILE_KEY_DEPENDENCIES, CONFIG_FILE_KEY_EXECUTION, ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument( "--model_id", type=str, default=None, nargs="+", help="Model ids to install dependencies for", ) args = parser.parse_args() return args def install_model( model_id: str, config_manager: ConfigManager = None, execution_config: dict = None ): """installing the dependencies required for this model as stated in config""" if execution_config is None: if config_manager is None: config_manager = ConfigManager() config = config_manager.get_config_by_id(model_id) execution_config = config[CONFIG_FILE_KEY_EXECUTION] dependencies = execution_config[CONFIG_FILE_KEY_DEPENDENCIES] for package in dependencies: subprocess.check_call([sys.executable, "-m", "pip", "install", package]) if __name__ == "__main__": """ This script is used to install dependencies for models. If no model_id is provided, all models from the config file are installed. """ args = parse_args() config_manager = ConfigManager() if args.model_id: for model_id in args.model_id: install_model(model_id) else: for model_id in config_manager.config_dict.keys(): install_model(model_id) ================================================ FILE: src/medigan/execute_model/model_executor.py ================================================ # -*- coding: utf-8 -*- # ! /usr/bin/env python """Model executor class that downloads models, loads them as python packages, and runs their generate functions. """ # Import python native libs from __future__ import absolute_import import importlib import logging import os import time # Import pypi libs from pathlib import Path import pkg_resources from tqdm import tqdm # Import library internal modules from ..constants import ( CONFIG_FILE_KEY_DEPENDENCIES, CONFIG_FILE_KEY_GENERATE, CONFIG_FILE_KEY_GENERATE_ARGS, CONFIG_FILE_KEY_GENERATE_ARGS_BASE, CONFIG_FILE_KEY_GENERATE_ARGS_CUSTOM, CONFIG_FILE_KEY_GENERATE_ARGS_INPUT_LATENT_VECTOR_SIZE, CONFIG_FILE_KEY_GENERATE_ARGS_MODEL_FILE, CONFIG_FILE_KEY_GENERATE_ARGS_NUM_SAMPLES, CONFIG_FILE_KEY_GENERATE_ARGS_OUTPUT_PATH, CONFIG_FILE_KEY_GENERATE_ARGS_SAVE_IMAGES, CONFIG_FILE_KEY_GENERATE_NAME, CONFIG_FILE_KEY_IMAGE_SIZE, CONFIG_FILE_KEY_MODEL_EXTENSION, CONFIG_FILE_KEY_MODEL_NAME, CONFIG_FILE_KEY_PACKAGE_LINK, CONFIG_FILE_KEY_PACKAGE_NAME, DEFAULT_OUTPUT_FOLDER, MODEL_FOLDER, PACKAGE_EXTENSION, ) from ..utils import Utils from .install_model_dependencies import install_model class ModelExecutor: """`ModelExecutor` class: Find config links to download models, init models as python packages, run generate methods. Parameters ---------- model_id: str The generative model's unique id execution_config: dict The part of the config below the 'execution' key download_package: bool Flag indicating, if True, that the model's package should be downloaded instead of using an existing one that was downloaded previously install_dependencies: bool flag indicating whether a generative model's dependencies are automatically installed. Else error is raised if missing dependencies are detected. Attributes ---------- model_id: str The generative model's unique id execution_config: dict The part of the config below the 'execution' key download_package: bool Flag indicating, if True, that the model's package should be downloaded instead of using an existing one that was downloaded previously image_size: int Pixel dimension of the generated samples, where images are assumed to have the same width and height dependencies: list List of the dependencies of a models python package. model_name: str Name of the generative model model_extension: str File extension of the generative model's weights file. package_name: str Name of the model's python package i.e. the name of the model's zip file and unzipped package folder package_link: str The link to the zipped model package. Note: Convention is to host models on Zenodo (reason: static doi content) generate_method_name: str The name of the model's generate method inside the model package. This method is called to generate samples. generate_method_args: dict The args of the model's generate method inside the model package serialised_model_file_path: str Path as string to the generative model's weights file package_path: str Path as string to the generative model's python package containing an `__init__.py` file deserialized_model_as_lib The generative model's package imported as python library. Generate method inside this library can be called. """ def __init__( self, model_id: str, execution_config: dict, download_package: bool = True, install_dependencies: bool = False, ): self.model_id = model_id self.execution_config = execution_config self.download_package = download_package self.install_dependencies = install_dependencies self.image_size = None self.dependencies = None self.model_name = None self.model_extension = None self.package_name = None self.package_link = None self.generate_method_name = None self.generate_method_args = None self.generate_method_input_latent_vector_size = None self.serialised_model_file_path = None self.package_path = None self.deserialized_model_as_lib = None self._setup_model_package() def _setup_model_package(self): """Use specific keys to retrieve needed model config values and load and initialize the model as package.""" self.image_size = self.execution_config[CONFIG_FILE_KEY_IMAGE_SIZE] self.dependencies = self.execution_config[CONFIG_FILE_KEY_DEPENDENCIES] self.model_name = self.execution_config[CONFIG_FILE_KEY_MODEL_NAME] self.model_extension = self.execution_config[CONFIG_FILE_KEY_MODEL_EXTENSION] self.package_name = self.execution_config[CONFIG_FILE_KEY_PACKAGE_NAME] self.package_link = self.execution_config[CONFIG_FILE_KEY_PACKAGE_LINK] self.generate_method_name = self.execution_config[CONFIG_FILE_KEY_GENERATE][ CONFIG_FILE_KEY_GENERATE_NAME ] self.generate_method_args = self.execution_config[CONFIG_FILE_KEY_GENERATE][ CONFIG_FILE_KEY_GENERATE_ARGS ] if ( CONFIG_FILE_KEY_GENERATE_ARGS_INPUT_LATENT_VECTOR_SIZE in self.execution_config[CONFIG_FILE_KEY_GENERATE] ): self.generate_method_input_latent_vector_size = self.execution_config[ CONFIG_FILE_KEY_GENERATE ][CONFIG_FILE_KEY_GENERATE_ARGS_INPUT_LATENT_VECTOR_SIZE] self._check_package_resources() if not self.is_model_already_unpacked(): self._get_and_store_package() self._import_package_as_lib() def _check_package_resources(self): """Check if the dependencies inside the generative model's package are installed in the current setup.""" logging.debug( f"{self.model_id}: Now checking availability of dependencies of model: {self.dependencies}" ) try: pkg_resources.require(self.dependencies) logging.debug( f"{self.model_id}: All necessary dependencies for model are available: {self.dependencies}" ) except Exception as e: if self.install_dependencies: logging.info( f"{self.model_id}: Now installing dependencies using pip for model {self.dependencies}. This may take a few minutes." ) install_model( model_id=self.model_id, execution_config=self.execution_config ) else: raise Exception( f"{self.model_id}: Some of the necessary dependencies ({self.dependencies}) for this model " f"are missing. Either set install_dependencies=True or manually run 'python src/medigan/install_model_dependencies.py --model_id {self.model_id}' to install them. Error: {e}" ) def _get_and_store_package(self): """Load and store the generative model's python package using the link from the model's `execution_config`.""" if self.package_path is None: assert Utils.mkdirs(path_as_string=f"{MODEL_FOLDER}/{self.model_id}"), ( f"{self.model_id}: The model folder was not found nor created " f"in {MODEL_FOLDER}/{self.model_id}." ) package_path = Path( f"{MODEL_FOLDER}/{self.model_id}/{self.package_name}{PACKAGE_EXTENSION}" ) try: if not Utils.is_file_located_or_downloaded( path_as_string=package_path, download_if_not_found=True, download_link=self.package_link, ): error_string = ( f"{self.model_id}: The package archive ({self.package_name}{PACKAGE_EXTENSION}) " f"was not found in {package_path} nor downloaded from {self.package_link}." ) raise FileNotFoundError(error_string) except Exception as e: raise e self.package_path = package_path logging.info( f"{self.model_id}: Model package should now be available in: {self.package_path}." ) def is_model_already_unpacked(self) -> bool: """Check if a valid path to the model files exists and, if so, set the `package_path`""" path_option_1 = Path( f"{MODEL_FOLDER}/{self.model_id}/{self.package_name}/{self.model_name}{self.model_extension}" ) path_option_2 = Path( f"{MODEL_FOLDER}/{self.model_id}/{self.model_name}{self.model_extension}" ) if path_option_1.is_file(): self.package_path = path_option_1 return True if path_option_2.is_file(): self.package_path = path_option_2 return True return False def _import_package_as_lib(self): """Unzip and import the generative model's python package using importlib.""" logging.debug( f"{self.model_id}: Now importing model package ({self.package_name}) as lib using " f"importlib from {self.package_path}." ) is_model_already_unpacked = self.is_model_already_unpacked() # if is_model_already_unpacked == True, then the package was already unzipped previously. if ( self.package_path.is_file() and PACKAGE_EXTENSION == ".zip" and not is_model_already_unpacked ): # Unzip the model package in {MODEL_FOLDER}/{model_id}/{MODEL_PACKAGE}{PACKAGE_EXTENSION} Utils.unzip_archive( source_path=self.package_path, target_path=f"{MODEL_FOLDER}/{self.model_id}", ) else: logging.debug( f"{self.model_id}: Either no file found (== {self.package_path.is_file()}) or package " f"already unarchived (=={is_model_already_unpacked}) in {self.package_path}. " f"No action was taken." ) try: # Installing generative model as python library self.deserialized_model_as_lib = importlib.import_module( name=f"{MODEL_FOLDER}.{self.model_id}.{self.package_name}" ) if not hasattr( self.deserialized_model_as_lib, f"{self.generate_method_name}" ): # if generate method is not in lib path, generating samples will not work. Next: Check fallback folder. raise ModuleNotFoundError self.serialised_model_file_path = f"{MODEL_FOLDER}/{self.model_id}/{self.package_name}/{self.model_name}{self.model_extension}" except ModuleNotFoundError: try: # Fallback: The zip's content might have been unzipped in the model_id folder without generating the package_name subfolder. self.deserialized_model_as_lib = importlib.import_module( name=f"{MODEL_FOLDER}.{self.model_id}" ) if not hasattr( self.deserialized_model_as_lib, f"{self.generate_method_name}" ): # if generate method is not in lib path, generating samples will not work. Next: Check fallback folder. raise AttributeError( f"Module '{MODEL_FOLDER}.{self.model_id}' has no attribute " f"'{self.generate_method_name}' (generate method). We also tried module " f"'{MODEL_FOLDER}.{self.model_id}.{self.package_name}'. Please check if " f"generate_method_name and package_name are correct for this model in its " f"global.json entry." ) self.serialised_model_file_path = f"{MODEL_FOLDER}/{self.model_id}/{self.model_name}{self.model_extension}" except Exception as e: logging.error( f"{self.model_id}: Error occurred while trying to import " f"'{MODEL_FOLDER}.{self.model_id}.{self.package_name}'." f"Fallback import of '{MODEL_FOLDER}.{self.model_id}' also failed. " f"Please make sure the module '{MODEL_FOLDER}' is not imported from elsewhere in your syspath: {e}" ) raise e def generate( self, num_samples: int = 20, output_path: str = None, save_images: bool = True, is_gen_function_returned: bool = False, batch_size: int = 32, **kwargs, ): """Generate samples using the generative model or return the model's generate function. The name amd additional parameters of the generate function of the respective generative model are retrieved from the model's `execution_config`. Parameters ---------- num_samples: int the number of samples that will be generated output_path: str the path as str to the output folder where the generated samples will be stored save_images: bool flag indicating whether generated samples are returned (i.e. as list of numpy arrays) or rather stored in file system (i.e in `output_path`) is_gen_function_returned: bool flag indicating whether, instead of generating samples, the sample generation function will be returned batch_size: int the batch size for the sample generation function **kwargs arbitrary number of keyword arguments passed to the model's sample generation function Returns ------- list Returns images as list of numpy arrays if `save_images` is False. However, if `is_gen_function_returned` is True, it returns the internal generate function of the model. Raises ------ Exception If the generate method of the model does not exist, cannot be called, or is called with missing params, or if the sample generation inside the model package returns an exception. """ if output_path is None: output_path = f"{DEFAULT_OUTPUT_FOLDER}/{self.model_id}/{time.time()}/" assert Utils.mkdirs( path_as_string=output_path ), f"{self.model_id}: The output folder was not found nor created in {output_path}." try: generate_method = getattr( self.deserialized_model_as_lib, f"{self.generate_method_name}" ) prepared_kwargs = self._prepare_generate_method_args( model_file=self.serialised_model_file_path, num_samples=num_samples, output_path=output_path, save_images=save_images, **kwargs, ) logging.debug(f"The generate function's parameters are: {prepared_kwargs}") if is_gen_function_returned: def gen(**some_other_kwargs): logging.debug( f"Generate method called with the following params. (i) default: {prepared_kwargs}, " f"(ii) custom: {some_other_kwargs}" ) prepared_kwargs.update(some_other_kwargs) return generate_method(**prepared_kwargs) return gen elif save_images: sample_index = 1 prepared_kwargs.update({"num_samples": batch_size}) for batch_num in tqdm(range(0, num_samples // batch_size + 1)): if batch_num == num_samples // batch_size: batch_size = num_samples % batch_size prepared_kwargs.update({"num_samples": batch_size}) batch_path = ( os.path.join(output_path, "batch_" + str(batch_num)) + "/" ) # Generate the path in case it is not yet available. assert Utils.mkdirs( path_as_string=batch_path ), f"{self.model_id}: The batch path was not found nor created in {batch_path}." prepared_kwargs.update({"output_path": batch_path}) generate_method(**prepared_kwargs) for filename in os.listdir(batch_path): os.rename( os.path.join(batch_path, filename), os.path.join( output_path, "batch_" + str(batch_num) + "_" + filename ), ) sample_index += 1 os.rmdir(batch_path) else: return generate_method(**prepared_kwargs) except Exception as e: logging.error( f"{self.model_id}: Error while trying to generate images with model " f"{self.serialised_model_file_path}: {e}" ) raise e def _prepare_generate_method_args( self, model_file: str, num_samples: int, output_path: str, save_images: bool, **kwargs, ): """Prepare the keyword arguments that will be passed to the models generate function. Prepares the keyword arguments that need to be passed to the generative model's generate function to generate samples. This contains the steps: - Update keyword args dict with default values for all params from model config - Update keyword args dict with the `**args` provided by user thus overwriting the previously set default values for which user has provided key-value pairs. - Checking if all mandatory 'base' values are set in model config. - Update keyword args dict with 'base' key-value pairs, which i.e. are the param values for `model_file`, `num_samples` and `output_path`, thus overwriting any previously set value for these keys. - Returning the updated and prepared keyword args dict Parameters ---------- model_file : str the path to the serialized weights of the generative model. num_samples: int the number of samples that will be generated output_path: str the path as str to the output folder where the generated samples will be stored **kwargs arbitrary number of keyword arguments passed to the model's sample generation function Returns ------- dict kwargs as dictionary containing both user input params (prioritized) and config input params of the model """ prepared_kwargs: dict = {} # get keys of mandatory custom dictionary input args and assign the default value from config to values of keys prepared_kwargs.update( self.generate_method_args[CONFIG_FILE_KEY_GENERATE_ARGS_CUSTOM] ) # update: If one of these keys was provided in **kwargs, then change default value to value provided in **kwargs prepared_kwargs.update(kwargs) try: # validating that these specific keys are available in the config. also retrieving default values base_config_list = [ self.generate_method_args[CONFIG_FILE_KEY_GENERATE_ARGS_BASE][0], self.generate_method_args[CONFIG_FILE_KEY_GENERATE_ARGS_BASE][1], self.generate_method_args[CONFIG_FILE_KEY_GENERATE_ARGS_BASE][2], self.generate_method_args[CONFIG_FILE_KEY_GENERATE_ARGS_BASE][3], ] if not all( x in base_config_list for x in [ CONFIG_FILE_KEY_GENERATE_ARGS_MODEL_FILE, CONFIG_FILE_KEY_GENERATE_ARGS_NUM_SAMPLES, CONFIG_FILE_KEY_GENERATE_ARGS_OUTPUT_PATH, CONFIG_FILE_KEY_GENERATE_ARGS_SAVE_IMAGES, ] ): raise KeyError except KeyError as e: logging.warning( f"{self.model_id}: Warning: In this model's generate args ({self.generate_method_args}), some " f"required generate method keys ({CONFIG_FILE_KEY_GENERATE_ARGS_MODEL_FILE}, " f"{CONFIG_FILE_KEY_GENERATE_ARGS_NUM_SAMPLES}, {CONFIG_FILE_KEY_GENERATE_ARGS_OUTPUT_PATH}, " f"{CONFIG_FILE_KEY_GENERATE_ARGS_SAVE_IMAGES}) are missing: {e}. A value for this key will be " f"provided nevertheless when calling the model's generate method ({self.generate_method_name})'. " f"This could hence cause an error." ) # Adding the always necessary base parameters to kwargs. They are updated if erroneously # introduced via the user-provided kwargs. prepared_kwargs.update( { CONFIG_FILE_KEY_GENERATE_ARGS_MODEL_FILE: model_file, CONFIG_FILE_KEY_GENERATE_ARGS_NUM_SAMPLES: num_samples, CONFIG_FILE_KEY_GENERATE_ARGS_OUTPUT_PATH: output_path, CONFIG_FILE_KEY_GENERATE_ARGS_SAVE_IMAGES: save_images, } ) return prepared_kwargs def __repr__(self): return ( f"ModelExecutor(model_id={self.model_id}, name={self.model_name}, package={self.package_name}, " f"image_size={self.image_size}, dependencies={self.dependencies}, link={self.package_link}, " f"path={self.serialised_model_file_path}, generate_method={self.generate_method_name}, " f"generate_method_args={self.generate_method_args})" ) def __len__(self): raise NotImplementedError def __getitem__(self, idx: int): raise NotImplementedError ================================================ FILE: src/medigan/execute_model/synthetic_dataset.py ================================================ # -*- coding: utf-8 -*- # ! /usr/bin/env python """ `SyntheticDataset` allows to return a generative model as torch dataset. """ from torch.utils.data import Dataset class SyntheticDataset(Dataset): """A synthetic dataset containing data generated by a model of medigan Parameters ---------- samples: list List of data points in the dataset e.g. generated images as numpy array. masks: list List of segmentation masks, if applicable, pertaining to the `samples` items other_imaging_output: list List of other imaging output produced by the generative model (e.g. specific types of other masks/modalities) labels: list list of labels, if applicable, pertaining to the `samples` items transform: torch compose transform functions that are applied to the torch dataset. Attributes ---------- samples: list List of data points in the dataset e.g. generated images as numpy array. masks: list List of segmentation masks, if applicable, pertaining to the `samples` items other_imaging_output: list List of other imaging output produced by the generative model (e.g. specific types of other masks/modalities) labels: list list of labels, if applicable, pertaining to the `samples` items transform: torch compose transform functions that are applied to the torch dataset. """ def __init__( self, samples, masks=None, other_imaging_output=None, labels=None, transform=None, ): self.samples = samples self.masks = masks self.other_imaging_output = other_imaging_output self.labels = labels self.transform = transform def __getitem__(self, index): x = self.samples[index] y = self.labels[index] if self.labels is not None else None mask = self.masks[index] if self.masks is not None else None other_imaging_output = ( self.other_imaging_output[index] if self.other_imaging_output is not None else None ) if self.transform: if mask is not None: if other_imaging_output is not None: x, mask, other_imaging_output = self.transform( x, mask, other_imaging_output ) # transform needs to be applied to both mask and image. x, mask = self.transform(x, mask) elif other_imaging_output is not None: x, other_imaging_output = self.transform(x, other_imaging_output) else: x = self.transform(x) item = {"sample": x} # extendable dictionary if y is not None: item["label"] = y if mask is not None: item["mask"] = mask if other_imaging_output is not None: item["other_imaging_output"] = other_imaging_output return item def __len__(self): return len(self.samples) ================================================ FILE: src/medigan/generators.py ================================================ # -*- coding: utf-8 -*- # ! /usr/bin/env python """ Base class providing user-library interaction methods for config management, and model selection and execution. """ # Import python native libs from __future__ import absolute_import import logging from torch.utils.data import DataLoader, Dataset # Import library internal modules from .config_manager import ConfigManager from .constants import CONFIG_FILE_KEY_EXECUTION, MODEL_ID from .contribute_model.model_contributor import ModelContributor from .execute_model.model_executor import ModelExecutor from .execute_model.synthetic_dataset import SyntheticDataset from .model_visualizer import ModelVisualizer from .select_model.model_selector import ModelSelector from .utils import Utils # Import pypi libs class Generators: """`Generators` class: Contains medigan's public methods to facilitate users' automated sample generation workflows. Parameters ---------- config_manager: ConfigManager Provides the config dictionary, based on which `model_ids` are retrieved and models are selected and executed model_selector: ModelSelector Provides model comparison, search, and selection based on keys/values in the selection part of the config dict model_executors: list List of initialized `ModelExecutor` instances that handle model package download, init, and sample generation initialize_all_models: bool Flag indicating, if True, that one `ModelExecutor` for each `model_id` in the config dict should be initialized triggered by creation of `Generators` class instance. Note that, if False, the `Generators` class will only initialize a `ModelExecutor` on the fly when need be i.e. when the generate method for the respective model is called. Attributes ---------- config_manager: ConfigManager Provides the config dictionary, based on which model_ids are retrieved and models are selected and executed model_selector: ModelSelector Provides model comparison, search, and selection based on keys/values in the selection part of the config dict model_executors: list List of initialized `ModelExecutor` instances that handle model package download, init, and sample generation """ def __init__( self, config_manager: ConfigManager = None, model_selector: ModelSelector = None, model_executors: list = None, model_contributors: list = None, initialize_all_models: bool = False, ): if config_manager is None: self.config_manager = ConfigManager() logging.debug(f"Initialized ConfigManager instance: {self.config_manager}") else: self.config_manager = config_manager if model_selector is None: self.model_selector = ModelSelector(config_manager=self.config_manager) logging.debug(f"Initialized ModelSelector instance: {self.model_selector}") else: self.model_selector = model_selector if model_executors is None: self.model_executors = [] else: self.model_executors = model_executors if model_contributors is None: self.model_contributors = [] else: self.model_contributors = model_contributors if initialize_all_models: self.add_all_model_executors() ############################ CONFIG MANAGER METHODS ############################ def get_config_by_id(self, model_id: str, config_key: str = None) -> dict: """Get and return the part of the config below a `config_key` for a specific `model_id`. The config_key parameters can be separated by a '.' (dot) to allow for retrieval of nested config keys, e.g, 'execution.generator.name' This function calls an identically named function in a `ConfigManager` instance. Parameters ---------- model_id: str The generative model's unique id config_key: str A key of interest present in the config dict Returns ------- dict a dictionary from the part of the config file corresponding to `model_id` and `config_key`. """ model_id = self.config_manager.match_model_id(provided_model_id=model_id) return self.config_manager.get_config_by_id( model_id=model_id, config_key=config_key ) def is_model_metadata_valid( self, model_id: str, metadata: dict, is_local_model: bool = True ) -> bool: """Checking if a model's corresponding metadata is valid. Specific fields in the model's metadata are mandatory. It is asserted if these key value pairs are present. Parameters ---------- model_id: str The generative model's unique id metadata: dict The model's corresponding metadata is_local_model: bool flag indicating whether the tested model is a new local user model i.e not yet part of medigan's official models Returns ------- bool Flag indicating whether the specific model's metadata format and fields are valid """ return self.config_manager.is_model_metadata_valid( model_id=model_id, metadata=metadata, is_local_model=is_local_model ) def add_model_to_config( self, model_id: str, metadata: dict, is_local_model: bool = None, overwrite_existing_metadata: bool = False, store_new_config: bool = True, ) -> bool: """Adding or updating a model entry in the global metadata. Parameters ---------- model_id: str The generative model's unique id metadata: dict The model's corresponding metadata is_local_model: bool flag indicating whether the tested model is a new local user model i.e not yet part of medigan's official models overwrite_existing_metadata: bool in case of `is_local_model`, flag indicating whether existing metadata for this model in medigan's `config/global.json` should be overwritten. store_new_config: bool flag indicating whether the current model metadata should be stored on disk i.e. in config/ Returns ------- bool Flag indicating whether model metadata update was successfully concluded """ if is_local_model is None: model_id = self.config_manager.match_model_id(provided_model_id=model_id) # if no model contributor can be found the model is assumed to be not a local model. is_local_model = not is_local_model == self.get_model_contributor_by_id( model_id=model_id ) return self.config_manager.add_model_to_config( model_id=model_id, metadata=metadata, is_local_model=is_local_model, overwrite_existing_metadata=overwrite_existing_metadata, store_new_config=store_new_config, ) ############################ MODEL SELECTOR METHODS ############################ def list_models(self) -> list: """Return the list of model_ids as strings based on config. Returns ------- list """ return [model_id for model_id in self.config_manager.model_ids] def get_selection_criteria_by_id( self, model_id: str, is_model_id_removed: bool = True ) -> dict: """Get and return the selection config dict for a specific model_id. This function calls an identically named function in a `ModelSelector` instance. Parameters ---------- model_id: str The generative model's unique id is_model_id_removed: bool flag to to remove the model_ids from first level of dictionary. Returns ------- dict a dictionary corresponding to the selection config of a model """ model_id = self.config_manager.match_model_id(provided_model_id=model_id) return self.model_selector.get_selection_criteria_by_id(model_id=model_id) def get_selection_criteria_by_ids( self, model_ids: list = None, are_model_ids_removed: bool = True ) -> list: """Get and return a list of selection config dicts for each of the specified model_ids. This function calls an identically named function in a `ModelSelector` instance. Parameters ---------- model_ids: list A list of generative models' unique ids or ids abbreviated as integers (e.g. 1, 2, .. 21) are_model_ids_removed: bool flag to remove the model_ids from first level of dictionary. Returns ------- list a list of dictionaries each corresponding to the selection config of a model """ mapped_model_ids = [] for model_id in model_ids: mapped_model_ids.append( self.config_manager.match_model_id(provided_model_id=model_id) ) return self.model_selector.get_selection_criteria_by_ids( model_ids=mapped_model_ids, are_model_ids_removed=are_model_ids_removed ) def get_selection_values_for_key(self, key: str, model_id: str = None) -> list: """Get and return the value of a specified key of the selection dict in the config for a specific model_id. The key param can contain '.' (dot) separations to allow for retrieval of nested config keys such as 'execution.generator.name' This function calls an identically named function in a `ModelSelector` instance. Parameters ---------- key: str The key in the selection dict model_id: str The generative model's unique id Returns ------- list a list of the values that correspond to the key in the selection config of the `model_id`. """ return self.model_selector.get_selection_values_for_key( key=key, model_id=model_id ) def get_selection_keys(self, model_id: str = None) -> list: """Get and return all first level keys from the selection config dict for a specific model_id. This function calls an identically named function in a `ModelSelector` instance. Parameters ---------- model_id: str The generative model's unique id Returns ------- list a list containing the keys as strings of the selection config of the `model_id`. """ return self.model_selector.get_selection_keys(model_id=model_id) def get_models_by_key_value_pair( self, key1: str, value1: str, is_case_sensitive: bool = False ) -> list: """Get and return a list of model_id dicts that contain the specified key value pair in their selection config. The key param can contain '.' (dot) separations to allow for retrieval of nested config keys such as 'execution.generator.name' This function calls an identically named function in a `ModelSelector` instance. Parameters ---------- key1: str The key in the selection dict value1: str The value in the selection dict that corresponds to key1 is_case_sensitive: bool flag to evaluate keys and values with case sensitivity if set to True Returns ------- list a list of the dictionaries each containing a models id and the found key-value pair in the models config """ return self.model_selector.get_models_by_key_value_pair( key1=key1, value1=value1, is_case_sensitive=is_case_sensitive ) def rank_models_by_performance( self, model_ids: list = None, metric: str = "SSIM", order: str = "asc" ) -> list: """Rank model based on a provided metric and return sorted list of model dicts. The metric param can contain '.' (dot) separations to allow for retrieval of nested metric config keys such as 'downstream_task.CLF.accuracy' This function calls an identically named function in a `ModelSelector` instance. Parameters ---------- model_ids: list only evaluate the `model_ids` in this list. If none, evaluate all available `model_ids` metric: str The key in the selection dict that corresponds to the metric of interest order: str the sorting order of the ranked results. Should be either "asc" (ascending) or "desc" (descending) Returns ------- list a list of model dictionaries containing metric and `model_id`, sorted by metric. """ return self.model_selector.rank_models_by_performance( model_ids=model_ids, metric=metric, order=order ) def find_matching_models_by_values( self, values: list, target_values_operator: str = "AND", are_keys_also_matched: bool = False, is_case_sensitive: bool = False, ) -> list: """Search for values (and keys) in model configs and return a list of each matching `ModelMatchCandidate`. This function calls an identically named function in a `ModelSelector` instance. Parameters ---------- values: list list of values used to search and find models corresponding to these `values` target_values_operator: str the operator indicating the relationship between `values` in the evaluation of model search results. Should be either "AND", "OR", or "XOR". are_keys_also_matched: bool flag indicating whether, apart from values, the keys in the model config should also be searchable is_case_sensitive: bool flag indicating whether the search for values (and) keys in the model config should be case-sensitive. Returns ------- list a list of `ModelMatchCandidate` class instances each of which was successfully matched against the search values. """ return self.model_selector.find_matching_models_by_values( values=values, target_values_operator=target_values_operator, are_keys_also_matched=are_keys_also_matched, is_case_sensitive=is_case_sensitive, ) def find_models_and_rank( self, values: list, target_values_operator: str = "AND", are_keys_also_matched: bool = False, is_case_sensitive: bool = False, metric: str = "SSIM", order: str = "asc", ) -> list: """Search for values (and keys) in model configs, rank results and return sorted list of model dicts. This function calls an identically named function in a `ModelSelector` instance. Parameters ---------- values: list` list of values used to search and find models corresponding to these `values` target_values_operator: str the operator indicating the relationship between `values` in the evaluation of model search results. Should be either "AND", "OR", or "XOR". are_keys_also_matched: bool flag indicating whether, apart from values, the keys in the model config should also be searchable is_case_sensitive: bool flag indicating whether the search for values (and) keys in the model config should be case-sensitive. metric: str The key in the selection dict that corresponds to the metric of interest order: str the sorting order of the ranked results. Should be either "asc" (ascending) or "desc" (descending) Returns ------- list a list of the searched and matched model dictionaries containing metric and model_id, sorted by metric. """ ranked_models = [] matching_models = self.model_selector.find_matching_models_by_values( values=values, target_values_operator=target_values_operator, are_keys_also_matched=are_keys_also_matched, is_case_sensitive=is_case_sensitive, ) if len(matching_models) < 1: logging.warning( f"For your input, there were {len(matching_models)} matching models, while at least 1 is needed. " f"Please adjust either your metric your search value inputs {values} to find at least one match." ) else: matching_model_ids = [model.model_id for model in matching_models] logging.debug(f"matching_model_ids: {matching_model_ids}") ranked_models = self.model_selector.rank_models_by_performance( model_ids=matching_model_ids, metric=metric, order=order ) if len(ranked_models) < 1: logging.warning( f"None ({len(ranked_models)}) of the {len(matching_model_ids)} found matching models, had a valid metric entry for {metric}. " f"Please adjust your metric to enable ranking of the found models." ) return ranked_models def find_models_rank_and_generate( self, values: list, target_values_operator: str = "AND", are_keys_also_matched: bool = False, is_case_sensitive: bool = False, metric: str = "SSIM", order: str = "asc", num_samples: int = 30, output_path: str = None, is_gen_function_returned: bool = False, install_dependencies: bool = False, **kwargs, ): """Search for values (and keys) in model configs, rank results to generate samples with highest ranked model. Parameters ---------- values: list list of values used to search and find models corresponding to these `values` target_values_operator: str the operator indicating the relationship between `values` in the evaluation of model search results. Should be either "AND", "OR", or "XOR". are_keys_also_matched: bool flag indicating whether, apart from values, the keys in the model config should also be searchable is_case_sensitive: bool flag indicating whether the search for values (and) keys in the model config should be case-sensitive. metric: str The key in the selection dict that corresponds to the metric of interest order: str the sorting order of the ranked results. Should be either "asc" (ascending) or "desc" (descending) num_samples: int the number of samples that will be generated output_path: str the path as str to the output folder where the generated samples will be stored is_gen_function_returned: bool flag indicating whether, instead of generating samples, the sample generation function will be returned install_dependencies: bool flag indicating whether a generative model's dependencies are automatically installed. Else error is raised if missing dependencies are detected. **kwargs arbitrary number of keyword arguments passed to the model's sample generation function Returns ------- None However, if `is_gen_function_returned` is True, it returns the internal generate function of the model. """ ranked_models = self.find_models_and_rank( values=values, target_values_operator=target_values_operator, are_keys_also_matched=are_keys_also_matched, is_case_sensitive=is_case_sensitive, metric=metric, order=order, ) assert ranked_models is not None and len(ranked_models) > 0, ( f"None of the models fulfilled both, the matching (values: {values}) AND " f"ranking (metric: {metric}) criteria you provided." ) # Get the ID of the highest ranking model to generate() with that model highest_ranking_model_id = ranked_models[0][MODEL_ID] # Let's generate with the best-ranked model logging.info( f"For your input, there were {len(ranked_models)} models found and ranked. " f"The highest ranked model ({highest_ranking_model_id}) will now be used for generation: " f"{ranked_models[0]}" ) return self.generate( model_id=highest_ranking_model_id, num_samples=num_samples, output_path=output_path, is_gen_function_returned=is_gen_function_returned, install_dependencies=install_dependencies, **kwargs, ) def find_model_and_generate( self, values: list, target_values_operator: str = "AND", are_keys_also_matched: bool = False, is_case_sensitive: bool = False, num_samples: int = 30, output_path: str = None, is_gen_function_returned: bool = False, install_dependencies: bool = False, **kwargs, ): """Search for values (and keys) in model configs to generate samples with the found model. Note that the number of found models should be ==1. Else no samples will be generated and a error is logged to console. Parameters ---------- values: list list of values used to search and find models corresponding to these `values` target_values_operator: str the operator indicating the relationship between `values` in the evaluation of model search results. Should be either "AND", "OR", or "XOR". are_keys_also_matched: bool flag indicating whether, apart from values, the keys in the model config should also be searchable is_case_sensitive: bool flag indicating whether the search for values (and) keys in the model config should be case-sensitive. num_samples: int the number of samples that will be generated output_path: str the path as str to the output folder where the generated samples will be stored is_gen_function_returned: bool flag indicating whether, instead of generating samples, the sample generation function will be returned install_dependencies: bool flag indicating whether a generative model's dependencies are automatically installed. Else error is raised if missing dependencies are detected. **kwargs arbitrary number of keyword arguments passed to the model's sample generation function Returns ------- None However, if `is_gen_function_returned` is True, it returns the internal generate function of the model. """ matching_models: list = self.model_selector.find_matching_models_by_values( values=values, target_values_operator=target_values_operator, are_keys_also_matched=are_keys_also_matched, is_case_sensitive=is_case_sensitive, ) if len(matching_models) > 1: logging.error( f"For your input, there were more than 1 matching model ({len(matching_models)}). " f"Please choose one of the models (see model_ids below) or use find_models_rank_and_generate() instead." f"Alternatively, you may also further specify additional search values apart from the provided ones " f"to find exactly one model: {values}. The matching models were the following: \n {matching_models}" ) elif len(matching_models) < 1: logging.error( f"For your input, there were {len(matching_models)} matching models, while 1 is needed. " f"Please adjust your search value inputs {values} to find at least one match." ) else: # Exactly one matching model. Let's generate with this model logging.info( f"For your input, there was {len(matching_models)} model matched. " f"This model will now be used for generation: {matching_models}" ) matched_model_id = matching_models[0].model_id return self.generate( model_id=matched_model_id, num_samples=num_samples, output_path=output_path, is_gen_function_returned=is_gen_function_returned, install_dependencies=install_dependencies, **kwargs, ) ############################ MODEL EXECUTOR METHODS ############################ def add_all_model_executors(self): """Add `ModelExecutor` class instances for all models available in the config. Returns ------- None """ for model_id in self.config_manager.model_ids: execution_config = self.config_manager.get_config_by_id( model_id=model_id, config_key=CONFIG_FILE_KEY_EXECUTION ) self._add_model_executor( model_id=model_id, execution_config=execution_config ) def add_model_executor(self, model_id: str, install_dependencies: bool = False): """Add one `ModelExecutor` class instance corresponding to the specified `model_id`. Parameters ---------- model_id: str The generative model's unique id install_dependencies: bool flag indicating whether a generative model's dependencies are automatically installed. Else error is raised if missing dependencies are detected. Returns ------- None """ if not self.is_model_executor_already_added(model_id): execution_config = self.config_manager.get_config_by_id( model_id=model_id, config_key=CONFIG_FILE_KEY_EXECUTION ) self._add_model_executor( model_id=model_id, execution_config=execution_config, install_dependencies=install_dependencies, ) def _add_model_executor( self, model_id: str, execution_config: dict, install_dependencies: bool = False ): """Add one `ModelExecutor` class instance corresponding to the specified `model_id` and `execution_config`. Parameters ---------- model_id: str The generative model's unique id execution_config: dict The part of the config below the 'execution' key install_dependencies: bool flag indicating whether a generative model's dependencies are automatically installed. Else error is raised if missing dependencies are detected. Returns ------- None """ if not self.is_model_executor_already_added(model_id): model_executor = ModelExecutor( model_id=model_id, execution_config=execution_config, download_package=True, install_dependencies=install_dependencies, ) self.model_executors.append(model_executor) def is_model_executor_already_added(self, model_id) -> bool: """Check whether the `ModelExecutor` instance of this model_id is already in `self.model_executors` list. Parameters ---------- model_id: str The generative model's unique id Returns ------- bool indicating whether this `ModelExecutor` had been already previously added to `self.model_executors` """ model_id = self.config_manager.match_model_id(provided_model_id=model_id) if self.find_model_executor_by_id(model_id=model_id) is None: logging.debug( f"{model_id}: The model has not yet been added to the model_executor list." ) return False return True def find_model_executor_by_id(self, model_id: str) -> ModelExecutor: """Find and return the `ModelExecutor` instance of this model_id in the `self.model_executors` list. Parameters ---------- model_id: str The generative model's unique id Returns ------- ModelExecutor `ModelExecutor` class instance corresponding to the `model_id` """ model_id = self.config_manager.match_model_id(provided_model_id=model_id) for idx, model_executor in enumerate(self.model_executors): if model_executor.model_id == model_id: return model_executor return None def get_model_executor( self, model_id: str, install_dependencies: bool = False ) -> ModelExecutor: """Add and return the `ModelExecutor` instance of this model_id from the `self.model_executors` list. Relies on `self.add_model_executor` and `self.find_model_executor_by_id` functions. Parameters ---------- model_id: str The generative model's unique id install_dependencies: bool flag indicating whether a generative model's dependencies are automatically installed. Else error is raised if missing dependencies are detected. Returns ------- ModelExecutor `ModelExecutor` class instance corresponding to the `model_id` """ model_id = self.config_manager.match_model_id(provided_model_id=model_id) try: self.add_model_executor( model_id=model_id, install_dependencies=install_dependencies, ) # only adds after checking that is not already added return self.find_model_executor_by_id(model_id=model_id) except Exception as e: logging.error( f"{model_id}: This model could not be added to model_executor list: {e}" ) raise e def generate( self, model_id: str, num_samples: int = 30, output_path: str = None, save_images: bool = True, is_gen_function_returned: bool = False, install_dependencies: bool = False, **kwargs, ): """Generate samples with the model corresponding to the `model_id` or return the model's generate function. Parameters ---------- model_id: str The generative model's unique id num_samples: int the number of samples that will be generated output_path: str the path as str to the output folder where the generated samples will be stored save_images: bool flag indicating whether generated samples are returned (i.e. as list of numpy arrays) or rather stored in file system (i.e in `output_path`) is_gen_function_returned: bool flag indicating whether, instead of generating samples, the sample generation function will be returned install_dependencies: bool flag indicating whether a generative model's dependencies are automatically installed. Else error is raised if missing dependencies are detected. **kwargs arbitrary number of keyword arguments passed to the model's sample generation function Returns ------- list Returns images as list of numpy arrays if `save_images` is False. However, if `is_gen_function_returned` is True, it returns the internal generate function of the model. """ model_id = self.config_manager.match_model_id(provided_model_id=model_id) model_executor = self.get_model_executor( model_id=model_id, install_dependencies=install_dependencies ) return model_executor.generate( num_samples=num_samples, output_path=output_path, save_images=save_images, is_gen_function_returned=is_gen_function_returned, **kwargs, ) def get_generate_function( self, model_id: str, num_samples: int = 30, output_path: str = None, install_dependencies: bool = False, **kwargs, ): """Return the model's generate function. Relies on the `self.generate` function. Parameters ---------- model_id: str The generative model's unique id num_samples: int the number of samples that will be generated output_path: str the path as str to the output folder where the generated samples will be stored install_dependencies: bool flag indicating whether a generative model's dependencies are automatically installed. Else error is raised if missing dependencies are detected. **kwargs arbitrary number of keyword arguments passed to the model's sample generation function Returns ------- function The internal reusable generate function of the generative model. """ return self.generate( model_id=model_id, num_samples=num_samples, output_path=output_path, is_gen_function_returned=True, install_dependencies=install_dependencies, **kwargs, ) ############################ MODEL CONTRIBUTOR METHODS ############################ def add_model_contributor( self, model_id: str, init_py_path: str = None, ) -> ModelContributor: """Add a `ModelContributor` instance of this model_id to the `self.model_contributors` list. Parameters ---------- model_id: str The generative model's unique id init_py_path: str The path to the local model's __init__.py file needed for importing and running this model. Returns ------- ModelContributor `ModelContributor` class instance corresponding to the `model_id` """ model_id = self.config_manager.match_model_id(provided_model_id=model_id) model_contributor = self.get_model_contributor_by_id(model_id=model_id) if model_contributor is not None: logging.warning( f"{model_id}: For this model_id, there already exists a ModelContributor. None was added. Returning the existing one." ) else: model_contributor = ModelContributor( model_id=model_id, init_py_path=init_py_path ) self.model_contributors.append(model_contributor) return model_contributor def get_model_contributor_by_id(self, model_id: str) -> ModelContributor: """Find and return the `ModelContributor` instance of this model_id in the `self.model_contributors` list. Parameters ---------- model_id: str The generative model's unique id Returns ------- ModelContributor `ModelContributor` class instance corresponding to the `model_id` """ model_id = self.config_manager.match_model_id(provided_model_id=model_id) for idx, model_contributor in enumerate(self.model_contributors): if model_contributor.model_id == model_id: return model_contributor return None def add_metadata_from_file(self, model_id: str, metadata_file_path: str) -> dict: """Read and parse the metadata of a local model, identified by `model_id`, from a metadata file in json format. Parameters ---------- model_id: str The generative model's unique id metadata_file_path: str the path pointing to the metadata file Returns ------- dict Returns a dict containing the contents of parsed metadata json file. """ model_id = self.config_manager.match_model_id(provided_model_id=model_id) model_contributor = self.get_model_contributor_by_id(model_id=model_id) assert ( model_contributor is not None ), f"{model_id}: No model_contributor is initialized for this model_id in Generators. Add a model_contributor first by running 'add_model_contributor()'." return model_contributor.add_metadata_from_file( metadata_file_path=metadata_file_path ) def add_metadata_from_input( self, model_id: str, model_weights_name: str, model_weights_extension: str, generate_method_name: str, dependencies: list, fill_more_fields_interactively: bool = True, output_path: str = "config", ) -> dict: """Create a metadata dict for a local model, identified by `model_id`, given the necessary minimum metadata contents. Parameters ---------- model_id: str The generative model's unique id model_weights_name: str the name of the checkpoint file containing the model's weights model_weights_extension: str the extension (e.g. .pt) of the checkpoint file containing the model's weights generate_method_name: str the name of the sample generation method inside the models __init__.py file dependencies: list the list of dependencies that need to be installed via pip to run the model fill_more_fields_interactively: bool flag indicating whether a user will be interactively asked via command line for further input to fill out missing metadata content output_path: str the path where the created metadata json file will be stored Returns ------- dict Returns a dict containing the contents of the metadata json file. """ model_id = self.config_manager.match_model_id(provided_model_id=model_id) model_contributor = self.get_model_contributor_by_id(model_id=model_id) assert ( model_contributor is not None ), f"{model_id}: No model_contributor is initialized for this model_id in Generators. Add a model_contributor first by running 'add_model_contributor()'." return model_contributor.add_metadata_from_input( model_weights_name=model_weights_name, model_weights_extension=model_weights_extension, generate_method_name=generate_method_name, dependencies=dependencies, fill_more_fields_interactively=fill_more_fields_interactively, output_path=output_path, ) def push_to_zenodo( self, model_id: str, zenodo_access_token: str, creator_name: str = "unknown name", creator_affiliation: str = "unknown affiliation", model_description: str = "", ) -> str: """Upload the model files as zip archive to a public Zenodo repository where the model will be persistently stored. Get your Zenodo access token here: https://zenodo.org/account/settings/applications/tokens/new/ (Enable scopes `deposit:actions` and `deposit:write`) Parameters ---------- model_id: str The generative model's unique id zenodo_access_token: str a personal access token in Zenodo linked to a user account for authentication creator_name: str the creator name that will appear on the corresponding Zenodo model upload homepage creator_affiliation: str the creator affiliation that will appear on the corresponding Zenodo model upload homepage model_description: list the model_description that will appear on the corresponding Zenodo model upload homepage Returns ------- str Returns the url pointing to the corresponding Zenodo model upload homepage """ model_id = self.config_manager.match_model_id(provided_model_id=model_id) model_contributor = self.get_model_contributor_by_id(model_id=model_id) assert ( model_contributor is not None ), f"{model_id}: No model_contributor is initialized for this model_id in Generators. Add a model_contributor first by running 'add_model_contributor()'." return model_contributor.push_to_zenodo( access_token=zenodo_access_token, creator_name=creator_name, creator_affiliation=creator_affiliation, model_description=model_description, ) def push_to_github( self, model_id: str, github_access_token: str, package_link: str = None, creator_name: str = "", creator_affiliation: str = "", model_description: str = "", ): """Upload the model's metadata inside a github issue to the medigan github repository. To add your model to medigan, your metadata will be reviewed on Github and added to medigan's official model metadata The medigan repository issues page: https://github.com/RichardObi/medigan/issues Get your Github access token here: https://github.com/settings/tokens Parameters ---------- model_id: str The generative model's unique id github_access_token: str a personal access token linked to your github user account, used as means of authentication package_link: a package link creator_name: str the creator name that will appear on the corresponding github issue creator_affiliation: str the creator affiliation that will appear on the corresponding github issue model_description: list the model_description that will appear on the corresponding github issue Returns ------- str Returns the url pointing to the corresponding issue on github """ model_id = self.config_manager.match_model_id(provided_model_id=model_id) model_contributor = self.get_model_contributor_by_id(model_id=model_id) assert ( model_contributor is not None ), f"{model_id}: No model_contributor is initialized for this model_id in Generators. Add a model_contributor first by running 'add_model_contributor()'." return model_contributor.push_to_github( access_token=github_access_token, package_link=package_link, creator_name=creator_name, creator_affiliation=creator_affiliation, model_description=model_description, ) def test_model( self, model_id: str, is_local_model: bool = True, overwrite_existing_metadata: bool = False, store_new_config: bool = True, num_samples: int = 3, install_dependencies: bool = False, ): """Test if a model generates and returns a specific number of samples in the correct format Parameters ---------- model_id: str The generative model's unique id is_local_model: bool flag indicating whether the tested model is a new local user model i.e not yet part of medigan's official models overwrite_existing_metadata: bool in case of `is_local_model`, flag indicating whether existing metadata for this model in medigan's `config/global.json` should be overwritten. store_new_config: bool flag indicating whether the current model metadata should be stored on disk i.e. in config/ num_samples: int the number of samples that will be generated install_dependencies: bool flag indicating whether a generative model's dependencies are automatically installed. Else error is raised if missing dependencies are detected. """ model_id = self.config_manager.match_model_id(provided_model_id=model_id) if is_local_model: model_contributor = self.get_model_contributor_by_id(model_id=model_id) assert model_contributor is not None, ( f"{model_id}: No model_contributor is initialized for this model_id. Try to set 'is_local_model=False'" f"or add a model_contributor first by running 'add_model_contributor(model_id, init_py_path)' ." ) self.add_model_to_config( model_id=model_id, metadata=model_contributor.metadata, is_local_model=is_local_model, overwrite_existing_metadata=overwrite_existing_metadata, store_new_config=store_new_config, ) samples = self.generate( model_id=model_id, save_images=False, install_dependencies=install_dependencies, num_samples=num_samples, ) assert ( samples is not None and isinstance(samples, list) and ( (len(samples) == num_samples) or (len(samples) > num_samples) ) # e.g., len(samples) = num_samples + 1, as sample generation can be restricted to be balanced among classes ), ( f"{model_id}: Model test was not successful. The generated samples {'is None, but ' if samples is None else ''}" f"should be a list (actual type: {type(samples)}) and of length {num_samples} (actual length: " f"{'None' if samples is None else len(samples)}). Check if input params (e.g. input_path) to model are valid. " ) # {f'Generated samples: {samples}' if samples is not None else ''}" logging.info( f"{model_id}: The test of " f"{'this new local user model' if is_local_model else 'this existing medigan model'} " f"was successful, as model created the expected number ({num_samples}) of synthetic " f"samples." ) def contribute( self, model_id: str, init_py_path: str, github_access_token: str, zenodo_access_token: str, metadata_file_path: str = None, model_weights_name: str = None, model_weights_extension: str = None, generate_method_name: str = None, dependencies: list = None, fill_more_fields_interactively: bool = True, overwrite_existing_metadata: bool = False, output_path: str = "config", creator_name: str = "unknown name", creator_affiliation: str = "unknown affiliation", model_description: str = "", install_dependencies: bool = False, ): """Implements the full model contribution workflow including model metadata generation, model test, model Zenodo upload, and medigan github issue creation. Parameters ---------- model_id: str The generative model's unique id init_py_path: str The path to the local model's __init__.py file needed for importing and running this model. github_access_token: str a personal access token linked to your github user account, used as means of authentication zenodo_access_token: str a personal access token in Zenodo linked to a user account for authentication metadata_file_path: str the path pointing to the metadata file model_weights_name: str the name of the checkpoint file containing the model's weights model_weights_extension: str the extension (e.g. .pt) of the checkpoint file containing the model's weights generate_method_name: str the name of the sample generation method inside the models __init__.py file dependencies: list the list of dependencies that need to be installed via pip to run the model fill_more_fields_interactively: bool flag indicating whether a user will be interactively asked via command line for further input to fill out missing metadata content overwrite_existing_metadata: bool flag indicating whether existing metadata for this model in medigan's `config/global.json` should be overwritten. output_path: str the path where the created metadata json file will be stored creator_name: str the creator name that will appear on the corresponding github issue creator_affiliation: str the creator affiliation that will appear on the corresponding github issue model_description: list the model_description that will appear on the corresponding github issue install_dependencies: bool flag indicating whether a generative model's dependencies are automatically installed. Else error is raised if missing dependencies are detected. Returns ------- str Returns the url pointing to the corresponding issue on github """ # Create model contributor self.add_model_contributor(model_id=model_id, init_py_path=init_py_path) # Adding the metadata of the model from input if metadata_file_path is not None: # Using an existing metadata json metadata = self.add_metadata_from_file( model_id=model_id, metadata_file_path=metadata_file_path ) else: # Creating the metadata json metadata = self.add_metadata_from_input( model_id=model_id, model_weights_name=model_weights_name, model_weights_extension=model_weights_extension, generate_method_name=generate_method_name, dependencies=dependencies, fill_more_fields_interactively=fill_more_fields_interactively, output_path=output_path, ) logging.debug( f"{model_id}: The following model metadata was created: {metadata}" ) try: self.test_model( model_id=model_id, is_local_model=True, overwrite_existing_metadata=overwrite_existing_metadata, install_dependencies=install_dependencies, ) except Exception as e: logging.error( f"{model_id}: Error while testing this local model. " f"Please revise and run model contribute() again. {e}" ) raise e # Model Upload to Zenodo zenodo_record_url = self.push_to_zenodo( model_id=model_id, zenodo_access_token=zenodo_access_token, creator_name=creator_name, creator_affiliation=creator_affiliation, model_description=model_description, ) # Creating and returning an issue with model metadata in medigan's Github return self.push_to_github( model_id=model_id, package_link=zenodo_record_url, github_access_token=github_access_token, creator_name=creator_name, creator_affiliation=creator_affiliation, model_description=model_description, ) ############################ OTHER METHODS ############################ def get_as_torch_dataloader( self, dataset=None, model_id: str = None, num_samples: int = 1000, install_dependencies: bool = False, transform=None, batch_size=None, shuffle=None, sampler=None, batch_sampler=None, num_workers=None, collate_fn=None, pin_memory=None, drop_last=None, timeout=None, worker_init_fn=None, prefetch_factor: int = None, persistent_workers: bool = None, pin_memory_device: str = None, **kwargs, ) -> DataLoader: """Get torch Dataloader sampling synthetic data from medigan model. Dataloader combines a dataset and a sampler, and provides an iterable over the given torch dataset. Dataloader is created for synthetic data for the specified medigan model. Pytorch native parameters are set to ``None`` per default. Only those params are are passed to the Dataloader() initialization function that are not ``None``. Args: dataset (Dataset): dataset from which to load the data. model_id: str The generative model's unique id num_samples: int the number of samples that will be generated install_dependencies: bool flag indicating whether a generative model's dependencies are automatically installed. Else error is raised if missing dependencies are detected. **kwargs arbitrary number of keyword arguments passed to the model's sample generation function (e.g. the input path for image-to-image translation models in medigan). transform the torch data transformation functions to be applied to the data in the dataset. batch_size (int, optional): how many samples per batch to load (default: ``None``). shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch (default: ``None``). sampler (Sampler or Iterable, optional): defines the strategy to draw samples from the dataset. Can be any ``Iterable`` with ``__len__`` implemented. If specified, :attr:`shuffle` must not be specified. (default: ``None``) batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but returns a batch of indices at a time. Mutually exclusive with :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. (default: ``None``) num_workers (int, optional): how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process. (default: ``None``) collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset. (default: ``None``) pin_memory (bool, optional): If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, see the example below. (default: ``None``) drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: ``None``) timeout (numeric, optional): if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: ``None``) worker_init_fn (callable, optional): If not ``None``, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: ``None``) prefetch_factor (int, optional, keyword-only arg): Number of batches loaded in advance by each worker. ``2`` means there will be a total of 2 * num_workers batches prefetched across all workers. (default: ``None``). persistent_workers (bool, optional): If ``True``, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the workers `Dataset` instances alive. (default: ``None``) pin_memory_device (str, optional): the device to pin memory to if ``pin_memory`` is ``True`` (default: ``None``). Returns ------- DataLoader a torch.utils.data.DataLoader object with data generated by model corresponding to inputted `Dataset` or `model_id`. """ dataset = ( self.get_as_torch_dataset( model_id=model_id, num_samples=num_samples, install_dependencies=install_dependencies, transform=transform, **kwargs, ) if dataset is None else dataset ) # Reducing dependency on torch.util.data.DataLoader param default values by passing # only the ones specified by the user. dataloader = Utils.call_without_removable_params( my_callable=DataLoader, removable_param_values=[None], dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, pin_memory_device=pin_memory_device, ) return dataloader def get_as_torch_dataset( self, model_id: str, num_samples: int = 100, install_dependencies: bool = False, transform=None, **kwargs, ) -> Dataset: """Get synthetic data in a torch Dataset for specified medigan model. The dataset returns a dict with keys sample (== image), labels (== condition), and mask (== segmentation mask). While key 'sample' is mandatory, the other key value pairs are only returned if applicable to generative model. Args: model_id: str The generative model's unique id num_samples: int the number of samples that will be generated install_dependencies: bool flag indicating whether a generative model's dependencies are automatically installed. Else error is raised if missing dependencies are detected. transform the torch data transformation functions to be applied to the data in the dataset. **kwargs arbitrary number of keyword arguments passed to the model's sample generation function (e.g. the input path for image-to-image translation models in medigan). Returns ------- Dataset a torch.utils.data.Dataset object with data generated by model corresponding to `model_id`. """ data = self.generate( model_id=model_id, num_samples=num_samples, is_gen_function_returned=False, install_dependencies=install_dependencies, save_images=False, # design decision: temporary storage in memory instead of I/O from disk **kwargs, ) logging.debug(f"data: {data}") ( samples, masks, other_imaging_output, labels, ) = Utils.split_images_masks_and_labels(data=data, num_samples=num_samples) logging.debug( f"samples: {samples} \n masks: {masks} \n other_imaging_output: {other_imaging_output} \n labels: {labels}" ) return SyntheticDataset( samples=samples, masks=masks, other_imaging_output=other_imaging_output, labels=labels, transform=transform, ) def visualize( self, model_id: str, slider_grouper: int = 10, auto_close: bool = False, install_dependencies: bool = False, ) -> None: """Initialize and run `ModelVisualizer` of this model_id if it is available. It allows to visualize a sample from the model's output. UI window will pop up allowing the user to control the generation parameters (conditional and unconditional ones). Parameters ---------- model_id: str The generative model's unique id to visualize. slider_grouper: int Number of input parameters to group together within one slider. auto_close: bool Flag for closing the user interface automatically after time. Used while testing. install_dependencies: bool flag indicating whether a generative model's dependencies are automatically installed. Else error is raised if missing dependencies are detected. """ model_id = self.config_manager.match_model_id(provided_model_id=model_id) config = self.get_config_by_id(model_id=model_id) model_executor = self.get_model_executor( model_id=model_id, install_dependencies=install_dependencies ) ModelVisualizer(model_executor=model_executor, config=config).visualize( slider_grouper=slider_grouper, auto_close=auto_close ) def __repr__(self): return ( f"Generators(model_ids={self.config_manager.model_ids}, model_executors={self.model_executors}, " f"model_selector: {self.model_selector})" ) def __len__(self): return len(self.model_executors) def __getitem__(self, idx: int): return self.model_executors[idx] ================================================ FILE: src/medigan/model_visualizer.py ================================================ # -*- coding: utf-8 -*- # ! /usr/bin/env python """ `ModelVisualizer` class providing visualizing corresponding model input and model output changes. """ import matplotlib.pyplot as plt import numpy as np from matplotlib.widgets import Button, Slider class ModelVisualizer: """`ModelVisualizer` class: Visualises synthetic data through a user interface. Depending on a model, it is possible to control the input latent vector values and conditional input. Parameters ---------- model_executor: ModelExecutor The generative model's executor object config: dict The config dict containing the model metadata Attributes ---------- model_executor: ModelExecutor The generative model's executor object input_latent_vector_size: int Size of the latent vector used as an input for generation conditional: bool Flag for models with conditional input condition: Union[int, float] Value of the conditinal input to the model max_input_value: float Absolute value used for setting latent values input range """ def __init__(self, model_executor, config: None): self.model_executor = model_executor self.model_id = self.model_executor.model_id self.config = config self.num_samples = 1 self.max_input_value = 3 self.conditional = False self.condition = None self.input_latent_vector_size = ( self.model_executor.generate_method_input_latent_vector_size ) if not self.input_latent_vector_size: raise ValueError( f"{self.model_id}: Visualization of this model is not supported. Reason: This model does not use a random vector 'z' as input, which is needed for visualization. This is determined via the absence of the 'input_latent_vector_size' variable in this model's metadata in config/global.json." ) self.gen_function = self.model_executor.generate( num_samples=1, save_images=False, is_gen_function_returned=True, ) if "condition" in self.model_executor.generate_method_args["custom"]: self.conditional = True self.condition = self.model_executor.generate_method_args["custom"][ "condition" ] def visualize(self, slider_grouper: int = 10, auto_close=False): """ Visualize the model's output. This method is called by the user. It opens up a user interface with available controls. Parameters ---------- slider_grouper: int Number of input parameters to group together within one slider. auto_close: bool Flag for closing the user interface automatically after time. Used while testing. Returns ------- None """ z = np.random.randn( self.num_samples, self.input_latent_vector_size, 1, 1 ).astype(np.float32) mask = None if self.conditional: output = self.gen_function(condition=self.condition, input_latent_vector=z) else: output = self.gen_function(input_latent_vector=z) image, mask = self._unpack_output(output) images_to_show = 1 if mask is not None: images_to_show += 1 fig, ax = plt.subplots(ncols=images_to_show) if images_to_show == 1: ax.axis("off") ax.set_title("Generated image") display = ax.imshow(image, cmap="gray", vmin=0, vmax=255) if images_to_show == 2: ax[0].axis("off") ax[0].set_title("Generated image") display = ax[0].imshow(image, cmap="gray", vmin=0, vmax=255) ax[1].axis("off") ax[1].set_title("Generated mask") display_mask = ax[1].imshow(mask, cmap="gray", vmin=0, vmax=255) fig.suptitle( "Model " + self.model_id, fontsize=15, # fontweight="bold", ) if self.config: plt.text( x=0.5, y=0.88, s=self.config["description"]["title"], fontsize=8, ha="center", transform=fig.transFigure, wrap=True, ) # adjust the main plot to make room for the sliders plt.subplots_adjust(left=0.45, bottom=0.3, top=0.8) padding = 0.03 sliders_x = 0.1 sliders_y = 0.75 sliders_width = 0.25 sliders_height = 0.02 sliders = [] row_index = 0 if self.conditional: condition_ax = plt.axes( (sliders_x, sliders_y, sliders_width, sliders_height) ) condition_slider = Slider( condition_ax, None, 0, 1, valinit=0.0, valstep=1, initcolor="none", # valfmt="%.2f", ) condition_ax.set_title("Input condition: " + output[0][1]) row_index += 5 offset_ax = plt.axes( (sliders_x, sliders_y - row_index * padding, sliders_width, sliders_height) ) offset_ax.set_title("Input latent vector") offset_slider = Slider( offset_ax, "offset", -self.max_input_value * 2, self.max_input_value * 2, valinit=0.0, initcolor="none", valfmt="%.2f", ) row_index += 2 # for i in range(int(self.input_latent_vector_size)): for i in range(int(self.input_latent_vector_size / slider_grouper)): axfreq = plt.axes( ( sliders_x, sliders_y - (i + row_index) * padding, sliders_width, sliders_height, ) ) slider = Slider( axfreq, "z{}".format(i + 1), -self.max_input_value, self.max_input_value, valinit=float(z[0][i]), initcolor="none", valfmt="%.2f", ) sliders.append(slider) text = "Offset: Add constant value to each latent variable \ \nInput vector: Modify latent values used to generate image \ \nSeed: Initialize new random seed for latent vector \ \nReset: Revert user changes to initial seed values" ax_legend = plt.axes( ( 0.45, 0.19, 0.5, 0.5, ) ) ax_legend.axis("off") ax_legend.text(0.0, 0.0, text, fontsize=8, va="top", linespacing=2) # The function to be called anytime a slider's value changes def update(val): for i, slider in enumerate(sliders): for j in range(10): z[0][i + j] = slider.val if self.conditional: self.condition = condition_slider.val output = self.gen_function( condition=self.condition, input_latent_vector=z ) condition_ax.set_title("Input condition: " + output[0][1]) else: output = self.gen_function(input_latent_vector=z) image, mask = self._unpack_output(output) if mask is not None: display_mask.set_data(mask) display.set_data(image) fig.canvas.draw_idle() # register the update function with each slider for slider in sliders: slider.on_changed(update) if self.conditional: condition_slider.on_changed(update) self.offset_old = 0 def update_offset(val): diff = offset_slider.val - self.offset_old self.offset_old = offset_slider.val for i, slider in enumerate(sliders): if slider.val + diff > self.max_input_value: slider.set_val(self.max_input_value) elif slider.val + diff < -self.max_input_value: slider.set_val(-self.max_input_value) else: slider.set_val(slider.val + diff) for j in range(10): z[0][i + j] = slider.val offset_slider.on_changed(update_offset) # Create a `matplotlib.widgets.Button` to reset the sliders to initial values. resetax = plt.axes([0.77, 0.220, 0.1, 0.04]) reset_button = Button(resetax, "Reset", hovercolor="0.975") seedax = plt.axes([0.62, 0.220, 0.1, 0.04]) seed_button = Button(seedax, "Seed", hovercolor="0.975") def reset(event): offset_slider.reset() for slider in sliders: slider.reset() def new_seed(event): z = np.random.randn( self.num_samples, self.input_latent_vector_size, 1, 1 ).astype(np.float32) for slider in sliders: slider.valinit = z[0][sliders.index(slider)] reset(event) reset_button.on_clicked(reset) seed_button.on_clicked(new_seed) update(0) if auto_close: plt.show(block=False) plt.pause(1) plt.close() else: plt.show() def _unpack_output(self, output) -> tuple: """ Unpack the output of the generator function Parameters ---------- output: Union[tuple, np.ndarray] Output of the generator function to unpack into an image and optional mask Returns ---------- tuple[image, mask] Tuple of the image and mask. Mask is None if no mask was available """ mask = None if type(output[0]) is tuple: image = output[0][0].squeeze() if type(output[0][1]) is not str: mask = output[0][1].squeeze() else: image = output[0].squeeze() return image, mask ================================================ FILE: src/medigan/select_model/__init__.py ================================================ ================================================ FILE: src/medigan/select_model/matched_entry.py ================================================ # -*- coding: utf-8 -*- # ! /usr/bin/env python """MatchedEntry class that represents one match of a key value pair of a model's config against a search query. """ # Import python native libs from __future__ import absolute_import import json class MatchedEntry: """`MatchedEntry` class: One target key-value pair that matches with a model's selection config. Parameters ---------- key: str string that represents the matched key in model selection dict value represents the key's matched value in the model selection dict matching_element: str string that was used to match the search value Attributes ---------- key: str string that represents the matched key in model selection dict value represents the key's matched value in the model selection dict matching_element: str string that was used to match the search value """ def __init__( self, key: str, value, matching_element: str = None, ): self.key = key self.value = value if matching_element is None: self.matching_element = str(value) else: self.matching_element = matching_element def __str__(self): return json.dumps( { "key": self.key, "value": self.value, "matching_element": self.matching_element, } ) def __repr__(self): return f"MatchedEntry(key={self.key}, value={self.value}, matching_element={self.matching_element})" def __len__(self): raise NotImplementedError def __getitem__(self, idx: int): raise NotImplementedError ================================================ FILE: src/medigan/select_model/model_match_candidate.py ================================================ # -*- coding: utf-8 -*- # ! /usr/bin/env python """ModelMatchCandidate class that holds data to evaluate if a generative model matches against a search query. """ # Import python native libs from __future__ import absolute_import import json import logging # Import library internal modules from .matched_entry import MatchedEntry class ModelMatchCandidate: """`ModelMatchCandidate` class: A prospectively matching model given the target values as model search params. Parameters ---------- model_id: str The generative model's unique id target_values: list list of target values used to evaluate if a `ModelMatchCandidate` instance is a match target_values_operator: str the operator indicating the relationship between `values` in the evaluation of `ModelMatchCandidate` instances. Should be either "AND", "OR", or "XOR". is_case_sensitive: bool flag indicating whether the matching of `values` (and) keys should be case-sensitive are_keys_also_matched: bool flag indicating whether, apart from `values`, keys should also be matched is_match: bool flag indicating whether the `ModelMatchCandidate` instance is a match Attributes ---------- model_id: str The generative model's unique id target_values: list list of target values used to evaluate if a `ModelMatchCandidate` instance is a match target_values_operator: str the operator indicating the relationship between `values` in the evaluation of `ModelMatchCandidate` instances. Should be either "AND", "OR", or "XOR". is_case_sensitive: bool flag indicating whether the matching of `values` (and) keys should be case-sensitive are_keys_also_matched: bool flag indicating whether, apart from values, keys should also be matched matched_entries: list contains iteratively added `MatchedEntry` class instances. Each of the `MatchedEntry` instances indicates a match between one of the user specified values in `self.target_values` and the selection config keys or `values` of the model of this `ModelMatchCandidate`. is_match: bool flag indicating whether the `ModelMatchCandidate` instance is a match """ def __init__( self, model_id: str, target_values: list, target_values_operator: str = "AND", is_case_sensitive: bool = False, are_keys_also_matched: bool = False, is_match: bool = False, ): # Descriptive variables self.model_id = model_id self.target_values = target_values self.target_values_operator = target_values_operator self.is_case_sensitive = is_case_sensitive self.are_keys_also_matched = are_keys_also_matched # Dynamically filled/changed variables self.matched_entries = [] self.is_match = is_match def add_matched_entry(self, matched_entry: MatchedEntry) -> None: """Add a `MatchedEntry` instance to the `matched_entries` list.""" self.matched_entries.append(matched_entry) def get_all_matching_elements(self) -> list: """Get the matching element from each of the `MatchedEntry` instances in the `matched_entries` list. Returns ------- list list of all matching elements (i.e. string that matched a search value) from each `MatchedEntry` in `matched_entries` """ matching_elements = [] for matched_entry in self.matched_entries: matching_elements.append(matched_entry.matching_element) return matching_elements def check_if_is_match(self) -> bool: """Evaluates whether the model represented by this instance is a match given search values and operator. The matching element from each `MatchEntry` of this instance ('self.matched_entries') are retrieved. To be a match, this instance ('self') needs to fulfill the requirement of the operator, which can be of value 'AND', or 'OR', or 'XOR'. For example, the default 'AND' requires that each search value ('self.target_values') has at least one corresponding `MatchEntry`, while in 'OR' only one of the search values needs to have been matched by a corresponding `MatchedEntry`. Returns ------- bool flag that, only if True, indicates that this instance is a match given the search values and operator. """ if self is not None and len(self) > 0: if self.target_values_operator == "OR": self.is_match = True elif self.target_values_operator == "AND": # removing duplicates via set conversion found_target_values = set(self.get_all_matching_elements()) if all(elem in found_target_values for elem in self.target_values): logging.debug( f"values: {self.target_values} AND found_target_values_list: {found_target_values}" ) self.is_match = True elif self.target_values_operator == "XOR": # removing duplicates via set conversion if ( len( list( set(self.get_all_matching_elements()).intersection( self.target_values ) ) ) == 1 ): self.is_match = True logging.debug(f"This ModelMatchCandidate was found to be a match: ({self}).") return self.is_match def __str__(self): matched_entry_dicts = { f"{idx}": json.loads(str(match)) for idx, match in enumerate(self.matched_entries) } return json.dumps( { "model_id": self.model_id, "is_match": self.is_match, "target_values": self.target_values, "operator": self.target_values_operator, "are_keys_also_matched": self.are_keys_also_matched, "is_case_sensitive": self.is_case_sensitive, "matched_entries": matched_entry_dicts, } ) def __repr__(self): return f"ModelMatchCandidate(model_id={self.model_id}, is_match={self.is_match}, operator: {self.target_values_operator}, target_values={self.target_values})" def __len__(self): return len(self.matched_entries) def __getitem__(self, idx: int): return self.matched_entries[idx] ================================================ FILE: src/medigan/select_model/model_selector.py ================================================ # -*- coding: utf-8 -*- # ! /usr/bin/env python """ Model selection class that describes, finds, compares, and ranks generative models. """ # Import python native libs from __future__ import absolute_import import logging # Import library internal modules from ..config_manager import ConfigManager from ..constants import CONFIG_FILE_KEY_PERFORMANCE, CONFIG_FILE_KEY_SELECTION, MODEL_ID from ..utils import Utils from .matched_entry import MatchedEntry from .model_match_candidate import ModelMatchCandidate class ModelSelector: """`ModelSelector` class: Given a config dict, gets, searches, and ranks matching models. Parameters ---------- config_manager: ConfigManager Provides the config dictionary, based on which models are selected and compared. Attributes ---------- config_manager: ConfigManager Provides the config dictionary, based on which models are selected and compared. model_selection_dicts: list Contains a dictionary for each model id that consists of the `model_id` and the selection config of that model """ def __init__( self, config_manager: ConfigManager = None, ): if config_manager is None: self.config_manager = ConfigManager() logging.debug(f"Initialized ConfigManager instance: {self.config_manager}") else: self.config_manager = config_manager self.model_selection_dicts = [] self._init_model_selector_data() def _init_model_selector_data(self): """Initialize class data structure: List of dicts containing two keys each: `model_id` and `selection`.""" for model_id in self.config_manager.model_ids: selection_config = self.config_manager.get_config_by_id( model_id=model_id, config_key=CONFIG_FILE_KEY_SELECTION ) model_selector_dict = { MODEL_ID: model_id, CONFIG_FILE_KEY_SELECTION: selection_config, } self.model_selection_dicts.append(model_selector_dict) logging.debug( f"These were the available model selection dicts that were added to the ModelSelector: " f"{self.model_selection_dicts}." ) def get_selection_criteria_by_id( self, model_id: str, is_model_id_removed: bool = True ) -> dict: """Get and return the selection config dict for a specific `model_id`. Parameters ---------- model_id: str The generative model's unique id is_model_id_removed: bool flag to to remove the `model_id` from first level of each dictionary. Returns ------- dict a dictionary corresponding to the selection config of a model """ for idx, selection_dict in enumerate(self.model_selection_dicts): if selection_dict[MODEL_ID] == model_id: if is_model_id_removed: logging.debug( f"For model {model_id}, the following selection dicts was found:" f" {selection_dict[CONFIG_FILE_KEY_SELECTION]}." ) return selection_dict[CONFIG_FILE_KEY_SELECTION] else: logging.debug( f"For model {model_id}, the following selection dicts was found:" f" {selection_dict}." ) return selection_dict return None def get_selection_criteria_by_ids( self, model_ids: list = None, are_model_ids_removed: bool = True ) -> list: """Get and return a list of selection config dicts for each of the specified `model_ids`. Parameters ---------- model_ids: list A list of generative models' unique ids are_model_ids_removed: bool flag to remove the `model_ids` from first level of dictionary. Returns ------- list a list of dictionaries each corresponding to the selection config of a model """ # Create list of models that contain a value for the metric of interest selection_dict_list = [] for idx, selection_dict in enumerate(self.model_selection_dicts): if model_ids is None or selection_dict[MODEL_ID] in model_ids: # if model_ids is None, we consider all models if are_model_ids_removed: selection_dict_list.append( selection_dict[CONFIG_FILE_KEY_SELECTION] ) else: selection_dict_list.append(selection_dict) logging.debug( f"The following selection dicts were found: {selection_dict_list}." ) return selection_dict_list def get_selection_keys(self, model_id: str = None) -> list: """Get and return all first level keys from the selection config dict for a specific `model_id`. Parameters ---------- model_id: str The generative model's unique id Returns ------- list a list containing the keys as strings of the selection config of the `model_id`. """ key_list = [] if model_id is not None: selection_config = self.get_selection_criteria_by_id(model_id) for key in selection_config: key_list.append(key) else: for selection_dict in self.model_selection_dicts: selection_config = selection_dict[CONFIG_FILE_KEY_SELECTION] for key in selection_config: if key not in key_list: key_list.append(key) logging.debug( f"For model {model_id}, the following selection keys were in its selection config: {key_list}." ) return key_list def get_selection_values_for_key(self, key: str, model_id: str = None) -> list: """Get and return the value of a specified key of the selection dict in the config for a specific `model_id`. The key param can contain '.' (dot) separations to allow for retrieval of nested config keys such as 'execution.generator.name' Parameters ---------- key: str The key in the selection dict model_id: str The generative model's unique id Returns ------- list a list of the values that correspond to the key in the selection config of the `model_id`. """ values_for_key = [] if model_id is not None: selection_config = self.get_selection_criteria_by_id(model_id) values_for_key.append(selection_config[key]) else: for selection_dict in self.model_selection_dicts: selection_config = selection_dict[CONFIG_FILE_KEY_SELECTION] # if applicable, split key by "." and get value in nested dict in selection_config selection_config = Utils.deep_get(base_dict=selection_config, key=key) values_for_key.append(selection_config) logging.debug( f"For key {key}, the following values were found across the models' selection " f"dicts {values_for_key}." ) return values_for_key def get_models_by_key_value_pair( self, key1: str, value1: str, is_case_sensitive: bool = False ) -> list: """Get and return a list of `model_id` dicts that contain the specified key value pair in their selection config. The key param can contain '.' (dot) separations to allow for retrieval of nested config keys such as 'execution.generator.name'. If `key1` points to a list, any value in the list that matches value1` is accepted. Parameters ---------- key1: str` The key in the selection dict value1: str The value in the selection dict that corresponds to key1 is_case_sensitive: bool flag to evaluate keys and values with case sensitivity if set to True Returns ------- list a list of the dictionaries each containing a model's `model_id` and the found key-value pair in the models config """ model_dict_list = [] for selection_dict in self.model_selection_dicts: is_model_match: bool = False # Now, for each model, we want to get the respective value for the key try: key_value = selection_dict[CONFIG_FILE_KEY_SELECTION] key_value = Utils.deep_get(base_dict=key_value, key=key1) if key_value is not None: # If key value is None, the model is not added to the model if isinstance(key_value, dict): # If the value of the key is a dict, we cannot evaluate a dict and continue the loop. continue if isinstance(key_value, list): # If the value of the key is a list, we check if the provided value1 is in that list. # Convert list of arbitrary type to list of strings key_value = list(map(str, key_value)) if not is_case_sensitive: key_value = Utils.list_to_lowercase(key_value) value1 = value1.lower() if str(value1) in key_value: is_model_match = True else: # If the value of the key is something else (str, float, int, etc), we check if equal to value1 if (str(key_value) == str(value1)) or ( not is_case_sensitive and str(key_value).lower() == str(value1).lower() ): is_model_match = True except KeyError as e: logging.debug( f"Model {selection_dict[MODEL_ID]} was discarded as it does not have the specified keys " f"in its selection dict: {selection_dict}" ) pass if is_model_match: model_id = selection_dict[MODEL_ID] model_dict = {MODEL_ID: model_id, key1: value1} logging.debug( f"Model {model_id} was a match for the specified key value pair: {model_dict}" ) model_dict_list.append(model_dict) return model_dict_list def rank_models_by_performance( self, model_ids: list = None, metric: str = "SSIM", order: str = "asc" ) -> list: """Rank model based on a provided metric and return sorted list of model dicts. The metric param can contain '.' (dot) separations to allow for retrieval via nested metric config keys such as 'downstream_task.CLF.accuracy'. If the value found for the metric key is of type list, then the largest value in the list is used for ranking if `order` is descending, while the smallest value is used if `order` is ascending. Parameters ---------- model_ids: list only evaluate the model_ids in this list. If none, evaluate all available `model_ids` metric: str The key in the selection dict that corresponds to the metric of interest order: str the sorting order of the ranked results. Should be either "asc" (ascending) or "desc" (descending) Returns ------- list a list of model dictionaries containing metric and `model_id`, sorted by `metric`. """ model_metric_dict_list = [] if model_ids is not None and len(model_ids) == 0: # empty model_ids list -> return empty list. return model_metric_dict_list # First, get all selection criteria for the model_ids selection_dict_list = self.get_selection_criteria_by_ids( model_ids=model_ids, are_model_ids_removed=False ) for selection_dict in selection_dict_list: # Now, for each model, we want to get the respective value for the metric try: # Maybe remove the case-sensitivity for metric here. metric_value = selection_dict[CONFIG_FILE_KEY_SELECTION][ CONFIG_FILE_KEY_PERFORMANCE ] metric_value = Utils.deep_get(base_dict=metric_value, key=metric) if metric_value is not None: # If metric value is None, the model is not added to the model_metric_dict_list # TODO Maybe add further validation of metric_value here, e.g. string to float conversion, etc. if isinstance(metric_value, list) and order == "asc": # Assumption: As order is ascending (smallest item at top of list), we want to get the # smallest (=best) possible value from our metric_value list. metric_value = min(metric_value) elif isinstance(metric_value, list): # Assumption: As order is descending (largest item at top of list), we want to get the # largest (=best) possible value from our metric_value list. metric_value = max(metric_value) model_id = selection_dict[MODEL_ID] model_metric_dict = {MODEL_ID: model_id, metric: metric_value} logging.debug( f"Model {model_id} was a match for the specified metric value: {model_metric_dict}" ) model_metric_dict_list.append(model_metric_dict) except KeyError as e: logging.debug( f"Model {selection_dict[MODEL_ID]} was discarded as it does not have the specified keys " f"in its selection dict: {selection_dict}" ) pass if order == "asc": # ascending -> the smallest item appears at the top of the list model_metric_dict_list.sort(key=lambda x: x.get(metric)) else: # descending -> the largest item appears at the top of the list model_metric_dict_list.sort(key=lambda x: x.get(metric), reverse=True) return model_metric_dict_list def find_models_and_rank( self, values: list, target_values_operator: str = "AND", are_keys_also_matched: bool = False, is_case_sensitive: bool = False, metric: str = "SSIM", order: str = "asc", ) -> list: """Search for values (and keys) in model configs, rank results and return sorted list of model dicts. Parameters ---------- values: list list of values used to search and find models corresponding to these `values` target_values_operator: str the operator indicating the relationship between `values` in the evaluation of model search results. Should be either "AND", "OR", or "XOR". are_keys_also_matched: bool flag indicating whether, apart from `values`, the keys in the model config should also be searchable is_case_sensitive: bool flag indicating whether the search for values (and) keys in the model config should be case-sensitive. metric: str The key in the selection dict that corresponds to the `metric` of interest order: str the sorting order of the ranked results. Should be either "asc" (ascending) or "desc" (descending) Returns ------- list a list of the searched and matched model dictionaries containing `metric` and `model_id`, sorted by `metric`. """ matching_models = self.find_matching_models_by_values( values=values, target_values_operator=target_values_operator, are_keys_also_matched=are_keys_also_matched, is_case_sensitive=is_case_sensitive, ) matching_model_ids = [model.model_id for model in matching_models] logging.debug(f"matching_model_ids: {matching_model_ids}") return self.rank_models_by_performance( model_ids=matching_model_ids, metric=metric, order=order ) def find_matching_models_by_values( self, values: list, target_values_operator: str = "AND", are_keys_also_matched: bool = False, is_case_sensitive: bool = False, ) -> list: """Search for values (and keys) in model configs and return a list of each matching `ModelMatchCandidate`. Uses `self.recursive_search_for_values` to recursively populate each `ModelMatchCandidate` with `MatchedEntry` instances. After populating, each `ModelMatchCandidate` is evaluated based on the provided `target_values_operator` and `values` list using `ModelMatchCandidate.check_if_is_match`. Parameters ---------- values: list list of values used to search and find models corresponding to these values target_values_operator: str the operator indicating the relationship between `values` in the evaluation of model search results. Should be either "AND", "OR", or "XOR". are_keys_also_matched: bool flag indicating whether, apart from values, the keys in the model config should also be searchable is_case_sensitive: bool flag indicating whether the search for values (and) keys in the model config should be case-sensitive. Returns ------- list a list of `ModelMatchCandidate` class instances each of which was successfully matched against the search values. """ assert ( values is not None and len(values) > 0 ), f"Please specify a list of values to search for. You specified: {values}." matching_models = [] if not is_case_sensitive: # Removing case-sensitivity search requirement by replacing with lowercase values list values = Utils.list_to_lowercase(target_list=values) logging.debug(f"Processed search values: {values}") for selection_dict in self.model_selection_dicts: selection_config = selection_dict[CONFIG_FILE_KEY_SELECTION] model_match_candidate = ModelMatchCandidate( model_id=selection_dict[MODEL_ID], target_values_operator=target_values_operator, is_case_sensitive=is_case_sensitive, target_values=values, are_keys_also_matched=are_keys_also_matched, ) model_match_candidate = self.recursive_search_for_values( search_dict=selection_config, model_match_candidate=model_match_candidate, ) if model_match_candidate.check_if_is_match(): logging.debug( f"Found a matching ModelMatchCandidate: {model_match_candidate}" ) matching_models.append(model_match_candidate) return matching_models def recursive_search_for_values( self, search_dict: dict, model_match_candidate: ModelMatchCandidate ): """Do a recursive search to match values in the `search_dict` with values (and keys) in a model's config. The provided `ModelMatchCandidate` instance is recursively populated with `MatchedEntry` instances. A `MatchedEntry` instance contains a key-value pair found in the model's config that matches with one search term of interest. The search terms of interest are stored in `ModelMatchCandidate.target_values`. The model's selection config is provided in the 'search_dict'. To traverse the `search_dict`, the value for each key in the `search_dict` is retrieved. - If that value is of type dictionary, the function calls itself with that nested dictionary as new `search_dict`. - If that value is of type list, each value in the list is compared with each search term of interest in `ModelMatchCandidate.target_values`. - If the value of the `search_dict` is of another type (i.e. str), it is compared with each search term of interest in `ModelMatchCandidate.target_values`. Parameters ---------- search_dict: dict contains keys and values from a model's config that are matched against a set of search values. model_match_candidate: ModelMatchCandidate a class instance representing a model to be prepared for evaluation (populated with `MatchedEntry` objects), as to whether it is a match given its search values (`self.target_values`). Returns ------- list a `ModelMatchCandidate` class instance that has been populated with `MatchedEntry` class instances. """ if search_dict is not None: counter = 0 for key in search_dict: if model_match_candidate.are_keys_also_matched and not isinstance( search_dict, list ): # Treating the key as a match due to a matching string in target_values. if ( not model_match_candidate.is_case_sensitive and key.lower() in model_match_candidate.target_values ): matched_entry = MatchedEntry( key="key", value=key, matching_element=key.lower() ) model_match_candidate.add_matched_entry( matched_entry=matched_entry ) elif key in model_match_candidate.target_values: matched_entry = MatchedEntry( key="key", value=key, matching_element=key ) model_match_candidate.add_matched_entry( matched_entry=matched_entry ) if isinstance(search_dict, list): # if we have a list we want the counter to get index position in list key_or_counter = counter else: # if we have something else i.e. a dict, we want to get the key to get nested dict key_or_counter = key if isinstance(search_dict[key_or_counter], dict): # The value of the key is of type dict, we thus search recursively inside that dictionary model_match_candidate = self.recursive_search_for_values( search_dict=search_dict[key_or_counter], model_match_candidate=model_match_candidate, ) elif isinstance(search_dict[key_or_counter], list): for item in search_dict[key_or_counter]: if not model_match_candidate.is_case_sensitive: item = str(item).lower() if str(item) in model_match_candidate.target_values: matched_entry = MatchedEntry( key=key, value=item, matching_element=str(item) ) model_match_candidate.add_matched_entry( matched_entry=matched_entry ) else: item = search_dict[key_or_counter] if not model_match_candidate.is_case_sensitive: item = str(item).lower() if str(item) in model_match_candidate.target_values: matched_entry = MatchedEntry( key=key, value=item, matching_element=str(item) ) model_match_candidate.add_matched_entry( matched_entry=matched_entry ) counter += counter return model_match_candidate def __repr__(self): return f"ModelSelector(model_ids={self.config_manager.model_ids})" def __len__(self): raise NotImplementedError def __getitem__(self, idx: int): raise NotImplementedError ================================================ FILE: src/medigan/utils.py ================================================ # -*- coding: utf-8 -*- # ! /usr/bin/env python """ `Utils` class providing generalized reusable functions for I/O, parsing, sorting, type conversions, etc. """ # Import python native libs import json import logging import os import shutil import time import zipfile from distutils.dir_util import copy_tree from pathlib import Path from urllib.parse import urlparse # python3 import numpy as np # Import pypi libs import requests from tqdm import tqdm class Utils: """Utils class containing reusable static methods.""" def __init__( self, ): pass @staticmethod def mkdirs(path_as_string: str) -> bool: """create folder in `path_as_string` if not already created.""" if not os.path.exists(path_as_string): try: os.makedirs(path_as_string) return True except Exception as e: logging.error( f"Error while creating folders for path {path_as_string}: {e}" ) return False return True @staticmethod def is_file_located_or_downloaded( path_as_string: str, download_if_not_found: bool = True, download_link: str = None, is_new_download_forced: bool = False, allow_local_path_as_url: bool = True, ) -> bool: """check if is file in `path_as_string` and optionally download the file (again).""" if not path_as_string.is_file() or is_new_download_forced: if not download_if_not_found: # download_if_not_found is prioritized over is_new_download_forced in this case, as users likely # prefer to avoid automated downloads altogether when setting download_if_not_found to False. logging.warning( f"File {path_as_string} was not found ({not path_as_string.is_file()}) or download " f"was forced ({is_new_download_forced}). However, downloading it from {download_link} " f"was not allowed: download_if_not_found == {download_if_not_found}. This may cause an " f"error, as the file might be outdated or missing, while being used in subsequent " f"workflows." ) return False else: try: if allow_local_path_as_url and not Utils.is_url_valid( the_url=download_link ): Utils.copy( source_path=download_link, target_path=os.path.split(path_as_string)[0], ) else: Utils.download_file( path_as_string=path_as_string, download_link=download_link ) except Exception as e: raise e return True @staticmethod def download_file( download_link: str, path_as_string: str, file_extension: str = ".json" ): """download a file using the `requests` lib and store in `path_as_string`""" logging.debug(f"Now downloading file {path_as_string} from {download_link} ...") try: for i in range(10): response = requests.get( download_link, allow_redirects=True, stream=True ) total_size_in_bytes = int( response.headers.get("content-length", 0) ) # / (32 * 1024) # 32*1024 bytes received by requests. logging.debug(total_size_in_bytes) block_size = 1024 progress_bar = tqdm( total=total_size_in_bytes, unit="B", unit_scale=True, position=0, leave=True, ascii=True, ) progress_bar.set_description(f"Downloading {download_link}") with open(path_as_string, "wb") as file: for data in response.iter_content(block_size): progress_bar.update(len(data)) file.write(data) logging.debug( f"Received response {response}: Retrieved file from {download_link} and wrote it " f"to {path_as_string}." ) try: if not ( download_link.endswith(file_extension) and Path(path_as_string).is_file() and str(path_as_string).endswith(file_extension) ): # If we do not download a json file (global.json), we assume a zip and want to check if the downloaded zip is valid. zipfile.ZipFile(path_as_string, "r") break except Exception as e: print(e) logging.debug( f"Download failed. Retrying download from {download_link}" ) except Exception as e: logging.error( f"Error while trying to download/copy from {download_link} to {path_as_string}:{e}" ) raise e @staticmethod def read_in_json(path_as_string) -> dict: """read a .json file and return as dict""" try: with open(path_as_string) as f: json_file = json.load(f) return json_file except Exception as e: logging.error( f"Error while reading in json file from {path_as_string}: {e}" ) raise e @staticmethod def unzip_archive(source_path: Path, target_path: str = "./"): """unzip a .zip archive in the `target_path`""" try: with zipfile.ZipFile(source_path, "r") as zip_ref: zip_ref.extractall(target_path) except Exception as e: logging.error(f"Error while unzipping {source_path}: {e}") raise e @staticmethod def unzip_and_return_unzipped_path(package_path: str): """if not already dir, unzip an archive with `Utils.unzip_archive`. Return path to unzipped dir/file""" if Path(package_path).is_file() and package_path.endswith(".zip"): # Get the source_path without .zip extension to unzip. package_path_unzipped = package_path[0:-4] # We have a zip. Let's unzip and do the same operation (with new path) Utils.unzip_archive( source_path=package_path, target_path_as_string=package_path_unzipped ) return package_path_unzipped elif Path(package_path).is_dir(): logging.info( f"Your package path ({package_path}) does already point to a directory. It was not unzipped." ) return package_path else: raise Exception( f"Your package path ({package_path}) does not point to a zip file nor directory. Please adjust and try again." ) @staticmethod def copy(source_path: Path, target_path: str = "./"): """copy a folder or file from `source_path` to `target_path`""" try: if Path(source_path).is_file(): shutil.copy2(src=source_path, dst=target_path) else: copy_tree(src=source_path, dst=target_path) except Exception as e: logging.error(f"Error while copying {source_path} to {target_path}: {e}") raise e @staticmethod def dict_to_lowercase(target_dict: dict, string_conversion: bool = True) -> dict: """transform values and keys in dict to lowercase, optionally with string conversion of the values. Warning: Does not convert nested dicts in the `target_dict`, but rather removes them from return object. """ if string_conversion: # keys should always be strings per default. values might differ in type. return dict((k.lower(), str(v).lower()) for k, v in target_dict.items()) else: return dict((k.lower(), v.lower()) for k, v in target_dict.items()) @staticmethod def list_to_lowercase(target_list: list) -> list: """string conversion and lower-casing of values in list. trade-off: String conversion for increased robustness > type failure detection """ return [str(x).lower() for x in target_list] @staticmethod def deep_get(base_dict: dict, key: str): """Split the key by "." to get value in nested dictionary.""" try: key_split = key.split(".") for key_ in key_split: base_dict = base_dict[key_] return base_dict except TypeError as e: logging.debug( f"No key ({key}) found in base_dict ({base_dict}) for this model. Fallback: Returning None." ) return None @staticmethod def is_url_valid(the_url: str) -> bool: """Checks if a url is valid using urllib.parse.urlparse""" try: result = urlparse(the_url) # testing if both result.scheme and result.netloc are non-empty strings (empty strings evaluate to False). return all([result.scheme, result.netloc]) except Exception: return False @staticmethod def has_more_than_n_diff_pixel_values(img: np.ndarray, n: int = 4) -> bool: """This function checks whether an image contains more than n different pixel values. This helps to differentiate between segmentation masks and actual images. """ import torch torch_img = torch.from_numpy(img) pixel_values_set = set(torch_img.flatten().tolist()) if len(pixel_values_set) > n: return True else: return False @staticmethod def split_images_masks_and_labels( data: list, num_samples: int, max_nested_arrays: int = 2 ) -> [list, list, list, list]: """Separates the data (sample, mask, other_imaging_data, label) returned by a generative model This functions expects a list of tuples as input `data` and assumes that each tuple contains sample, mask, other_imaging_data, label at index positions [0], [1], [2], and [3] respectively. samples, masks, and imaging data are expected to be of type np.ndarray and labels of type "str". For example, this extendable function assumes that, in data, a mask follows the image that it corresponds to or vice versa. """ samples = [] masks = [] other_imaging_output = [] labels = [] # if data is smaller than the number of samples that should have been generated, then data likely contains a nested array. # We go a maximum of max_nested_arrays deep into the data. counter = 0 while len(data) < num_samples and isinstance(data, list): data = data[0] counter = counter + 1 if counter >= max_nested_arrays: break for data_point in data: logging.debug(f"data_point: {data_point}") if isinstance(data_point, tuple): for i, item in enumerate(data_point): if isinstance(item, np.ndarray) and i == 0: samples.append(item) elif isinstance(item, np.ndarray) and i == 1: masks.append(item) elif isinstance(item, np.ndarray) and i == 2: other_imaging_output.append(item) elif isinstance(item, str): labels.append(item) elif isinstance(data_point, np.ndarray): # An image is expected in the case no tuple is returned samples.append(data_point) masks = None if len(masks) == 0 else masks other_imaging_output = ( None if len(other_imaging_output) == 0 else other_imaging_output ) labels = None if len(labels) == 0 else labels return samples, masks, other_imaging_output, labels @staticmethod def split_images_and_masks_no_ordering( data: list, num_samples: int, max_nested_arrays: int = 2 ) -> [np.ndarray, np.ndarray]: """Extracts and separates the masks from the images if a model returns both in the same np.ndarray. This extendable function assumes that, in data, a mask follows the image that it corresponds to or vice versa. - This function is deprecated. Please use `split_images_masks_and_labels` instead. """ images = [] masks = [] # if data is smaller than the number of samples that should have been generated, then data likely contains a nested array. # We go a maximum of max_nested_arrays deep into the data. counter = 0 while len(data) < num_samples: data = data[0] counter = counter + 1 if counter >= max_nested_arrays: break for data_point in data: logging.debug(f"data_point {data_point}") if isinstance(data_point, tuple): for i, sample in enumerate(data_point): if ( isinstance(i, np.ndarray) and "int" in str(i.dtype) and not Utils.has_more_than_n_diff_pixel_values(i) ): # Check if numpy array that contains integers instead of floats indicates the presence of a mask masks.append(i) elif Utils.has_more_than_n_diff_pixel_values(i): images.append(i) elif ( isinstance(data_point, np.ndarray) and "int" in str(data_point.dtype) and not Utils.has_more_than_n_diff_pixel_values(data_point) ): masks.append(data_point) else: images.append(data_point) masks = None if len(masks) == 0 else masks return images, masks @staticmethod def order_dict_by_value( dict_list, key: str, order: str = "asc", sort_algorithm="bubbleSort" ) -> list: """Sorting a list of dicts by the values of a specific key in the dict using a sorting algorithm. - This function is deprecated. You may use Python List sort() with key=lambda function instead. """ if sort_algorithm == "bubbleSort": for i in range(len(dict_list)): for j in range(len(dict_list) - i - 1): if dict_list[j][key] > dict_list[j + 1][key]: # no need for a temp variable holder dict_list[j][key], dict_list[j + 1][key] = ( dict_list[j + 1][key], dict_list[j][key], ) return dict_list @staticmethod def is_file_in(folder_path: str, filename: str) -> bool: """Checks if a file is inside a folder""" try: if ( Path(folder_path).is_dir() and Path(f"{folder_path}/{filename}").is_file() ): return True except Exception as e: logging.warning(f"File ({filename}) was not found in {folder_path}: {e}") return False @staticmethod def store_dict_as( dictionary, extension: str = ".json", output_path: str = "config/", filename: str = "metadata.json", ): """store a Python dictionary in file system as variable filetype.""" if extension not in output_path: Utils.mkdirs(path_as_string=output_path) if extension not in filename: filename = filename + extension output_path = f"{output_path}/{filename}" json_object = json.dumps(dictionary, indent=4) with open(output_path, "w") as outfile: outfile.write(json_object) @staticmethod def call_without_removable_params( my_callable, removable_param_values: list = [None], **params ): """call a callable without passing parameters that contain any of the removable_param_values as value.""" not_removed_params = params for removable_param_value in removable_param_values: if removable_param_value is None: not_removed_params = { k: v for k, v in not_removed_params.items() if v is not removable_param_value } else: not_removed_params = { k: v for k, v in not_removed_params.items() if v != removable_param_value } logging.debug( f"call_without_removable_params final params: {not_removed_params}" ) return my_callable(**not_removed_params) def __len__(self): raise NotImplementedError def __getitem__(self, idx: int): raise NotImplementedError ================================================ FILE: templates/examples/500.pt.txt ================================================ Download 500.pt file from: https://drive.google.com/file/d/1C9vVPymsKJ5i5gpwQM6cpX0y1G89vcpk/view?usp=sharing ================================================ FILE: templates/examples/LICENSE ================================================ MIT License Copyright (c) 2021 Richard Osuala, Noussair Lazrak Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: templates/examples/__init__.py ================================================ """ authors: Richard Osuala, Zuzanna Szafranowska BCN-AIM 2021 """ import logging import os from pathlib import Path import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.parallel class BaseGenerator(nn.Module): def __init__( self, nz: int, ngf: int, nc: int, ngpu: int, leakiness: float = 0.2, bias: bool = False, ): super(BaseGenerator, self).__init__() self.nz = nz self.ngf = ngf self.nc = nc self.ngpu = ngpu self.leakiness = leakiness self.bias = bias self.main = None def forward(self, input): raise NotImplementedError class Generator(BaseGenerator): def __init__( self, nz: int, ngf: int, nc: int, ngpu: int, image_size: int, conditional: bool, leakiness: float, bias: bool = False, n_cond: int = 10, is_condition_categorical: bool = False, num_embedding_dimensions: int = 50, ): super(Generator, self).__init__( nz=nz, ngf=ngf, nc=nc, ngpu=ngpu, leakiness=leakiness, bias=bias, ) # if is_condition_categorical is False, we model the condition as continous input to the network self.is_condition_categorical = is_condition_categorical # n_cond is only used if is_condition_categorical is True. self.num_embedding_input = n_cond # num_embedding_dimensions is only used if is_condition_categorical is True. # num_embedding_dimensions standard would be dim(z), but atm we have a nn.Linear after # nn.Embedding that upscales the dimension to self.nz. Using same value of num_embedding_dims in D and G. self.num_embedding_dimensions = num_embedding_dimensions # whether the is a conditional input into the GAN i.e. cGAN self.conditional: bool = conditional # The image size (supported params should be 128 or 64) self.image_size = image_size if self.image_size == 128: self.first_layers = nn.Sequential( # input is Z, going into a convolution nn.ConvTranspose2d( self.nz * self.nc, self.ngf * 16, 4, 1, 0, bias=self.bias ), nn.BatchNorm2d(self.ngf * 16), nn.ReLU(True), # state size. (ngf*16) x 4 x 4 nn.ConvTranspose2d( self.ngf * 16, self.ngf * 8, 4, 2, 1, bias=self.bias ), nn.BatchNorm2d(self.ngf * 8), nn.ReLU(True), ) elif self.image_size == 64: self.first_layers = nn.Sequential( # input is Z, going into a convolution nn.ConvTranspose2d( self.nz * self.nc, self.ngf * 8, 4, 1, 0, bias=self.bias ), nn.BatchNorm2d(self.ngf * 8), nn.ReLU(True), ) else: raise ValueError( f"Allowed image sizes are 128 and 64. You provided {self.image_size}. Please adjust." ) self.main = nn.Sequential( *self.first_layers.children(), # state size. (ngf*8) x 8 x 8 nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=self.bias), nn.BatchNorm2d(self.ngf * 4), nn.ReLU(True), # state size. (ngf*4) x 16 x 16 nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=self.bias), nn.BatchNorm2d(self.ngf * 2), nn.ReLU(True), # state size. (ngf*2) x 32 x 32 nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=self.bias), nn.BatchNorm2d(self.ngf), nn.ReLU(True), # state size. (ngf) x 64 x 64 # Note that out_channels=1 instead of out_channels=self.nc. # This is due to conditional input channel of our grayscale images nn.ConvTranspose2d( in_channels=self.ngf, out_channels=1, kernel_size=4, stride=2, padding=1, bias=self.bias, ), nn.Tanh(), # state size. (nc) x 128 x 128 ) if self.is_condition_categorical: self.embed_nn = nn.Sequential( # e.g. condition -> int -> embedding -> fcl -> feature map -> concat with image -> conv layers.. # embedding layer nn.Embedding( num_embeddings=self.num_embedding_input, embedding_dim=self.num_embedding_dimensions, ), # target output dim of dense layer is batch_size x self.nz x 1 x 1 # input is dimension of the embedding layer output nn.Linear( in_features=self.num_embedding_dimensions, out_features=self.nz ), # nn.BatchNorm1d(self.nz), nn.LeakyReLU(self.leakiness, inplace=True), ) else: self.embed_nn = nn.Sequential( # target output dim of dense layer is: nz x 1 x 1 # input is dimension of the numbers of embedding nn.Linear(in_features=1, out_features=self.nz), # TODO Ablation: How does BatchNorm1d affect the conditional model performance? nn.BatchNorm1d(self.nz), nn.LeakyReLU(self.leakiness, inplace=True), ) def forward(self, x, conditions=None): if self.conditional: # combining condition labels and input images via a new image channel if not self.is_condition_categorical: # If labels are continuous (not modelled as categorical), use floats instead of integers for labels. # Also adjust dimensions to (batch_size x 1) as needed for input into linear layer # labels should already be of type float, no change expected in .float() conversion (it is only a safety check) # Just for testing: conditions *= 0 conditions += 1 conditions = conditions.view(conditions.size(0), -1).float() embedded_conditions = self.embed_nn(conditions) embedded_conditions_with_random_noise_dim = embedded_conditions.view( -1, self.nz, 1, 1 ) x = torch.cat([x, embedded_conditions_with_random_noise_dim], 1) return self.main(x) def interval_mapping(image, from_min, from_max, to_min, to_max): # map values from [from_min, from_max] to [to_min, to_max] # image: input array from_range = from_max - from_min to_range = to_max - to_min # scale to interval [0,1] scaled = np.array((image - from_min) / float(from_range), dtype=float) # multiply by range and add minimum to get interval [min,range+min] return to_min + (scaled * to_range) def image_generator(model_path, device, nz, ngf, nc, ngpu, num_samples): # instantiate the model logging.debug("Instantiating model...") netG = Generator( nz=nz, ngf=ngf, nc=nc, ngpu=ngpu, image_size=128, leakiness=0.1, conditional=False, ) if device.type == "cuda": netG.cuda() # load the model's weights from state_dict *'.pt file logging.debug(f"Loading model weights from {model_path} ...") checkpoint = torch.load(model_path, map_location=device) try: netG.load_state_dict(state_dict=checkpoint["generator"]) except KeyError: raise KeyError( f"checkpoint['generator_state_dict'] was not found." ) # checkpoint={checkpoint}") logging.debug(f"Using retrieved model from generator_state_dict checkpoint") netG.eval() # generate the images logging.debug(f"Generating {num_samples} images using {device}...") z = torch.randn(num_samples, nz, 1, 1, device=device) images = netG(z).detach().cpu().numpy() image_list = [] for j, img_ in enumerate(images): image_list.append(img_) return image_list def save_generated_images(image_list, path): logging.debug(f"Saving generated images now in {path}") for i, img_ in enumerate(image_list): Path(path).mkdir(parents=True, exist_ok=True) img_path = f"{path}/{i}.png" img_ = interval_mapping(img_.transpose(1, 2, 0), -1.0, 0.0, 0, 255) img_ = img_.astype("uint8") cv2.imwrite(img_path, img_) logging.debug(f"Saved generated images to {path}") def return_images(image_list): # logging.debug(f"Returning generated images as {type(image_list)}.") processed_image_list = [] for i, img_ in enumerate(image_list): img_ = interval_mapping(img_.transpose(1, 2, 0), -1.0, 0.0, 0, 255) img_ = img_.astype("uint8") processed_image_list.append(img_) return processed_image_list def generate(model_file, num_samples, output_path, save_images: bool): """This function generates synthetic images of mammography regions of interest""" try: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ngpu = 0 if device == "cuda": ngpu = 1 image_list = image_generator(model_file, device, 100, 64, 1, ngpu, num_samples) if save_images: save_generated_images(image_list, output_path) else: return return_images(image_list) except Exception as e: logging.error( f"Error while trying to generate {num_samples} images with model {model_file}: {e}" ) raise e ================================================ FILE: templates/examples/metadata.json ================================================ { "00005_DCGAN_MMG_MASS_ROI": { "execution": { "package_name": "MMG_MASS_BCDR_DCGAN", "package_link": "ADD_ZENODO_OR_LOCAL_URL_HERE", "model_name": "500", "extension": ".pt", "image_size": [ 128, 128 ], "dependencies": [ "numpy", "torch", "opencv-contrib-python-headless" ], "generate_method": { "name": "generate", "args": { "base": [ "model_file", "num_samples", "output_path", "save_images" ], "custom": {} } } }, "selection": { "performance": { "SSIM": null, "MSE": null, "NSME": null, "PSNR": null, "IS": null, "turing_test": null, "FID_no_images":1000, "FID": 67.60, "FID_ratio": 0.497, "FID_RADIMAGENET": 1.27, "FID_RADIMAGENET_ratio": 0.197, "CLF_delta": null, "SEG_delta": null, "CLF": { "trained_on_fake": { "accuracy": 0.9528, "f1": 0.9721, "AUROC": 0.9596, "AUPRC": 0.9908 }, "trained_on_real_and_fake": {}, "trained_on_real": {} }, "SEG": { "trained_on_fake": {}, "trained_on_real_and_fake": {}, "trained_on_real": {} } }, "use_cases": [ "classification" ], "organ": [ "breast", "breasts", "chest" ], "modality": [ "MMG", "Mammography", "Mammogram", "full-field digital", "full-field digital MMG", "full-field MMG", "full-field Mammography", "digital Mammography", "digital MMG", "x-ray mammography" ], "vendors": [], "centres": [], "function": [ "noise to image", "image generation", "unconditional generation", "data augmentation" ], "condition": [], "dataset": [ "BCDR" ], "augmentations": [ "horizontal flip", "vertical flip" ], "generates": [ "mass", "masses", "mass roi", "mass ROI", "mass images", "mass region of interest", "nodule", "nodule", "nodule roi", "nodule ROI", "nodule images", "nodule region of interest" ], "height": 128, "width": 128, "depth": null, "type": "DCGAN", "license": "MIT", "dataset_type": "public", "privacy_preservation": null, "tags": [ "Mammogram", "Mammography", "Digital Mammography", "Full field Mammography", "Full-field Mammography", "128x128", "128 x 128", "MammoGANs", "Masses", "Nodules" ], "year": "2021" }, "description": { "title": "DCGAN Model for Mammogram MASS Patch Generation (Trained on BCDR)", "provided_date": "15 Dec 2021", "trained_date": "Nov 2021", "provided_after_epoch": 1500, "version": "0.0.1", "publication": null, "doi": [], "comment": "A deep convolutional generative adversarial network (DCGAN) that generates mass patches of mammograms. Pixel dimensions are 128x128. The DCGAN was trained on MMG patches from the BCDR dataset (Lopez et al, 2012). The uploaded ZIP file contains the files 500.pt (model weight), __init__.py (image generation method and utils), a requirements.txt, and the GAN model architecture (in pytorch) below the /src folder." } }, ================================================ FILE: templates/examples/requirements.txt ================================================ numpy torch opencv-contrib-python-headless ================================================ FILE: templates/examples/test.sh ================================================ #! /bin/bash echo "Running a test: Generate method of this medigan model module." echo "If not done already, please download 500.pt file from: https://drive.google.com/file/d/1C9vVPymsKJ5i5gpwQM6cpX0y1G89vcpk/view?usp=sharing" echo "1. Creating and activating virtual environment called MMG_env." python3 -m venv MMG_env source MMG_env/bin/activate echo "2. Pip install dependencies from requirements.txt" pip install -r requirements.txt echo "3. Run the generate function with parameters" python __init__.py python -c "from __init__ import generate; model_file='500.pt'; num_samples=10; output_path='images/'; save_images=True; generate(model_file=model_file, num_samples=num_samples, output_path=output_path, save_images=save_images)" echo "4. Done. Any errors? Have synthetic images been successfully created in folder /images?" ================================================ FILE: templates/raw_examples/LICENSE ================================================ copy license here ================================================ FILE: templates/raw_examples/__init__.py ================================================ ================================================ FILE: templates/raw_examples/metadata.json ================================================ { "MODEL_ID": { "execution": { "package_name": null, "package_link": null, "model_name": null, "extension": null, "image_size": [], "dependencies": [], "generate_method": { "name": null, "args": { "base": [ "model_file", "num_samples", "output_path", "save_images" ], "custom": {} } } }, "selection": { "performance": { "SSIM": null, "MSE": null, "NSME": null, "PSNR": null, "IS": null, "FID": null, "turing_test": null, "downstream_task": { "CLF": { "trained_on_fake": { "accuracy": null, "precision": null, "recall": null, "f1": null, "specificity": null, "AUROC": null, "AUPRC": null }, "trained_on_real_and_fake": {}, "trained_on_real": {} }, "SEG": { "trained_on_fake": { "dice": null, "jaccard": null, "accuracy": null, "precision": null, "recall": null, "f1": null }, "trained_on_real_and_fake": {}, "trained_on_real": {} } } }, "use_cases": [], "organ": [], "modality": [], "vendors": [], "centres": [], "function": [], "condition": [], "dataset": [], "augmentations": [], "generates": [], "height": null, "width": null, "depth": null, "type": null, "license": null, "dataset_type": null, "privacy_preservation": null, "tags": [], "year": null }, "description": { "title": null, "provided_date": null, "trained_date": null, "provided_after_epoch": null, "version": null, "publication": null, "doi": [], "comment": null } }, } ================================================ FILE: templates/raw_examples/model.pt ================================================ ================================================ FILE: templates/template.json ================================================ { "MODEL_ID": { "execution": { "package_name": "", "package_link": "", "model_name": "", "extension": "", "image_size": [], "dependencies": [], "generate_method": { "name": "", "args": { "base": [ "model_file", "num_samples", "output_path", "save_images" ], "custom": {} } } }, "selection": { "performance": { "SSIM": null, "MSE": null, "NSME": null, "PSNR": null, "IS": null, "FID": null, "turing_test": "", "downstream_task": { "CLF": { "trained_on_fake": { "accuracy": null, "precision": null, "recall": null, "f1": null, "specificity": null, "AUROC": null, "AUPRC": null }, "trained_on_real_and_fake": {}, "trained_on_real": {} }, "SEG": { "trained_on_fake": { "dice": null, "jaccard": null, "accuracy": null, "precision": null, "recall": null, "f1": null }, "trained_on_real_and_fake": {}, "trained_on_real": {} } } }, "use_cases": [], "organ": [], "modality": [], "vendors": [], "centres": [], "function": [], "condition": [], "dataset": [], "augmentations": [], "generates": [], "height": null, "width": null, "depth": null, "type": "", "license": "", "dataset_type": "", "privacy_preservation": "", "tags": [], "year": null }, "description": { "title": "", "provided_date": "", "trained_date": "", "provided_after_epoch": null, "version": "", "publication": "", "doi": [], "inputs": [], "comment": "" } } } ================================================ FILE: tests/__init__.py ================================================ ================================================ FILE: tests/fid.py ================================================ """ Calculates the Frechet Inception Distance between two distributions, using chosen feature extractor model. RadImageNet Model source: https://github.com/BMEII-AI/RadImageNet RadImageNet InceptionV3 weights (original, broken since 11.07.2023): https://drive.google.com/file/d/1p0q9AhG3rufIaaUE1jc2okpS8sdwN6PU RadImageNet InceptionV3 weights (for medigan, updated link 11.07.2023): https://drive.google.com/drive/folders/1lGFiS8_a5y28l4f8zpc7fklwzPJC-gZv Usage: python fid.py dir1 dir2 """ import argparse import os import cv2 import numpy as np import tensorflow as tf import tensorflow_gan as tfgan import wget from tensorflow.keras.applications import InceptionV3 from tensorflow.keras.applications.inception_v3 import preprocess_input img_size = 299 batch_size = 64 num_batches = 1 RADIMAGENET_URL = "https://drive.google.com/uc?id=1uvJHLG1K71Qzl7Km4JMpNOwE7iTjN8g9" RADIMAGENET_WEIGHTS = "RadImageNet-InceptionV3_notop.h5" IMAGENET_TFHUB_URL = "https://tfhub.dev/tensorflow/tfgan/eval/inception/1" def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Calculates the Frechet Inception Distance between two distributions using RadImageNet model." ) parser.add_argument( "dataset_path_1", type=str, help="Path to images from first dataset", ) parser.add_argument( "dataset_path_2", type=str, help="Path to images from second dataset", ) parser.add_argument( "--model", type=str, default="imagenet", help="Use RadImageNet feature extractor for FID calculation", ) parser.add_argument( "--lower_bound", action="store_true", help="Calculate lower bound of FID using the 50/50 split of images from dataset_path_1", ) parser.add_argument( "--normalize_images", action="store_true", help="Normalize images from both datasources using min and max of each sample", ) args = parser.parse_args() return args def load_images(directory, normalize=False, split=False, limit=None): """ Loads images from the given directory. If split is True, then half of the images is loaded to one array and the other half to another. """ if split: subset_1 = [] subset_2 = [] else: images = [] for count, filename in enumerate(os.listdir(directory)): if filename.lower().endswith((".png", ".jpg", ".jpeg")): img = cv2.imread(os.path.join(directory, filename)) img = cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_LINEAR) if normalize: img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX) if len(img.shape) > 2 and img.shape[2] == 4: img = img[:, :, :3] if len(img.shape) == 2: img = np.stack([img] * 3, axis=2) if split: if count % 2 == 0: subset_1.append(img) else: subset_2.append(img) else: images.append(img) if count == limit: break if split: subset_1 = preprocess_input(np.array(subset_1)) subset_2 = preprocess_input(np.array(subset_2)) return subset_1, subset_2 else: images = preprocess_input(np.array(images)) return images def check_model_weights(model_name): """ Checks if the model weights are available and download them if not. """ model_weights_path = None if model_name == "radimagenet": model_weights_path = RADIMAGENET_WEIGHTS if not os.path.exists(RADIMAGENET_WEIGHTS): print("Downloading RadImageNet InceptionV3 model:") wget.download( RADIMAGENET_URL, model_weights_path, ) print("\n") return model_weights_path def _radimagenet_fn(images): """ Get RadImageNet inception v3 model """ model = InceptionV3( weights=RADIMAGENET_WEIGHTS, input_shape=(img_size, img_size, 3), include_top=False, pooling="avg", ) output = model(images) output = tf.nest.map_structure(tf.keras.layers.Flatten(), output) return output def get_classifier_fn(model_name="imagenet"): """ Get model as TF function for optimized inference. """ check_model_weights(model_name) if model_name == "radimagenet": return _radimagenet_fn elif model_name == "imagenet": return tfgan.eval.classifier_fn_from_tfhub(IMAGENET_TFHUB_URL, "pool_3", True) else: raise ValueError("Model {} not recognized".format(model_name)) def calculate_fid( directory_1, directory_2, model_name, lower_bound=False, normalize_images=False, ): """ Calculates the Frechet Inception Distance between two distributions using chosen feature extractor model. """ limit = min(len(os.listdir(directory_1)), len(os.listdir(directory_2))) if lower_bound: images_1, images_2 = load_images(directory_1, split=True, limit=limit) else: images_1 = load_images(directory_1, limit=limit, normalize=normalize_images) images_2 = load_images(directory_2, limit=limit, normalize=normalize_images) fid = tfgan.eval.frechet_classifier_distance( images_1, images_2, get_classifier_fn(model_name) ) return fid if __name__ == "__main__": args = parse_args() directory_1 = args.dataset_path_1 directory_2 = args.dataset_path_2 lower_bound = args.lower_bound normalize_images = args.normalize_images model_name = args.model fid = calculate_fid( directory_1=directory_1, directory_2=directory_2, model_name=model_name, lower_bound=lower_bound, normalize_images=normalize_images, ) if lower_bound: print("Lower bound FID {}: {}".format(model_name, fid)) else: print("FID {}: {}".format(model_name, fid)) ================================================ FILE: tests/model_contribution_test_manual.py ================================================ # -*- coding: utf-8 -*- # ! /usr/bin/env python """ script for quick local testing if a new model can be added and works inside medigan.""" # run with python -m tests.model_contribution_test_manual import glob import logging import shutil import sys import unittest try: from src.medigan.generators import Generators LOGGING_LEVEL = "INFO" logger = logging.getLogger() # (__name__) logger.setLevel(LOGGING_LEVEL) stream_handler = logging.StreamHandler(sys.stdout) stream_handler.setLevel(LOGGING_LEVEL) formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) stream_handler.setFormatter(formatter) logger.addHandler(stream_handler) generators = Generators() # Testing init of contributor with correct params init_py_path = "../models/00012_C-DCGAN_MMG_MASSES/__init__.py" metadata_file_path = "../models/00012_C-DCGAN_MMG_MASSES/metadata.json" model_id = "00012_C-DCGAN_MMG_MASSES" zenodo_access_token = "ACCESS_TOKEN" github_access_token = "ACCESS_TOKEN" creator_name = "John Doe" creator_affiliation = "University of Barcelona" # Testing full model contribution workflow. generators.contribute( model_id=model_id, init_py_path=init_py_path, zenodo_access_token=zenodo_access_token, github_access_token=github_access_token, metadata_file_path=metadata_file_path, creator_name=creator_name, creator_affiliation=creator_affiliation, ) # Testing init of contributor with erroneous params # contributor = generators.add_model_contributor(model_id ='Some model id', init_py_path="somePath") # contributor = generators.add_model_contributor(model_id ='00008_WGANGP_MMG_MASS_ROI', init_py_path="somePath") # contributor = generators.add_model_contributor(model_id ='Some model id', init_py_path="init_py_path") # Creating the model contributor # generators.add_model_contributor(model_id=model_id, init_py_path=init_py_path) # Adding the metadata of the model from input # generators.add_metadata_from_file( # model_id=model_id, metadata_file_path=metadata_file_path # ) # Alternatively, Adding the metadata of the model from file # metadata = contributor.add_metadata_from_input( # model_weights_name = "10000", # model_weights_extension=".pt", # generate_method_name = "generate", # dependencies=["numpy", "torch", "opencv-contrib-python-headless"]) # Add metadata to global.json config # generators.test_model(model_id=model_id) # Alternatively, explicitely providing model metadata to add the metadata to config # generators._add_model_to_config(model_id=model_id, metadata=metadata, metadata_file_path=metadata_file_path, # overwrite_existing_metadata=True) # Zenodo upload test # generators.push_to_zenodo( # model_id=model_id, # access_token=zenodo_access_token, # creator_name="test", # creator_affiliation="test affiliation", # ) # Manual Zenodo Test 1 # import requests # r = requests.get('https://zenodo.org/api/deposit/depositions', params = {'access_token': zenodo_access_token}) # print(r.status_code) # print(r.json()) # Manual Zenodo Test 2 # headers = {"Content-Type": "application/json"} # params = {"access_token": zenodo_access_token} # r = requests.post( # "https://zenodo.org/api/deposit/depositions", # params=params, # json={}, # headers=headers, # ) # print(r.json()) # print(r.status_code) # Github upload test # generators.push_to_github( # model_id=model_id, # github_access_token=github_access_token, # package_link=None, # creator_name="test", # creator_affiliation="test affiliation", # model_description="test description", # ) except Exception as e: logging.error(f"test_init_generators error: {e}") raise e ================================================ FILE: tests/model_integration_test_manual.py ================================================ # -*- coding: utf-8 -*- # ! /usr/bin/env python """ script for quick local testing if a model works inside medigan.""" # run with python -m tests.model_integration_test_manual import logging MODEL_ID = "YOUR_MODEL_ID_HERE" MODEL_ID = 23 # "00023_PIX2PIXHD_BREAST_DCEMRI" #"00002_DCGAN_MMG_MASS_ROI" # "00007_BEZIERCURVE_TUMOUR_MASK" NUM_SAMPLES = 2 OUTPUT_PATH = f"output/{MODEL_ID}/" try: from src.medigan.generators import Generators generators = Generators() except Exception as e: logging.error(f"test_init_generators error: {e}") raise e generators.generate( model_id=MODEL_ID, num_samples=NUM_SAMPLES, output_path=OUTPUT_PATH, input_path="input/", gpu_id=0, image_size=448, install_dependencies=True, ) data_loader = generators.get_as_torch_dataloader( model_id=MODEL_ID, num_samples=NUM_SAMPLES, output_path=OUTPUT_PATH, input_path="input/", gpu_id=0, image_size=448, # prefetch_factor=2, # debugging with torch v2.0.0: This will raise an error for torch DataLoader if num_workers == None at the same time. ) print(f"len(data_loader): {len(data_loader)}") if len(data_loader) != NUM_SAMPLES: logging.warning( f"{MODEL_ID}: The number of samples in the dataloader (={len(data_loader)}) is not equal the number of samples requested (={NUM_SAMPLES})." ) #### Get the object at index 0 from the dataloader data_dict = next(iter(data_loader)) print(f"data_dict: {data_dict}") ================================================ FILE: tests/test_model_executor.py ================================================ # -*- coding: utf-8 -*- # ! /usr/bin/env python """ main test script to test the primary functions/classes/methods. """ # run with python -m tests.test_generator import glob import logging import os import shutil import sys import pytest import torch # import unittest # Set the logging level depending on the level of detail you would like to have in the logs while running the tests. LOGGING_LEVEL = logging.INFO # WARNING # logging.INFO models_with_args = [ ( "00001_DCGAN_MMG_CALC_ROI", {}, 100, ), # 100 samples to test automatic batch-wise image generation in model_executor ( "00002", {}, 3, ), # "00002" instead of "00002_DCGAN_MMG_MASS_ROI" to test shortcut model_ids ( "03", {"translate_all_images": False}, 2, ), # "03" instead of "00003_CYCLEGAN_MMG_DENSITY_FULL" to test shortcut model_ids ( 4, # 4 instead of "00004_PIX2PIX_MMG_MASSES_W_MASKS" to test shortcut model_ids { "shapes": ["oval"], "ssim_threshold": 0.18, "image_size": [128, 128], "patch_size": [30, 30], }, 3, ), ("00005_DCGAN_MMG_MASS_ROI", {}, 3), ("00006_WGANGP_MMG_MASS_ROI", {}, 3), ( "00007_INPAINT_BRAIN_MRI", { "image_size": (256, 256), "num_inpaints_per_sample": 2, "randomize_input_image_order": False, "add_variations_to_mask": False, "x_center": 120, "y_center": 140, "radius_1": 8, "radius_2": 12, "radius_3": 24, }, 3, ), ( "00008_C-DCGAN_MMG_MASSES", {"condition": 0, "is_cbisddsm_training_data": False}, 3, ), ("00009_PGGAN_POLYP_PATCHES_W_MASKS", {"save_option": "image_only"}, 3), ("00010_FASTGAN_POLYP_PATCHES_W_MASKS", {"save_option": "image_only"}, 3), # ("00011_SINGAN_POLYP_PATCHES_W_MASKS", {"checkpoint_ids": [999]}, 3), # removed after successful testing due to limited CI pipeline capacity # ("00012_C-DCGAN_MMG_MASSES", {"condition": 0}, 3), # removed after successful testing due to limited CI pipeline capacity # ("00013_CYCLEGAN_MMG_DENSITY_OPTIMAM_MLO", {"translate_all_images": False}, 2), # removed after successful testing due to limited CI pipeline capacity # ("00014_CYCLEGAN_MMG_DENSITY_OPTIMAM_CC", {"translate_all_images": False}, 2), # removed after successful testing due to limited CI pipeline capacity # ("00015_CYCLEGAN_MMG_DENSITY_CSAW_MLO", {"translate_all_images": False}, 2), # removed after successful testing due to limited CI pipeline capacity # ("00016_CYCLEGAN_MMG_DENSITY_CSAW_CC", {"translate_all_images": False}, 2), # removed after successful testing due to limited CI pipeline capacity ("00017_DCGAN_XRAY_LUNG_NODULES", {}, 3), ("00018_WGANGP_XRAY_LUNG_NODULES", {}, 3), ("00019_PGGAN_CHEST_XRAY", {}, 3), ("00020_PGGAN_CHEST_XRAY", {"resize_pixel_dim": 512, "image_size": 256}, 3), ( "00021_CYCLEGAN_BRAIN_MRI_T1_T2", { "input_path": "models/00021_CYCLEGAN_Brain_MRI_T1_T2/inputs/T2", "gpu_id": 0, "T1_to_T2": False, }, 3, ), ("00022_WGAN_CARDIAC_AGING", {}, 3), ( "00023_PIX2PIXHD_BREAST_DCEMRI", { "input_path": "input", "gpu_id": 0, "image_size": 448, }, 3, ), ] # class TestMediganExecutorMethods(unittest.TestCase): class TestMediganExecutorMethods: def setup_class(self): ## unittest logger config # This logger on root level initialized via logging.getLogger() will also log all log events # from the medigan library. Pass a logger name (e.g. __name__) instead if you only want logs from tests.py self.logger = logging.getLogger() # (__name__) self.logger.setLevel(LOGGING_LEVEL) stream_handler = logging.StreamHandler(sys.stdout) stream_handler.setLevel(LOGGING_LEVEL) formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) stream_handler.setFormatter(formatter) self.logger.addHandler(stream_handler) self.test_output_path = "test_output_path" self.num_samples = 2 self.test_imports_and_init_generators(self) self._remove_dir_and_contents(self) # in case something is left there. self.model_ids = self.generators.config_manager.model_ids def test_imports_and_init_generators(self): from src.medigan.constants import ( CONFIG_FILE_KEY_EXECUTION, CONFIG_FILE_KEY_GENERATE, CONFIG_FILE_KEY_GENERATE_ARGS_INPUT_LATENT_VECTOR_SIZE, ) from src.medigan.generators import Generators self.generators = Generators() self.CONFIG_FILE_KEY_EXECUTION = CONFIG_FILE_KEY_EXECUTION self.CONFIG_FILE_KEY_GENERATE = CONFIG_FILE_KEY_GENERATE self.CONFIG_FILE_KEY_GENERATE_ARGS_INPUT_LATENT_VECTOR_SIZE = ( CONFIG_FILE_KEY_GENERATE_ARGS_INPUT_LATENT_VECTOR_SIZE ) @pytest.mark.parametrize("models_with_args", [models_with_args]) def test_sample_generation_methods(self, models_with_args: list): self.logger.debug(f"models: {models_with_args}") for i, model_id in enumerate(self.model_ids): # if ( # model_id != "00011_SINGAN_POLYP_PATCHES_W_MASKS" # ): ## avoiding full memory on Windows ci test server # continue self.logger.debug(f"Now testing model {model_id}") self._remove_dir_and_contents() # Already done in each test independently, but to be sure, here again. self.test_generate_method(model_id=model_id) # Check if args available fo model_id. Note: The models list may not include the latest medigan models for model in models_with_args: if model_id == model[0]: self.test_generate_method_with_additional_args( model_id=model[0], args=model[1], expected_num_samples=model[2] ) self.test_get_generate_method(model_id=model_id) self.test_get_dataloader_method(model_id=model_id) # if i == 16: # just for local testing # self._remove_model_dir_and_zip( # model_ids=[model_id], are_all_models_deleted=False # ) @pytest.mark.parametrize( "values_list, should_sample_be_generated", [ (["dcgan", "mMg", "ClF", "modality", "inbreast"], True), (["dcgan", "mMg", "ClF", "modality", "optimam"], True), (["dcgan", "mMg", "ClF", "modalities"], False), ], ) def test_find_model_and_generate_method( self, values_list, should_sample_be_generated ): self._remove_dir_and_contents() self.generators.find_model_and_generate( values=values_list, target_values_operator="AND", are_keys_also_matched=True, is_case_sensitive=False, num_samples=self.num_samples, output_path=self.test_output_path, ) self._check_if_samples_were_generated( should_sample_be_generated=should_sample_be_generated ) @pytest.mark.parametrize( "values_list, metric", [ (["dcgan", "MMG"], "CLF.trained_on_real_and_fake.f1"), (["dcgan", "MMG"], "turing_test.AUC"), ], ) def test_find_and_rank_models_then_generate_method(self, values_list, metric): self._remove_dir_and_contents() # TODO This test needs the respective metrics for any of these models to be available in config/global.json. # These values would need to find at least two models. self.generators.find_models_rank_and_generate( values=values_list, target_values_operator="AND", are_keys_also_matched=True, is_case_sensitive=False, metric=metric, order="asc", num_samples=self.num_samples, output_path=self.test_output_path, ) self._check_if_samples_were_generated() # @pytest.mark.parametrize("model_id", [model[0] for model in models_with_args]) @pytest.mark.skip def test_generate_method(self, model_id): self._remove_dir_and_contents() self.generators.generate( model_id=model_id, num_samples=self.num_samples, output_path=self.test_output_path, install_dependencies=True, ) self._check_if_samples_were_generated(model_id=model_id) # @pytest.mark.parametrize("model_id, args, expected_num_samples", models_with_args) @pytest.mark.skip def test_generate_method_with_additional_args( self, model_id, args, expected_num_samples ): self._remove_dir_and_contents() self.generators.generate( model_id=model_id, num_samples=expected_num_samples, output_path=self.test_output_path, **args, ) self._check_if_samples_were_generated( model_id=model_id, num_samples=expected_num_samples ) # @pytest.mark.parametrize("model_id", [model[0] for model in models_with_args]) @pytest.mark.skip def test_get_generate_method(self, model_id): self._remove_dir_and_contents() gen_function = self.generators.get_generate_function( model_id=model_id, num_samples=self.num_samples, output_path=self.test_output_path, ) gen_function() self._check_if_samples_were_generated(model_id=model_id) del gen_function # @pytest.mark.parametrize("model_id", [model[0] for model in models_with_args]) @pytest.mark.skip def test_get_dataloader_method(self, model_id): self._remove_dir_and_contents() data_loader = self.generators.get_as_torch_dataloader( model_id=model_id, num_samples=self.num_samples ) self.logger.debug(f"{model_id}: len(data_loader): {len(data_loader)}") if len(data_loader) != self.num_samples: logging.warning( f"{model_id}: The number of samples in the dataloader (={len(data_loader)}) is not equal the number of samples requested (={self.num_samples}). " f"Hint: Revise if the model's internal generate() function returned tuples as required in get_as_torch_dataloader()." ) #### Get the object at index 0 from the dataloader data_dict = next(iter(data_loader)) # Test if the items at index [0] of the aforementioned object is of type torch tensor (e.g. torch.uint8) and not None, as expected by data structure design decision. assert torch.is_tensor(data_dict.get("sample")) # Test if the items at index [1], [2] of the aforementioned object are None and, if not, whether they are of type torch tensor, as expected assert data_dict.get("mask") is None or torch.is_tensor(data_dict.get("mask")) assert data_dict.get("other_imaging_output") is None or torch.is_tensor( data_dict.get("other_imaging_output") ) # Test if the items at index [3] of the aforementioned object is None and, if not, whether it is of type list of strings, as expected. assert data_dict.get("label") is None or ( isinstance(data_dict.get("label"), list) and isinstance(data_dict.get("label")[0], str) ) del data_dict del data_loader # @pytest.mark.parametrize("model_id", [model[0] for model in models_with_args]) @pytest.mark.skip def test_visualize_method(self, model_id): if ( self.CONFIG_FILE_KEY_GENERATE_ARGS_INPUT_LATENT_VECTOR_SIZE in self.generators.config_manager.config_dict[model_id][ self.CONFIG_FILE_KEY_EXECUTION ][self.CONFIG_FILE_KEY_GENERATE] ): self.generators.visualize(model_id, auto_close=True) else: with pytest.raises(Exception) as e: self.generators.visualize(model_id, auto_close=True) assert e.type == ValueError @pytest.mark.skip def _check_if_samples_were_generated( self, model_id=None, num_samples=None, should_sample_be_generated: bool = True ): # check if the number of generated samples of model_id_1 is as expected. file_list = glob.glob(self.test_output_path + "/*") self.logger.debug(f"{model_id}: {len(file_list)} == {self.num_samples} ?") if num_samples is None: num_samples = self.num_samples if should_sample_be_generated: assert ( len(file_list) == num_samples or len(file_list) == num_samples * 2 * 6 # 00007_INPAINT_BRAIN_MRI: 2 inpaints per sample, 6 outputs per sample or len(file_list) == num_samples * 2 # Temporary fix for different outputs per model. or len(file_list) == num_samples + 1 ), f"Model {model_id} generated {len(file_list)} samples instead of the expected {num_samples}, {num_samples*2*6}, or {num_samples + 1}." # Some models are balanced per label by default: If num_samples is odd, then len(file_list)==num_samples +1 else: assert len(file_list) == 0 # @pytest.mark.skip def _remove_dir_and_contents(self): """After each test, empty the created folders and files to avoid corrupting a new test.""" try: shutil.rmtree(self.test_output_path) except OSError as e: # This may give an error if the folders are not created. self.logger.debug( f"Exception while trying to delete folder. Likely it simply had not yet been created: {e}" ) except Exception as e2: self.logger.error(f"Error while trying to delete folder: {e2}") @pytest.mark.skip def _remove_model_dir_and_zip( self, model_ids=[], are_all_models_deleted: bool = False ): """After a specific model folders, model_executor, and model zip file to avoid running out-of-disk space.""" try: for i, model_executor in enumerate(self.generators.model_executors): if are_all_models_deleted or ( model_ids is not None and model_executor.model_id in model_ids ): try: # Delete the folder containing the model model_path = os.path.dirname( model_executor.deserialized_model_as_lib.__file__ ) shutil.rmtree(model_path) self.logger.info( f"Deleted directory of model {model_executor.model_id}. ({model_path})" ) except OSError as e: # This may give an error if the FOLDER is not present self.logger.warning( f"Exception while trying to delete the model folder of model {model_executor.model_id}: {e}" ) try: # If the downloaded zip package of the model was not deleted inside the model_path, we explicitely delete it now. if model_executor.package_path.is_file(): os.remove(model_executor.package_path) self.logger.info( f"Deleted zip file of model {model_executor.model_id}. ({model_executor.package_path})" ) except Exception as e: self.logger.warning( f"Exception while trying to delete the ZIP file ({model_executor.package_path}) of model {model_executor.model_id}: {e}" ) # Deleting the stateful model_executors instantiated by the generators module, after deleting folders and zips if are_all_models_deleted: self.generators.model_executors.clear() else: if model_ids is not None: for model_id in model_ids: model_executor = self.generators.find_model_executor_by_id( model_id ) if model_executor is not None: self.generators.model_executors.remove(model_executor) del model_executor except Exception as e2: self.logger.error( f"Error while trying to delete model folders and zips: {e2}" ) # @pytest.fixture(scope="session", autouse=True) def teardown_class(self): """After all tests, empty the large model folders, model_executors, and zip files to avoid running out-of-disk space.""" # yield is at test-time, signaling that things after yield are run after the execution of the last test has terminated # https://docs.pytest.org/en/7.1.x/reference/reference.html?highlight=fixture#pytest.fixture # yield None # Remove all test outputs in test_output_path self._remove_dir_and_contents(self) # Remove all model folders, zip files and model executors # self._remove_model_dir_and_zip( # self, model_ids=["00006_WGANGP_MMG_MASS_ROI"], are_all_models_deleted=False # ) # just for local testing # self._remove_model_dir_and_zip( # self, model_ids=None, are_all_models_deleted=True # ) ================================================ FILE: tests/test_model_selector.py ================================================ # -*- coding: utf-8 -*- # ! /usr/bin/env python """ main test script to test the primary functions/classes/methods. """ # run with python -m tests.test_model_selector import logging import sys import pytest # import unittest # Set the logging level depending on the level of detail you would like to have in the logs while running the tests. LOGGING_LEVEL = logging.INFO # WARNING # logging.INFO models = [ ( "00001_DCGAN_MMG_CALC_ROI", {}, 100, ), ("00002_DCGAN_MMG_MASS_ROI", {}, 3), ("00003_CYCLEGAN_MMG_DENSITY_FULL", {"translate_all_images": False}, 2), ("00005_DCGAN_MMG_MASS_ROI", {}, 3), # Further models can be added here if/when needed. ] # class TestMediganSelectorMethods(unittest.TestCase): class TestMediganSelectorMethods: def setup_method(self): ## unittest logger config # This logger on root level initialized via logging.getLogger() will also log all log events # from the medigan library. Pass a logger name (e.g. __name__) instead if you only want logs from tests.py self.logger = logging.getLogger() # (__name__) self.logger.setLevel(LOGGING_LEVEL) stream_handler = logging.StreamHandler(sys.stdout) stream_handler.setLevel(LOGGING_LEVEL) formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) stream_handler.setFormatter(formatter) self.logger.addHandler(stream_handler) self.test_init_generators() def test_init_generators(self): from src.medigan.generators import Generators self.generators = Generators() @pytest.mark.parametrize( "values_list", [ (["dcgan", "mMg", "ClF", "modality"]), (["DCGAN", "Mammography"]), ], ) def test_search_for_models_method(self, values_list): found_models = self.generators.find_matching_models_by_values( values=values_list, target_values_operator="AND", are_keys_also_matched=True, is_case_sensitive=False, ) self.logger.debug( f"For value {values_list}, these models were found: {found_models}" ) assert len(found_models) > 0 @pytest.mark.parametrize( "models, values_list, metric", [ ( models, ["dcgan", "MMG"], "CLF.trained_on_real_and_fake.f1", ), (models, ["dcgan", "MMG"], "turing_test.AUC"), ], ) def test_find_and_rank_models_by_performance(self, models, values_list, metric): # These values would need to find at least two models. See metrics and values in the config/global.json file. found_ranked_models = self.generators.find_models_and_rank( values=values_list, target_values_operator="AND", are_keys_also_matched=True, is_case_sensitive=False, metric=metric, order="desc", ) assert ( len(found_ranked_models) > 0 # some models were found as is expected and found_ranked_models[0]["model_id"] is not None # has a model id and ( len(found_ranked_models) < 2 or found_ranked_models[0][metric] > found_ranked_models[1][metric] ) # descending order (the higher a model's value, the lower its index in the list) is working ) @pytest.mark.parametrize( "models, metric, order", [ ( models, "FID", "asc", ), # Note: normally a lower FID is better, therefore asc (model with lowest FID has lowest result list index). ( models, "FID_RADIMAGENET_ratio", "desc", # descending, as the higher the FID ratio the better. ), # Note: normally a lower FID is better, therefore asc (model with lowest FID has lowest result list index). (models, "CLF.trained_on_real_and_fake.f1", "desc"), (models, "turing_test.AUC", "desc"), ], ) def test_rank_models_by_performance(self, models, metric, order): """Ranking according to metrics in the config/global.json file.""" ranked_models = self.generators.rank_models_by_performance( model_ids=None, metric=metric, order=order, ) assert ( len(ranked_models) > 0 # at least one model was found and ( len(ranked_models) >= 21 or metric != "FID" ) # we should find at least 21 models with FID in medigan and ranked_models[0]["model_id"] is not None # found model has a model id (i.e. correctly formatted results) and ( len(ranked_models) == 1 or ( ranked_models[0][metric] > ranked_models[1][metric] or metric == "FID" ) ) # descending order (the higher a model's value, the lower its index in the list) is working. In case of FID it is the other way around (ascending order is better). ) @pytest.mark.parametrize( "models, metric, order", [ (models, "CLF.trained_on_real_and_fake.f1", "desc"), (models, "turing_test.AUC", "desc"), ], ) def test_rank_models_by_performance_with_given_ids(self, models, metric, order): """Ranking a specified set of models according to metrics in the config/global.json file.""" ranked_models = self.generators.rank_models_by_performance( model_ids=[models[1][0], models[2][0]], metric=metric, order=order, ) assert 0 < len(ranked_models) <= 2 and ( len(ranked_models) < 2 or (ranked_models[0][metric] > ranked_models[1][metric]) ) # checking if descending order (the higher a model's value, the lower its index in the list) is working. @pytest.mark.parametrize( "key1, value1, expected", [ ("modality", "Full-Field Mammography", 2), ("license", "BSD", 2), ("performance.CLF.trained_on_real_and_fake.f1", "0.96", 0), ("performance.turing_test.AUC", "0.56", 0), ], ) def test_get_models_by_key_value_pair(self, key1, value1, expected): found_models = self.generators.get_models_by_key_value_pair( key1=key1, value1=value1, is_case_sensitive=False ) assert len(found_models) >= expected