[
  {
    "path": ".github/CODEOWNERS",
    "content": "# See: https://help.github.com/en/articles/about-code-owners\n#\n# Owners will be requested for review when someone opens a pull request.\n*       @jantic @alexandrevicenzi\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n\n# DeOldify\ndata\n*SymbolicLinks.sh\n*.ipynb_checkpoints*\nColorizeTraining*[0-9]*.ipynb\n*Colorizer[0-9]*.ipynb\nlesson7-superres*.ipynb\ntest.py\nresult_images\n*.prof\n*.pth\nvideo\ntest_images/*.jpg\ntest_images/*.JPG\ntest_images/*.PNG\ntest_images/*.png\ntest_images/*.jpeg\ntest_images/*.JPEG \ndeoldify/.ipynb_checkpoints/*-checkpoint.py\ntmp*\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n-   repo: https://github.com/ambv/black\n    rev: stable\n    hooks:\n    - id: black\n      args: [-S]\n      language_version: python3.6\n"
  },
  {
    "path": ".pylintrc",
    "content": "[MASTER]\n\n# A comma-separated list of package or module names from where C extensions may\n# be loaded. Extensions are loading into the active Python interpreter and may\n# run arbitrary code.\nextension-pkg-whitelist=\n\n# Add files or directories to the blacklist. They should be base names, not\n# paths.\nignore=CVS\n\n# Add files or directories matching the regex patterns to the blacklist. The\n# regex matches against base names, not paths.\nignore-patterns=\n\n# Python code to execute, usually for sys.path manipulation such as\n# pygtk.require().\n#init-hook='import sys; sys.path.append(\"./venv/lib/python3.7/site-packages\")'\n\n# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the\n# number of processors available to use.\njobs=1\n\n# Control the amount of potential inferred values when inferring a single\n# object. This can help the performance when dealing with large functions or\n# complex, nested conditions.\nlimit-inference-results=100\n\n# List of plugins (as comma separated values of python modules names) to load,\n# usually to register additional checkers.\nload-plugins=\n\n# Pickle collected data for later comparisons.\npersistent=yes\n\n# Specify a configuration file.\n#rcfile=\n\n# When enabled, pylint would attempt to guess common misconfiguration and emit\n# user-friendly hints instead of false-positive error messages.\nsuggestion-mode=yes\n\n# Allow loading of arbitrary C extensions. Extensions are imported into the\n# active Python interpreter and may run arbitrary code.\nunsafe-load-any-extension=no\n\n\n[MESSAGES CONTROL]\n\n# Only show warnings with the listed confidence levels. Leave empty to show\n# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.\nconfidence=\n\n# Disable the message, report, category or checker with the given id(s). You\n# can either give multiple identifiers separated by comma (,) or put this\n# option multiple times (only on the command line, not in the configuration\n# file where it should appear only once). You can also use \"--disable=all\" to\n# disable everything first and then reenable specific checks. For example, if\n# you want to run only the similarities checker, you can use \"--disable=all\n# --enable=similarities\". If you want to run only the classes checker, but have\n# no Warning level messages displayed, use \"--disable=all --enable=classes\n# --disable=W\".\ndisable=print-statement,\n        parameter-unpacking,\n        unpacking-in-except,\n        old-raise-syntax,\n        backtick,\n        long-suffix,\n        old-ne-operator,\n        old-octal-literal,\n        import-star-module-level,\n        non-ascii-bytes-literal,\n        raw-checker-failed,\n        bad-inline-option,\n        locally-disabled,\n        locally-enabled,\n        file-ignored,\n        suppressed-message,\n        useless-suppression,\n        deprecated-pragma,\n        use-symbolic-message-instead,\n        apply-builtin,\n        basestring-builtin,\n        buffer-builtin,\n        cmp-builtin,\n        coerce-builtin,\n        execfile-builtin,\n        file-builtin,\n        long-builtin,\n        raw_input-builtin,\n        reduce-builtin,\n        standarderror-builtin,\n        unicode-builtin,\n        xrange-builtin,\n        coerce-method,\n        delslice-method,\n        getslice-method,\n        setslice-method,\n        no-absolute-import,\n        old-division,\n        dict-iter-method,\n        dict-view-method,\n        next-method-called,\n        metaclass-assignment,\n        indexing-exception,\n        raising-string,\n        reload-builtin,\n        oct-method,\n        hex-method,\n        nonzero-method,\n        cmp-method,\n        input-builtin,\n        round-builtin,\n        intern-builtin,\n        unichr-builtin,\n        map-builtin-not-iterating,\n        zip-builtin-not-iterating,\n        range-builtin-not-iterating,\n        filter-builtin-not-iterating,\n        using-cmp-argument,\n        eq-without-hash,\n        div-method,\n        idiv-method,\n        rdiv-method,\n        exception-message-attribute,\n        invalid-str-codec,\n        sys-max-int,\n        bad-python3-import,\n        deprecated-string-function,\n        deprecated-str-translate-call,\n        deprecated-itertools-function,\n        deprecated-types-field,\n        next-method-defined,\n        dict-items-not-iterating,\n        dict-keys-not-iterating,\n        dict-values-not-iterating,\n        deprecated-operator-function,\n        deprecated-urllib-function,\n        xreadlines-attribute,\n        deprecated-sys-function,\n        exception-escape,\n        comprehension-escape,\n        # Disabled due Black\n        bad-continuation,\n        bad-whitespace,\n        # We don't care about these\n        redundant-keyword-arg,\n\n# Enable the message, report, category or checker with the given id(s). You can\n# either give multiple identifier separated by comma (,) or put this option\n# multiple time (only on the command line, not in the configuration file where\n# it should appear only once). See also the \"--disable\" option for examples.\nenable=c-extension-no-member\n\n\n[REPORTS]\n\n# Python expression which should return a note less than 10 (10 is the highest\n# note). You have access to the variables errors warning, statement which\n# respectively contain the number of errors / warnings messages and the total\n# number of statements analyzed. This is used by the global evaluation report\n# (RP0004).\nevaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)\n\n# Template used to display messages. This is a python new-style format string\n# used to format the message information. See doc for all details.\n#msg-template=\n\n# Set the output format. Available formats are text, parseable, colorized, json\n# and msvs (visual studio). You can also give a reporter class, e.g.\n# mypackage.mymodule.MyReporterClass.\noutput-format=text\n\n# Tells whether to display a full report or only the messages.\nreports=no\n\n# Activate the evaluation score.\nscore=yes\n\n\n[REFACTORING]\n\n# Maximum number of nested blocks for function / method body\nmax-nested-blocks=5\n\n# Complete name of functions that never returns. When checking for\n# inconsistent-return-statements if a never returning function is called then\n# it will be considered as an explicit return statement and no message will be\n# printed.\nnever-returning-functions=sys.exit\n\n\n[LOGGING]\n\n# Logging modules to check that the string format arguments are in logging\n# function parameter format.\nlogging-modules=logging\n\n\n[SIMILARITIES]\n\n# Ignore comments when computing similarities.\nignore-comments=yes\n\n# Ignore docstrings when computing similarities.\nignore-docstrings=yes\n\n# Ignore imports when computing similarities.\nignore-imports=no\n\n# Minimum lines number of a similarity.\nmin-similarity-lines=4\n\n\n[MISCELLANEOUS]\n\n# List of note tags to take in consideration, separated by a comma.\nnotes=FIXME,\n      XXX,\n      TODO\n\n\n[FORMAT]\n\n# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.\nexpected-line-ending-format=\n\n# Regexp for a line that is allowed to be longer than the limit.\nignore-long-lines=^\\s*(# )?<?https?://\\S+>?$\n\n# Number of spaces of indent required inside a hanging  or continued line.\nindent-after-paren=4\n\n# String used as indentation unit. This is usually \"    \" (4 spaces) or \"\\t\" (1\n# tab).\nindent-string='    '\n\n# Maximum number of characters on a single line.\nmax-line-length=100\n\n# Maximum number of lines in a module.\nmax-module-lines=1000\n\n# List of optional constructs for which whitespace checking is disabled. `dict-\n# separator` is used to allow tabulation in dicts, etc.: {1  : 1,\\n222: 2}.\n# `trailing-comma` allows a space between comma and closing bracket: (a, ).\n# `empty-line` allows space-only lines.\nno-space-check=trailing-comma,\n               dict-separator\n\n# Allow the body of a class to be on the same line as the declaration if body\n# contains single statement.\nsingle-line-class-stmt=no\n\n# Allow the body of an if to be on the same line as the test if there is no\n# else.\nsingle-line-if-stmt=no\n\n\n[BASIC]\n\n# Naming style matching correct argument names.\nargument-naming-style=snake_case\n\n# Regular expression matching correct argument names. Overrides argument-\n# naming-style.\n#argument-rgx=\n\n# Naming style matching correct attribute names.\nattr-naming-style=snake_case\n\n# Regular expression matching correct attribute names. Overrides attr-naming-\n# style.\n#attr-rgx=\n\n# Bad variable names which should always be refused, separated by a comma.\nbad-names=foo,\n          bar,\n          baz,\n          toto,\n          tutu,\n          tata\n\n# Naming style matching correct class attribute names.\nclass-attribute-naming-style=any\n\n# Regular expression matching correct class attribute names. Overrides class-\n# attribute-naming-style.\n#class-attribute-rgx=\n\n# Naming style matching correct class names.\nclass-naming-style=PascalCase\n\n# Regular expression matching correct class names. Overrides class-naming-\n# style.\n#class-rgx=\n\n# Naming style matching correct constant names.\nconst-naming-style=UPPER_CASE\n\n# Regular expression matching correct constant names. Overrides const-naming-\n# style.\n#const-rgx=\n\n# Minimum line length for functions/classes that require docstrings, shorter\n# ones are exempt.\ndocstring-min-length=-1\n\n# Naming style matching correct function names.\nfunction-naming-style=snake_case\n\n# Regular expression matching correct function names. Overrides function-\n# naming-style.\n#function-rgx=\n\n# Good variable names which should always be accepted, separated by a comma.\ngood-names=f,\n           i,\n           j,\n           k,\n           s,\n           t,\n           ex,\n           Run,\n           _\n\n# Include a hint for the correct naming format with invalid-name.\ninclude-naming-hint=no\n\n# Naming style matching correct inline iteration names.\ninlinevar-naming-style=any\n\n# Regular expression matching correct inline iteration names. Overrides\n# inlinevar-naming-style.\n#inlinevar-rgx=\n\n# Naming style matching correct method names.\nmethod-naming-style=snake_case\n\n# Regular expression matching correct method names. Overrides method-naming-\n# style.\n#method-rgx=\n\n# Naming style matching correct module names.\nmodule-naming-style=snake_case\n\n# Regular expression matching correct module names. Overrides module-naming-\n# style.\n#module-rgx=\n\n# Colon-delimited sets of names that determine each other's naming style when\n# the name regexes allow several styles.\nname-group=\n\n# Regular expression which should only match function or class names that do\n# not require a docstring.\nno-docstring-rgx=^_\n\n# List of decorators that produce properties, such as abc.abstractproperty. Add\n# to this list to register other decorators that produce valid properties.\n# These decorators are taken in consideration only for invalid-name.\nproperty-classes=abc.abstractproperty\n\n# Naming style matching correct variable names.\nvariable-naming-style=snake_case\n\n# Regular expression matching correct variable names. Overrides variable-\n# naming-style.\nvariable-rgx=_?[a-z][A-Za-z0-9_]{0,30}$\nargument-rgx=_?[a-z][A-Za-z0-9_]{0,30}$\n\n\n[TYPECHECK]\n\n# List of decorators that produce context managers, such as\n# contextlib.contextmanager. Add to this list to register other decorators that\n# produce valid context managers.\ncontextmanager-decorators=contextlib.contextmanager\n\n# List of members which are set dynamically and missed by pylint inference\n# system, and so shouldn't trigger E1101 when accessed. Python regular\n# expressions are accepted.\ngenerated-members=torch.mm,\n                  torch.diag,\n                  torch.symeig,\n                  torch.sqrt,\n                  torch.cat,\n                  cv2.cvtColor,\n                  cv2.COLOR_BGR2YUV,\n                  cv2.COLOR_YUV2BGR,\n\n# Tells whether missing members accessed in mixin class should be ignored. A\n# mixin class is detected if its name ends with \"mixin\" (case insensitive).\nignore-mixin-members=yes\n\n# Tells whether to warn about missing members when the owner of the attribute\n# is inferred to be None.\nignore-none=yes\n\n# This flag controls whether pylint should warn about no-member and similar\n# checks whenever an opaque object is returned when inferring. The inference\n# can return multiple potential results while evaluating a Python object, but\n# some branches might not be evaluated, which results in partial inference. In\n# that case, it might be useful to still emit no-member and other checks for\n# the rest of the inferred objects.\nignore-on-opaque-inference=yes\n\n# List of class names for which member attributes should not be checked (useful\n# for classes with dynamically set attributes). This supports the use of\n# qualified names.\nignored-classes=optparse.Values,thread._local,_thread._local\n\n# List of module names for which member attributes should not be checked\n# (useful for modules/projects where namespaces are manipulated during runtime\n# and thus existing member attributes cannot be deduced by static analysis. It\n# supports qualified module names, as well as Unix pattern matching.\nignored-modules=\n\n# Show a hint with possible names when a member name was not found. The aspect\n# of finding the hint is based on edit distance.\nmissing-member-hint=yes\n\n# The minimum edit distance a name should have in order to be considered a\n# similar match for a missing member name.\nmissing-member-hint-distance=1\n\n# The total number of similar names that should be taken in consideration when\n# showing a hint for a missing member.\nmissing-member-max-choices=1\n\n\n[VARIABLES]\n\n# List of additional names supposed to be defined in builtins. Remember that\n# you should avoid to define new builtins when possible.\nadditional-builtins=\n\n# Tells whether unused global variables should be treated as a violation.\nallow-global-unused-variables=yes\n\n# List of strings which can identify a callback function by name. A callback\n# name must start or end with one of those strings.\ncallbacks=cb_,\n          _cb\n\n# A regular expression matching the name of dummy variables (i.e. expected to\n# not be used).\ndummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_\n\n# Argument names that match this expression will be ignored. Default to name\n# with leading underscore.\nignored-argument-names=_.*|^ignored_|^unused_\n\n# Tells whether we should check for unused import in __init__ files.\ninit-import=no\n\n# List of qualified module names which can have objects that can redefine\n# builtins.\nredefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io\n\n\n[SPELLING]\n\n# Limits count of emitted suggestions for spelling mistakes.\nmax-spelling-suggestions=4\n\n# Spelling dictionary name. Available dictionaries: en_IE (myspell), en_ZM\n# (myspell), en_GB (myspell), en_HK (myspell), en_BZ (myspell), en_PH\n# (myspell), en_ZA (myspell), en_MW (myspell), en_AU (myspell), en_CA\n# (myspell), en_JM (myspell), en_GH (myspell), en_TT (myspell), en_SG\n# (myspell), en_BW (myspell), en_US (myspell), en_NZ (myspell), en_AG\n# (myspell), en_ZW (myspell), en_NA (myspell), en_IN (myspell), en_BS\n# (myspell), en_DK (myspell), en_NG (myspell)..\nspelling-dict=\n\n# List of comma separated words that should not be checked.\nspelling-ignore-words=\n\n# A path to a file that contains private dictionary; one word per line.\nspelling-private-dict-file=\n\n# Tells whether to store unknown words to indicated private dictionary in\n# --spelling-private-dict-file option instead of raising a message.\nspelling-store-unknown-words=no\n\n\n[IMPORTS]\n\n# Allow wildcard imports from modules that define __all__.\nallow-wildcard-with-all=no\n\n# Analyse import fallback blocks. This can be used to support both Python 2 and\n# 3 compatible code, which means that the block might have code that exists\n# only in one or another interpreter, leading to false positives when analysed.\nanalyse-fallback-blocks=no\n\n# Deprecated modules which should not be used, separated by a comma.\ndeprecated-modules=optparse,tkinter.tix\n\n# Create a graph of external dependencies in the given file (report RP0402 must\n# not be disabled).\next-import-graph=\n\n# Create a graph of every (i.e. internal and external) dependencies in the\n# given file (report RP0402 must not be disabled).\nimport-graph=\n\n# Create a graph of internal dependencies in the given file (report RP0402 must\n# not be disabled).\nint-import-graph=\n\n# Force import order to recognize a module as part of the standard\n# compatibility libraries.\nknown-standard-library=\n\n# Force import order to recognize a module as part of a third party library.\nknown-third-party=enchant\n\n\n[CLASSES]\n\n# List of method names used to declare (i.e. assign) instance attributes.\ndefining-attr-methods=__init__,\n                      __new__,\n                      setUp\n\n# List of member names, which should be excluded from the protected access\n# warning.\nexclude-protected=_asdict,\n                  _fields,\n                  _replace,\n                  _source,\n                  _make\n\n# List of valid names for the first argument in a class method.\nvalid-classmethod-first-arg=cls\n\n# List of valid names for the first argument in a metaclass class method.\nvalid-metaclass-classmethod-first-arg=cls\n\n\n[DESIGN]\n\n# Maximum number of arguments for function / method.\nmax-args=5\n\n# Maximum number of attributes for a class (see R0902).\nmax-attributes=7\n\n# Maximum number of boolean expressions in an if statement.\nmax-bool-expr=5\n\n# Maximum number of branch for function / method body.\nmax-branches=12\n\n# Maximum number of locals for function / method body.\nmax-locals=15\n\n# Maximum number of parents for a class (see R0901).\nmax-parents=7\n\n# Maximum number of public methods for a class (see R0904).\nmax-public-methods=20\n\n# Maximum number of return / yield for function / method body.\nmax-returns=6\n\n# Maximum number of statements in function / method body.\nmax-statements=50\n\n# Minimum number of public methods for a class (see R0903).\nmin-public-methods=2\n\n\n[EXCEPTIONS]\n\n# Exceptions that will emit a warning when being caught. Defaults to\n# \"Exception\".\novergeneral-exceptions=Exception\n"
  },
  {
    "path": ".travis.yml",
    "content": "sudo: false\nlanguage: python\ninstall: pip install tox\nmatrix:\n  include:\n  - python: \"3.6\"\n    env: TOX_ENV=static\n  - python: \"3.6\"\n    env: TOX_ENV=format\nscript: tox -e $TOX_ENV\n"
  },
  {
    "path": "ColorFIDBenchmarkArtistic.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Color FID Benchmark (HQ)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"os.environ['CUDA_VISIBLE_DEVICES']='1'\\n\",\n    \"os.environ['OMP_NUM_THREADS']='1'\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import statistics\\n\",\n    \"from fastai import *\\n\",\n    \"from deoldify.visualize import *\\n\",\n    \"import cv2\\n\",\n    \"from fid.fid_score import *\\n\",\n    \"from fid.inception import *\\n\",\n    \"import imageio\\n\",\n    \"plt.style.use('dark_background')\\n\",\n    \"torch.backends.cudnn.benchmark=True\\n\",\n    \"import warnings\\n\",\n    \"warnings.filterwarnings(\\\"ignore\\\", category=UserWarning, module=\\\"torch.nn.functional\\\")\\n\",\n    \"warnings.filterwarnings(\\\"ignore\\\", category=UserWarning, message='.*?retrieve source code for container of type.*?')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Setup\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#NOTE:  Data should come from here:  'https://datasets.figure-eight.com/figure_eight_datasets/open-images/test_challenge.zip'\\n\",\n    \"#NOTE:  Minimum recommmended number of samples is 10K.  Source:  https://github.com/bioinf-jku/TTUR\\n\",\n    \"\\n\",\n    \"path = Path('data/ColorBenchmark')\\n\",\n    \"path_hr = path/'source'\\n\",\n    \"path_lr = path/'bandw'\\n\",\n    \"path_results = Path('./result_images/ColorBenchmarkFID/artistic')\\n\",\n    \"path_rendered = path_results/'rendered'\\n\",\n    \"\\n\",\n    \"#path = Path('data/DeOldifyColor')\\n\",\n    \"#path_hr = path\\n\",\n    \"#path_lr = path/'bandw'\\n\",\n    \"#path_results = Path('./result_images/ColorBenchmark/edge')\\n\",\n    \"#path_rendered = path_results/'rendered'\\n\",\n    \"\\n\",\n    \"#num_images = 2048\\n\",\n    \"#num_images = 15000\\n\",\n    \"num_images = 50000\\n\",\n    \"render_factor=35\\n\",\n    \"fid_batch_size = 4\\n\",\n    \"eval_size=299\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def inception_model(dims:int):\\n\",\n    \"    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]\\n\",\n    \"    model = InceptionV3([block_idx])\\n\",\n    \"    model.cuda()\\n\",\n    \"    return model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def create_before_images(fn,i):\\n\",\n    \"    dest = path_lr/fn.relative_to(path_hr)\\n\",\n    \"    dest.parent.mkdir(parents=True, exist_ok=True)\\n\",\n    \"    img = PIL.Image.open(fn).convert('LA').convert('RGB')\\n\",\n    \"    img.save(dest)  \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def render_images(colorizer, source_dir:Path, filtered_dir:Path, target_dir:Path, render_factor:int, num_images:int)->[(Path, Path, Path)]:\\n\",\n    \"    results = []\\n\",\n    \"    bandw_list = ImageList.from_folder(path_lr)\\n\",\n    \"    bandw_list = bandw_list[:num_images]\\n\",\n    \"\\n\",\n    \"    if len(bandw_list.items) == 0: return results\\n\",\n    \"\\n\",\n    \"    results = []\\n\",\n    \"    img_iterator = progress_bar(bandw_list.items)\\n\",\n    \"\\n\",\n    \"    for bandw_path in img_iterator:\\n\",\n    \"        target_path = target_dir/bandw_path.relative_to(source_dir)\\n\",\n    \"\\n\",\n    \"        try:\\n\",\n    \"            result_image = colorizer.get_transformed_image(path=bandw_path, render_factor=render_factor)\\n\",\n    \"            result_path = Path(str(path_results) + '/' + bandw_path.parent.name + '/' + bandw_path.name)\\n\",\n    \"            if not result_path.parent.exists():\\n\",\n    \"                result_path.parent.mkdir(parents=True, exist_ok=True)\\n\",\n    \"            result_image.save(result_path)\\n\",\n    \"            results.append((result_path, bandw_path, target_path))\\n\",\n    \"        except Exception as err:\\n\",\n    \"            print('Failed to render image.  Skipping.  Details: {0}'.format(err))\\n\",\n    \"    \\n\",\n    \"    return results \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def calculate_fid_score(render_results, bs:int, eval_size:int):\\n\",\n    \"    dims = 2048\\n\",\n    \"    cuda = True\\n\",\n    \"    model = inception_model(dims=dims)\\n\",\n    \"    rendered_paths = []\\n\",\n    \"    target_paths = []\\n\",\n    \"    \\n\",\n    \"    for render_result in render_results:\\n\",\n    \"        rendered_path, _, target_path = render_result\\n\",\n    \"        rendered_paths.append(str(rendered_path))\\n\",\n    \"        target_paths.append(str(target_path))\\n\",\n    \"        \\n\",\n    \"    rendered_m, rendered_s = calculate_activation_statistics(files=rendered_paths, model=model, batch_size=bs, dims=dims, cuda=cuda)\\n\",\n    \"    target_m, target_s = calculate_activation_statistics(files=target_paths, model=model, batch_size=bs, dims=dims, cuda=cuda)\\n\",\n    \"    fid_score = calculate_frechet_distance(rendered_m, rendered_s, target_m, target_s)\\n\",\n    \"    del model\\n\",\n    \"    return fid_score\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Create black and whites source images\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Only runs if the directory isn't already created.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not path_lr.exists():\\n\",\n    \"    il = ImageList.from_folder(path_hr)\\n\",\n    \"    parallel(create_before_images, il.items)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"path_results.parent.mkdir(parents=True, exist_ok=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Rendering\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"colorizer = get_image_colorizer(artistic=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"render_results = render_images(colorizer=colorizer, source_dir=path_lr, target_dir=path_hr, filtered_dir=path_results, render_factor=render_factor, num_images=num_images)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Colorizaton Scoring\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"fid_score = calculate_fid_score(render_results, bs=fid_batch_size, eval_size=eval_size)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print('FID Score: ' + str(fid_score))\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.0\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "ColorizeTrainingArtistic.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Artistic Model Training\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### NOTES:  \\n\",\n    \"* This is \\\"NoGAN\\\" based training, described in the DeOldify readme.\\n\",\n    \"* This model prioritizes colorful renderings.  It has higher variation in renderings at different resolutions compared to the \\\"stable\\\" model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#NOTE:  This must be the first call in order to work properly!\\n\",\n    \"from deoldify import device\\n\",\n    \"from deoldify.device_id import DeviceId\\n\",\n    \"#choices:  CPU, GPU0...GPU7\\n\",\n    \"device.set(device=DeviceId.GPU0)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import fastai\\n\",\n    \"from fastai import *\\n\",\n    \"from fastai.vision import *\\n\",\n    \"from fastai.callbacks.tensorboard import *\\n\",\n    \"from fastai.vision.gan import *\\n\",\n    \"from deoldify.generators import *\\n\",\n    \"from deoldify.critics import *\\n\",\n    \"from deoldify.dataset import *\\n\",\n    \"from deoldify.loss import *\\n\",\n    \"from deoldify.save import *\\n\",\n    \"from PIL import Image, ImageDraw, ImageFont\\n\",\n    \"from PIL import ImageFile\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Setup\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\\n\",\n    \"path_hr = path\\n\",\n    \"path_lr = path/'bandw'\\n\",\n    \"\\n\",\n    \"proj_id = 'ArtisticModel'\\n\",\n    \"\\n\",\n    \"gen_name = proj_id + '_gen'\\n\",\n    \"pre_gen_name = gen_name + '_0'\\n\",\n    \"crit_name = proj_id + '_crit'\\n\",\n    \"\\n\",\n    \"name_gen = proj_id + '_image_gen'\\n\",\n    \"path_gen = path/name_gen\\n\",\n    \"\\n\",\n    \"TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\\n\",\n    \"\\n\",\n    \"nf_factor = 1.5\\n\",\n    \"pct_start = 1e-8\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def get_data(bs:int, sz:int, keep_pct:float):\\n\",\n    \"    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \\n\",\n    \"                             random_seed=None, keep_pct=keep_pct)\\n\",\n    \"\\n\",\n    \"def get_crit_data(classes, bs, sz):\\n\",\n    \"    src = ImageList.from_folder(path, include=classes, recurse=True).split_by_rand_pct(0.1, seed=42)\\n\",\n    \"    ll = src.label_from_folder(classes=classes)\\n\",\n    \"    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\\n\",\n    \"           .databunch(bs=bs).normalize(imagenet_stats))\\n\",\n    \"    return data\\n\",\n    \"\\n\",\n    \"def create_training_images(fn,i):\\n\",\n    \"    dest = path_lr/fn.relative_to(path_hr)\\n\",\n    \"    dest.parent.mkdir(parents=True, exist_ok=True)\\n\",\n    \"    img = PIL.Image.open(fn).convert('LA').convert('RGB')\\n\",\n    \"    img.save(dest)  \\n\",\n    \"    \\n\",\n    \"def save_preds(dl):\\n\",\n    \"    i=0\\n\",\n    \"    names = dl.dataset.items\\n\",\n    \"    \\n\",\n    \"    for b in dl:\\n\",\n    \"        preds = learn_gen.pred_batch(batch=b, reconstruct=True)\\n\",\n    \"        for o in preds:\\n\",\n    \"            o.save(path_gen/names[i].name)\\n\",\n    \"            i += 1\\n\",\n    \"    \\n\",\n    \"def save_gen_images():\\n\",\n    \"    if path_gen.exists(): shutil.rmtree(path_gen)\\n\",\n    \"    path_gen.mkdir(exist_ok=True)\\n\",\n    \"    data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)\\n\",\n    \"    save_preds(data_gen.fix_dl)\\n\",\n    \"    PIL.Image.open(path_gen.ls()[0])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Create black and white training images\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Only runs if the directory isn't already created.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not path_lr.exists():\\n\",\n    \"    il = ImageList.from_folder(path_hr)\\n\",\n    \"    parallel(create_training_images, il.items)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Pre-train generator\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### NOTE\\n\",\n    \"Most of the training takes place here in pretraining for NoGAN.  The goal here is to take the generator as far as possible with conventional training, as that is much easier to control and obtain glitch-free results compared to GAN training.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 64px\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=88\\n\",\n    \"sz=64\\n\",\n    \"keep_pct=1.0\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.fit_one_cycle(1, pct_start=0.8, max_lr=slice(1e-3))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.save(pre_gen_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.unfreeze()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.fit_one_cycle(1, pct_start=pct_start,  max_lr=slice(3e-7, 3e-4))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.save(pre_gen_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 128px\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=22\\n\",\n    \"sz=128\\n\",\n    \"keep_pct=1.0\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.unfreeze()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(1e-7,1e-4))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.save(pre_gen_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 192px\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=11\\n\",\n    \"sz=192\\n\",\n    \"keep_pct=0.50\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.unfreeze()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.save(pre_gen_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Repeatable GAN Cycle\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### NOTE\\n\",\n    \"Best results so far have been based on repeating the cycle below a few times (about 5-8?), until diminishing returns are hit (no improvement in image quality).  Each time you repeat the cycle, you want to increment that old_checkpoint_num by 1 so that new check points don't overwrite the old.  \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"old_checkpoint_num = 0\\n\",\n    \"checkpoint_num = old_checkpoint_num + 1\\n\",\n    \"gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\\n\",\n    \"gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\\n\",\n    \"crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\\n\",\n    \"crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Save Generated Images\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=8\\n\",\n    \"sz=192\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"save_gen_images()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Pretrain Critic\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"##### Only need full pretraining of critic when starting from scratch.  Otherwise, just finetune!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if old_checkpoint_num == 0:\\n\",\n    \"    bs=64\\n\",\n    \"    sz=128\\n\",\n    \"    learn_gen=None\\n\",\n    \"    gc.collect()\\n\",\n    \"    data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\\n\",\n    \"    data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\\n\",\n    \"    learn_critic = colorize_crit_learner(data=data_crit, nf=256)\\n\",\n    \"    learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))\\n\",\n    \"    learn_critic.fit_one_cycle(6, 1e-3)\\n\",\n    \"    learn_critic.save(crit_old_checkpoint_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=16\\n\",\n    \"sz=192\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_critic.fit_one_cycle(4, 1e-4)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_critic.save(crit_new_checkpoint_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### GAN\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_crit=None\\n\",\n    \"learn_gen=None\\n\",\n    \"gc.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"lr=1e-5\\n\",\n    \"sz=192\\n\",\n    \"bs=9\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\\n\",\n    \"learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,2.0), show_img=False, switcher=switcher,\\n\",\n    \"                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\\n\",\n    \"learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\\n\",\n    \"learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))\\n\",\n    \"learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Instructions:  \\n\",\n    \"Find the checkpoint just before where glitches start to be introduced.  This is all very new so you may need to play around with just how far you go here with keep_pct.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)\\n\",\n    \"learn_gen.freeze_to(-1)\\n\",\n    \"learn.fit(1,lr)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.0\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "ColorizeTrainingStable.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Stable Model Training\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### NOTES:  \\n\",\n    \"* This is \\\"NoGAN\\\" based training, described in the DeOldify readme.\\n\",\n    \"* This model prioritizes stable and reliable renderings.  It does particularly well on portraits and landscapes.  It's not as colorful as the artistic model.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#NOTE:  This must be the first call in order to work properly!\\n\",\n    \"from deoldify import device\\n\",\n    \"from deoldify.device_id import DeviceId\\n\",\n    \"#choices:  CPU, GPU0...GPU7\\n\",\n    \"device.set(device=DeviceId.GPU0)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import fastai\\n\",\n    \"from fastai import *\\n\",\n    \"from fastai.vision import *\\n\",\n    \"from fastai.callbacks.tensorboard import *\\n\",\n    \"from fastai.vision.gan import *\\n\",\n    \"from deoldify.generators import *\\n\",\n    \"from deoldify.critics import *\\n\",\n    \"from deoldify.dataset import *\\n\",\n    \"from deoldify.loss import *\\n\",\n    \"from deoldify.save import *\\n\",\n    \"from PIL import Image, ImageDraw, ImageFont\\n\",\n    \"from PIL import ImageFile\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Setup\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\\n\",\n    \"path_hr = path\\n\",\n    \"path_lr = path/'bandw'\\n\",\n    \"\\n\",\n    \"proj_id = 'StableModel'\\n\",\n    \"\\n\",\n    \"gen_name = proj_id + '_gen'\\n\",\n    \"pre_gen_name = gen_name + '_0'\\n\",\n    \"crit_name = proj_id + '_crit'\\n\",\n    \"\\n\",\n    \"name_gen = proj_id + '_image_gen'\\n\",\n    \"path_gen = path/name_gen\\n\",\n    \"\\n\",\n    \"TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\\n\",\n    \"\\n\",\n    \"nf_factor = 2\\n\",\n    \"pct_start = 1e-8\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def get_data(bs:int, sz:int, keep_pct:float):\\n\",\n    \"    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \\n\",\n    \"                             random_seed=None, keep_pct=keep_pct)\\n\",\n    \"\\n\",\n    \"def get_crit_data(classes, bs, sz):\\n\",\n    \"    src = ImageList.from_folder(path, include=classes, recurse=True).split_by_rand_pct(0.1, seed=42)\\n\",\n    \"    ll = src.label_from_folder(classes=classes)\\n\",\n    \"    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\\n\",\n    \"           .databunch(bs=bs).normalize(imagenet_stats))\\n\",\n    \"    return data\\n\",\n    \"\\n\",\n    \"def create_training_images(fn,i):\\n\",\n    \"    dest = path_lr/fn.relative_to(path_hr)\\n\",\n    \"    dest.parent.mkdir(parents=True, exist_ok=True)\\n\",\n    \"    img = PIL.Image.open(fn).convert('LA').convert('RGB')\\n\",\n    \"    img.save(dest)  \\n\",\n    \"    \\n\",\n    \"def save_preds(dl):\\n\",\n    \"    i=0\\n\",\n    \"    names = dl.dataset.items\\n\",\n    \"    \\n\",\n    \"    for b in dl:\\n\",\n    \"        preds = learn_gen.pred_batch(batch=b, reconstruct=True)\\n\",\n    \"        for o in preds:\\n\",\n    \"            o.save(path_gen/names[i].name)\\n\",\n    \"            i += 1\\n\",\n    \"    \\n\",\n    \"def save_gen_images():\\n\",\n    \"    if path_gen.exists(): shutil.rmtree(path_gen)\\n\",\n    \"    path_gen.mkdir(exist_ok=True)\\n\",\n    \"    data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)\\n\",\n    \"    save_preds(data_gen.fix_dl)\\n\",\n    \"    PIL.Image.open(path_gen.ls()[0])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Create black and white training images\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Only runs if the directory isn't already created.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not path_lr.exists():\\n\",\n    \"    il = ImageList.from_folder(path_hr)\\n\",\n    \"    parallel(create_training_images, il.items)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Pre-train generator\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### NOTE\\n\",\n    \"Most of the training takes place here in pretraining for NoGAN.  The goal here is to take the generator as far as possible with conventional training, as that is much easier to control and obtain glitch-free results compared to GAN training.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 64px\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=88\\n\",\n    \"sz=64\\n\",\n    \"keep_pct=1.0\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.fit_one_cycle(1, pct_start=0.8, max_lr=slice(1e-3))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.save(pre_gen_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.unfreeze()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.fit_one_cycle(1, pct_start=pct_start,  max_lr=slice(3e-7, 3e-4))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.save(pre_gen_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 128px\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=20\\n\",\n    \"sz=128\\n\",\n    \"keep_pct=1.0\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.unfreeze()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(1e-7,1e-4))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.save(pre_gen_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 192px\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=8\\n\",\n    \"sz=192\\n\",\n    \"keep_pct=0.50\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.unfreeze()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.save(pre_gen_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Repeatable GAN Cycle\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### NOTE\\n\",\n    \"Best results so far have been based on repeating the cycle below a few times (about 5-8?), until diminishing returns are hit (no improvement in image quality).  Each time you repeat the cycle, you want to increment that old_checkpoint_num by 1 so that new check points don't overwrite the old.  \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"old_checkpoint_num = 0\\n\",\n    \"checkpoint_num = old_checkpoint_num + 1\\n\",\n    \"gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\\n\",\n    \"gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\\n\",\n    \"crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\\n\",\n    \"crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Save Generated Images\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=8\\n\",\n    \"sz=192\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"save_gen_images()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Pretrain Critic\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"##### Only need full pretraining of critic when starting from scratch.  Otherwise, just finetune!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if old_checkpoint_num == 0:\\n\",\n    \"    bs=64\\n\",\n    \"    sz=128\\n\",\n    \"    learn_gen=None\\n\",\n    \"    gc.collect()\\n\",\n    \"    data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\\n\",\n    \"    data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\\n\",\n    \"    learn_critic = colorize_crit_learner(data=data_crit, nf=256)\\n\",\n    \"    learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))\\n\",\n    \"    learn_critic.fit_one_cycle(6, 1e-3)\\n\",\n    \"    learn_critic.save(crit_old_checkpoint_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=16\\n\",\n    \"sz=192\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_critic.fit_one_cycle(4, 1e-4)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_critic.save(crit_new_checkpoint_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### GAN\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_crit=None\\n\",\n    \"learn_gen=None\\n\",\n    \"gc.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"lr=2e-5\\n\",\n    \"sz=192\\n\",\n    \"bs=5\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\\n\",\n    \"learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\\n\",\n    \"                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\\n\",\n    \"learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\\n\",\n    \"learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))\\n\",\n    \"learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Instructions:  \\n\",\n    \"Find the checkpoint just before where glitches start to be introduced.  This is all very new so you may need to play around with just how far you go here with keep_pct.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)\\n\",\n    \"learn_gen.freeze_to(-1)\\n\",\n    \"learn.fit(1,lr)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.6\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "ColorizeTrainingStableLargeBatch.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Stable Model Training (Large Batch/Limited GPU Memory Support)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## IMPORTANT: Training has -not- been verified by myself for this notebook ~jantic\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### NOTES:  \\n\",\n    \"* This is \\\"NoGAN\\\" based training, described in the DeOldify readme.\\n\",\n    \"* This model prioritizes stable and reliable renderings.  It does particularly well on portraits and landscapes.  It's not as colorful as the artistic model.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"os.environ['CUDA_VISIBLE_DEVICES']='0' \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import fastai\\n\",\n    \"from fastai import *\\n\",\n    \"from fastai.vision import *\\n\",\n    \"from fastai.callbacks.tensorboard import *\\n\",\n    \"from fastai.vision.gan import *\\n\",\n    \"from deoldify.generators import *\\n\",\n    \"from deoldify.critics import *\\n\",\n    \"from deoldify.dataset import *\\n\",\n    \"from deoldify.loss import *\\n\",\n    \"from deoldify.save import *\\n\",\n    \"from PIL import Image, ImageDraw, ImageFont\\n\",\n    \"from PIL import ImageFile\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Setup\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Activate Large Model Support for PyTorch\\n\",\n    \"This will allow us to fit the model within a GPU with smaller memory capacity (e.g. GTX 1070 8Gb).\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Large Model Support (LMS) is a feature provided in IBM Watson Machine Learning Community Edition (WML-CE) PyTorch V1.1.0 that allows the successful training of deep learning models that would otherwise exhaust GPU memory and abort with “out-of-memory” errors. LMS manages this oversubscription of GPU memory by temporarily swapping tensors to host memory when they are not needed. One or more elements of a deep learning model can lead to GPU memory exhaustion.\\n\",\n    \"\\n\",\n    \"Requires the use of IBM WML-CE (Available here: https://www.ibm.com/support/knowledgecenter/en/SS5SF7_1.6.1/welcome/welcome.html)\\n\",\n    \"\\n\",\n    \"Further Reading on PyTorch with Large Model Support: https://www.ibm.com/support/knowledgecenter/en/SS5SF7_1.6.1/navigation/wmlce_getstarted_pytorch.html\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import shutil\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Set limit of GPU used before swapping to tensors to host memory\\n\",\n    \"max_gpu_mem = 7\\n\",\n    \"\\n\",\n    \"def gb_to_bytes(gb):\\n\",\n    \"    return gb*1024*1024*1024\\n\",\n    \"\\n\",\n    \"# Enable PyTorch LMS\\n\",\n    \"torch.cuda.set.enabled_lms(True)\\n\",\n    \"# Set LMS limit\\n\",\n    \"torch.cuda.set_limit_lms(gb_to_bytes(max_gpu_mem))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Check LMS is enabled\\n\",\n    \"torch.cuda.get_enabled_lms()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Check LMS Limit has been set\\n\",\n    \"torch.cuda.get_limit_lms()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \" \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Path to Training Data\\n\",\n    \"path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\\n\",\n    \"path_hr = path\\n\",\n    \"\\n\",\n    \"# Path to Black and White images\\n\",\n    \"path_bandw = Path('/training/DeOldify')\\n\",\n    \"path_lr = path_bandw/'bandw'\\n\",\n    \"\\n\",\n    \"# Name of Model\\n\",\n    \"proj_id = 'StableModel'\\n\",\n    \"\\n\",\n    \"# Name of Generator\\n\",\n    \"gen_name = proj_id + '_gen'\\n\",\n    \"pre_gen_name = gen_name + '_0'\\n\",\n    \"\\n\",\n    \"# Name of Critic\\n\",\n    \"crit_name = proj_id + '_crit'\\n\",\n    \"\\n\",\n    \"# Name of Generated Images folder, located within the Black and White folder\\n\",\n    \"name_gen = proj_id + '_image_gen'\\n\",\n    \"path_gen = path/name_gen\\n\",\n    \"\\n\",\n    \"# Path to tensorboard data\\n\",\n    \"TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\\n\",\n    \"\\n\",\n    \"nf_factor = 2\\n\",\n    \"pct_start = 1e-8\\n\",\n    \"\\n\",\n    \"# Number of workers for DataLoader\\n\",\n    \"num_works = 2\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def get_data(bs:int, sz:int, keep_pct:float):\\n\",\n    \"    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \\n\",\n    \"                             random_seed=None, keep_pct=keep_pct, num_workers=num_works)\\n\",\n    \"\\n\",\n    \"def get_crit_data(classes, bs, sz):\\n\",\n    \"    src = ImageList.from_folder(path, include=classes, recurse=True).split_by_rand_pct(0.1, seed=42)\\n\",\n    \"    ll = src.label_from_folder(classes=classes)\\n\",\n    \"    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\\n\",\n    \"           .databunch(bs=bs).normalize(imagenet_stats))\\n\",\n    \"    return data\\n\",\n    \"\\n\",\n    \"def create_training_images(fn,i):\\n\",\n    \"    dest = path_lr/fn.relative_to(path_hr)\\n\",\n    \"    dest.parent.mkdir(parents=True, exist_ok=True)\\n\",\n    \"    img = PIL.Image.open(fn).convert('LA').convert('RGB')\\n\",\n    \"    img.save(dest)  \\n\",\n    \"    \\n\",\n    \"def save_preds(dl):\\n\",\n    \"    i=0\\n\",\n    \"    names = dl.dataset.items\\n\",\n    \"    \\n\",\n    \"    for b in dl:\\n\",\n    \"        preds = learn_gen.pred_batch(batch=b, reconstruct=True)\\n\",\n    \"        for o in preds:\\n\",\n    \"            o.save(path_gen/names[i].name)\\n\",\n    \"            i += 1\\n\",\n    \"    \\n\",\n    \"def save_gen_images():\\n\",\n    \"    if path_gen.exists(): shutil.rmtree(path_gen)\\n\",\n    \"    path_gen.mkdir(exist_ok=True)\\n\",\n    \"    data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)\\n\",\n    \"    save_preds(data_gen.fix_dl)\\n\",\n    \"    PIL.Image.open(path_gen.ls()[0])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Create black and white training images\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Only runs if the directory isn't already created.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not path_lr.exists():\\n\",\n    \"    il = ImageList.from_folder(path_hr)\\n\",\n    \"    parallel(create_training_images, il.items)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Pre-train generator\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### NOTE\\n\",\n    \"Most of the training takes place here in pretraining for NoGAN.  The goal here is to take the generator as far as possible with conventional training, as that is much easier to control and obtain glitch-free results compared to GAN training.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 64px\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=88 # This can be increased if using PyTorch LMS, training could be slower.\\n\",\n    \"sz=64\\n\",\n    \"keep_pct=1.0\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.fit_one_cycle(1, pct_start=0.8, max_lr=slice(1e-3))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.save(pre_gen_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.unfreeze()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.fit_one_cycle(1, pct_start=pct_start,  max_lr=slice(3e-7, 3e-4))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.save(pre_gen_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 128px\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=40 # This can be increased if using PyTorch LMS, training could be slower.\\n\",\n    \"sz=128\\n\",\n    \"keep_pct=1.0\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.unfreeze()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(1e-7,1e-4))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.save(pre_gen_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 192px\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=16 # This can be increased if using PyTorch LMS, training could be slower.\\n\",\n    \"sz=192\\n\",\n    \"keep_pct=0.50\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.unfreeze()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.save(pre_gen_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 256px\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=8 # This can be increased if using PyTorch LMS, training could be slower.\\n\",\n    \"sz=256\\n\",\n    \"keep_pct=0.50\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.unfreeze()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.save(pre_gen_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Repeatable GAN Cycle\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### NOTE\\n\",\n    \"Best results so far have been based on repeating the cycle below a few times (about 5-8?), until diminishing returns are hit (no improvement in image quality).  Each time you repeat the cycle, you want to increment that old_checkpoint_num by 1 so that new check points don't overwrite the old.  \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"old_checkpoint_num = 0\\n\",\n    \"checkpoint_num = old_checkpoint_num + 1\\n\",\n    \"gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\\n\",\n    \"gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\\n\",\n    \"crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\\n\",\n    \"crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Save Generated Images\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=8\\n\",\n    \"sz=256\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"save_gen_images()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Pretrain Critic\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"##### Only need full pretraining of critic when starting from scratch.  Otherwise, just finetune!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if old_checkpoint_num == 0:\\n\",\n    \"    bs=64\\n\",\n    \"    sz=128\\n\",\n    \"    learn_gen=None\\n\",\n    \"    gc.collect()\\n\",\n    \"    data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\\n\",\n    \"    data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\\n\",\n    \"    learn_critic = colorize_crit_learner(data=data_crit, nf=256)\\n\",\n    \"    learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))\\n\",\n    \"    learn_critic.fit_one_cycle(6, 1e-3)\\n\",\n    \"    learn_critic.save(crit_old_checkpoint_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=8\\n\",\n    \"sz=256\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_critic.fit_one_cycle(4, 1e-4)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_critic.save(crit_new_checkpoint_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### GAN\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_crit=None\\n\",\n    \"learn_gen=None\\n\",\n    \"gc.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"lr=2e-5\\n\",\n    \"sz=256\\n\",\n    \"bs=5\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\\n\",\n    \"learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\\n\",\n    \"                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\\n\",\n    \"learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\\n\",\n    \"learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))\\n\",\n    \"learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Instructions:  \\n\",\n    \"Find the checkpoint just before where glitches start to be introduced.  This is all very new so you may need to play around with just how far you go here with keep_pct.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)\\n\",\n    \"learn_gen.freeze_to(-1)\\n\",\n    \"learn.fit(1,lr)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.0\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "ColorizeTrainingVideo.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Video Model Training\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### NOTES:  \\n\",\n    \"* It's assumed that there's a pretrained generator from the ColorizeTrainingStable notebook available at the specified path.\\n\",\n    \"* This is \\\"NoGAN\\\" based training, described in the DeOldify readme.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#NOTE:  This must be the first call in order to work properly!\\n\",\n    \"from deoldify import device\\n\",\n    \"from deoldify.device_id import DeviceId\\n\",\n    \"#choices:  CPU, GPU0...GPU7\\n\",\n    \"device.set(device=DeviceId.GPU0)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import fastai\\n\",\n    \"from fastai import *\\n\",\n    \"from fastai.vision import *\\n\",\n    \"from fastai.callbacks.tensorboard import *\\n\",\n    \"from fastai.vision.gan import *\\n\",\n    \"from deoldify.generators import *\\n\",\n    \"from deoldify.critics import *\\n\",\n    \"from deoldify.dataset import *\\n\",\n    \"from deoldify.loss import *\\n\",\n    \"from deoldify.save import *\\n\",\n    \"from deoldify.augs import noisify \\n\",\n    \"from PIL import Image, ImageDraw, ImageFont\\n\",\n    \"from PIL import ImageFile\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Setup\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\\n\",\n    \"path_hr = path\\n\",\n    \"path_lr = path/'bandw'\\n\",\n    \"\\n\",\n    \"proj_id = 'VideoModel'\\n\",\n    \"gen_name = proj_id + '_gen'\\n\",\n    \"pre_gen_name = gen_name + '_0'\\n\",\n    \"crit_name = proj_id + '_crit'\\n\",\n    \"\\n\",\n    \"name_gen = proj_id + '_image_gen'\\n\",\n    \"path_gen = path/name_gen\\n\",\n    \"\\n\",\n    \"TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\\n\",\n    \"\\n\",\n    \"nf_factor = 2\\n\",\n    \"xtra_tfms=[noisify(p=0.8)]\\n\",\n    \"pct_start = 1e-8\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def get_data(bs:int, sz:int, keep_pct:float):\\n\",\n    \"    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \\n\",\n    \"                             random_seed=None, keep_pct=keep_pct, xtra_tfms=xtra_tfms)\\n\",\n    \"\\n\",\n    \"def get_crit_data(classes, bs, sz):\\n\",\n    \"    src = ImageList.from_folder(path, include=classes, recurse=True).split_by_rand_pct(0.1, seed=42)\\n\",\n    \"    ll = src.label_from_folder(classes=classes)\\n\",\n    \"    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\\n\",\n    \"           .databunch(bs=bs).normalize(imagenet_stats))\\n\",\n    \"    return data\\n\",\n    \"\\n\",\n    \"def create_training_images(fn,i):\\n\",\n    \"    dest = path_lr/fn.relative_to(path_hr)\\n\",\n    \"    dest.parent.mkdir(parents=True, exist_ok=True)\\n\",\n    \"    img = PIL.Image.open(fn).convert('LA').convert('RGB')\\n\",\n    \"    img.save(dest)  \\n\",\n    \"    \\n\",\n    \"def save_preds(dl):\\n\",\n    \"    i=0\\n\",\n    \"    names = dl.dataset.items\\n\",\n    \"    \\n\",\n    \"    for b in dl:\\n\",\n    \"        preds = learn_gen.pred_batch(batch=b, reconstruct=True)\\n\",\n    \"        for o in preds:\\n\",\n    \"            o.save(path_gen/names[i].name)\\n\",\n    \"            i += 1\\n\",\n    \"            \\n\",\n    \"def save_gen_images():\\n\",\n    \"    if path_gen.exists(): shutil.rmtree(path_gen)\\n\",\n    \"    path_gen.mkdir(exist_ok=True)\\n\",\n    \"    data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)\\n\",\n    \"    save_preds(data_gen.fix_dl)\\n\",\n    \"    PIL.Image.open(path_gen.ls()[0])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Create black and white training images\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Only runs if the directory isn't already created.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not path_lr.exists():\\n\",\n    \"    il = ImageList.from_folder(path_hr)\\n\",\n    \"    parallel(create_training_images, il.items)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Finetune Generator With Noise Augmented Images.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"##### This helps the generator better deal with noisy/grainy video (which is pretty normal).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=8\\n\",\n    \"sz=192\\n\",\n    \"keep_pct=0.25\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen = learn_gen.load(pre_gen_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.unfreeze()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.save(pre_gen_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Repeatable GAN Cycle\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### NOTE\\n\",\n    \"Best results so far have been based only doing a single run of the cells below (otherwise glitches are introduced that are visible in video).  \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"old_checkpoint_num = 0\\n\",\n    \"checkpoint_num = old_checkpoint_num + 1\\n\",\n    \"gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\\n\",\n    \"gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\\n\",\n    \"crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\\n\",\n    \"crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Save Generated Images\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=8\\n\",\n    \"sz=192\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"save_gen_images()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Pretrain Critic\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=16\\n\",\n    \"sz=192\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen=None\\n\",\n    \"gc.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_critic.fit_one_cycle(4, 1e-4)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_critic.save(crit_new_checkpoint_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### GAN\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_crit=None\\n\",\n    \"learn_gen=None\\n\",\n    \"gc.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"lr=5e-6\\n\",\n    \"sz=192\\n\",\n    \"bs=5\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\\n\",\n    \"learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\\n\",\n    \"                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\\n\",\n    \"learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\\n\",\n    \"learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100, stats_iters=10, loss_iters=1))\\n\",\n    \"learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Instructions:  \\n\",\n    \"Find the checkpoint just before where glitches start to be introduced.  So far this has been found at the point of iterating through 1.4% of the data when using learning rate of 1e-5, and at 2.2% of the data for 5e-6.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)\\n\",\n    \"learn_gen.freeze_to(-1)\\n\",\n    \"learn.fit(1,lr)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.6\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "ColorizeTrainingWandb.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Stable Model Training with monitoring through Weights & Biases\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### NOTES:  \\n\",\n    \"* This is \\\"NoGAN\\\" based training, described in the DeOldify readme.\\n\",\n    \"* This model prioritizes stable and reliable renderings.  It does particularly well on portraits and landscapes.  It's not as colorful as the artistic model.\\n\",\n    \"* Training with this notebook has been logged and monitored through [Weights & Biases](https://www.wandb.com/). Refer to [W&B Report](https://app.wandb.ai/borisd13/DeOldify/reports?view=borisd13%2FDeOldify).\\n\",\n    \"* It is **highly** recommended to use a 11 Go GPU to run this notebook. Anything lower will require to reduce the batch size (leading to moro instability) or use of \\\"Large Model Support\\\" from IBM WML-CE (not so easy to setup). An alternative is to rent ressources online.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Install W&B Callback\\n\",\n    \"#!pip install wandb\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#NOTE:  This must be the first call in order to work properly!\\n\",\n    \"from deoldify import device\\n\",\n    \"from deoldify.device_id import DeviceId\\n\",\n    \"#choices:  CPU, GPU0...GPU7\\n\",\n    \"device.set(device=DeviceId.GPU0)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import fastai\\n\",\n    \"from fastai import *\\n\",\n    \"from fastai.vision import *\\n\",\n    \"from fastai.vision.gan import *\\n\",\n    \"from deoldify.generators import *\\n\",\n    \"from deoldify.critics import *\\n\",\n    \"from deoldify.dataset import *\\n\",\n    \"from deoldify.loss import *\\n\",\n    \"from deoldify.save import *\\n\",\n    \"from PIL import Image, ImageDraw, ImageFont\\n\",\n    \"from PIL import ImageFile\\n\",\n    \"from torch.utils.data.sampler import RandomSampler, SequentialSampler\\n\",\n    \"from tqdm import tqdm\\n\",\n    \"import wandb\\n\",\n    \"from wandb.fastai import WandbCallback\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Setup\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Set up W&B: checks user can connect to W&B servers\\n\",\n    \"# Note: set up API key the first time\\n\",\n    \"wandb.login()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Dataset can be downloaded from https://www.kaggle.com/c/imagenet-object-localization-challenge/data\\n\",\n    \"path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\\n\",\n    \"path_hr = path\\n\",\n    \"path_lr = path/'bandw'\\n\",\n    \"\\n\",\n    \"proj_id = 'StableModel'\\n\",\n    \"\\n\",\n    \"gen_name = proj_id + '_gen'\\n\",\n    \"pre_gen_name = gen_name + '_0'\\n\",\n    \"crit_name = proj_id + '_crit'\\n\",\n    \"\\n\",\n    \"name_gen = proj_id + '_image_gen'\\n\",\n    \"path_gen = path/name_gen\\n\",\n    \"\\n\",\n    \"nf_factor = 2\\n\",\n    \"pct_start = 1e-8\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Iterating through the dataset\\n\",\n    \"\\n\",\n    \"The dataset is very large and it would take a long time to iterate through all the samples at each epoch.\\n\",\n    \"\\n\",\n    \"We use custom samplers in order to limit epochs to subsets of data while still iterating slowly through the entire dataset (epoch after epoch). This let us run the validation loop more often where we log metrics as well as prediction samples on validation data.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Reduce quantity of samples per training epoch\\n\",\n    \"# Adapted from https://forums.fast.ai/t/epochs-of-arbitrary-length/27777/10\\n\",\n    \"\\n\",\n    \"@classmethod\\n\",\n    \"def create(cls, train_ds:Dataset, valid_ds:Dataset, test_ds:Optional[Dataset]=None, path:PathOrStr='.', bs:int=64,\\n\",\n    \"            val_bs:int=None, num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None,\\n\",\n    \"            device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False, sampler=None, **dl_kwargs)->'DataBunch':\\n\",\n    \"    \\\"Create a `DataBunch` from `train_ds`, `valid_ds` and maybe `test_ds` with a batch size of `bs`. Passes `**dl_kwargs` to `DataLoader()`\\\"\\n\",\n    \"    datasets = cls._init_ds(train_ds, valid_ds, test_ds)\\n\",\n    \"    val_bs = ifnone(val_bs, bs)\\n\",\n    \"    if sampler is None: sampler = [RandomSampler] + 3*[SequentialSampler]\\n\",\n    \"    dls = [DataLoader(d, b, sampler=sa(d), drop_last=sh, num_workers=num_workers, **dl_kwargs) for d,b,sh,sa in\\n\",\n    \"            zip(datasets, (bs,val_bs,val_bs,val_bs), (True,False,False,False), sampler) if d is not None]\\n\",\n    \"    return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)\\n\",\n    \"\\n\",\n    \"ImageDataBunch.create = create\\n\",\n    \"ImageImageList._bunch = ImageDataBunch\\n\",\n    \"\\n\",\n    \"class FixedLenRandomSampler(RandomSampler):\\n\",\n    \"    def __init__(self, data_source, epoch_size):\\n\",\n    \"        super().__init__(data_source)\\n\",\n    \"        self.epoch_size = epoch_size\\n\",\n    \"        self.not_sampled = np.array([True]*len(data_source))\\n\",\n    \"    \\n\",\n    \"    @property\\n\",\n    \"    def reset_state(self): self.not_sampled[:] = True\\n\",\n    \"        \\n\",\n    \"    def __iter__(self):\\n\",\n    \"        ns = sum(self.not_sampled)\\n\",\n    \"        idx_last = []\\n\",\n    \"        if ns >= len(self):\\n\",\n    \"            idx = np.random.choice(np.where(self.not_sampled)[0], size=len(self), replace=False).tolist()\\n\",\n    \"            if ns == len(self): self.reset_state\\n\",\n    \"        else:\\n\",\n    \"            idx_last = np.where(self.not_sampled)[0].tolist()\\n\",\n    \"            self.reset_state\\n\",\n    \"            idx = np.random.choice(np.where(self.not_sampled)[0], size=len(self)-len(idx_last), replace=False).tolist()\\n\",\n    \"        self.not_sampled[idx] = False\\n\",\n    \"        idx = [*idx_last, *idx]\\n\",\n    \"        return iter(idx)\\n\",\n    \"    \\n\",\n    \"    def __len__(self):\\n\",\n    \"        return self.epoch_size\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def get_data(bs:int, sz:int, keep_pct=1.0, random_seed=None, valid_pct=0.2, epoch_size=1000):\\n\",\n    \"    \\n\",\n    \"    # Create samplers\\n\",\n    \"    train_sampler = partial(FixedLenRandomSampler, epoch_size=epoch_size)\\n\",\n    \"    samplers = [train_sampler, SequentialSampler, SequentialSampler, SequentialSampler]\\n\",\n    \"\\n\",\n    \"    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, random_seed=random_seed,\\n\",\n    \"                             keep_pct=keep_pct, samplers=samplers, valid_pct=valid_pct)\\n\",\n    \"\\n\",\n    \"# Function modified to allow use of custom samplers\\n\",\n    \"def get_colorize_data(sz:int, bs:int, crappy_path:Path, good_path:Path, random_seed:int=None,\\n\",\n    \"        keep_pct:float=1.0, num_workers:int=8, samplers=None, valid_pct=0.2, xtra_tfms=[])->ImageDataBunch:\\n\",\n    \"    src = (ImageImageList.from_folder(crappy_path, convert_mode='RGB')\\n\",\n    \"        .use_partial_data(sample_pct=keep_pct, seed=random_seed)\\n\",\n    \"        .split_by_rand_pct(valid_pct, seed=random_seed))\\n\",\n    \"    data = (src.label_from_func(lambda x: good_path/x.relative_to(crappy_path))\\n\",\n    \"        .transform(get_transforms(max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms), size=sz, tfm_y=True)\\n\",\n    \"        .databunch(bs=bs, num_workers=num_workers, sampler=samplers, no_check=True)\\n\",\n    \"        .normalize(imagenet_stats, do_y=True))\\n\",\n    \"    data.c = 3\\n\",\n    \"    return data\\n\",\n    \"\\n\",\n    \"# Function to limit amount of data in critic\\n\",\n    \"def filter_data(pct=1.0):\\n\",\n    \"    def _f(fname):\\n\",\n    \"        if 'test' in str(fname):\\n\",\n    \"            if np.random.random_sample() > pct:\\n\",\n    \"                return False\\n\",\n    \"        return True\\n\",\n    \"    return _f\\n\",\n    \"\\n\",\n    \"def get_crit_data(classes, bs, sz, pct=1.0):\\n\",\n    \"    src = ImageList.from_folder(path, include=classes, recurse=True).filter_by_func(filter_data(pct)).split_by_rand_pct(0.1)\\n\",\n    \"    ll = src.label_from_folder(classes=classes)\\n\",\n    \"    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\\n\",\n    \"           .databunch(bs=bs).normalize(imagenet_stats))\\n\",\n    \"    return data\\n\",\n    \"\\n\",\n    \"def create_training_images(fn,i):\\n\",\n    \"    dest = path_lr/fn.relative_to(path_hr)\\n\",\n    \"    dest.parent.mkdir(parents=True, exist_ok=True)\\n\",\n    \"    img = PIL.Image.open(fn).convert('LA').convert('RGB')\\n\",\n    \"    img.save(dest)  \\n\",\n    \"    \\n\",\n    \"def save_preds(dl):\\n\",\n    \"    i=0\\n\",\n    \"    names = dl.dataset.items    \\n\",\n    \"    for b in tqdm(dl):\\n\",\n    \"        preds = learn_gen.pred_batch(batch=b, reconstruct=True)\\n\",\n    \"        for o in preds:\\n\",\n    \"            o.save(path_gen/names[i].name)\\n\",\n    \"            i += 1\\n\",\n    \"    \\n\",\n    \"def save_gen_images(keep_pct):\\n\",\n    \"    if path_gen.exists(): shutil.rmtree(path_gen)\\n\",\n    \"    path_gen.mkdir(exist_ok=True)\\n\",\n    \"    data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)\\n\",\n    \"    save_preds(data_gen.fix_dl)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Create black and white training images\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Only runs if the directory isn't already created.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not path_lr.exists():\\n\",\n    \"    il = ImageList.from_folder(path_hr)\\n\",\n    \"    parallel(create_training_images, il.items)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Number of black & white images\\n\",\n    \"data_size = len(list(path_lr.rglob('*.*')))\\n\",\n    \"print('Number of black & white images:', data_size)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Pre-train generator\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### NOTE\\n\",\n    \"Most of the training takes place here in pretraining for NoGAN.  The goal here is to take the generator as far as possible with conventional training, as that is much easier to control and obtain glitch-free results compared to GAN training.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 64px\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Init logging of a new run\\n\",\n    \"wandb.init(tags=['Pre-train Gen'])  # tags are optional\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=88\\n\",\n    \"sz=64\\n\",\n    \"\\n\",\n    \"# Define target number of training/validation samples as well as number of epochs\\n\",\n    \"epoch_train_size = 100 * bs\\n\",\n    \"epoch_valid_size = 10 * bs\\n\",\n    \"valid_pct = epoch_valid_size / data_size\\n\",\n    \"number_epochs = (data_size - epoch_valid_size) // epoch_train_size\\n\",\n    \"\\n\",\n    \"# Log hyper parameters\\n\",\n    \"wandb.config.update({\\\"Step 1 - batch size\\\": bs, \\\"Step 1 - image size\\\": sz,\\n\",\n    \"                     \\\"Step 1 - epoch size\\\": epoch_train_size, \\\"Step 1 - number epochs\\\": number_epochs})\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_gen = get_data(bs=bs, sz=sz, random_seed=123, valid_pct=valid_pct, epoch_size=100*bs)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.callback_fns.append(partial(WandbCallback,\\n\",\n    \"                                      input_type='images'))  # log prediction samples\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.fit_one_cycle(number_epochs, pct_start=0.8, max_lr=slice(1e-3))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.save(pre_gen_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.unfreeze()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.fit_one_cycle(number_epochs, pct_start=pct_start,  max_lr=slice(3e-7, 3e-4))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.save(pre_gen_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 128px\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=20\\n\",\n    \"sz=128\\n\",\n    \"\\n\",\n    \"# Define target number of training/validation samples as well as number of epochs\\n\",\n    \"epoch_train_size = 100 * bs\\n\",\n    \"epoch_valid_size = 10 * bs\\n\",\n    \"valid_pct = epoch_valid_size / data_size\\n\",\n    \"number_epochs = (data_size - epoch_valid_size) // epoch_train_size\\n\",\n    \"\\n\",\n    \"# Log hyper parameters\\n\",\n    \"wandb.config.update({\\\"Step 2 - batch size\\\": bs, \\\"Step 2 - image size\\\": sz,\\n\",\n    \"                     \\\"Step 2 - epoch size\\\": epoch_train_size, \\\"Step 2 - number epochs\\\": number_epochs})\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.data = get_data(bs=bs, sz=sz, random_seed=123, valid_pct=valid_pct, epoch_size=100*bs)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.unfreeze()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.fit_one_cycle(number_epochs, pct_start=pct_start, max_lr=slice(1e-7,1e-4))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.save(pre_gen_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 192px\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=8\\n\",\n    \"sz=192\\n\",\n    \"\\n\",\n    \"# Define target number of training/validation samples as well as number of epochs\\n\",\n    \"epoch_train_size = 100 * bs\\n\",\n    \"epoch_valid_size = 10 * bs\\n\",\n    \"valid_pct = epoch_valid_size / data_size\\n\",\n    \"number_epochs = (data_size - epoch_valid_size) // epoch_train_size // 2  # Training is long - we use half of data\\n\",\n    \"\\n\",\n    \"# Log hyper parameters\\n\",\n    \"wandb.config.update({\\\"Step 3 - batch size\\\": bs, \\\"Step 3 - image size\\\": sz,\\n\",\n    \"                     \\\"Step 3 - epoch size\\\": epoch_train_size, \\\"Step 3 - number epochs\\\": number_epochs})\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.data = get_data(bs=bs, sz=sz, random_seed=123, valid_pct=valid_pct, epoch_size=100*bs)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.unfreeze()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.fit_one_cycle(number_epochs, pct_start=pct_start, max_lr=slice(5e-8,5e-5))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen.save(pre_gen_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# End logging of current session run\\n\",\n    \"# Note: this is optional and would be automatically triggered when stopping the kernel\\n\",\n    \"wandb.join()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Repeatable GAN Cycle\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### NOTE\\n\",\n    \"Best results so far have been based on repeating the cycle below a few times (about 5-8?), until diminishing returns are hit (no improvement in image quality).  Each time you repeat the cycle, you want to increment that old_checkpoint_num by 1 so that new check points don't overwrite the old.  \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"old_checkpoint_num = 0\\n\",\n    \"checkpoint_num = old_checkpoint_num + 1\\n\",\n    \"gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\\n\",\n    \"gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\\n\",\n    \"crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\\n\",\n    \"crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Save Generated Images\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=8\\n\",\n    \"sz=192\\n\",\n    \"\\n\",\n    \"# Define target number of training/validation samples as well as number of epochs\\n\",\n    \"epoch_train_size = 100 * bs\\n\",\n    \"epoch_valid_size = 10 * bs\\n\",\n    \"valid_pct = epoch_valid_size / data_size\\n\",\n    \"number_epochs = (data_size - epoch_valid_size) // epoch_train_size\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_gen = get_data(bs=bs, sz=sz, random_seed=123, valid_pct=valid_pct, epoch_size=100*bs)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"save_gen_images(0.1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Pretrain Critic\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"##### Only need full pretraining of critic when starting from scratch.  Otherwise, just finetune!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if old_checkpoint_num == 0:\\n\",\n    \"    \\n\",\n    \"    # Init logging of a new run\\n\",\n    \"    wandb.init(tags=['Pre-train Crit'])  # tags are optional\\n\",\n    \"    \\n\",\n    \"    bs=64\\n\",\n    \"    sz=128\\n\",\n    \"    learn_gen=None\\n\",\n    \"    \\n\",\n    \"    # Log hyper parameters\\n\",\n    \"    wandb.config.update({\\\"Step 1 - batch size\\\": bs, \\\"Step 1 - image size\\\": sz})\\n\",\n    \"\\n\",\n    \"    gc.collect()    \\n\",\n    \"    data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\\n\",\n    \"    data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\\n\",\n    \"    learn_crit = colorize_crit_learner(data=data_crit, nf=256)\\n\",\n    \"    learn_crit.callback_fns.append(partial(WandbCallback))  # log prediction samples\\n\",\n    \"    learn_crit.fit_one_cycle(6, 1e-3)\\n\",\n    \"    learn_crit.save(crit_old_checkpoint_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bs=16\\n\",\n    \"sz=192\\n\",\n    \"\\n\",\n    \"# Log hyper parameters\\n\",\n    \"wandb.config.update({\\\"Step 2 - batch size\\\": bs, \\\"Step 2 - image size\\\": sz})\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_crit.fit_one_cycle(4, 1e-4)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"learn_crit.save(crit_new_checkpoint_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### GAN\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# free up memory\\n\",\n    \"learn_crit=None\\n\",\n    \"learn_gen=None\\n\",\n    \"learn=None\\n\",\n    \"gc.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Set old_checkpoint_num to last iteration\\n\",\n    \"old_checkpoint_num = 0\\n\",\n    \"save_checkpoints = False\\n\",\n    \"batch_per_epoch = 200\\n\",\n    \"\\n\",\n    \"checkpoint_num = old_checkpoint_num + 1\\n\",\n    \"gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\\n\",\n    \"gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\\n\",\n    \"crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\\n\",\n    \"crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)   \\n\",\n    \"\\n\",\n    \"if False:   # need only to do it once\\n\",\n    \"        \\n\",\n    \"    # Generate data\\n\",\n    \"    print('Generating data…')\\n\",\n    \"    bs=8\\n\",\n    \"    sz=192\\n\",\n    \"    epoch_train_size = batch_per_epoch * bs\\n\",\n    \"    epoch_valid_size = batch_per_epoch * bs // 10\\n\",\n    \"    valid_pct = epoch_valid_size / data_size\\n\",\n    \"    data_gen = get_data(bs=bs, sz=sz, epoch_size=epoch_train_size, valid_pct=valid_pct)\\n\",\n    \"    learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)\\n\",\n    \"    save_gen_images(0.02)\\n\",\n    \"\\n\",\n    \"    # Pre-train critic\\n\",\n    \"    print('Pre-training critic…')\\n\",\n    \"    bs=16\\n\",\n    \"    sz=192\\n\",\n    \"\\n\",\n    \"    len_test = len(list((path / 'test').rglob('*.*')))\\n\",\n    \"    len_gen = len(list((path / name_gen).rglob('*.*')))\\n\",\n    \"    keep_test_pct = len_gen / len_test * 2\\n\",\n    \"\\n\",\n    \"    data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz, pct=keep_test_pct)\\n\",\n    \"    learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)\\n\",\n    \"    learn_crit.fit_one_cycle(1, 1e-4)\\n\",\n    \"    learn_crit.save(crit_new_checkpoint_name)\\n\",\n    \"\\n\",\n    \"# Creating GAN\\n\",\n    \"print('Creating GAN…')\\n\",\n    \"sz=192\\n\",\n    \"bs=8\\n\",\n    \"lr_GAN=2e-5\\n\",\n    \"epoch_train_size = batch_per_epoch * bs\\n\",\n    \"epoch_valid_size = batch_per_epoch * bs // 10\\n\",\n    \"valid_pct = epoch_valid_size / data_size\\n\",\n    \"len_test = len(list((path / 'test').rglob('*.*')))\\n\",\n    \"len_gen = len(list((path / name_gen).rglob('*.*')))\\n\",\n    \"keep_test_pct = len_gen / len_test * 2\\n\",\n    \"\\n\",\n    \"data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz, pct=keep_test_pct)\\n\",\n    \"learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)\\n\",\n    \"data_gen = get_data(bs=bs, sz=sz, epoch_size=epoch_train_size, valid_pct=valid_pct)\\n\",\n    \"learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)\\n\",\n    \"switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\\n\",\n    \"learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\\n\",\n    \"                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\\n\",\n    \"learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\\n\",\n    \"learn.callback_fns.append(partial(WandbCallback, input_type='images', seed=None, save_model=False))\\n\",\n    \"learn.data = get_data(bs=bs, sz=sz, epoch_size=epoch_train_size, valid_pct=valid_pct)\\n\",\n    \"\\n\",\n    \"# Start logging to W&B\\n\",\n    \"wandb.init(tags=['GAN'])\\n\",\n    \"wandb.config.update({\\\"learning rate\\\": lr_GAN})  \\n\",\n    \"\\n\",\n    \"# Run the loop until satisfied with the results\\n\",\n    \"while True:\\n\",\n    \"\\n\",\n    \"    # Current loop\\n\",\n    \"    checkpoint_num = old_checkpoint_num + 1\\n\",\n    \"    gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\\n\",\n    \"    gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\\n\",\n    \"    crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\\n\",\n    \"    crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)      \\n\",\n    \"    \\n\",\n    \"    \\n\",\n    \"    # GAN for 10 epochs between each checkpoint\\n\",\n    \"    try:\\n\",\n    \"        learn.fit(1, lr_GAN)\\n\",\n    \"    except:\\n\",\n    \"        # Sometimes we get an error for some unknown reason during callbacks\\n\",\n    \"        learn.callback_fns[-1](learn).on_epoch_end(old_checkpoint_num, None, [])\\n\",\n    \"        \\n\",\n    \"    if save_checkpoints:\\n\",\n    \"        learn_crit.save(crit_new_checkpoint_name)\\n\",\n    \"        learn_gen.save(gen_new_checkpoint_name)\\n\",\n    \"    \\n\",\n    \"    old_checkpoint_num += 1\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# End logging of current session run\\n\",\n    \"# Note: this is optional and would be automatically triggered when stopping the kernel\\n\",\n    \"wandb.join()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.6\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "ImageColorizer.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#NOTE:  This must be the first call in order to work properly!\\n\",\n    \"from deoldify import device\\n\",\n    \"from deoldify.device_id import DeviceId\\n\",\n    \"#choices:  CPU, GPU0...GPU7\\n\",\n    \"device.set(device=DeviceId.GPU0)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from deoldify.visualize import *\\n\",\n    \"plt.style.use('dark_background')\\n\",\n    \"torch.backends.cudnn.benchmark=True\\n\",\n    \"import warnings\\n\",\n    \"warnings.filterwarnings(\\\"ignore\\\", category=UserWarning, message=\\\".*?Your .*? set is empty.*?\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"NOTE:  Set artistic to False if you're having trouble getting a good render.  Chances are it will work with the Stable model. \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"colorizer = get_image_colorizer(artistic=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Instructions\\n\",\n    \"\\n\",\n    \"### source_url\\n\",\n    \"Type in a url to a direct link of an image.  Usually that means they'll end in .png, .jpg, etc.  NOTE: If you want to use your own image, you can set source_url to None and just upload the image to /test_images/ in Jupyter.  Just make sure that the source_path parameter matches the file you uploaded.\\n\",\n    \"\\n\",\n    \"### source_path\\n\",\n    \"Name this whatever sensible image path (plus extension of jpg/png/ext) you want!  Sensible means the path exists and the file exists if source_url=None.\\n\",\n    \"\\n\",\n    \"### render_factor\\n\",\n    \"The default value of 35 has been carefully chosen and should work -ok- for most scenarios (but probably won't be the -best-). This determines resolution at which the color portion of the image is rendered. Lower resolution will render faster, and colors also tend to look more vibrant. Older and lower quality images in particular will generally benefit by lowering the render factor. Higher render factors are often better for higher quality images, but the colors may get slightly washed out. \\n\",\n    \"\\n\",\n    \"### result_path\\n\",\n    \"Ditto- don't change.\\n\",\n    \"\\n\",\n    \"### How to Download a Copy\\n\",\n    \"Simply shift+right click on the displayed image and click \\\"Save Image As...\\\"!\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"## Pro Tips\\n\",\n    \"1. You can evaluate how well the image is rendered at each render_factor by using the code at the bottom (that cell under \\\"See how well render_factor values perform on a frame here\\\"). \\n\",\n    \"2. Keep in mind again that you can go up top and set artistic to False for the colorizer to use the 'Stable' model instead.  This will often tend to do better on portraits, and natural landscapes.  \\n\",\n    \"\\n\",\n    \"\\n\",\n    \"## Troubleshooting\\n\",\n    \"If you get a 'CUDA out of memory' error, you probably have the render_factor too high.  The max is 45 on 11GB video cards.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Colorize!!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#NOTE:  Max is 45 with 11GB video cards. 35 is a good default\\n\",\n    \"render_factor=35\\n\",\n    \"#NOTE:  Make source_url None to just read from file at ./video/source/[file_name] directly without modification\\n\",\n    \"source_url='https://upload.wikimedia.org/wikipedia/commons/e/e4/Raceland_Louisiana_Beer_Drinkers_Russell_Lee.jpg'\\n\",\n    \"source_path = 'test_images/image.png'\\n\",\n    \"result_path = None\\n\",\n    \"\\n\",\n    \"if source_url is not None:\\n\",\n    \"    result_path = colorizer.plot_transformed_image_from_url(url=source_url, path=source_path, render_factor=render_factor, compare=True)\\n\",\n    \"else:\\n\",\n    \"    result_path = colorizer.plot_transformed_image(path=source_path, render_factor=render_factor, compare=True)\\n\",\n    \"\\n\",\n    \"show_image_in_notebook(result_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## See how well render_factor values perform on the image here\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#for i in range(10,46):\\n\",\n    \"    #colorizer.plot_transformed_image(source_path, render_factor=i, display_render_factor=True, figsize=(10,10))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.6\"\n  },\n  \"toc\": {\n   \"colors\": {\n    \"hover_highlight\": \"#DAA520\",\n    \"navigate_num\": \"#000000\",\n    \"navigate_text\": \"#333333\",\n    \"running_highlight\": \"#FF0000\",\n    \"selected_highlight\": \"#FFD700\",\n    \"sidebar_border\": \"#EEEEEE\",\n    \"wrapper_background\": \"#FFFFFF\"\n   },\n   \"moveMenuLeft\": true,\n   \"nav_menu\": {\n    \"height\": \"67px\",\n    \"width\": \"252px\"\n   },\n   \"navigate_menu\": true,\n   \"number_sections\": true,\n   \"sideBar\": true,\n   \"threshold\": 4,\n   \"toc_cell\": false,\n   \"toc_section_display\": \"block\",\n   \"toc_window_display\": false,\n   \"widenNotebook\": false\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "ImageColorizerArtisticTests.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#NOTE:  This must be the first call in order to work properly!\\n\",\n    \"from deoldify import device\\n\",\n    \"from deoldify.device_id import DeviceId\\n\",\n    \"#choices:  CPU, GPU0...GPU7\\n\",\n    \"device.set(device=DeviceId.GPU0)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from deoldify.visualize import *\\n\",\n    \"plt.style.use('dark_background')\\n\",\n    \"import warnings\\n\",\n    \"warnings.filterwarnings(\\\"ignore\\\", category=UserWarning, message=\\\".*?Your .*? set is empty.*?\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#Adjust render_factor (int) if image doesn't look quite right (max 64 on 11GB GPU).  The default here works for most photos.  \\n\",\n    \"#It literally just is a number multiplied by 16 to get the square render resolution.  \\n\",\n    \"#Note that this doesn't affect the resolution of the final output- the output is the same resolution as the input.\\n\",\n    \"#Example:  render_factor=21 => color is rendered at 16x21 = 336x336 px.  \\n\",\n    \"render_factor=35\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis = get_image_colorizer(render_factor=render_factor, artistic=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/poolparty.jpg\\\", render_factor=38, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1852GatekeepersWindsor.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Chief.jpg\\\", render_factor=14, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1850SchoolForGirls.jpg\\\", render_factor=46, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AtlanticCityBeach1905.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CottonMillWorkers1913.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BrooklynNavyYardHospital.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FinnishPeasant1867.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AtlanticCity1905.png\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PushingCart.jpg\\\", render_factor=21, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Drive1905.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/IronLung.png\\\", render_factor=21, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FamilyWithDog.jpg\\\", render_factor=21, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/DayAtSeaBelgium.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/marilyn_woods.jpg\\\", render_factor=29, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/OldWomanSweden1904.jpg\\\", render_factor=36, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WomenTapingPlanes.jpg\\\", render_factor=32, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/overmiller.jpg\\\", render_factor=13, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BritishDispatchRider.jpg\\\", render_factor=19, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/MuseauNacionalDosCoches.jpg\\\", render_factor=17, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/abe.jpg\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/RossCorbettHouseCork.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/HPLabelleOfficeMontreal.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/einstein_beach.jpg\\\", render_factor=29, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/airmen1943.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/20sWoman.jpg\\\", render_factor=22, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/egypt-1.jpg\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Rutherford_Hayes.jpg\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/einstein_portrait.jpg\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/pinkerton.jpg\\\", render_factor=13, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WaltWhitman.jpg\\\", render_factor=12, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/dorothea-lange.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Hemmingway2.jpg\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/hemmingway.jpg\\\", render_factor=9, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/smoking_kid.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/teddy_rubble.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/dustbowl_2.jpg\\\", render_factor=16, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/camera_man.jpg\\\", render_factor=23, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/migrant_mother.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/marktwain.jpg\\\", render_factor=10, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/HelenKeller.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Evelyn_Nesbit.jpg\\\", render_factor=21, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Eddie-Adams.jpg\\\", render_factor=22, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/soldier_kids.jpg\\\", render_factor=18, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AnselAdamsYosemite.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/unnamed.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/workers_canyon.jpg\\\", render_factor=48, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CottonMill.jpg\\\", render_factor=16, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/JudyGarland.jpeg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/kids_pit.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/last_samurai.jpg\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AnselAdamsWhiteChurch.jpg\\\", render_factor=21, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/opium.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/dorothea_lange_2.jpg\\\", render_factor=22, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/rgs.jpg\\\", render_factor=46, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/wh-auden.jpg\\\", render_factor=24, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/w-b-yeats.jpg\\\", render_factor=16, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/marilyn_portrait.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/wilson-slaverevivalmeeting.jpg\\\", render_factor=38, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ww1_trench.jpg\\\", render_factor=18, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/women-bikers.png\\\", render_factor=47, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Unidentified1855.jpg\\\", render_factor=32, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/skycrapper_lunch.jpg\\\", render_factor=32, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/sioux.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/school_kids.jpg\\\", render_factor=26, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/royal_family.jpg\\\", render_factor=33, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/redwood_lumberjacks.jpg\\\", render_factor=47, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/poverty.jpg\\\", render_factor=26, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/paperboy.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/NativeAmericans.jpg\\\", render_factor=22, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/helmut_newton-.jpg\\\", render_factor=43, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Greece1911.jpg\\\", render_factor=26, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FatMenClub.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/EgyptColosus.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/egypt-2.jpg\\\", render_factor=22, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/dustbowl_sd.jpg\\\", render_factor=12, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/dustbowl_people.jpg\\\", render_factor=24, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/dustbowl_5.jpg\\\", render_factor=18, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/dustbowl_1.jpg\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/DriveThroughGiantTree.jpg\\\", render_factor=39, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/covered-wagons-traveling.jpg\\\", render_factor=18, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/civil-war_2.jpg\\\", render_factor=12, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/civil_war_4.jpg\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/civil_war_3.jpg\\\", render_factor=46, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/civil_war.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BritishSlum.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/bicycles.jpg\\\", render_factor=33, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/brooklyn_girls_1940s.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/40sCouple.jpg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1946Wedding.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Dolores1920s.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/TitanicGym.jpg\\\", render_factor=31, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FrenchVillage1950s.jpg\\\", render_factor=38, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ClassDivide1930sBrittain.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1870sSphinx.jpg\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1890Surfer.png\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/TV1930s.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1864UnionSoldier.jpg\\\", render_factor=13, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1890sMedStudents.jpg\\\", render_factor=23, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BellyLaughWWI.jpg\\\", render_factor=13, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PiggyBackRide.jpg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/HealingTree.jpg\\\", render_factor=13, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ManPile.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1910Bike.jpg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FreeportIL.jpg\\\", render_factor=36, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/DutchBabyCoupleEllis.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/InuitWoman1903.png\\\", render_factor=33, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1920sDancing.jpg\\\", render_factor=16, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AirmanDad.jpg\\\", render_factor=16, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1910Racket.png\\\", render_factor=34, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1880Paris.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Deadwood1860s.jpg\\\", render_factor=38, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1860sSamauris.jpg\\\", render_factor=34, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonUnderground1860.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Mid1800sSisters.jpg\\\", render_factor=22, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1860Girls.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SanFran1851.jpg\\\", render_factor=22, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Kabuki1870s.png\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Mormons1870s.jpg\\\", render_factor=47, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/EgyptianWomenLate1800s.jpg\\\", render_factor=7, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PicadillyLate1800s.jpg\\\", render_factor=46, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SutroBaths1880s.jpg\\\", render_factor=18, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1880sBrooklynBridge.jpg\\\", render_factor=18, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ChinaOpiumc1880.jpg\\\", render_factor=43, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Locomotive1880s.jpg\\\", render_factor=10, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ViennaBoys1880s.png\\\", render_factor=19, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/VictorianDragQueen1880s.png\\\", render_factor=13, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Sami1880s.jpg\\\", render_factor=39, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Ballet1890Russia.jpg\\\", render_factor=32, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Rottindean1890s.png\\\", render_factor=22, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1890sPingPong.jpg\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/London1937.png\\\", render_factor=36, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Harlem1932.jpg\\\", render_factor=27, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/OregonTrail1870s.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/EasterNyc1911.jpg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1899NycBlizzard.jpg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Edinburgh1920s.jpg\\\", render_factor=21, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1890sShoeShopOhio.jpg\\\", render_factor=46, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1890sTouristsEgypt.png\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1938Reading.jpg\\\", render_factor=27, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1850Geography.jpg\\\", render_factor=22, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1901Electrophone.jpg\\\", render_factor=7, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Texas1938Woman.png\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/MaioreWoman1895NZ.jpg\\\", render_factor=43, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WestVirginiaHouse.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1920sGuadalope.jpg\\\", render_factor=33, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1909Chicago.jpg\\\", render_factor=14, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1920sFarmKid.jpg\\\", render_factor=12, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ParisLate1800s.jpg\\\", render_factor=18, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1900sDaytonaBeach.png\\\", render_factor=24, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1930sGeorgia.jpg\\\", render_factor=17, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/NorwegianBride1920s.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Depression.jpg\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1888Slum.jpg\\\", render_factor=32, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LivingRoom1920Sweden.jpg\\\", render_factor=46, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1896NewsBoyGirl.jpg\\\", render_factor=21, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PetDucks1927.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1899SodaFountain.jpg\\\", render_factor=46, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/TimesSquare1955.jpg\\\", render_factor=42, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PuppyGify.jpg\\\", render_factor=22, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1890CliffHouseSF.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1908FamilyPhoto.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1900sSaloon.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1890BostonHospital.jpg\\\", render_factor=19, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1870Girl.jpg\\\", render_factor=9, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AustriaHungaryWomen1890s.jpg\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Shack.jpg\\\",render_factor=43, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Apsaroke1908.png\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1948CarsGrandma.jpg\\\", render_factor=14, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PlanesManhattan1931.jpg\\\", render_factor=11, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WorriedKid1940sNyc.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1920sFamilyPhoto.jpg\\\", render_factor=13, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CatWash1931.jpg\\\", render_factor=34, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1940sBeerRiver.jpg\\\", render_factor=46, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/VictorianLivingRoom.jpg\\\", render_factor=47, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1897BlindmansBluff.jpg\\\", render_factor=23, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1874Mexico.png\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/MadisonSquare1900.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1867MusicianConstantinople.jpg\\\", render_factor=11, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1925Girl.jpg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1907Cowboys.jpg\\\", render_factor=22, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WWIIPeeps.jpg\\\", render_factor=26, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BabyBigBoots.jpg\\\", render_factor=17, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1895BikeMaidens.jpg\\\", render_factor=8, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/IrishLate1800s.jpg\\\", render_factor=13, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LibraryOfCongress1910.jpg\\\", render_factor=33, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1875Olds.jpg\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SenecaNative1908.jpg\\\", render_factor=22, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WWIHospital.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1892WaterLillies.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/GreekImmigrants1905.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FatMensShop.jpg\\\", render_factor=24, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/KidCage1930s.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FarmWomen1895.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/NewZealand1860s.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/JerseyShore1905.jpg\\\", render_factor=43, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonKidsEarly1900s.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/NYStreetClean1906.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Boston1937.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Cork1905.jpg\\\", render_factor=37, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BoxedBedEarly1900s.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ZoologischerGarten1898.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/EmpireState1930.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Agamemnon1919.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AppalachianLoggers1901.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WWISikhs.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/MementoMori1865.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/RepBrennanRadio1922.jpg\\\", render_factor=43, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Late1800sNative.jpg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/GasPrices1939.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1933RockefellerCenter.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Scotland1919.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1920CobblersShopLondon.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1909ParisFirstFemaleTaxisDriver.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/HoovervilleSeattle1932.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ElephantLondon1934.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Jane_Addams.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AnselAdamsAdobe.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CricketLondon1930.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Donegal1907Yarn.jpg\\\", render_factor=32, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AnselAdamsChurch.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BreadDelivery1920sIreland.jpg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BritishTeaBombay1890s.png\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CafeParis1928.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BigManTavern1908NYC.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Cars1890sIreland.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/GalwayIreland1902.jpg\\\", render_factor=47, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/HomeIreland1924.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/HydeParkLondon1920s.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1929LondonOverFleetSt.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AccordianKid1900Paris.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AnselAdamsBuildings.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AthleticClubParis1913.jpg\\\", render_factor=42, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BombedLibraryLondon1940.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Boston1937.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BoulevardDuTemple1838.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BumperCarsParis1930.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CafeTerrace1925Paris.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CoalDeliveryParis1915.jpg\\\", render_factor=37, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CorkKids1910.jpg\\\", render_factor=32, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/DeepSeaDiver1915.png\\\", render_factor=16, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/EastEndLondonStreetKids1901.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FreightTrainTeens1934.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/HarrodsLondon1920.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/HerbSeller1899Paris.jpg\\\", render_factor=17, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CalcuttaPoliceman1920.jpg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ElectricScooter1915.jpeg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/GreatGrandparentsIrelandEarly1900s.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/HalloweenEarly1900s.jpg\\\", render_factor=11, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/IceManLondon1919.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LeBonMarcheParis1875.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LittleAirplane1934.jpg\\\", render_factor=47, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/RoyalUniversityMedStudent1900Ireland.jpg\\\", render_factor=24, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LewisTomalinLondon1895.png\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SunHelmetsLondon1933.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Killarney1910.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonSheep1920s.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PostOfficeVermont1914.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ServantsBessboroughHouse1908Ireland.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WaterfordIreland1909.jpg\\\", render_factor=47, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Lisbon1919.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/London1918WartimeClothesManufacture.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonHeatWave1935.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonsSmallestShop1900.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/MetropolitanDistrictRailway1869London.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/NativeWoman1926.jpg\\\", render_factor=43, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PaddysMarketCork1900s.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PaddysMarketCork1900s.jpg\\\", render_factor=i, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Paris1920Cart.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ParisLadies1910.jpg\\\", render_factor=38, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ParisLadies1930s.jpg\\\", render_factor=18, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Sphinx.jpeg\\\") \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/TheatreGroupBombay1875.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WorldsFair1900Paris.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/London1850Coach.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/London1900EastEndBlacksmith.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/London1930sCheetah.jpg\\\", render_factor=42, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonFireBrigadeMember1926.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonGarbageTruck1910.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonRailwayWork1931.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonStreets1900.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/MuffinManlLondon1910.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/NativeCouple1912.jpg\\\", render_factor=21, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/NewspaperCivilWar1863.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PaddingtonStationLondon1907.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Paris1899StreetDig.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Paris1926.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ParisWomenFurs1920s.jpg\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PeddlerParis1899.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SchoolKidsConnemaraIreland1901.jpg\\\", render_factor=18, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SecondHandClothesLondonLate1800s.jpg\\\", render_factor=44, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SoapBoxRacerParis1920s.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SoccerMotorcycles1923London.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WalkingLibraryLondon1930.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonStreetDoctor1877.png\\\", render_factor=19, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/jacksonville.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ZebraCarriageLondon1900.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/StreetGramaphonePlayerLondon1920s.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/YaleBranchBarnardsExpress.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SynagogueInterior.PNG\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ArmisticeDay1918.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FlyingMachinesParis1909.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/GreatAunt1920.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/NewBrunswick1915.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ShoeMakerLate1800s.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SpottedBull1908.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/TouristsGermany1904.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/TunisianStudents1914.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Yorktown1862.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonFashion1911.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1939GypsyKids.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1936OpiumShanghai.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1923HollandTunnel.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1939YakimaWAGirl.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/GoldenGateConstruction.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PostCivilWarAncestors.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1939SewingBike.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1930MaineSchoolBus.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1913NewYorkConstruction.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1945HiroshimaChild.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1941GeorgiaFarmhouse.jpg\\\", render_factor=47, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1934UmbriaItaly.jpg\\\", render_factor=21, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1900sLadiesTeaParty.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1919WWIAviationOxygenMask.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1900NJThanksgiving.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1940Connecticut.jpg\\\", render_factor=42, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1911ThanksgivingMaskers.jpg\\\", render_factor=36, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1910ThanksgivingMaskersII.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1936PetToad.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1908RookeriesLondon.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1890sChineseImmigrants.jpg\\\", render_factor=36, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1897VancouverAmberlamps.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1929VictorianCosplayLondon.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1959ParisFriends.png\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1925GypsyCampMaryland.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1941PoolTableGeorgia.jpg\\\", render_factor=47, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1900ParkDog.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1886Hoop.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1950sLondonPoliceChild.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1886ProspectPark.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1930sRooftopPoland.jpg\\\", render_factor=37, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1919RevereBeach.jpg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1936ParisCafe.jpg\\\", render_factor=47, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1902FrenchYellowBellies.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1940PAFamily.jpg\\\", render_factor=34, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1910Finland.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ZebraCarriageLondon1900.jpg\\\", render_factor=21, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1904ChineseMan.jpg\\\", render_factor=14, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CrystalPalaceLondon1854.PNG\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/James1.jpg\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/James2.jpg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/James3.jpg\\\", render_factor=19, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/James4.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/James5.jpg\\\", render_factor=32, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/James6.jpg\\\", render_factor=28, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.6\"\n  },\n  \"toc\": {\n   \"colors\": {\n    \"hover_highlight\": \"#DAA520\",\n    \"navigate_num\": \"#000000\",\n    \"navigate_text\": \"#333333\",\n    \"running_highlight\": \"#FF0000\",\n    \"selected_highlight\": \"#FFD700\",\n    \"sidebar_border\": \"#EEEEEE\",\n    \"wrapper_background\": \"#FFFFFF\"\n   },\n   \"moveMenuLeft\": true,\n   \"nav_menu\": {\n    \"height\": \"67px\",\n    \"width\": \"252px\"\n   },\n   \"navigate_menu\": true,\n   \"number_sections\": true,\n   \"sideBar\": true,\n   \"threshold\": 4,\n   \"toc_cell\": false,\n   \"toc_section_display\": \"block\",\n   \"toc_window_display\": false,\n   \"widenNotebook\": false\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "ImageColorizerColab.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"view-in-github\"\n   },\n   \"source\": [\n    \"<a href=\\\"https://colab.research.google.com/github/jantic/DeOldify/blob/master/ImageColorizerColab.ipynb\\\" target=\\\"_parent\\\"><img src=\\\"https://colab.research.google.com/assets/colab-badge.svg\\\" alt=\\\"Open In Colab\\\"/></a>\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### **<font color='blue'> Artistic Colorizer </font>**\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"663IVxfrpIAb\"\n   },\n   \"source\": [\n    \"#◢ DeOldify - Colorize your own photos!\\n\",\n    \"\\n\",\n    \"####**Credits:**\\n\",\n    \"\\n\",\n    \"Special thanks to:\\n\",\n    \"\\n\",\n    \"Matt Robinson and María Benavente for pioneering the DeOldify image colab notebook.  \\n\",\n    \"\\n\",\n    \"Dana Kelley for doing things, breaking stuff & having an opinion on everything.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"ZjPqTBNoohK9\"\n   },\n   \"source\": [\n    \"\\n\",\n    \"\\n\",\n    \"---\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"#◢ Verify Correct Runtime Settings\\n\",\n    \"\\n\",\n    \"**<font color='#FF000'> IMPORTANT </font>**\\n\",\n    \"\\n\",\n    \"In the \\\"Runtime\\\" menu for the notebook window, select \\\"Change runtime type.\\\" Ensure that the following are selected:\\n\",\n    \"* Runtime Type = Python 3\\n\",\n    \"* Hardware Accelerator = GPU \\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"gaEJBGDlptEo\"\n   },\n   \"source\": [\n    \"#◢ Git clone and install DeOldify\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {},\n    \"colab_type\": \"code\",\n    \"id\": \"-T-svuHytJ-8\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!git clone https://github.com/jantic/DeOldify.git DeOldify \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"cd DeOldify\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"BDFjbNxaadNK\"\n   },\n   \"source\": [\n    \"#◢ Setup\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {},\n    \"colab_type\": \"code\",\n    \"id\": \"00_GcC_trpdE\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"#NOTE:  This must be the first call in order to work properly!\\n\",\n    \"from deoldify import device\\n\",\n    \"from deoldify.device_id import DeviceId\\n\",\n    \"#choices:  CPU, GPU0...GPU7\\n\",\n    \"device.set(device=DeviceId.GPU0)\\n\",\n    \"\\n\",\n    \"import torch\\n\",\n    \"\\n\",\n    \"if not torch.cuda.is_available():\\n\",\n    \"    print('GPU not available.')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {},\n    \"colab_type\": \"code\",\n    \"id\": \"Lsx7xCXNSVt6\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!pip install -r requirements-colab.txt\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {},\n    \"colab_type\": \"code\",\n    \"id\": \"MsJa69CMwj3l\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import fastai\\n\",\n    \"from deoldify.visualize import *\\n\",\n    \"import warnings\\n\",\n    \"warnings.filterwarnings(\\\"ignore\\\", category=UserWarning, message=\\\".*?Your .*? set is empty.*?\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"!mkdir 'models'\\n\",\n    \"!wget https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth -O ./models/ColorizeArtistic_gen.pth\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {},\n    \"colab_type\": \"code\",\n    \"id\": \"tzHVnegp21hC\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"colorizer = get_image_colorizer(artistic=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"BDFjbNxaadNJ\"\n   },\n   \"source\": [\n    \"#◢ Instructions\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### source_url\\n\",\n    \"Type in a url to a direct link of an image.  Usually that means they'll end in .png, .jpg, etc. NOTE: If you want to use your own image, upload it first to a site like Imgur. \\n\",\n    \"\\n\",\n    \"### render_factor\\n\",\n    \"The default value of 35 has been carefully chosen and should work -ok- for most scenarios (but probably won't be the -best-). This determines resolution at which the color portion of the image is rendered. Lower resolution will render faster, and colors also tend to look more vibrant. Older and lower quality images in particular will generally benefit by lowering the render factor. Higher render factors are often better for higher quality images, but the colors may get slightly washed out. \\n\",\n    \"\\n\",\n    \"### watermarked\\n\",\n    \"Selected by default, this places a watermark icon of a palette at the bottom left corner of the image.  This is intended to be a standard way to convey to others viewing the image that it is colorized by AI. We want to help promote this as a standard, especially as the technology continues to improve and the distinction between real and fake becomes harder to discern. This palette watermark practice was initiated and lead by the company MyHeritage in the MyHeritage In Color feature (which uses a newer version of DeOldify than what you're using here).\\n\",\n    \"\\n\",\n    \"#### How to Download a Copy\\n\",\n    \"Simply right click on the displayed image and click \\\"Save image as...\\\"!\\n\",\n    \"\\n\",\n    \"## Pro Tips\\n\",\n    \"\\n\",\n    \"You can evaluate how well the image is rendered at each render_factor by using the code at the bottom (that cell under \\\"See how well render_factor values perform on a frame here\\\"). \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"sUQrbSYipiJn\"\n   },\n   \"source\": [\n    \"#◢ Colorize!!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"source_url = '' #@param {type:\\\"string\\\"}\\n\",\n    \"render_factor = 35  #@param {type: \\\"slider\\\", min: 7, max: 40}\\n\",\n    \"watermarked = True #@param {type:\\\"boolean\\\"}\\n\",\n    \"\\n\",\n    \"if source_url is not None and source_url !='':\\n\",\n    \"    image_path = colorizer.plot_transformed_image_from_url(url=source_url, render_factor=render_factor, compare=True, watermarked=watermarked)\\n\",\n    \"    show_image_in_notebook(image_path)\\n\",\n    \"else:\\n\",\n    \"    print('Provide an image url and try again.')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## See how well render_factor values perform on the image here\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"for i in range(10,40,2):\\n\",\n    \"    colorizer.plot_transformed_image('test_images/image.png', render_factor=i, display_render_factor=True, figsize=(8,8))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"X7Ycv_Y9xAHp\"\n   },\n   \"source\": [\n    \"---\\n\",\n    \"#⚙ Recommended image sources \\n\",\n    \"* [/r/TheWayWeWere](https://www.reddit.com/r/TheWayWeWere/)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"accelerator\": \"GPU\",\n  \"colab\": {\n   \"collapsed_sections\": [],\n   \"name\": \"ImageColorizerColab.ipynb\",\n   \"provenance\": [],\n   \"toc_visible\": true,\n   \"version\": \"0.3.2\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.6\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "ImageColorizerColabStable.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"view-in-github\"\n   },\n   \"source\": [\n    \"<a href=\\\"https://colab.research.google.com/github/jantic/DeOldify/blob/master/ImageColorizerColabStable.ipynb\\\" target=\\\"_parent\\\"><img src=\\\"https://colab.research.google.com/assets/colab-badge.svg\\\" alt=\\\"Open In Colab\\\"/></a>\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### **<font color='blue'> Stable Colorizer </font>**\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"663IVxfrpIAb\"\n   },\n   \"source\": [\n    \"#◢ DeOldify - Colorize your own photos!\\n\",\n    \"\\n\",\n    \"####**Credits:**\\n\",\n    \"\\n\",\n    \"Special thanks to:\\n\",\n    \"\\n\",\n    \"Matt Robinson and María Benavente for pioneering the DeOldify image colab notebook.  \\n\",\n    \"\\n\",\n    \"Dana Kelley for doing things, breaking stuff & having an opinion on everything.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"ZjPqTBNoohK9\"\n   },\n   \"source\": [\n    \"\\n\",\n    \"\\n\",\n    \"---\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"#◢ Verify Correct Runtime Settings\\n\",\n    \"\\n\",\n    \"**<font color='#FF000'> IMPORTANT </font>**\\n\",\n    \"\\n\",\n    \"In the \\\"Runtime\\\" menu for the notebook window, select \\\"Change runtime type.\\\" Ensure that the following are selected:\\n\",\n    \"* Runtime Type = Python 3\\n\",\n    \"* Hardware Accelerator = GPU \\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"gaEJBGDlptEo\"\n   },\n   \"source\": [\n    \"#◢ Git clone and install DeOldify\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {},\n    \"colab_type\": \"code\",\n    \"id\": \"-T-svuHytJ-8\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!git clone https://github.com/jantic/DeOldify.git DeOldify \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"cd DeOldify\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"BDFjbNxaadNK\"\n   },\n   \"source\": [\n    \"#◢ Setup\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {},\n    \"colab_type\": \"code\",\n    \"id\": \"00_GcC_trpdE\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"#NOTE:  This must be the first call in order to work properly!\\n\",\n    \"from deoldify import device\\n\",\n    \"from deoldify.device_id import DeviceId\\n\",\n    \"#choices:  CPU, GPU0...GPU7\\n\",\n    \"device.set(device=DeviceId.GPU0)\\n\",\n    \"\\n\",\n    \"import torch\\n\",\n    \"\\n\",\n    \"if not torch.cuda.is_available():\\n\",\n    \"    print('GPU not available.')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {},\n    \"colab_type\": \"code\",\n    \"id\": \"Lsx7xCXNSVt6\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!pip install -r requirements-colab.txt\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {},\n    \"colab_type\": \"code\",\n    \"id\": \"MsJa69CMwj3l\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import fastai\\n\",\n    \"from deoldify.visualize import *\\n\",\n    \"\\n\",\n    \"torch.backends.cudnn.benchmark = True\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"!mkdir 'models'\\n\",\n    \"!wget https://www.dropbox.com/s/axsd2g85uyixaho/ColorizeStable_gen.pth?dl=0 -O ./models/ColorizeStable_gen.pth\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {},\n    \"colab_type\": \"code\",\n    \"id\": \"tzHVnegp21hC\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"colorizer = get_image_colorizer(artistic=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"BDFjbNxaadNJ\"\n   },\n   \"source\": [\n    \"#◢ Instructions\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### source_url\\n\",\n    \"Type in a url to a direct link of an image.  Usually that means they'll end in .png, .jpg, etc. NOTE: If you want to use your own image, upload it first to a site like Imgur. \\n\",\n    \"\\n\",\n    \"### render_factor\\n\",\n    \"The default value of 35 has been carefully chosen and should work -ok- for most scenarios (but probably won't be the -best-). This determines resolution at which the color portion of the image is rendered. Lower resolution will render faster, and colors also tend to look more vibrant. Older and lower quality images in particular will generally benefit by lowering the render factor. Higher render factors are often better for higher quality images, but the colors may get slightly washed out. \\n\",\n    \"\\n\",\n    \"### watermarked\\n\",\n    \"Selected by default, this places a watermark icon of a palette at the bottom left corner of the image.  This is intended to be a standard way to convey to others viewing the image that it is colorized by AI. We want to help promote this as a standard, especially as the technology continues to improve and the distinction between real and fake becomes harder to discern. This palette watermark practice was initiated and lead by the company MyHeritage in the MyHeritage In Color feature (which uses a newer version of DeOldify than what you're using here).\\n\",\n    \"\\n\",\n    \"#### How to Download a Copy\\n\",\n    \"Simply right click on the displayed image and click \\\"Save image as...\\\"!\\n\",\n    \"\\n\",\n    \"## Pro Tips\\n\",\n    \"\\n\",\n    \"You can evaluate how well the image is rendered at each render_factor by using the code at the bottom (that cell under \\\"See how well render_factor values perform on a frame here\\\"). \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"sUQrbSYipiJn\"\n   },\n   \"source\": [\n    \"#◢ Colorize!!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"source_url = '' #@param {type:\\\"string\\\"}\\n\",\n    \"render_factor = 35  #@param {type: \\\"slider\\\", min: 7, max: 40}\\n\",\n    \"watermarked = True #@param {type:\\\"boolean\\\"}\\n\",\n    \"\\n\",\n    \"if source_url is not None and source_url !='':\\n\",\n    \"    image_path = colorizer.plot_transformed_image_from_url(url=source_url, render_factor=render_factor, compare=True, watermarked=watermarked)\\n\",\n    \"    show_image_in_notebook(image_path)\\n\",\n    \"else:\\n\",\n    \"    print('Provide an image url and try again.')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## See how well render_factor values perform on the image here\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"for i in range(10,40,2):\\n\",\n    \"    colorizer.plot_transformed_image('test_images/image.png', render_factor=i, display_render_factor=True, figsize=(8,8))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"X7Ycv_Y9xAHp\"\n   },\n   \"source\": [\n    \"---\\n\",\n    \"#⚙ Recommended image sources \\n\",\n    \"* [/r/TheWayWeWere](https://www.reddit.com/r/TheWayWeWere/)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"accelerator\": \"GPU\",\n  \"colab\": {\n   \"collapsed_sections\": [],\n   \"name\": \"ImageColorizerColabStable.ipynb\",\n   \"provenance\": [],\n   \"toc_visible\": true,\n   \"version\": \"0.3.2\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.6\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "ImageColorizerStableTests.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#NOTE:  This must be the first call in order to work properly!\\n\",\n    \"from deoldify import device\\n\",\n    \"from deoldify.device_id import DeviceId\\n\",\n    \"#choices:  CPU, GPU0...GPU7\\n\",\n    \"device.set(device=DeviceId.GPU0)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from deoldify.visualize import *\\n\",\n    \"plt.style.use('dark_background')\\n\",\n    \"import warnings\\n\",\n    \"warnings.filterwarnings(\\\"ignore\\\", category=UserWarning, message=\\\".*?Your .*? set is empty.*?\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#Adjust render_factor (int) if image doesn't look quite right (max 64 on 11GB GPU).  The default here works for most photos.  \\n\",\n    \"#It literally just is a number multiplied by 16 to get the square render resolution.  \\n\",\n    \"#Note that this doesn't affect the resolution of the final output- the output is the same resolution as the input.\\n\",\n    \"#Example:  render_factor=21 => color is rendered at 16x21 = 336x336 px.  \\n\",\n    \"render_factor=35\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis = get_image_colorizer(render_factor=render_factor, artistic=False)\\n\",\n    \"#vis = get_video_colorizer(render_factor=render_factor).vis\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/poolparty.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1852GatekeepersWindsor.jpg\\\", render_factor=44, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Chief.jpg\\\", render_factor=10, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1850SchoolForGirls.jpg\\\", render_factor=42, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AtlanticCityBeach1905.jpg\\\", render_factor=32, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CottonMillWorkers1913.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BrooklynNavyYardHospital.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FinnishPeasant1867.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AtlanticCity1905.png\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PushingCart.jpg\\\", render_factor=24, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Drive1905.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/IronLung.png\\\", render_factor=26, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FamilyWithDog.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/DayAtSeaBelgium.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/marilyn_woods.jpg\\\", render_factor=16, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/OldWomanSweden1904.jpg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WomenTapingPlanes.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/overmiller.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BritishDispatchRider.jpg\\\", render_factor=16, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/MuseauNacionalDosCoches.jpg\\\", render_factor=19, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/abe.jpg\\\", render_factor=13, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/RossCorbettHouseCork.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/HPLabelleOfficeMontreal.jpg\\\", render_factor=44, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/einstein_beach.jpg\\\", render_factor=32, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/airmen1943.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/20sWoman.jpg\\\", render_factor=24, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/egypt-1.jpg\\\", render_factor=18, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Rutherford_Hayes.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/einstein_portrait.jpg\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/pinkerton.jpg\\\", render_factor=7, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WaltWhitman.jpg\\\", render_factor=9, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/dorothea-lange.jpg\\\", render_factor=18, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Hemmingway2.jpg\\\", render_factor=22, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/hemmingway.jpg\\\", render_factor=14, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/smoking_kid.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/teddy_rubble.jpg\\\", render_factor=42, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/dustbowl_2.jpg\\\", render_factor=16, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/camera_man.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/migrant_mother.jpg\\\", render_factor=32, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/marktwain.jpg\\\", render_factor=14, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/HelenKeller.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Evelyn_Nesbit.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Eddie-Adams.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/soldier_kids.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AnselAdamsYosemite.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/unnamed.jpg\\\", render_factor=28, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/workers_canyon.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CottonMill.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/JudyGarland.jpeg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/kids_pit.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/last_samurai.jpg\\\", render_factor=22, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AnselAdamsWhiteChurch.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/opium.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/dorothea_lange_2.jpg\\\", render_factor=42, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/rgs.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/wh-auden.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/w-b-yeats.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/marilyn_portrait.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/wilson-slaverevivalmeeting.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ww1_trench.jpg\\\", render_factor=18, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/women-bikers.png\\\", render_factor=23, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Unidentified1855.jpg\\\", render_factor=19, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/skycrapper_lunch.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/sioux.jpg\\\", render_factor=28, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/school_kids.jpg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/royal_family.jpg\\\", render_factor=42, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/redwood_lumberjacks.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/poverty.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/paperboy.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/NativeAmericans.jpg\\\", render_factor=21, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/helmut_newton-.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Greece1911.jpg\\\", render_factor=44, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FatMenClub.jpg\\\", render_factor=18, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/EgyptColosus.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/egypt-2.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/dustbowl_sd.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/dustbowl_people.jpg\\\", render_factor=24, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/dustbowl_5.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/dustbowl_1.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/DriveThroughGiantTree.jpg\\\", render_factor=21, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/covered-wagons-traveling.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/civil-war_2.jpg\\\", render_factor=42, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/civil_war_4.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/civil_war_3.jpg\\\", render_factor=28, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/civil_war.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BritishSlum.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/bicycles.jpg\\\", render_factor=27, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/brooklyn_girls_1940s.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/40sCouple.jpg\\\", render_factor=21, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1946Wedding.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Dolores1920s.jpg\\\", render_factor=18, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/TitanicGym.jpg\\\", render_factor=26, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FrenchVillage1950s.jpg\\\", render_factor=41, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FrenchVillage1950s.jpg\\\", render_factor=32, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ClassDivide1930sBrittain.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1870sSphinx.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1890Surfer.png\\\", render_factor=37, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/TV1930s.jpg\\\", render_factor=43, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1864UnionSoldier.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1890sMedStudents.jpg\\\", render_factor=18, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BellyLaughWWI.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PiggyBackRide.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/HealingTree.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ManPile.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1910Bike.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FreeportIL.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/DutchBabyCoupleEllis.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/InuitWoman1903.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1920sDancing.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AirmanDad.jpg\\\", render_factor=13, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1910Racket.png\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1880Paris.jpg\\\", render_factor=16, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Deadwood1860s.jpg\\\", render_factor=13, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1860sSamauris.jpg\\\", render_factor=43, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonUnderground1860.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Mid1800sSisters.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1860Girls.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SanFran1851.jpg\\\", render_factor=44, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Kabuki1870s.png\\\", render_factor=8, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Mormons1870s.jpg\\\", render_factor=44, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/EgyptianWomenLate1800s.jpg\\\", render_factor=44, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PicadillyLate1800s.jpg\\\", render_factor=26, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SutroBaths1880s.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1880sBrooklynBridge.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ChinaOpiumc1880.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Locomotive1880s.jpg\\\", render_factor=9, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ViennaBoys1880s.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/VictorianDragQueen1880s.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Sami1880s.jpg\\\", render_factor=44, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ArkansasCowboys1880s.jpg\\\", render_factor=22, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Ballet1890Russia.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Rottindean1890s.png\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1890sPingPong.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/London1937.png\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Harlem1932.jpg\\\", render_factor=37, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/OregonTrail1870s.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/EasterNyc1911.jpg\\\", render_factor=19, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1899NycBlizzard.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Edinburgh1920s.jpg\\\", render_factor=17, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1890sShoeShopOhio.jpg\\\", render_factor=46, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1890sTouristsEgypt.png\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1938Reading.jpg\\\", render_factor=19, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1850Geography.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1901Electrophone.jpg\\\", render_factor=10, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"for i in range(8, 47):\\n\",\n    \"    vis.plot_transformed_image(\\\"test_images/1901Electrophone.jpg\\\", render_factor=i, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Texas1938Woman.png\\\", render_factor=38, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/MaioreWoman1895NZ.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WestVirginiaHouse.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1920sGuadalope.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1909Chicago.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1920sFarmKid.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ParisLate1800s.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1900sDaytonaBeach.png\\\", render_factor=23, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1930sGeorgia.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/NorwegianBride1920s.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Depression.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1888Slum.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LivingRoom1920Sweden.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1896NewsBoyGirl.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PetDucks1927.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1899SodaFountain.jpg\\\", render_factor=46, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/TimesSquare1955.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PuppyGify.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1890CliffHouseSF.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1908FamilyPhoto.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1900sSaloon.jpg\\\", render_factor=43, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1890BostonHospital.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1870Girl.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AustriaHungaryWomen1890s.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Shack.jpg\\\",render_factor=42, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Apsaroke1908.png\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1948CarsGrandma.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PlanesManhattan1931.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WorriedKid1940sNyc.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1920sFamilyPhoto.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CatWash1931.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1940sBeerRiver.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/VictorianLivingRoom.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1897BlindmansBluff.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1874Mexico.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/MadisonSquare1900.jpg\\\", render_factor=46, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1867MusicianConstantinople.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1925Girl.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1907Cowboys.jpg\\\", render_factor=28, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WWIIPeeps.jpg\\\", render_factor=37, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BabyBigBoots.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1895BikeMaidens.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/IrishLate1800s.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LibraryOfCongress1910.jpg\\\", render_factor=21, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1875Olds.jpg\\\", render_factor=16, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SenecaNative1908.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WWIHospital.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1892WaterLillies.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/GreekImmigrants1905.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FatMensShop.jpg\\\", render_factor=21, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/KidCage1930s.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FarmWomen1895.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/NewZealand1860s.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/JerseyShore1905.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonKidsEarly1900s.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/NYStreetClean1906.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Boston1937.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Cork1905.jpg\\\", render_factor=28, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BoxedBedEarly1900s.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ZoologischerGarten1898.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/EmpireState1930.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Agamemnon1919.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AppalachianLoggers1901.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WWISikhs.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/MementoMori1865.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/RepBrennanRadio1922.jpg\\\", render_factor=43, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Late1800sNative.jpg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/GasPrices1939.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1933RockefellerCenter.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Scotland1919.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1920CobblersShopLondon.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1909ParisFirstFemaleTaxisDriver.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/HoovervilleSeattle1932.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ElephantLondon1934.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Jane_Addams.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AnselAdamsAdobe.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CricketLondon1930.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Donegal1907Yarn.jpg\\\", render_factor=32, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AnselAdamsChurch.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BreadDelivery1920sIreland.jpg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BritishTeaBombay1890s.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CafeParis1928.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BigManTavern1908NYC.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Cars1890sIreland.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/GalwayIreland1902.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/HomeIreland1924.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/HydeParkLondon1920s.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1929LondonOverFleetSt.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AccordianKid1900Paris.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AnselAdamsBuildings.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/AthleticClubParis1913.jpg\\\", render_factor=42, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BombedLibraryLondon1940.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Boston1937.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BoulevardDuTemple1838.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/BumperCarsParis1930.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CafeTerrace1925Paris.jpg\\\", render_factor=24, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CoalDeliveryParis1915.jpg\\\", render_factor=37, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CorkKids1910.jpg\\\", render_factor=32, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/DeepSeaDiver1915.png\\\", render_factor=16, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/EastEndLondonStreetKids1901.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FreightTrainTeens1934.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/HarrodsLondon1920.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/HerbSeller1899Paris.jpg\\\", render_factor=17, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CalcuttaPoliceman1920.jpg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ElectricScooter1915.jpeg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/GreatGrandparentsIrelandEarly1900s.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/HalloweenEarly1900s.jpg\\\", render_factor=11, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/IceManLondon1919.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LeBonMarcheParis1875.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LittleAirplane1934.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/RoyalUniversityMedStudent1900Ireland.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LewisTomalinLondon1895.png\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SunHelmetsLondon1933.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Killarney1910.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonSheep1920s.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PostOfficeVermont1914.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ServantsBessboroughHouse1908Ireland.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WaterfordIreland1909.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Lisbon1919.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/London1918WartimeClothesManufacture.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonHeatWave1935.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonsSmallestShop1900.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/MetropolitanDistrictRailway1869London.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/NativeWoman1926.jpg\\\", render_factor=21, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PaddysMarketCork1900s.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Paris1920Cart.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ParisLadies1910.jpg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ParisLadies1930s.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Sphinx.jpeg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/TheatreGroupBombay1875.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WorldsFair1900Paris.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/London1850Coach.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/London1900EastEndBlacksmith.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/London1930sCheetah.jpg\\\", render_factor=42, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonFireBrigadeMember1926.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonGarbageTruck1910.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonRailwayWork1931.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonStreets1900.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/MuffinManlLondon1910.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/NativeCouple1912.jpg\\\", render_factor=21, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/NewspaperCivilWar1863.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PaddingtonStationLondon1907.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Paris1899StreetDig.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Paris1926.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ParisWomenFurs1920s.jpg\\\", render_factor=21, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PeddlerParis1899.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SchoolKidsConnemaraIreland1901.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SecondHandClothesLondonLate1800s.jpg\\\", render_factor=33, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SoapBoxRacerParis1920s.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SoccerMotorcycles1923London.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/WalkingLibraryLondon1930.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonStreetDoctor1877.png\\\", render_factor=38, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/jacksonville.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ZebraCarriageLondon1900.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/StreetGramaphonePlayerLondon1920s.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/YaleBranchBarnardsExpress.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SynagogueInterior.PNG\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ArmisticeDay1918.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/FlyingMachinesParis1909.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/GreatAunt1920.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/NewBrunswick1915.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ShoeMakerLate1800s.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/SpottedBull1908.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/TouristsGermany1904.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/TunisianStudents1914.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/Yorktown1862.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/LondonFashion1911.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1939GypsyKids.jpg\\\", render_factor=37, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1936OpiumShanghai.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1923HollandTunnel.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1939YakimaWAGirl.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/GoldenGateConstruction.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/PostCivilWarAncestors.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1939SewingBike.png\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1930MaineSchoolBus.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1913NewYorkConstruction.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1945HiroshimaChild.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1941GeorgiaFarmhouse.jpg\\\", render_factor=43, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1934UmbriaItaly.jpg\\\", render_factor=21) \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1900sLadiesTeaParty.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1919WWIAviationOxygenMask.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1900NJThanksgiving.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1940Connecticut.jpg\\\", render_factor=43, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1940Connecticut.jpg\\\", render_factor=i, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1911ThanksgivingMaskers.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1910ThanksgivingMaskersII.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1936PetToad.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1908RookeriesLondon.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1890sChineseImmigrants.jpg\\\", render_factor=25, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1897VancouverAmberlamps.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1929VictorianCosplayLondon.jpg\\\", render_factor=35, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1959ParisFriends.png\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1925GypsyCampMaryland.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1941PoolTableGeorgia.jpg\\\", render_factor=45, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1900ParkDog.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1886Hoop.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1950sLondonPoliceChild.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1886ProspectPark.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1930sRooftopPoland.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1919RevereBeach.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1936ParisCafe.jpg\\\", render_factor=46, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1902FrenchYellowBellies.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1940PAFamily.jpg\\\", render_factor=42, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1910Finland.jpg\\\", render_factor=40, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/ZebraCarriageLondon1900.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/1904ChineseMan.jpg\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/CrystalPalaceLondon1854.PNG\\\", compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/James1.jpg\\\", render_factor=15, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/James2.jpg\\\", render_factor=20, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/James3.jpg\\\", render_factor=19, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/James4.jpg\\\", render_factor=30, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/James5.jpg\\\", render_factor=32, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis.plot_transformed_image(\\\"test_images/James6.jpg\\\", render_factor=28, compare=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.6\"\n  },\n  \"toc\": {\n   \"colors\": {\n    \"hover_highlight\": \"#DAA520\",\n    \"navigate_num\": \"#000000\",\n    \"navigate_text\": \"#333333\",\n    \"running_highlight\": \"#FF0000\",\n    \"selected_highlight\": \"#FFD700\",\n    \"sidebar_border\": \"#EEEEEE\",\n    \"wrapper_background\": \"#FFFFFF\"\n   },\n   \"moveMenuLeft\": true,\n   \"nav_menu\": {\n    \"height\": \"67px\",\n    \"width\": \"252px\"\n   },\n   \"navigate_menu\": true,\n   \"number_sections\": true,\n   \"sideBar\": true,\n   \"threshold\": 4,\n   \"toc_cell\": false,\n   \"toc_section_display\": \"block\",\n   \"toc_window_display\": false,\n   \"widenNotebook\": false\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2018 Jason Antic\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "MANIFEST.in",
    "content": "include README.md\ninclude LICENSE\ninclude requirements.txt\n"
  },
  {
    "path": "README.md",
    "content": "\n# DeOldify\n\n**This Reposisitory is Archived**  This project was a wild ride since I started it back in 2018.  6 years ago as of this writing (October 19, 2024)!.  It's time for me to move on and put this repo in the archives as I simply don't have the time to attend to it anymore, and frankly it's ancient as far as deep-learning projects go at this point! ~Jason\n\n**Quick Start**: The easiest way to colorize images using open source DeOldify\n(for free!) is here: [DeOldify Image Colorization on DeepAI](https://deepai.org/machine-learning-model/colorizer)\n\n**Desktop**: Want to run open source DeOldify for photos and videos on the desktop?\n* Stable Diffusion Web UI Plugin- Photos and video, cross-platform (NEW!). <https://github.com/SpenserCai/sd-webui-deoldify>\n* ColorfulSoft Windows GUI- No GPU required! Photos/Windows only. <https://github.com/ColorfulSoft/DeOldify.NET>.\nNo GPU required!\n\n**In Browser (new!)**  Check out this Onnx-based in browser implementation:  https://github.com/akbartus/DeOldify-on-Browser\n\nThe **most advanced** version of DeOldify image colorization is available here,\nexclusively.  Try a few images for free! [MyHeritage In Color](https://www.myheritage.com/incolor)\n\n**Replicate:** Image: <a href=\"https://replicate.com/arielreplicate/deoldify_image\"><img src=\"https://replicate.com/arielreplicate/deoldify_image/badge\"></a> | Video: <a href=\"https://replicate.com/arielreplicate/deoldify_video\"><img src=\"https://replicate.com/arielreplicate/deoldify_video/badge\"></a>\n\n----------------------------\n\nImage (artistic) [![Colab for images](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jantic/DeOldify/blob/master/ImageColorizerColab.ipynb)\n| Video [![Colab for video](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jantic/DeOldify/blob/master/VideoColorizerColab.ipynb)\n\nHaving trouble with the default image colorizer, aka \"artistic\"?  Try the\n\"stable\" one below.  It generally won't produce colors that are as interesting as\n\"artistic\", but the glitches are noticeably reduced.\n\nImage (stable) [![Colab for stable model](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jantic/DeOldify/blob/master/ImageColorizerColabStable.ipynb)\n\nInstructions on how to use the Colabs above have been kindly provided in video\ntutorial form by Old Ireland in Colour's John Breslin.  It's great! Click video\nimage below to watch.\n\n[![DeOldify Tutorial](http://img.youtube.com/vi/VaEl0faDw38/0.jpg)](http://www.youtube.com/watch?v=VaEl0faDw38)\n\nGet more updates on [Twitter\n![Twitter logo](resource_images/twitter.svg)](https://twitter.com/DeOldify).\n\n## Table of Contents\n\n- [About DeOldify](#about-deoldify)\n- [Example Videos](#example-videos)\n- [Example Images](#example-images)\n- [Stuff That Should Probably Be In A Paper](#stuff-that-should-probably-be-in-a-paper)\n  - [How to Achieve Stable Video](#how-to-achieve-stable-video)\n  - [What is NoGAN?](#what-is-nogan)\n- [Why Three Models?](#why-three-models)\n- [Technical Details](#the-technical-details)\n- [Going Forward](#this-project-going-forward)\n- [Getting Started Yourself](#getting-started-yourself)\n  - [Easiest Approach](#easiest-approach)\n  - [Your Own Machine](#your-own-machine-not-as-easy)\n- [Pretrained Weights](#pretrained-weights)\n\n## About DeOldify\n\nSimply put, the mission of this project is to colorize and restore old images and\nfilm footage. We'll get into the details in a bit, but first let's see some\npretty pictures and videos!\n\n### New and Exciting Stuff in DeOldify\n\n- Glitches and artifacts are almost entirely eliminated\n- Better skin (less zombies)\n- More highly detailed and photorealistic renders\n- Much less \"blue bias\"\n- **Video** - it actually looks good!  \n- **NoGAN** - a new and weird but highly effective way to do GAN training for\n  image to image.\n\n## Example Videos\n\n**Note:**  Click images to watch\n\n### Facebook F8 Demo\n\n[![DeOldify Facebook F8 Movie Colorization Demo](http://img.youtube.com/vi/l3UXXid04Ys/0.jpg)](http://www.youtube.com/watch?v=l3UXXid04Ys)\n\n### Silent Movie Examples\n\n[![DeOldify Silent Movie Examples](http://img.youtube.com/vi/EXn-n2iqEjI/0.jpg)](http://www.youtube.com/watch?v=EXn-n2iqEjI)\n\n## Example Images\n\n\"Migrant Mother\" by Dorothea Lange (1936)\n\n![Migrant Mother](https://i.imgur.com/Bt0vnke.jpg)\n\nWoman relaxing in her livingroom in Sweden (1920)\n\n![Sweden Living Room](https://i.imgur.com/158d0oU.jpg)\n\n\"Toffs and Toughs\" by Jimmy Sime (1937)\n\n![Class Divide](https://i.imgur.com/VYuav4I.jpg)\n\nThanksgiving Maskers (1911)\n\n![Thanksgiving Maskers](https://i.imgur.com/n8qVJ5c.jpg)\n\nGlen Echo Madame Careta Gypsy Camp in Maryland (1925)\n\n![Gypsy Camp](https://i.imgur.com/1oYrJRI.jpg)\n\n\"Mr. and Mrs. Lemuel Smith and their younger children in their farm house,\nCarroll County, Georgia.\" (1941)\n\n![Georgia Farmhouse](https://i.imgur.com/I2j8ynm.jpg)\n\n\"Building the Golden Gate Bridge\" (est 1937)\n\n![Golden Gate Bridge](https://i.imgur.com/6SbFjfq.jpg)\n\n> **Note:**  What you might be wondering is while this render looks cool, are the\n> colors accurate? The original photo certainly makes it look like the towers of\n> the bridge could be white. We looked into this and it turns out the answer is\n> no - the towers were already covered in red primer by this time. So that's\n> something to keep in mind- historical accuracy remains a huge challenge!\n\n\"Terrasse de café, Paris\" (1925)\n\n![Cafe Paris](https://i.imgur.com/WprQwP5.jpg)\n\nNorwegian Bride (est late 1890s)\n\n![Norwegian Bride](https://i.imgur.com/MmtvrZm.jpg)\n\nZitkála-Šá (Lakota: Red Bird), also known as Gertrude Simmons Bonnin (1898)\n\n![Native Woman](https://i.imgur.com/zIGM043.jpg)\n\nChinese Opium Smokers (1880)\n\n![Opium Real](https://i.imgur.com/lVGq8Vq.jpg)\n\n## Stuff That Should Probably Be In A Paper\n\n### How to Achieve Stable Video\n\nNoGAN training is crucial to getting the kind of stable and colorful images seen\nin this iteration of DeOldify. NoGAN training combines the benefits of GAN\ntraining (wonderful colorization) while eliminating the nasty side effects\n(like flickering objects in video). Believe it or not, video is rendered using\nisolated image generation without any sort of temporal modeling tacked on. The\nprocess performs 30-60 minutes of the GAN portion of \"NoGAN\" training, using 1%\nto 3% of imagenet data once.  Then, as with still image colorization, we\n\"DeOldify\" individual frames before rebuilding the video.\n\nIn addition to improved video stability, there is an interesting thing going on\nhere worth mentioning. It turns out the models I run, even different ones and\nwith different training structures, keep arriving at more or less the same\nsolution.  That's even the case for the colorization of things you may think\nwould be arbitrary and unknowable, like the color of clothing, cars, and even\nspecial effects (as seen in \"Metropolis\").\n\n![Metropolis Special FX](https://thumbs.gfycat.com/HeavyLoneBlowfish-size_restricted.gif)\n\nMy best guess is that the models are learning some interesting rules about how to\ncolorize based on subtle cues present in the black and white images that I\ncertainly wouldn't expect to exist.  This result leads to nicely deterministic and\nconsistent results, and that means you don't have track model colorization\ndecisions because they're not arbitrary.  Additionally, they seem remarkably\nrobust so that even in moving scenes the renders are very consistent.\n\n![Moving Scene Example](https://thumbs.gfycat.com/FamiliarJubilantAsp-size_restricted.gif)\n\nOther ways to stabilize video add up as well. First, generally speaking rendering\nat a higher resolution (higher render_factor) will increase stability of\ncolorization decisions.  This stands to reason because the model has higher\nfidelity image information to work with and will have a greater chance of making\nthe \"right\" decision consistently.  Closely related to this is the use of\nresnet101 instead of resnet34 as the backbone of the generator- objects are\ndetected more consistently and correctly with this. This is especially important\nfor getting good, consistent skin rendering.  It can be particularly visually\njarring if you wind up with \"zombie hands\", for example.\n\n![Zombie Hand Example](https://thumbs.gfycat.com/ThriftyInferiorIsabellinewheatear-size_restricted.gif)\n\nAdditionally, gaussian noise augmentation during training appears to help but at\nthis point the conclusions as to just how much are bit more tenuous (I just\nhaven't formally measured this yet).  This is loosely based on work done in style\ntransfer video, described here:\n <https://medium.com/element-ai-research-lab/stabilizing-neural-style-transfer-for-video-62675e203e42>.\n\nSpecial thanks go to Rani Horev for his contributions in implementing this noise\naugmentation.\n\n### What is NoGAN?\n\nThis is a new type of GAN training that I've developed to solve some key problems\nin the previous DeOldify model. It provides the benefits of GAN training while\nspending minimal time doing direct GAN training.  Instead, most of the training\ntime is spent pretraining the generator and critic separately with more\nstraight-forward, fast and reliable conventional methods.  A key insight here is\nthat those more \"conventional\" methods generally get you most of the results you\nneed, and that GANs can be used to close the gap on realism. During the very\nshort amount of actual GAN training the generator not only gets the full\nrealistic colorization capabilities that used to take days of progressively\nresized GAN training, but it also doesn't accrue nearly as much of the artifacts\nand other ugly baggage of GANs. In fact, you can pretty much eliminate glitches\nand artifacts almost entirely depending on your approach. As far as I know this\nis a new technique. And it's incredibly effective.\n\n#### Original DeOldify Model\n\n![Before Flicker](https://thumbs.gfycat.com/CoordinatedVeneratedHogget-size_restricted.gif)\n\n#### NoGAN-Based DeOldify Model\n\n![After Flicker](https://thumbs.gfycat.com/OilyBlackArctichare-size_restricted.gif)\n\nThe steps are as follows: First train the generator in a conventional way by\nitself with just the feature loss. Next, generate images from that, and train\nthe critic on distinguishing between those outputs and real images as a basic\nbinary classifier. Finally, train the generator and critic together in a GAN\nsetting (starting right at the target size of 192px in this case).  Now for\nthe weird part:  All the useful GAN training here only takes place within a very\nsmall window of time.  There's an inflection point where it appears the critic\nhas transferred everything it can that is useful to the generator. Past this\npoint, image quality oscillates between the best that you can get at the\ninflection point, or bad in a predictable way (orangish skin, overly red lips,\netc).  There appears to be no productive training after the inflection point.\nAnd this point lies within training on just 1% to 3% of the Imagenet Data!\nThat amounts to about 30-60 minutes of training at 192px.\n\nThe hard part is finding this inflection point.  So far, I've accomplished this\nby making a whole bunch of model save checkpoints (every 0.1% of data iterated\non) and then just looking for the point where images look great before they go\ntotally bonkers with orange skin (always the first thing to go). Additionally,\ngenerator rendering starts immediately getting glitchy and inconsistent at this\npoint, which is no good particularly for video. What I'd really like to figure\nout is what the tell-tale sign of the inflection point is that can be easily\nautomated as an early stopping point.  Unfortunately, nothing definitive is\njumping out at me yet.  For one, it's happening in the middle of training loss\ndecreasing- not when it flattens out, which would seem more reasonable on the surface.\n\nAnother key thing about NoGAN training is you can repeat pretraining the critic\non generated images after the initial GAN training, then repeat the GAN training\nitself in the same fashion.  This is how I was able to get extra colorful results\nwith the \"artistic\" model.  But this does come at a cost currently- the output of\nthe generator becomes increasingly inconsistent and you have to experiment with\nrender resolution (render_factor) to get the best result.  But the renders are\nstill glitch free and way more consistent than I was ever able to achieve with\nthe original DeOldify model. You can do about five of these repeat cycles, give\nor take, before you get diminishing returns, as far as I can tell.\n\nKeep in mind- I haven't been entirely rigorous in figuring out what all is going\non in NoGAN- I'll save that for a paper. That means there's a good chance I'm\nwrong about something.  But I think it's definitely worth putting out there now\nbecause I'm finding it very useful- it's solving basically much of my remaining\nproblems I had in DeOldify.\n\nThis builds upon a technique developed in collaboration with Jeremy Howard and\nSylvain Gugger for Fast.AI's Lesson 7 in version 3 of Practical Deep Learning\nfor Coders Part I. The particular lesson notebook can be found here:\n  <https://github.com/fastai/course-v3/blob/master/nbs/dl1/lesson7-superres-gan.ipynb>\n\n## Why Three Models?\n\nThere are now three models to choose from in DeOldify. Each of these has key\nstrengths and weaknesses, and so have different use cases.  Video is for video\nof course.  But stable and artistic are both for images, and sometimes one will\ndo images better than the other.\n\nMore details:\n\n- **Artistic** - This model achieves the highest quality results in image\ncoloration, in terms of interesting details and vibrance. The most notable\ndrawback however is that it's a bit of a pain to fiddle around with to get the\nbest results (you have to adjust the rendering resolution or render_factor to\nachieve this).  Additionally, the model does not do as well as stable in a few\nkey common scenarios- nature scenes and portraits.  The model uses a resnet34\nbackbone on a UNet with an emphasis on depth of layers on the decoder side.\nThis model was trained with 5 critic pretrain/GAN cycle repeats via NoGAN, in\naddition to the initial generator/critic pretrain/GAN NoGAN training, at 192px.\nThis adds up to a total of 32% of Imagenet data trained once (12.5 hours of\ndirect GAN training).\n\n- **Stable** - This model achieves the best results with landscapes and\nportraits.  Notably, it produces less \"zombies\"- where faces or limbs stay gray\nrather than being colored in properly.  It generally has less weird\nmiscolorations than artistic, but it's also less colorful in general.  This\nmodel uses a resnet101 backbone on a UNet with an emphasis on width of layers on\nthe decoder side.  This model was trained with 3 critic pretrain/GAN cycle\nrepeats via NoGAN, in addition to the initial generator/critic pretrain/GAN\nNoGAN training, at 192px.  This adds up to a total of 7% of Imagenet data\ntrained once (3 hours of direct GAN training).\n\n- **Video** - This model is optimized for smooth, consistent and flicker-free\nvideo.  This would definitely be the least colorful of the three models, but\nit's honestly not too far off from \"stable\". The model is the same as \"stable\"\nin terms of architecture, but differs in training.  It's trained for a mere 2.2% \nof Imagenet data once at 192px, using only the initial generator/critic \npretrain/GAN NoGAN training (1 hour of direct GAN training).\n\nBecause the training of the artistic and stable models was done before the\n\"inflection point\" of NoGAN training described in \"What is NoGAN???\" was\ndiscovered, I believe this amount of training on them can be knocked down\nconsiderably. As far as I can tell, the models were stopped at \"good points\"\nthat were well beyond where productive training was taking place.  I'll be\nlooking into this in the future.\n\nIdeally, eventually these three models will be consolidated into one that has all\nthese good desirable unified.  I think there's a path there, but it's going to\nrequire more work!  So for now, the most practical solution appears to be to\nmaintain multiple models.\n\n## The Technical Details\n\nThis is a deep learning based model.  More specifically, what I've done is\ncombined the following approaches:\n\n### [Self-Attention Generative Adversarial Network](https://arxiv.org/abs/1805.08318)\n\nExcept the generator is a **pretrained U-Net**, and I've just modified it to\nhave the spectral normalization and self-attention.  It's a pretty\nstraightforward translation.\n\n### [Two Time-Scale Update Rule](https://arxiv.org/abs/1706.08500)\n\nThis is also very straightforward – it's just one to one generator/critic\niterations and higher critic learning rate.\nThis is modified to incorporate a \"threshold\" critic loss that makes sure that\nthe critic is \"caught up\" before moving on to generator training.\nThis is particularly useful for the \"NoGAN\" method described below.\n\n### NoGAN\n\nThere's no paper here! This is a new type of GAN training that I've developed to\nsolve some key problems in the previous DeOldify model.\nThe gist is that you get the benefits of GAN training while spending minimal time\ndoing direct GAN training.\nMore details are in the [What is NoGAN?](#what-is-nogan) section (it's a doozy).\n\n### Generator Loss\n\nLoss during NoGAN learning is two parts:  One is a basic Perceptual Loss (or\nFeature Loss) based on VGG16 – this just biases the generator model to replicate\nthe input image.\nThe second is the loss score from the critic.  For the curious – Perceptual Loss\nisn't sufficient by itself to produce good results.\nIt tends to just encourage a bunch of brown/green/blue – you know, cheating to\nthe test, basically, which neural networks are really good at doing!\nKey thing to realize here is that GANs essentially are learning the loss function\nfor you – which is really one big step closer to toward the ideal that we're\nshooting for in machine learning.\nAnd of course you generally get much better results when you get the machine to\nlearn something you were previously hand coding.\nThat's certainly the case here.\n\n**Of note:**  There's no longer any \"Progressive Growing of GANs\" type training\ngoing on here.  It's just not needed in lieu of the superior results obtained\nby the \"NoGAN\" technique described above.\n\nThe beauty of this model is that it should be generally useful for all sorts of\nimage modification, and it should do it quite well.\nWhat you're seeing above are the results of the colorization model, but that's\njust one component in a pipeline that I'm developing with the exact same approach.\n\n## This Project, Going Forward\n\nSo that's the gist of this project – I'm looking to make old photos and film\nlook reeeeaaally good with GANs, and more importantly, make the project *useful*.\nIn the meantime though this is going to be my baby and I'll be actively updating\nand improving the code over the foreseeable future.\nI'll try to make this as user-friendly as possible, but I'm sure there's going\nto be hiccups along the way.\n\nOh and I swear I'll document the code properly...eventually.  Admittedly I'm\n*one of those* people who believes in \"self documenting code\" (LOL).\n\n## Getting Started Yourself\n\n### Easiest Approach\n\nThe easiest way to get started is to go straight to the Colab notebooks:\n\nImage [![Colab for images](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jantic/DeOldify/blob/master/ImageColorizerColab.ipynb)\n| Video [![Colab for video](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jantic/DeOldify/blob/master/VideoColorizerColab.ipynb)\n\nSpecial thanks to Matt Robinson and María Benavente for their image Colab notebook\ncontributions, and Robert Bell for the video Colab notebook work!\n\n### Your Own Machine (not as easy)\n\n#### Hardware and Operating System Requirements\n\n- **(Training Only) BEEFY Graphics card**.  I'd really like to have more memory\n  than the 11 GB in my GeForce 1080TI (11GB).  You'll have a tough time with less.\n  The Generators and Critic are ridiculously large.  \n- **(Colorization Alone) A decent graphics card**. Approximately 4GB+ memory\n  video cards should be sufficient.\n- **Linux**.  I'm using Ubuntu 18.04, and I know 16.04 works fine too.  **Windows\n  is not supported and any issues brought up related to this will not be investigated.**\n\n#### Easy Install\n\nYou should now be able to do a simple install with Anaconda. Here are the steps:\n\nOpen the command line and navigate to the root folder you wish to install.  Then\ntype the following commands\n\n```console\ngit clone https://github.com/jantic/DeOldify.git DeOldify\ncd DeOldify\nconda env create -f environment.yml\n```\n\nThen start running with these commands:\n\n```console\nsource activate deoldify\njupyter lab\n```\n\nFrom there you can start running the notebooks in Jupyter Lab, via the url they\nprovide you in the console.\n\n> **Note:** You can also now do \"conda activate deoldify\" if you have the latest\nversion of conda and in fact that's now recommended. But a lot of people don't\nhave that yet so I'm not going to make it the default instruction here yet.\n\n**Alternative Install:** User daddyparodz has kindly created an installer script\nfor Ubuntu, and in particular Ubuntu on WSL, that may make things easier:\n  <https://github.com/daddyparodz/AutoDeOldifyLocal>\n\n#### Note on test_images Folder\n\nThe images in the `test_images` folder have been removed because they were using\nGit LFS and that costs a lot of money when GitHub actually charges for bandwidth\non a popular open source project (they had a billing bug for while that was\nrecently fixed).  The notebooks that use them (the image test ones) still point\nto images in that directory that I (Jason) have personally and I'd like to keep\nit that way because, after all, I'm by far the primary and most active developer.\nBut they won't work for you.  Still, those notebooks are a convenient template\nfor making your own tests if you're so inclined.\n\n#### Typical training\n\nThe notebook `ColorizeTrainingWandb` has been created to log and monitor results\nthrough [Weights & Biases](https://www.wandb.com/). You can find a description of\ntypical training by consulting [W&B Report](https://app.wandb.ai/borisd13/DeOldify/reports?view=borisd13%2FDeOldify).\n\n## Pretrained Weights\n\nTo start right away on your own machine with your own images or videos without\ntraining the models yourself, you'll need to download the \"Completed Generator\nWeights\" listed below and drop them in the /models/ folder.\n\nThe colorization inference notebooks should be able to guide you from here. The\nnotebooks to use are named ImageColorizerArtistic.ipynb,\nImageColorizerStable.ipynb, and VideoColorizer.ipynb.\n\n### Completed Generator Weights\n\n- [Artistic](https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth)\n- [Stable](https://www.dropbox.com/s/axsd2g85uyixaho/ColorizeStable_gen.pth?dl=0)\n- [Video](https://data.deepai.org/deoldify/ColorizeVideo_gen.pth)\n\n### Completed Critic Weights\n\n- [Artistic](https://www.dropbox.com/s/xpq2ip9occuzgen/ColorizeArtistic_crit.pth?dl=0)\n- [Stable](https://www.dropbox.com/s/s53699e9n84q6sp/ColorizeStable_crit.pth?dl=0)\n- [Video](https://www.dropbox.com/s/xnq1z1oppvgpgtn/ColorizeVideo_crit.pth?dl=0)\n\n### Pretrain Only Generator Weights\n\n- [Artistic](https://www.dropbox.com/s/h782d1zar3vdblw/ColorizeArtistic_PretrainOnly_gen.pth?dl=0)\n- [Stable](https://www.dropbox.com/s/mz5n9hiq6hmwjq7/ColorizeStable_PretrainOnly_gen.pth?dl=0)\n- [Video](https://www.dropbox.com/s/ix993ci6ve7crlk/ColorizeVideo_PretrainOnly_gen.pth?dl=0)\n\n### Pretrain Only Critic Weights\n\n- [Artistic](https://www.dropbox.com/s/gr81b3pkidwlrc7/ColorizeArtistic_PretrainOnly_crit.pth?dl=0)\n- [Stable](https://www.dropbox.com/s/007qj0kkkxt5gb4/ColorizeStable_PretrainOnly_crit.pth?dl=0)\n- [Video](https://www.dropbox.com/s/wafc1uogyjuy4zq/ColorizeVideo_PretrainOnly_crit.pth?dl=0)\n\n## Want the Old DeOldify?\n\nWe suspect some of you are going to want access to the original DeOldify model\nfor various reasons.  We have that archived here:  <https://github.com/dana-kelley/DeOldify>\n\n## Want More?\n\nFollow [#DeOldify](https://twitter.com/search?q=%23Deoldify) on Twitter.\n\n## License\n\nAll code in this repository is under the MIT license as specified by the LICENSE\nfile.\n\nThe model weights listed in this readme under the \"Pretrained Weights\" section\nare trained by ourselves and are released under the MIT license.\n\n## A Statement on Open Source Support\n\nWe believe that open source has done a lot of good for the world.  After all,\nDeOldify simply wouldn't exist without it. But we also believe that there needs\nto be boundaries on just how much is reasonable to be expected from an open\nsource project maintained by just two developers.\n\nOur stance is that we're providing the code and documentation on research that\nwe believe is beneficial to the world.  What we have provided are novel takes\non colorization, GANs, and video that are hopefully somewhat friendly for\ndevelopers and researchers to learn from and adopt. This is the culmination of\nwell over a year of continuous work, free for you. What wasn't free was\nshouldered by us, the developers.  We left our jobs, bought expensive GPUs, and\nhad huge electric bills as a result of dedicating ourselves to this.\n\nWhat we haven't provided here is a ready to use free \"product\" or \"app\", and we\ndon't ever intend on providing that.  It's going to remain a Linux based project\nwithout Windows support, coded in Python, and requiring people to have some extra\ntechnical background to be comfortable using it.  Others have stepped in with\ntheir own apps made with DeOldify, some paid and some free, which is what we want!\nWe're instead focusing on what we believe we can do best- making better\ncommercial models that people will pay for.\nDoes that mean you're not getting the very best for free?  Of course. We simply\ndon't believe that we're obligated to provide that, nor is it feasible! We\ncompete on research and sell that.  Not a GUI or web service that wraps said\nresearch- that part isn't something we're going to be great at anyways. We're not\nabout to shoot ourselves in the foot by giving away our actual competitive\nadvantage for free, quite frankly.\n\nWe're also not willing to go down the rabbit hole of providing endless, open\nended and personalized support on this open source project.  Our position is\nthis:  If you have the proper background and resources, the project provides\nmore than enough to get you started. We know this because we've seen plenty of\npeople using it and making money off of their own projects with it.\n\nThus, if you have an issue come up and it happens to be an actual bug that\nhaving it be fixed will benefit users generally, then great- that's something\nwe'll be happy to look into.\n\nIn contrast, if you're asking about something that really amounts to asking for\npersonalized and time consuming support that won't benefit anybody else, we're\nnot going to help. It's simply not in our interest to do that. We have bills to\npay, after all. And if you're asking for help on something that can already be\nderived from the documentation or code?  That's simply annoying, and we're not\ngoing to pretend to be ok with that.\n"
  },
  {
    "path": "VideoColorizer.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#NOTE:  This must be the first call in order to work properly!\\n\",\n    \"from deoldify import device\\n\",\n    \"from deoldify.device_id import DeviceId\\n\",\n    \"#choices:  CPU, GPU0...GPU7\\n\",\n    \"device.set(device=DeviceId.GPU0)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from deoldify.visualize import *\\n\",\n    \"plt.style.use('dark_background')\\n\",\n    \"import warnings\\n\",\n    \"warnings.filterwarnings(\\\"ignore\\\", category=UserWarning, message=\\\".*?Your .*? set is empty.*?\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"colorizer = get_video_colorizer()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Instructions\\n\",\n    \"\\n\",\n    \"### source_url\\n\",\n    \"Type in a url hosting a video from YouTube, Imgur, Twitter, Reddit, Vimeo, etc.  Many sources work!  GIFs also work.  Full list here: https://ytdl-org.github.io/youtube-dl/supportedsites.html NOTE: If you want to use your own video, you can set source_url to None and just upload the file to video/source/ in Jupyter.  Just make sure that the file_name parameter matches the file you uploaded.\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"### file_name\\n\",\n    \"Name this whatever sensible file name you want (minus extension)! It should actually exist in video/source if source_url=None\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"### render_factor\\n\",\n    \"The default value of 21 has been carefully chosen and should work -ok- for most scenarios (but probably won't be the -best-). This determines resolution at which the color portion of the video is rendered. Lower resolution will render faster, and colors also tend to look more vibrant. Older and lower quality film in particular will generally benefit by lowering the render factor. Higher render factors are often better for higher quality videos and inconsistencies (flashy render) will generally be reduced, but the colors may get slightly washed out. \\n\",\n    \"\\n\",\n    \"\\n\",\n    \"### file_name_ext\\n\",\n    \"There's no reason to changes this.\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"### result_path\\n\",\n    \"Ditto- don't change.\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"### How to Download a Copy\\n\",\n    \"Simply shift+right click on the displayed video and click \\\"Save video as...\\\"!\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"## Pro Tips\\n\",\n    \"1. If a video takes a long time to render and you're wondering how well the frames will actually be colorized, you can preview how well the frames will be rendered at each render_factor by using the code at the bottom. Just stop the video rendering by hitting the stop button on the cell, then run that bottom cell under \\\"See how well render_factor values perform on a frame here\\\". It's not perfect and you may still need to experiment a bit especially when it comes to figuring out how to reduce frame inconsistency.  But it'll go a long way in narrowing down what actually works.\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"## Troubleshooting\\n\",\n    \"The video player may wind up not showing up, in which case- make sure to wait for the Jupyter cell to complete processing first (the play button will stop spinning).  Then follow these alternative download instructions\\n\",\n    \"\\n\",\n    \"1. In the menu to the left, click Home icon.\\n\",\n    \"2. By default, rendered video will be in /video/result/\\n\",\n    \"\\n\",\n    \"If a video you downloaded doesn't play, it's probably because the cell didn't complete processing and the video is in a half-finished state.\\n\",\n    \"If you get a 'CUDA out of memory' error, you probably have the render_factor too high.  The max is 44 on 11GB video cards.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Colorize!!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#NOTE:  Max is 44 with 11GB video cards.  21 is a good default\\n\",\n    \"render_factor=21\\n\",\n    \"#NOTE:  Make source_url None to just read from file at ./video/source/[file_name] directly without modification\\n\",\n    \"source_url='https://twitter.com/silentmoviegifs/status/1116751583386034176'\\n\",\n    \"file_name = 'DogShy1926'\\n\",\n    \"file_name_ext = file_name + '.mp4'\\n\",\n    \"result_path = None\\n\",\n    \"\\n\",\n    \"if source_url is not None:\\n\",\n    \"    result_path = colorizer.colorize_from_url(source_url, file_name_ext, render_factor=render_factor)\\n\",\n    \"else:\\n\",\n    \"    result_path = colorizer.colorize_from_file_name(file_name_ext, render_factor=render_factor)\\n\",\n    \"\\n\",\n    \"show_video_in_notebook(result_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## See how well render_factor values perform on a frame here\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"for i in range(10,45,2):\\n\",\n    \"    colorizer.vis.plot_transformed_image('video/bwframes/' + file_name + '/00001.jpg', render_factor=i, display_render_factor=True, figsize=(8,8))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.6\"\n  },\n  \"toc\": {\n   \"colors\": {\n    \"hover_highlight\": \"#DAA520\",\n    \"navigate_num\": \"#000000\",\n    \"navigate_text\": \"#333333\",\n    \"running_highlight\": \"#FF0000\",\n    \"selected_highlight\": \"#FFD700\",\n    \"sidebar_border\": \"#EEEEEE\",\n    \"wrapper_background\": \"#FFFFFF\"\n   },\n   \"moveMenuLeft\": true,\n   \"nav_menu\": {\n    \"height\": \"67px\",\n    \"width\": \"252px\"\n   },\n   \"navigate_menu\": true,\n   \"number_sections\": true,\n   \"sideBar\": true,\n   \"threshold\": 4,\n   \"toc_cell\": false,\n   \"toc_section_display\": \"block\",\n   \"toc_window_display\": false,\n   \"widenNotebook\": false\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "VideoColorizerColab.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"view-in-github\"\n   },\n   \"source\": [\n    \"<a href=\\\"https://colab.research.google.com/github/jantic/DeOldify/blob/master/VideoColorizerColab.ipynb\\\" target=\\\"_parent\\\"><img src=\\\"https://colab.research.google.com/assets/colab-badge.svg\\\" alt=\\\"Open In Colab\\\"/></a>\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### **<font color='blue'> Video Colorizer </font>**\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"663IVxfrpIAb\"\n   },\n   \"source\": [\n    \"#◢ DeOldify - Colorize your own videos!\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"_FYI: This notebook is intended as a tool to colorize gifs and short videos, if you are trying to convert longer video you may hit the limit on processing space. Running the Jupyter notebook on your own machine is recommended (and faster) for larger video sizes._\\n\",\n    \"\\n\",\n    \"####**Credits:**\\n\",\n    \"\\n\",\n    \"Big special thanks to:\\n\",\n    \"\\n\",\n    \"Robert Bell for all his work on the video Colab notebook, and paving the way to video in DeOldify!\\n\",\n    \"\\n\",\n    \"Dana Kelley for doing things, breaking stuff & having an opinion on everything.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"ZjPqTBNoohK9\"\n   },\n   \"source\": [\n    \"\\n\",\n    \"\\n\",\n    \"---\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"#◢ Verify Correct Runtime Settings\\n\",\n    \"\\n\",\n    \"**<font color='#FF000'> IMPORTANT </font>**\\n\",\n    \"\\n\",\n    \"In the \\\"Runtime\\\" menu for the notebook window, select \\\"Change runtime type.\\\" Ensure that the following are selected:\\n\",\n    \"* Runtime Type = Python 3\\n\",\n    \"* Hardware Accelerator = GPU \\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"gaEJBGDlptEo\"\n   },\n   \"source\": [\n    \"#◢ Git clone and install DeOldify\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {},\n    \"colab_type\": \"code\",\n    \"id\": \"-T-svuHytJ-8\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!git clone https://github.com/jantic/DeOldify.git DeOldify\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"cd DeOldify\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"BDFjbNxaadNJ\"\n   },\n   \"source\": [\n    \"#◢ Setup\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {},\n    \"colab_type\": \"code\",\n    \"id\": \"00_GcC_trpdE\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"#NOTE:  This must be the first call in order to work properly!\\n\",\n    \"from deoldify import device\\n\",\n    \"from deoldify.device_id import DeviceId\\n\",\n    \"#choices:  CPU, GPU0...GPU7\\n\",\n    \"device.set(device=DeviceId.GPU0)\\n\",\n    \"\\n\",\n    \"import torch\\n\",\n    \"\\n\",\n    \"if not torch.cuda.is_available():\\n\",\n    \"    print('GPU not available.')\\n\",\n    \"\\n\",\n    \"from os import path\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {},\n    \"colab_type\": \"code\",\n    \"id\": \"Lsx7xCXNSVt6\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!pip install -r requirements-colab.txt\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {},\n    \"colab_type\": \"code\",\n    \"id\": \"MsJa69CMwj3l\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import fastai\\n\",\n    \"from deoldify.visualize import *\\n\",\n    \"from pathlib import Path\\n\",\n    \"torch.backends.cudnn.benchmark=True\\n\",\n    \"import warnings\\n\",\n    \"warnings.filterwarnings(\\\"ignore\\\", category=UserWarning, message=\\\".*?Your .*? set is empty.*?\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"!mkdir 'models'\\n\",\n    \"!wget https://data.deepai.org/deoldify/ColorizeVideo_gen.pth -O ./models/ColorizeVideo_gen.pth\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {},\n    \"colab_type\": \"code\",\n    \"id\": \"tzHVnegp21hC\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"colorizer = get_video_colorizer()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#◢ Instructions\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### source_url\\n\",\n    \"Type in a url hosting a video from YouTube, Imgur, Twitter, Reddit, Vimeo, etc.  Many sources work!  GIFs also work.  Full list here: https://ytdl-org.github.io/youtube-dl/supportedsites.html NOTE: If you want to use your own video, upload it first to a site like YouTube. \\n\",\n    \"\\n\",\n    \"### render_factor\\n\",\n    \"The default value of 21 has been carefully chosen and should work -ok- for most scenarios (but probably won't be the -best-). This determines resolution at which the color portion of the video is rendered. Lower resolution will render faster, and colors also tend to look more vibrant. Older and lower quality film in particular will generally benefit by lowering the render factor. Higher render factors are often better for higher quality videos and inconsistencies (flashy render) will generally be reduced, but the colors may get slightly washed out.\\n\",\n    \"\\n\",\n    \"### watermarked\\n\",\n    \"Selected by default, this places a watermark icon of a palette at the bottom left corner of the image.  This is intended to be a standard way to convey to others viewing the image that it is colorized by AI. We want to help promote this as a standard, especially as the technology continues to improve and the distinction between real and fake becomes harder to discern. This palette watermark practice was initiated and lead by the company MyHeritage in the MyHeritage In Color feature (which uses a newer version of DeOldify than what you're using here).\\n\",\n    \"\\n\",\n    \"### How to Download a Copy\\n\",\n    \"Simply right click on the displayed video and click \\\"Save video as...\\\"!\\n\",\n    \"\\n\",\n    \"## Pro Tips\\n\",\n    \"1. If a video takes a long time to render and you're wondering how well the frames will actually be colorized, you can preview how well the frames will be rendered at each render_factor by using the code at the bottom. Just stop the video rendering by hitting the stop button on the cell, then run that bottom cell under \\\"See how well render_factor values perform on a frame here\\\". It's not perfect and you may still need to experiment a bit especially when it comes to figuring out how to reduce frame inconsistency.  But it'll go a long way in narrowing down what actually works.\\n\",\n    \"2. If videos are taking way too much time for your liking, running the Jupyter notebook VideoColorizer.ipynb on your own machine (with DeOldify installed) will generally be much faster (as long as you have the hardware for it).   \\n\",\n    \"3. Longer videos (running multiple minutes) are going to have a rough time on Colabs. You'll be much better off using a local install of DeOldify instead in this case.\\n\",\n    \"\\n\",\n    \"## Troubleshooting\\n\",\n    \"The video player may wind up not showing up, in which case- make sure to wait for the Jupyter cell to complete processing first (the play button will stop spinning).  Then follow these alternative download instructions\\n\",\n    \"\\n\",\n    \"1. In the menu to the left, click Files\\n\",\n    \"2. If you don't see the 'DeOldify' folder, click \\\"Refresh\\\"\\n\",\n    \"3. By default, rendered video will be in /DeOldify/video/result/\\n\",\n    \"\\n\",\n    \"If a video you downloaded doesn't play, it's probably because the cell didn't complete processing and the video is in a half-finished state.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"sUQrbSYipiJn\"\n   },\n   \"source\": [\n    \"#◢ Colorize!!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"source_url = '' #@param {type:\\\"string\\\"}\\n\",\n    \"render_factor = 21  #@param {type: \\\"slider\\\", min: 5, max: 40}\\n\",\n    \"watermarked = True #@param {type:\\\"boolean\\\"}\\n\",\n    \"\\n\",\n    \"if source_url is not None and source_url !='':\\n\",\n    \"    video_path = colorizer.colorize_from_url(source_url, 'video.mp4', render_factor, watermarked=watermarked)\\n\",\n    \"    show_video_in_notebook(video_path)\\n\",\n    \"else:\\n\",\n    \"    print('Provide a video url and try again.')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## See how well render_factor values perform on a frame here\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"for i in range(10,40,2):\\n\",\n    \"    colorizer.vis.plot_transformed_image('video/bwframes/video/00001.jpg', render_factor=i, display_render_factor=True, figsize=(8,8))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"X7Ycv_Y9xAHp\"\n   },\n   \"source\": [\n    \"---\\n\",\n    \"#⚙ Recommended video and gif sources \\n\",\n    \"* [/r/Nickelodeons/](https://www.reddit.com/r/Nickelodeons/)\\n\",\n    \"* [r/silentmoviegifs](https://www.reddit.com/r/silentmoviegifs/)\\n\",\n    \"* https://twitter.com/silentmoviegifs \"\n   ]\n  }\n ],\n \"metadata\": {\n  \"accelerator\": \"GPU\",\n  \"colab\": {\n   \"collapsed_sections\": [],\n   \"name\": \"VideoColorizerColab.ipynb\",\n   \"provenance\": [],\n   \"toc_visible\": true,\n   \"version\": \"0.3.2\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.6\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "deoldify/__init__.py",
    "content": "import sys\nimport logging\nlogging.getLogger().addHandler(logging.StreamHandler(sys.stdout))\nlogging.getLogger().setLevel(logging.INFO)\n\nfrom deoldify._device import _Device\n\ndevice = _Device()"
  },
  {
    "path": "deoldify/_device.py",
    "content": "import os\nfrom enum import Enum\nfrom .device_id import DeviceId\n\n#NOTE:  This must be called first before any torch imports in order to work properly!\n\nclass DeviceException(Exception):\n    pass\n\nclass _Device:\n    def __init__(self):\n        self.set(DeviceId.CPU)\n\n    def is_gpu(self):\n        ''' Returns `True` if the current device is GPU, `False` otherwise. '''\n        return self.current() is not DeviceId.CPU\n  \n    def current(self):\n        return self._current_device\n\n    def set(self, device:DeviceId):     \n        if device == DeviceId.CPU:\n            os.environ['CUDA_VISIBLE_DEVICES']=''\n        else:\n            os.environ['CUDA_VISIBLE_DEVICES']=str(device.value)\n            import torch\n            torch.backends.cudnn.benchmark=False\n        \n        self._current_device = device    \n        return device"
  },
  {
    "path": "deoldify/augs.py",
    "content": "import random\n\nfrom fastai.vision.image import TfmPixel\n\n# Contributed by Rani Horev. Thank you!\ndef _noisify(\n    x, pct_pixels_min: float = 0.001, pct_pixels_max: float = 0.4, noise_range: int = 30\n):\n    if noise_range > 255 or noise_range < 0:\n        raise Exception(\"noise_range must be between 0 and 255, inclusively.\")\n\n    h, w = x.shape[1:]\n    img_size = h * w\n    mult = 10000.0\n    pct_pixels = (\n        random.randrange(int(pct_pixels_min * mult), int(pct_pixels_max * mult)) / mult\n    )\n    noise_count = int(img_size * pct_pixels)\n\n    for ii in range(noise_count):\n        yy = random.randrange(h)\n        xx = random.randrange(w)\n        noise = random.randrange(-noise_range, noise_range) / 255.0\n        x[:, yy, xx].add_(noise)\n\n    return x\n\n\nnoisify = TfmPixel(_noisify)\n"
  },
  {
    "path": "deoldify/critics.py",
    "content": "from fastai.basic_train import Learner\nfrom fastai.core import *\nfrom fastai.layers import NormType, conv_layer\nfrom fastai.torch_core import *\nfrom fastai.vision import *\nfrom fastai.vision.data import ImageDataBunch\nfrom fastai.vision.gan import AdaptiveLoss, accuracy_thresh_expand\n\n_conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)\n\n\ndef _conv(ni: int, nf: int, ks: int = 3, stride: int = 1, **kwargs):\n    return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)\n\n\ndef custom_gan_critic(\n    n_channels: int = 3, nf: int = 256, n_blocks: int = 3, p: int = 0.15\n):\n    \"Critic to train a `GAN`.\"\n    layers = [_conv(n_channels, nf, ks=4, stride=2), nn.Dropout2d(p / 2)]\n    for i in range(n_blocks):\n        layers += [\n            _conv(nf, nf, ks=3, stride=1),\n            nn.Dropout2d(p),\n            _conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)),\n        ]\n        nf *= 2\n    layers += [\n        _conv(nf, nf, ks=3, stride=1),\n        _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),\n        Flatten(),\n    ]\n    return nn.Sequential(*layers)\n\n\ndef colorize_crit_learner(\n    data: ImageDataBunch,\n    loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()),\n    nf: int = 256,\n) -> Learner:\n    return Learner(\n        data,\n        custom_gan_critic(nf=nf),\n        metrics=accuracy_thresh_expand,\n        loss_func=loss_critic,\n        wd=1e-3,\n    )\n"
  },
  {
    "path": "deoldify/dataset.py",
    "content": "from fastai import *\nfrom fastai.core import *\nfrom fastai.vision.transform import get_transforms\nfrom fastai.vision.data import ImageImageList, ImageDataBunch, imagenet_stats\n\n\ndef get_colorize_data(\n    sz: int,\n    bs: int,\n    crappy_path: Path,\n    good_path: Path,\n    random_seed: int = None,\n    keep_pct: float = 1.0,\n    num_workers: int = 8,\n    stats: tuple = imagenet_stats,\n    xtra_tfms=[],\n) -> ImageDataBunch:\n    \n    src = (\n        ImageImageList.from_folder(crappy_path, convert_mode='RGB')\n        .use_partial_data(sample_pct=keep_pct, seed=random_seed)\n        .split_by_rand_pct(0.1, seed=random_seed)\n    )\n\n    data = (\n        src.label_from_func(lambda x: good_path / x.relative_to(crappy_path))\n        .transform(\n            get_transforms(\n                max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms\n            ),\n            size=sz,\n            tfm_y=True,\n        )\n        .databunch(bs=bs, num_workers=num_workers, no_check=True)\n        .normalize(stats, do_y=True)\n    )\n\n    data.c = 3\n    return data\n\n\ndef get_dummy_databunch() -> ImageDataBunch:\n    path = Path('./dummy/')\n    return get_colorize_data(\n        sz=1, bs=1, crappy_path=path, good_path=path, keep_pct=0.001\n    )\n"
  },
  {
    "path": "deoldify/device_id.py",
    "content": "from enum import IntEnum\n\nclass DeviceId(IntEnum):\n    GPU0 = 0,\n    GPU1 = 1,\n    GPU2 = 2,\n    GPU3 = 3,\n    GPU4 = 4,\n    GPU5 = 5,\n    GPU6 = 6,\n    GPU7 = 7,\n    CPU = 99\n"
  },
  {
    "path": "deoldify/filters.py",
    "content": "from fastai.basic_data import DatasetType\nfrom fastai.basic_train import Learner\nfrom abc import ABC, abstractmethod\nfrom fastai.core import *\nfrom fastai.vision import *\nfrom fastai.vision.image import *\nfrom fastai.vision.data import *\nfrom fastai import *\nimport cv2\nfrom PIL import Image as PilImage\nfrom deoldify import device as device_settings\nimport logging\n\n\nclass IFilter(ABC):\n    @abstractmethod\n    def filter(\n        self, orig_image: PilImage, filtered_image: PilImage, render_factor: int\n    ) -> PilImage:\n        pass\n\n\nclass BaseFilter(IFilter):\n    def __init__(self, learn: Learner, stats: tuple = imagenet_stats):\n        super().__init__()\n        self.learn = learn\n        \n        if not device_settings.is_gpu():\n            self.learn.model = self.learn.model.cpu()\n        \n        self.device = next(self.learn.model.parameters()).device\n        self.norm, self.denorm = normalize_funcs(*stats)\n\n    def _transform(self, image: PilImage) -> PilImage:\n        return image\n\n    def _scale_to_square(self, orig: PilImage, targ: int) -> PilImage:\n        # a simple stretch to fit a square really makes a big difference in rendering quality/consistency.\n        # I've tried padding to the square as well (reflect, symetric, constant, etc).  Not as good!\n        targ_sz = (targ, targ)\n        return orig.resize(targ_sz, resample=PIL.Image.BILINEAR)\n\n    def _get_model_ready_image(self, orig: PilImage, sz: int) -> PilImage:\n        result = self._scale_to_square(orig, sz)\n        result = self._transform(result)\n        return result\n\n    def _model_process(self, orig: PilImage, sz: int) -> PilImage:\n        model_image = self._get_model_ready_image(orig, sz)\n        x = pil2tensor(model_image, np.float32)\n        x = x.to(self.device)\n        x.div_(255)\n        x, y = self.norm((x, x), do_x=True)\n        \n        try:\n            result = self.learn.pred_batch(\n                ds_type=DatasetType.Valid, batch=(x[None], y[None]), reconstruct=True\n            )\n        except RuntimeError as rerr:\n            if 'memory' not in str(rerr):\n                raise rerr\n            logging.warn('Warning: render_factor was set too high, and out of memory error resulted. Returning original image.')\n            return model_image\n            \n        out = result[0]\n        out = self.denorm(out.px, do_x=False)\n        out = image2np(out * 255).astype(np.uint8)\n        return PilImage.fromarray(out)\n\n    def _unsquare(self, image: PilImage, orig: PilImage) -> PilImage:\n        targ_sz = orig.size\n        image = image.resize(targ_sz, resample=PIL.Image.BILINEAR)\n        return image\n\n\nclass ColorizerFilter(BaseFilter):\n    def __init__(self, learn: Learner, stats: tuple = imagenet_stats):\n        super().__init__(learn=learn, stats=stats)\n        self.render_base = 16\n\n    def filter(\n        self, orig_image: PilImage, filtered_image: PilImage, render_factor: int, post_process: bool = True) -> PilImage:\n        render_sz = render_factor * self.render_base\n        model_image = self._model_process(orig=filtered_image, sz=render_sz)\n        raw_color = self._unsquare(model_image, orig_image)\n\n        if post_process:\n            return self._post_process(raw_color, orig_image)\n        else:\n            return raw_color\n\n    def _transform(self, image: PilImage) -> PilImage:\n        return image.convert('LA').convert('RGB')\n\n    # This takes advantage of the fact that human eyes are much less sensitive to\n    # imperfections in chrominance compared to luminance.  This means we can\n    # save a lot on memory and processing in the model, yet get a great high\n    # resolution result at the end.  This is primarily intended just for\n    # inference\n    def _post_process(self, raw_color: PilImage, orig: PilImage) -> PilImage:\n        color_np = np.asarray(raw_color)\n        orig_np = np.asarray(orig)\n        color_yuv = cv2.cvtColor(color_np, cv2.COLOR_RGB2YUV)\n        # do a black and white transform first to get better luminance values\n        orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_RGB2YUV)\n        hires = np.copy(orig_yuv)\n        hires[:, :, 1:3] = color_yuv[:, :, 1:3]\n        final = cv2.cvtColor(hires, cv2.COLOR_YUV2RGB)\n        final = PilImage.fromarray(final)\n        return final\n\n\nclass MasterFilter(BaseFilter):\n    def __init__(self, filters: List[IFilter], render_factor: int):\n        self.filters = filters\n        self.render_factor = render_factor\n\n    def filter(\n        self, orig_image: PilImage, filtered_image: PilImage, render_factor: int = None, post_process: bool = True) -> PilImage:\n        render_factor = self.render_factor if render_factor is None else render_factor\n        for filter in self.filters:\n            filtered_image = filter.filter(orig_image, filtered_image, render_factor, post_process)\n\n        return filtered_image\n"
  },
  {
    "path": "deoldify/generators.py",
    "content": "from fastai.basic_data import DataBunch\nfrom fastai.basic_train import Learner\nfrom fastai.layers import NormType\nfrom fastai.torch_core import SplitFuncOrIdxList, apply_init, to_device\nfrom fastai.vision import *\nfrom fastai.vision.learner import cnn_config, create_body\nfrom torch import nn\nfrom .unet import DynamicUnetWide, DynamicUnetDeep\nfrom .dataset import *\n\n# Weights are implicitly read from ./models/ folder\ndef gen_inference_wide(\n    root_folder: Path, weights_name: str, nf_factor: int = 2, arch=models.resnet101) -> Learner:\n    data = get_dummy_databunch()\n    learn = gen_learner_wide(\n        data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch\n    )\n    learn.path = root_folder\n    learn.load(weights_name)\n    learn.model.eval()\n    return learn\n\n\ndef gen_learner_wide(\n    data: ImageDataBunch, gen_loss, arch=models.resnet101, nf_factor: int = 2\n) -> Learner:\n    return unet_learner_wide(\n        data,\n        arch=arch,\n        wd=1e-3,\n        blur=True,\n        norm_type=NormType.Spectral,\n        self_attention=True,\n        y_range=(-3.0, 3.0),\n        loss_func=gen_loss,\n        nf_factor=nf_factor,\n    )\n\n\n# The code below is meant to be merged into fastaiv1 ideally\ndef unet_learner_wide(\n    data: DataBunch,\n    arch: Callable,\n    pretrained: bool = True,\n    blur_final: bool = True,\n    norm_type: Optional[NormType] = NormType,\n    split_on: Optional[SplitFuncOrIdxList] = None,\n    blur: bool = False,\n    self_attention: bool = False,\n    y_range: Optional[Tuple[float, float]] = None,\n    last_cross: bool = True,\n    bottle: bool = False,\n    nf_factor: int = 1,\n    **kwargs: Any\n) -> Learner:\n    \"Build Unet learner from `data` and `arch`.\"\n    meta = cnn_config(arch)\n    body = create_body(arch, pretrained)\n    model = to_device(\n        DynamicUnetWide(\n            body,\n            n_classes=data.c,\n            blur=blur,\n            blur_final=blur_final,\n            self_attention=self_attention,\n            y_range=y_range,\n            norm_type=norm_type,\n            last_cross=last_cross,\n            bottle=bottle,\n            nf_factor=nf_factor,\n        ),\n        data.device,\n    )\n    learn = Learner(data, model, **kwargs)\n    learn.split(ifnone(split_on, meta['split']))\n    if pretrained:\n        learn.freeze()\n    apply_init(model[2], nn.init.kaiming_normal_)\n    return learn\n\n\n# ----------------------------------------------------------------------\n\n# Weights are implicitly read from ./models/ folder\ndef gen_inference_deep(\n    root_folder: Path, weights_name: str, arch=models.resnet34, nf_factor: float = 1.5) -> Learner:\n    data = get_dummy_databunch()\n    learn = gen_learner_deep(\n        data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor\n    )\n    learn.path = root_folder\n    learn.load(weights_name)\n    learn.model.eval()\n    return learn\n\n\ndef gen_learner_deep(\n    data: ImageDataBunch, gen_loss, arch=models.resnet34, nf_factor: float = 1.5\n) -> Learner:\n    return unet_learner_deep(\n        data,\n        arch,\n        wd=1e-3,\n        blur=True,\n        norm_type=NormType.Spectral,\n        self_attention=True,\n        y_range=(-3.0, 3.0),\n        loss_func=gen_loss,\n        nf_factor=nf_factor,\n    )\n\n\n# The code below is meant to be merged into fastaiv1 ideally\ndef unet_learner_deep(\n    data: DataBunch,\n    arch: Callable,\n    pretrained: bool = True,\n    blur_final: bool = True,\n    norm_type: Optional[NormType] = NormType,\n    split_on: Optional[SplitFuncOrIdxList] = None,\n    blur: bool = False,\n    self_attention: bool = False,\n    y_range: Optional[Tuple[float, float]] = None,\n    last_cross: bool = True,\n    bottle: bool = False,\n    nf_factor: float = 1.5,\n    **kwargs: Any\n) -> Learner:\n    \"Build Unet learner from `data` and `arch`.\"\n    meta = cnn_config(arch)\n    body = create_body(arch, pretrained)\n    model = to_device(\n        DynamicUnetDeep(\n            body,\n            n_classes=data.c,\n            blur=blur,\n            blur_final=blur_final,\n            self_attention=self_attention,\n            y_range=y_range,\n            norm_type=norm_type,\n            last_cross=last_cross,\n            bottle=bottle,\n            nf_factor=nf_factor,\n        ),\n        data.device,\n    )\n    learn = Learner(data, model, **kwargs)\n    learn.split(ifnone(split_on, meta['split']))\n    if pretrained:\n        learn.freeze()\n    apply_init(model[2], nn.init.kaiming_normal_)\n    return learn\n\n\n# -----------------------------\n"
  },
  {
    "path": "deoldify/layers.py",
    "content": "from fastai.layers import *\nfrom fastai.torch_core import *\n\n\n# The code below is meant to be merged into fastaiv1 ideally\n\n\ndef custom_conv_layer(\n    ni: int,\n    nf: int,\n    ks: int = 3,\n    stride: int = 1,\n    padding: int = None,\n    bias: bool = None,\n    is_1d: bool = False,\n    norm_type: Optional[NormType] = NormType.Batch,\n    use_activ: bool = True,\n    leaky: float = None,\n    transpose: bool = False,\n    init: Callable = nn.init.kaiming_normal_,\n    self_attention: bool = False,\n    extra_bn: bool = False,\n):\n    \"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers.\"\n    if padding is None:\n        padding = (ks - 1) // 2 if not transpose else 0\n    bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn == True\n    if bias is None:\n        bias = not bn\n    conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d\n    conv = init_default(\n        conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding),\n        init,\n    )\n    if norm_type == NormType.Weight:\n        conv = weight_norm(conv)\n    elif norm_type == NormType.Spectral:\n        conv = spectral_norm(conv)\n    layers = [conv]\n    if use_activ:\n        layers.append(relu(True, leaky=leaky))\n    if bn:\n        layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))\n    if self_attention:\n        layers.append(SelfAttention(nf))\n    return nn.Sequential(*layers)\n"
  },
  {
    "path": "deoldify/loss.py",
    "content": "from fastai import *\nfrom fastai.core import *\nfrom fastai.torch_core import *\nfrom fastai.callbacks import hook_outputs\nimport torchvision.models as models\n\n\nclass FeatureLoss(nn.Module):\n    def __init__(self, layer_wgts=[20, 70, 10]):\n        super().__init__()\n\n        self.m_feat = models.vgg16_bn(True).features.cuda().eval()\n        requires_grad(self.m_feat, False)\n        blocks = [\n            i - 1\n            for i, o in enumerate(children(self.m_feat))\n            if isinstance(o, nn.MaxPool2d)\n        ]\n        layer_ids = blocks[2:5]\n        self.loss_features = [self.m_feat[i] for i in layer_ids]\n        self.hooks = hook_outputs(self.loss_features, detach=False)\n        self.wgts = layer_wgts\n        self.metric_names = ['pixel'] + [f'feat_{i}' for i in range(len(layer_ids))]\n        self.base_loss = F.l1_loss\n\n    def _make_features(self, x, clone=False):\n        self.m_feat(x)\n        return [(o.clone() if clone else o) for o in self.hooks.stored]\n\n    def forward(self, input, target):\n        out_feat = self._make_features(target, clone=True)\n        in_feat = self._make_features(input)\n        self.feat_losses = [self.base_loss(input, target)]\n        self.feat_losses += [\n            self.base_loss(f_in, f_out) * w\n            for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)\n        ]\n\n        self.metrics = dict(zip(self.metric_names, self.feat_losses))\n        return sum(self.feat_losses)\n\n    def __del__(self):\n        self.hooks.remove()\n\n\n# Refactored code, originally from https://github.com/VinceMarron/style_transfer\nclass WassFeatureLoss(nn.Module):\n    def __init__(self, layer_wgts=[5, 15, 2], wass_wgts=[3.0, 0.7, 0.01]):\n        super().__init__()\n        self.m_feat = models.vgg16_bn(True).features.cuda().eval()\n        requires_grad(self.m_feat, False)\n        blocks = [\n            i - 1\n            for i, o in enumerate(children(self.m_feat))\n            if isinstance(o, nn.MaxPool2d)\n        ]\n        layer_ids = blocks[2:5]\n        self.loss_features = [self.m_feat[i] for i in layer_ids]\n        self.hooks = hook_outputs(self.loss_features, detach=False)\n        self.wgts = layer_wgts\n        self.wass_wgts = wass_wgts\n        self.metric_names = (\n            ['pixel']\n            + [f'feat_{i}' for i in range(len(layer_ids))]\n            + [f'wass_{i}' for i in range(len(layer_ids))]\n        )\n        self.base_loss = F.l1_loss\n\n    def _make_features(self, x, clone=False):\n        self.m_feat(x)\n        return [(o.clone() if clone else o) for o in self.hooks.stored]\n\n    def _calc_2_moments(self, tensor):\n        chans = tensor.shape[1]\n        tensor = tensor.view(1, chans, -1)\n        n = tensor.shape[2]\n        mu = tensor.mean(2)\n        tensor = (tensor - mu[:, :, None]).squeeze(0)\n        # Prevents nasty bug that happens very occassionally- divide by zero.  Why such things happen?\n        if n == 0:\n            return None, None\n        cov = torch.mm(tensor, tensor.t()) / float(n)\n        return mu, cov\n\n    def _get_style_vals(self, tensor):\n        mean, cov = self._calc_2_moments(tensor)\n        if mean is None:\n            return None, None, None\n        eigvals, eigvects = torch.symeig(cov, eigenvectors=True)\n        eigroot_mat = torch.diag(torch.sqrt(eigvals.clamp(min=0)))\n        root_cov = torch.mm(torch.mm(eigvects, eigroot_mat), eigvects.t())\n        tr_cov = eigvals.clamp(min=0).sum()\n        return mean, tr_cov, root_cov\n\n    def _calc_l2wass_dist(\n        self, mean_stl, tr_cov_stl, root_cov_stl, mean_synth, cov_synth\n    ):\n        tr_cov_synth = torch.symeig(cov_synth, eigenvectors=True)[0].clamp(min=0).sum()\n        mean_diff_squared = (mean_stl - mean_synth).pow(2).sum()\n        cov_prod = torch.mm(torch.mm(root_cov_stl, cov_synth), root_cov_stl)\n        var_overlap = torch.sqrt(\n            torch.symeig(cov_prod, eigenvectors=True)[0].clamp(min=0) + 1e-8\n        ).sum()\n        dist = mean_diff_squared + tr_cov_stl + tr_cov_synth - 2 * var_overlap\n        return dist\n\n    def _single_wass_loss(self, pred, targ):\n        mean_test, tr_cov_test, root_cov_test = targ\n        mean_synth, cov_synth = self._calc_2_moments(pred)\n        loss = self._calc_l2wass_dist(\n            mean_test, tr_cov_test, root_cov_test, mean_synth, cov_synth\n        )\n        return loss\n\n    def forward(self, input, target):\n        out_feat = self._make_features(target, clone=True)\n        in_feat = self._make_features(input)\n        self.feat_losses = [self.base_loss(input, target)]\n        self.feat_losses += [\n            self.base_loss(f_in, f_out) * w\n            for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)\n        ]\n\n        styles = [self._get_style_vals(i) for i in out_feat]\n\n        if styles[0][0] is not None:\n            self.feat_losses += [\n                self._single_wass_loss(f_pred, f_targ) * w\n                for f_pred, f_targ, w in zip(in_feat, styles, self.wass_wgts)\n            ]\n\n        self.metrics = dict(zip(self.metric_names, self.feat_losses))\n        return sum(self.feat_losses)\n\n    def __del__(self):\n        self.hooks.remove()\n"
  },
  {
    "path": "deoldify/save.py",
    "content": "from fastai.basic_train import Learner, LearnerCallback\nfrom fastai.vision.gan import GANLearner\n\n\nclass GANSaveCallback(LearnerCallback):\n    \"\"\"A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`.\"\"\"\n\n    def __init__(\n        self,\n        learn: GANLearner,\n        learn_gen: Learner,\n        filename: str,\n        save_iters: int = 1000,\n    ):\n        super().__init__(learn)\n        self.learn_gen = learn_gen\n        self.filename = filename\n        self.save_iters = save_iters\n\n    def on_batch_end(self, iteration: int, epoch: int, **kwargs) -> None:\n        if iteration == 0:\n            return\n\n        if iteration % self.save_iters == 0:\n            self._save_gen_learner(iteration=iteration, epoch=epoch)\n\n    def _save_gen_learner(self, iteration: int, epoch: int):\n        filename = '{}_{}_{}'.format(self.filename, epoch, iteration)\n        self.learn_gen.save(filename)\n"
  },
  {
    "path": "deoldify/unet.py",
    "content": "from fastai.layers import *\nfrom .layers import *\nfrom fastai.torch_core import *\nfrom fastai.callbacks.hooks import *\nfrom fastai.vision import *\n\n\n# The code below is meant to be merged into fastaiv1 ideally\n\n__all__ = ['DynamicUnetDeep', 'DynamicUnetWide']\n\n\ndef _get_sfs_idxs(sizes: Sizes) -> List[int]:\n    \"Get the indexes of the layers where the size of the activation changes.\"\n    feature_szs = [size[-1] for size in sizes]\n    sfs_idxs = list(\n        np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0]\n    )\n    if feature_szs[0] != feature_szs[1]:\n        sfs_idxs = [0] + sfs_idxs\n    return sfs_idxs\n\n\nclass CustomPixelShuffle_ICNR(nn.Module):\n    \"Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`.\"\n\n    def __init__(\n        self,\n        ni: int,\n        nf: int = None,\n        scale: int = 2,\n        blur: bool = False,\n        leaky: float = None,\n        **kwargs\n    ):\n        super().__init__()\n        nf = ifnone(nf, ni)\n        self.conv = custom_conv_layer(\n            ni, nf * (scale ** 2), ks=1, use_activ=False, **kwargs\n        )\n        icnr(self.conv[0].weight)\n        self.shuf = nn.PixelShuffle(scale)\n        # Blurring over (h*w) kernel\n        # \"Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts\"\n        # - https://arxiv.org/abs/1806.02658\n        self.pad = nn.ReplicationPad2d((1, 0, 1, 0))\n        self.blur = nn.AvgPool2d(2, stride=1)\n        self.relu = relu(True, leaky=leaky)\n\n    def forward(self, x):\n        x = self.shuf(self.relu(self.conv(x)))\n        return self.blur(self.pad(x)) if self.blur else x\n\n\nclass UnetBlockDeep(nn.Module):\n    \"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`.\"\n\n    def __init__(\n        self,\n        up_in_c: int,\n        x_in_c: int,\n        hook: Hook,\n        final_div: bool = True,\n        blur: bool = False,\n        leaky: float = None,\n        self_attention: bool = False,\n        nf_factor: float = 1.0,\n        **kwargs\n    ):\n        super().__init__()\n        self.hook = hook\n        self.shuf = CustomPixelShuffle_ICNR(\n            up_in_c, up_in_c // 2, blur=blur, leaky=leaky, **kwargs\n        )\n        self.bn = batchnorm_2d(x_in_c)\n        ni = up_in_c // 2 + x_in_c\n        nf = int((ni if final_div else ni // 2) * nf_factor)\n        self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs)\n        self.conv2 = custom_conv_layer(\n            nf, nf, leaky=leaky, self_attention=self_attention, **kwargs\n        )\n        self.relu = relu(leaky=leaky)\n\n    def forward(self, up_in: Tensor) -> Tensor:\n        s = self.hook.stored\n        up_out = self.shuf(up_in)\n        ssh = s.shape[-2:]\n        if ssh != up_out.shape[-2:]:\n            up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')\n        cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))\n        return self.conv2(self.conv1(cat_x))\n\n\nclass DynamicUnetDeep(SequentialEx):\n    \"Create a U-Net from a given architecture.\"\n\n    def __init__(\n        self,\n        encoder: nn.Module,\n        n_classes: int,\n        blur: bool = False,\n        blur_final=True,\n        self_attention: bool = False,\n        y_range: Optional[Tuple[float, float]] = None,\n        last_cross: bool = True,\n        bottle: bool = False,\n        norm_type: Optional[NormType] = NormType.Batch,\n        nf_factor: float = 1.0,\n        **kwargs\n    ):\n        extra_bn = norm_type == NormType.Spectral\n        imsize = (256, 256)\n        sfs_szs = model_sizes(encoder, size=imsize)\n        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))\n        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)\n        x = dummy_eval(encoder, imsize).detach()\n\n        ni = sfs_szs[-1][1]\n        middle_conv = nn.Sequential(\n            custom_conv_layer(\n                ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs\n            ),\n            custom_conv_layer(\n                ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs\n            ),\n        ).eval()\n        x = middle_conv(x)\n        layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]\n\n        for i, idx in enumerate(sfs_idxs):\n            not_final = i != len(sfs_idxs) - 1\n            up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])\n            do_blur = blur and (not_final or blur_final)\n            sa = self_attention and (i == len(sfs_idxs) - 3)\n            unet_block = UnetBlockDeep(\n                up_in_c,\n                x_in_c,\n                self.sfs[i],\n                final_div=not_final,\n                blur=blur,\n                self_attention=sa,\n                norm_type=norm_type,\n                extra_bn=extra_bn,\n                nf_factor=nf_factor,\n                **kwargs\n            ).eval()\n            layers.append(unet_block)\n            x = unet_block(x)\n\n        ni = x.shape[1]\n        if imsize != sfs_szs[0][-2:]:\n            layers.append(PixelShuffle_ICNR(ni, **kwargs))\n        if last_cross:\n            layers.append(MergeLayer(dense=True))\n            ni += in_channels(encoder)\n            layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))\n        layers += [\n            custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)\n        ]\n        if y_range is not None:\n            layers.append(SigmoidRange(*y_range))\n        super().__init__(*layers)\n\n    def __del__(self):\n        if hasattr(self, \"sfs\"):\n            self.sfs.remove()\n\n\n# ------------------------------------------------------\nclass UnetBlockWide(nn.Module):\n    \"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`.\"\n\n    def __init__(\n        self,\n        up_in_c: int,\n        x_in_c: int,\n        n_out: int,\n        hook: Hook,\n        final_div: bool = True,\n        blur: bool = False,\n        leaky: float = None,\n        self_attention: bool = False,\n        **kwargs\n    ):\n        super().__init__()\n        self.hook = hook\n        up_out = x_out = n_out // 2\n        self.shuf = CustomPixelShuffle_ICNR(\n            up_in_c, up_out, blur=blur, leaky=leaky, **kwargs\n        )\n        self.bn = batchnorm_2d(x_in_c)\n        ni = up_out + x_in_c\n        self.conv = custom_conv_layer(\n            ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs\n        )\n        self.relu = relu(leaky=leaky)\n\n    def forward(self, up_in: Tensor) -> Tensor:\n        s = self.hook.stored\n        up_out = self.shuf(up_in)\n        ssh = s.shape[-2:]\n        if ssh != up_out.shape[-2:]:\n            up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')\n        cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))\n        return self.conv(cat_x)\n\n\nclass DynamicUnetWide(SequentialEx):\n    \"Create a U-Net from a given architecture.\"\n\n    def __init__(\n        self,\n        encoder: nn.Module,\n        n_classes: int,\n        blur: bool = False,\n        blur_final=True,\n        self_attention: bool = False,\n        y_range: Optional[Tuple[float, float]] = None,\n        last_cross: bool = True,\n        bottle: bool = False,\n        norm_type: Optional[NormType] = NormType.Batch,\n        nf_factor: int = 1,\n        **kwargs\n    ):\n\n        nf = 512 * nf_factor\n        extra_bn = norm_type == NormType.Spectral\n        imsize = (256, 256)\n        sfs_szs = model_sizes(encoder, size=imsize)\n        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))\n        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)\n        x = dummy_eval(encoder, imsize).detach()\n\n        ni = sfs_szs[-1][1]\n        middle_conv = nn.Sequential(\n            custom_conv_layer(\n                ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs\n            ),\n            custom_conv_layer(\n                ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs\n            ),\n        ).eval()\n        x = middle_conv(x)\n        layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]\n\n        for i, idx in enumerate(sfs_idxs):\n            not_final = i != len(sfs_idxs) - 1\n            up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])\n            do_blur = blur and (not_final or blur_final)\n            sa = self_attention and (i == len(sfs_idxs) - 3)\n\n            n_out = nf if not_final else nf // 2\n\n            unet_block = UnetBlockWide(\n                up_in_c,\n                x_in_c,\n                n_out,\n                self.sfs[i],\n                final_div=not_final,\n                blur=blur,\n                self_attention=sa,\n                norm_type=norm_type,\n                extra_bn=extra_bn,\n                **kwargs\n            ).eval()\n            layers.append(unet_block)\n            x = unet_block(x)\n\n        ni = x.shape[1]\n        if imsize != sfs_szs[0][-2:]:\n            layers.append(PixelShuffle_ICNR(ni, **kwargs))\n        if last_cross:\n            layers.append(MergeLayer(dense=True))\n            ni += in_channels(encoder)\n            layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))\n        layers += [\n            custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)\n        ]\n        if y_range is not None:\n            layers.append(SigmoidRange(*y_range))\n        super().__init__(*layers)\n\n    def __del__(self):\n        if hasattr(self, \"sfs\"):\n            self.sfs.remove()\n"
  },
  {
    "path": "deoldify/visualize.py",
    "content": "from fastai.core import *\nfrom fastai.vision import *\nfrom matplotlib.axes import Axes\nfrom .filters import IFilter, MasterFilter, ColorizerFilter\nfrom .generators import gen_inference_deep, gen_inference_wide\nfrom PIL import Image\nimport ffmpeg\nimport yt_dlp as youtube_dl\nimport gc\nimport requests\nfrom io import BytesIO\nimport base64\nfrom IPython import display as ipythondisplay\nfrom IPython.display import HTML\nfrom IPython.display import Image as ipythonimage\nimport cv2\nimport logging\n\n# adapted from https://www.pyimagesearch.com/2016/04/25/watermarking-images-with-opencv-and-python/\ndef get_watermarked(pil_image: Image) -> Image:\n    try:\n        image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)\n        (h, w) = image.shape[:2]\n        image = np.dstack([image, np.ones((h, w), dtype=\"uint8\") * 255])\n        pct = 0.05\n        full_watermark = cv2.imread(\n            './resource_images/watermark.png', cv2.IMREAD_UNCHANGED\n        )\n        (fwH, fwW) = full_watermark.shape[:2]\n        wH = int(pct * h)\n        wW = int((pct * h / fwH) * fwW)\n        watermark = cv2.resize(full_watermark, (wH, wW), interpolation=cv2.INTER_AREA)\n        overlay = np.zeros((h, w, 4), dtype=\"uint8\")\n        (wH, wW) = watermark.shape[:2]\n        overlay[h - wH - 10 : h - 10, 10 : 10 + wW] = watermark\n        # blend the two images together using transparent overlays\n        output = image.copy()\n        cv2.addWeighted(overlay, 0.5, output, 1.0, 0, output)\n        rgb_image = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)\n        final_image = Image.fromarray(rgb_image)\n        return final_image\n    except:\n        # Don't want this to crash everything, so let's just not watermark the image for now.\n        return pil_image\n\n\nclass ModelImageVisualizer:\n    def __init__(self, filter: IFilter, results_dir: str = None):\n        self.filter = filter\n        self.results_dir = None if results_dir is None else Path(results_dir)\n        self.results_dir.mkdir(parents=True, exist_ok=True)\n\n    def _clean_mem(self):\n        torch.cuda.empty_cache()\n        # gc.collect()\n\n    def _open_pil_image(self, path: Path) -> Image:\n        return PIL.Image.open(path).convert('RGB')\n\n    def _get_image_from_url(self, url: str) -> Image:\n        response = requests.get(url, timeout=30, headers={'user-agent':'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/62.0.3202.94 Safari/537.36'})\n        img = PIL.Image.open(BytesIO(response.content)).convert('RGB')\n        return img\n\n    def plot_transformed_image_from_url(\n        self,\n        url: str,\n        path: str = 'test_images/image.png',\n        results_dir:Path = None,\n        figsize: Tuple[int, int] = (20, 20),\n        render_factor: int = None,\n        \n        display_render_factor: bool = False,\n        compare: bool = False,\n        post_process: bool = True,\n        watermarked: bool = True,\n    ) -> Path:\n        img = self._get_image_from_url(url)\n        img.save(path)\n        return self.plot_transformed_image(\n            path=path,\n            results_dir=results_dir,\n            figsize=figsize,\n            render_factor=render_factor,\n            display_render_factor=display_render_factor,\n            compare=compare,\n            post_process = post_process,\n            watermarked=watermarked,\n        )\n\n    def plot_transformed_image(\n        self,\n        path: str,\n        results_dir:Path = None,\n        figsize: Tuple[int, int] = (20, 20),\n        render_factor: int = None,\n        display_render_factor: bool = False,\n        compare: bool = False,\n        post_process: bool = True,\n        watermarked: bool = True,\n    ) -> Path:\n        path = Path(path)\n        if results_dir is None:\n            results_dir = Path(self.results_dir)\n        result = self.get_transformed_image(\n            path, render_factor, post_process=post_process,watermarked=watermarked\n        )\n        orig = self._open_pil_image(path)\n        if compare:\n            self._plot_comparison(\n                figsize, render_factor, display_render_factor, orig, result\n            )\n        else:\n            self._plot_solo(figsize, render_factor, display_render_factor, result)\n\n        orig.close()\n        result_path = self._save_result_image(path, result, results_dir=results_dir)\n        result.close()\n        return result_path\n\n    def _plot_comparison(\n        self,\n        figsize: Tuple[int, int],\n        render_factor: int,\n        display_render_factor: bool,\n        orig: Image,\n        result: Image,\n    ):\n        fig, axes = plt.subplots(1, 2, figsize=figsize)\n        self._plot_image(\n            orig,\n            axes=axes[0],\n            figsize=figsize,\n            render_factor=render_factor,\n            display_render_factor=False,\n        )\n        self._plot_image(\n            result,\n            axes=axes[1],\n            figsize=figsize,\n            render_factor=render_factor,\n            display_render_factor=display_render_factor,\n        )\n\n    def _plot_solo(\n        self,\n        figsize: Tuple[int, int],\n        render_factor: int,\n        display_render_factor: bool,\n        result: Image,\n    ):\n        fig, axes = plt.subplots(1, 1, figsize=figsize)\n        self._plot_image(\n            result,\n            axes=axes,\n            figsize=figsize,\n            render_factor=render_factor,\n            display_render_factor=display_render_factor,\n        )\n\n    def _save_result_image(self, source_path: Path, image: Image, results_dir = None) -> Path:\n        if results_dir is None:\n            results_dir = Path(self.results_dir)\n        result_path = results_dir / source_path.name\n        image.save(result_path)\n        return result_path\n\n    def get_transformed_image(\n        self, path: Path, render_factor: int = None, post_process: bool = True,\n        watermarked: bool = True,\n    ) -> Image:\n        self._clean_mem()\n        orig_image = self._open_pil_image(path)\n        filtered_image = self.filter.filter(\n            orig_image, orig_image, render_factor=render_factor,post_process=post_process\n        )\n\n        if watermarked:\n            return get_watermarked(filtered_image)\n\n        return filtered_image\n\n    def _plot_image(\n        self,\n        image: Image,\n        render_factor: int,\n        axes: Axes = None,\n        figsize=(20, 20),\n        display_render_factor = False,\n    ):\n        if axes is None:\n            _, axes = plt.subplots(figsize=figsize)\n        axes.imshow(np.asarray(image) / 255)\n        axes.axis('off')\n        if render_factor is not None and display_render_factor:\n            plt.text(\n                10,\n                10,\n                'render_factor: ' + str(render_factor),\n                color='white',\n                backgroundcolor='black',\n            )\n\n    def _get_num_rows_columns(self, num_images: int, max_columns: int) -> Tuple[int, int]:\n        columns = min(num_images, max_columns)\n        rows = num_images // columns\n        rows = rows if rows * columns == num_images else rows + 1\n        return rows, columns\n\n\nclass VideoColorizer:\n    def __init__(self, vis: ModelImageVisualizer):\n        self.vis = vis\n        workfolder = Path('./video')\n        self.source_folder = workfolder / \"source\"\n        self.bwframes_root = workfolder / \"bwframes\"\n        self.audio_root = workfolder / \"audio\"\n        self.colorframes_root = workfolder / \"colorframes\"\n        self.result_folder = workfolder / \"result\"\n\n    def _purge_images(self, dir):\n        for f in os.listdir(dir):\n            if re.search('.*?\\.jpg', f):\n                os.remove(os.path.join(dir, f))\n\n    def _get_ffmpeg_probe(self, path:Path):\n        try:\n            probe = ffmpeg.probe(str(path))\n            return probe\n        except ffmpeg.Error as e:\n            logging.error(\"ffmpeg error: {0}\".format(e), exc_info=True)\n            logging.error('stdout:' + e.stdout.decode('UTF-8'))\n            logging.error('stderr:' + e.stderr.decode('UTF-8'))\n            raise e\n        except Exception as e:\n            logging.error('Failed to instantiate ffmpeg.probe.  Details: {0}'.format(e), exc_info=True)   \n            raise e\n\n    def _get_fps(self, source_path: Path) -> str:\n        probe = self._get_ffmpeg_probe(source_path)\n        stream_data = next(\n            (stream for stream in probe['streams'] if stream['codec_type'] == 'video'),\n            None,\n        )\n        return stream_data['avg_frame_rate']\n\n    def _download_video_from_url(self, source_url, source_path: Path):\n        if source_path.exists():\n            source_path.unlink()\n\n        ydl_opts = {\n            'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/mp4',\n            'outtmpl': str(source_path),\n            'retries': 30,\n            'fragment-retries': 30\n        }\n        with youtube_dl.YoutubeDL(ydl_opts) as ydl:\n            ydl.download([source_url])\n\n    def _extract_raw_frames(self, source_path: Path):\n        bwframes_folder = self.bwframes_root / (source_path.stem)\n        bwframe_path_template = str(bwframes_folder / '%5d.jpg')\n        bwframes_folder.mkdir(parents=True, exist_ok=True)\n        self._purge_images(bwframes_folder)\n\n        process = (\n            ffmpeg\n                .input(str(source_path))\n                .output(str(bwframe_path_template), format='image2', vcodec='mjpeg', **{'q:v':'0'})\n                .global_args('-hide_banner')\n                .global_args('-nostats')\n                .global_args('-loglevel', 'error')\n        )\n\n        try:\n            process.run()\n        except ffmpeg.Error as e:\n            logging.error(\"ffmpeg error: {0}\".format(e), exc_info=True)\n            logging.error('stdout:' + e.stdout.decode('UTF-8'))\n            logging.error('stderr:' + e.stderr.decode('UTF-8'))\n            raise e\n        except Exception as e:\n            logging.error('Errror while extracting raw frames from source video.  Details: {0}'.format(e), exc_info=True)   \n            raise e\n\n    def _colorize_raw_frames(\n        self, source_path: Path, render_factor: int = None, post_process: bool = True,\n        watermarked: bool = True,\n    ):\n        colorframes_folder = self.colorframes_root / (source_path.stem)\n        colorframes_folder.mkdir(parents=True, exist_ok=True)\n        self._purge_images(colorframes_folder)\n        bwframes_folder = self.bwframes_root / (source_path.stem)\n\n        for img in progress_bar(os.listdir(str(bwframes_folder))):\n            img_path = bwframes_folder / img\n\n            if os.path.isfile(str(img_path)):\n                color_image = self.vis.get_transformed_image(\n                    str(img_path), render_factor=render_factor, post_process=post_process,watermarked=watermarked\n                )\n                color_image.save(str(colorframes_folder / img))\n\n    def _build_video(self, source_path: Path) -> Path:\n        colorized_path = self.result_folder / (\n            source_path.name.replace('.mp4', '_no_audio.mp4')\n        )\n        colorframes_folder = self.colorframes_root / (source_path.stem)\n        colorframes_path_template = str(colorframes_folder / '%5d.jpg')\n        colorized_path.parent.mkdir(parents=True, exist_ok=True)\n        if colorized_path.exists():\n            colorized_path.unlink()\n        fps = self._get_fps(source_path)\n\n        process = (\n            ffmpeg \n                .input(str(colorframes_path_template), format='image2', vcodec='mjpeg', framerate=fps) \n                .output(str(colorized_path), crf=17, vcodec='libx264')\n                .global_args('-hide_banner')\n                .global_args('-nostats')\n                .global_args('-loglevel', 'error')\n        )\n\n        try:\n            process.run()\n        except ffmpeg.Error as e:\n            logging.error(\"ffmpeg error: {0}\".format(e), exc_info=True)\n            logging.error('stdout:' + e.stdout.decode('UTF-8'))\n            logging.error('stderr:' + e.stderr.decode('UTF-8'))\n            raise e\n        except Exception as e:\n            logging.error('Errror while building output video.  Details: {0}'.format(e), exc_info=True)   \n            raise e\n\n        result_path = self.result_folder / source_path.name\n        if result_path.exists():\n            result_path.unlink()\n        # making copy of non-audio version in case adding back audio doesn't apply or fails.\n        shutil.copyfile(str(colorized_path), str(result_path))\n\n        # adding back sound here\n        audio_file = Path(str(source_path).replace('.mp4', '.aac'))\n        if audio_file.exists():\n            audio_file.unlink()\n\n        os.system(\n            'ffmpeg -y -i \"'\n            + str(source_path)\n            + '\" -vn -acodec copy \"'\n            + str(audio_file)\n            + '\"'\n            + ' -hide_banner'\n            + ' -nostats'\n            + ' -loglevel error'\n        )\n\n        if audio_file.exists():\n            os.system(\n                'ffmpeg -y -i \"'\n                + str(colorized_path)\n                + '\" -i \"'\n                + str(audio_file)\n                + '\" -shortest -c:v copy -c:a aac -b:a 256k \"'\n                + str(result_path)\n                + '\"'\n                + ' -hide_banner'\n                + ' -nostats'\n                + ' -loglevel error'\n            )\n        logging.info('Video created here: ' + str(result_path))\n        return result_path\n\n    def colorize_from_url(\n        self,\n        source_url,\n        file_name: str,\n        render_factor: int = None,\n        post_process: bool = True,\n        watermarked: bool = True,\n\n    ) -> Path:\n        source_path = self.source_folder / file_name\n        self._download_video_from_url(source_url, source_path)\n        return self._colorize_from_path(\n            source_path, render_factor=render_factor, post_process=post_process,watermarked=watermarked\n        )\n\n    def colorize_from_file_name(\n        self, file_name: str, render_factor: int = None,  watermarked: bool = True, post_process: bool = True,\n    ) -> Path:\n        source_path = self.source_folder / file_name\n        return self._colorize_from_path(\n            source_path, render_factor=render_factor,  post_process=post_process,watermarked=watermarked\n        )\n\n    def _colorize_from_path(\n        self, source_path: Path, render_factor: int = None,  watermarked: bool = True, post_process: bool = True\n    ) -> Path:\n        if not source_path.exists():\n            raise Exception(\n                'Video at path specfied, ' + str(source_path) + ' could not be found.'\n            )\n        self._extract_raw_frames(source_path)\n        self._colorize_raw_frames(\n            source_path, render_factor=render_factor,post_process=post_process,watermarked=watermarked\n        )\n        return self._build_video(source_path)\n\n\ndef get_video_colorizer(render_factor: int = 21) -> VideoColorizer:\n    return get_stable_video_colorizer(render_factor=render_factor)\n\n\ndef get_artistic_video_colorizer(\n    root_folder: Path = Path('./'),\n    weights_name: str = 'ColorizeArtistic_gen',\n    results_dir='result_images',\n    render_factor: int = 35\n) -> VideoColorizer:\n    learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)\n    filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)\n    vis = ModelImageVisualizer(filtr, results_dir=results_dir)\n    return VideoColorizer(vis)\n\n\ndef get_stable_video_colorizer(\n    root_folder: Path = Path('./'),\n    weights_name: str = 'ColorizeVideo_gen',\n    results_dir='result_images',\n    render_factor: int = 21\n) -> VideoColorizer:\n    learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)\n    filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)\n    vis = ModelImageVisualizer(filtr, results_dir=results_dir)\n    return VideoColorizer(vis)\n\n\ndef get_image_colorizer(\n    root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True\n) -> ModelImageVisualizer:\n    if artistic:\n        return get_artistic_image_colorizer(root_folder=root_folder, render_factor=render_factor)\n    else:\n        return get_stable_image_colorizer(root_folder=root_folder, render_factor=render_factor)\n\n\ndef get_stable_image_colorizer(\n    root_folder: Path = Path('./'),\n    weights_name: str = 'ColorizeStable_gen',\n    results_dir='result_images',\n    render_factor: int = 35\n) -> ModelImageVisualizer:\n    learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)\n    filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)\n    vis = ModelImageVisualizer(filtr, results_dir=results_dir)\n    return vis\n\n\ndef get_artistic_image_colorizer(\n    root_folder: Path = Path('./'),\n    weights_name: str = 'ColorizeArtistic_gen',\n    results_dir='result_images',\n    render_factor: int = 35\n) -> ModelImageVisualizer:\n    learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)\n    filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)\n    vis = ModelImageVisualizer(filtr, results_dir=results_dir)\n    return vis\n\n\ndef show_image_in_notebook(image_path: Path):\n    ipythondisplay.display(ipythonimage(str(image_path)))\n\n\ndef show_video_in_notebook(video_path: Path):\n    video = io.open(video_path, 'r+b').read()\n    encoded = base64.b64encode(video)\n    ipythondisplay.display(\n        HTML(\n            data='''<video alt=\"test\" autoplay \n                loop controls style=\"height: 400px;\">\n                <source src=\"data:video/mp4;base64,{0}\" type=\"video/mp4\" />\n             </video>'''.format(\n                encoded.decode('ascii')\n            )\n        )\n    )\n"
  },
  {
    "path": "environment.yml",
    "content": "name: deoldify\nchannels:\n- fastai\n- conda-forge\n- defaults\ndependencies:\n- pip\n- fastai=1.0.60\n- mkl=2024.0\n- python=3.10\n- pytorch::pytorch=1.11.0\n- pytorch::torchvision\n- pytorch::torchaudio\n- tensorboardX\n- jupyterlab\n- pillow>=9.0.0\n- ipywidgets\n- ffmpeg\n- pip:\n  - ffmpeg-python\n  - opencv-python>=4.2.0.32\n  - wandb\n  - yt-dlp\n"
  },
  {
    "path": "fastai/LICENSE",
    "content": "Apache License, Version 2.0 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/\n\nTERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n1. Definitions.\n\n\"License\" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.\n\n\"Licensor\" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.\n\n\"Legal Entity\" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, \"control\" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.\n\n\"You\" (or \"Your\") shall mean an individual or Legal Entity exercising permissions granted by this License.\n\n\"Source\" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.\n\n\"Object\" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.\n\n\"Work\" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).\n\n\"Derivative Works\" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.\n\n\"Contribution\" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, \"submitted\" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as \"Not a Contribution.\"\n\n\"Contributor\" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.\n\n2. Grant of Copyright License.\n\nSubject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.\n\n3. Grant of Patent License.\n\nSubject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.\n\n4. Redistribution.\n\nYou may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:\n\nYou must give any other recipients of the Work or Derivative Works a copy of this License; and You must cause any modified files to carry prominent notices stating that You changed the files; and You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and If the Work includes a \"NOTICE\" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.\n\n5. Submission of Contributions.\n\nUnless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.\n\n6. Trademarks.\n\nThis License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.\n\n7. Disclaimer of Warranty.\n\nUnless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.\n\n8. Limitation of Liability.\n\nIn no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.\n\n9. Accepting Warranty or Additional Liability.\n\nWhile redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.\n\n"
  },
  {
    "path": "fastai/__init__.py",
    "content": "from .version import __version__\n\n"
  },
  {
    "path": "fastai/basic_data.py",
    "content": "\"`fastai.data` loads and manages datasets with `DataBunch`\"\nfrom .torch_core import *\nfrom torch.utils.data.dataloader import default_collate\n\nDatasetType = Enum('DatasetType', 'Train Valid Test Single Fix')\n__all__ = ['DataBunch', 'DeviceDataLoader', 'DatasetType', 'load_data']\n\nold_dl_init = torch.utils.data.DataLoader.__init__\n\ndef intercept_args(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,\n                 num_workers=0, collate_fn=default_collate, pin_memory=True, drop_last=False,\n                 timeout=0, worker_init_fn=None):\n    self.init_kwargs = {'batch_size':batch_size, 'shuffle':shuffle, 'sampler':sampler, 'batch_sampler':batch_sampler,\n                        'num_workers':num_workers, 'collate_fn':collate_fn, 'pin_memory':pin_memory,\n                        'drop_last': drop_last, 'timeout':timeout, 'worker_init_fn':worker_init_fn}\n    old_dl_init(self, dataset, **self.init_kwargs)\n\ntorch.utils.data.DataLoader.__init__ = intercept_args\n\ndef DataLoader___getattr__(dl, k:str)->Any: return getattr(dl.dataset, k)\nDataLoader.__getattr__ = DataLoader___getattr__\n\ndef DataLoader___setstate__(dl, data:Any): dl.__dict__.update(data)\nDataLoader.__setstate__ = DataLoader___setstate__\n\n@dataclass\nclass DeviceDataLoader():\n    \"Bind a `DataLoader` to a `torch.device`.\"\n    dl: DataLoader\n    device: torch.device\n    tfms: List[Callable]=None\n    collate_fn: Callable=data_collate\n    def __post_init__(self):\n        self.dl.collate_fn=self.collate_fn\n        self.tfms = listify(self.tfms)\n\n    def __len__(self)->int: return len(self.dl)\n    def __getattr__(self,k:str)->Any: return getattr(self.dl, k)\n    def __setstate__(self,data:Any): self.__dict__.update(data)\n\n    @property\n    def batch_size(self):   return self.dl.batch_size\n    @batch_size.setter\n    def batch_size(self,v):\n        new_kwargs = {**self.dl.init_kwargs, 'batch_size':v, 'collate_fn':self.collate_fn}\n        self.dl = self.dl.__class__(self.dl.dataset, **new_kwargs)\n        if hasattr(self.dl.dataset, 'bs'): self.dl.dataset.bs = v\n\n    @property\n    def num_workers(self):   return self.dl.num_workers\n    @num_workers.setter\n    def num_workers(self,v): self.dl.num_workers = v\n\n    def add_tfm(self,tfm:Callable)->None:\n        \"Add `tfm` to `self.tfms`.\"\n        self.tfms.append(tfm)\n    def remove_tfm(self,tfm:Callable)->None:\n        \"Remove `tfm` from `self.tfms`.\"\n        if tfm in self.tfms: self.tfms.remove(tfm)\n\n    def new(self, **kwargs):\n        \"Create a new copy of `self` with `kwargs` replacing current values.\"\n        new_kwargs = {**self.dl.init_kwargs, **kwargs}\n        return DeviceDataLoader(self.dl.__class__(self.dl.dataset, **new_kwargs), self.device, self.tfms,\n                                self.collate_fn)\n\n    def proc_batch(self,b:Tensor)->Tensor:\n        \"Process batch `b` of `TensorImage`.\"\n        b = to_device(b, self.device)\n        for f in listify(self.tfms): b = f(b)\n        return b\n\n    def __iter__(self):\n        \"Process and returns items from `DataLoader`.\"\n        for b in self.dl: yield self.proc_batch(b)\n\n    @classmethod\n    def create(cls, dataset:Dataset, bs:int=64, shuffle:bool=False, device:torch.device=defaults.device,\n               tfms:Collection[Callable]=tfms, num_workers:int=defaults.cpus, collate_fn:Callable=data_collate, **kwargs:Any):\n        \"Create DeviceDataLoader from `dataset` with `bs` and `shuffle`: process using `num_workers`.\"\n        return cls(DataLoader(dataset, batch_size=bs, shuffle=shuffle, num_workers=num_workers, **kwargs),\n                   device=device, tfms=tfms, collate_fn=collate_fn)\n\nclass DataBunch():\n    \"Bind `train_dl`,`valid_dl` and `test_dl` in a data object.\"\n\n    def __init__(self, train_dl:DataLoader, valid_dl:DataLoader, fix_dl:DataLoader=None, test_dl:Optional[DataLoader]=None,\n                 device:torch.device=None, dl_tfms:Optional[Collection[Callable]]=None, path:PathOrStr='.',\n                 collate_fn:Callable=data_collate, no_check:bool=False):\n        self.dl_tfms = listify(dl_tfms)\n        self.device = defaults.device if device is None else device\n        assert not isinstance(train_dl,DeviceDataLoader)\n        def _create_dl(dl, **kwargs):\n            if dl is None: return None\n            return DeviceDataLoader(dl, self.device, self.dl_tfms, collate_fn, **kwargs)\n        self.train_dl,self.valid_dl,self.fix_dl,self.test_dl = map(_create_dl, [train_dl,valid_dl,fix_dl,test_dl])\n        if fix_dl is None: self.fix_dl = self.train_dl.new(shuffle=False, drop_last=False)\n        self.single_dl = _create_dl(DataLoader(valid_dl.dataset, batch_size=1, num_workers=0))\n        self.path = Path(path)\n        if not no_check: self.sanity_check()\n\n    def __repr__(self)->str:\n        return f'{self.__class__.__name__};\\n\\nTrain: {self.train_ds};\\n\\nValid: {self.valid_ds};\\n\\nTest: {self.test_ds}'\n\n    @staticmethod\n    def _init_ds(train_ds:Dataset, valid_ds:Dataset, test_ds:Optional[Dataset]=None):\n        # train_ds, but without training tfms\n        fix_ds = valid_ds.new(train_ds.x, train_ds.y) if hasattr(valid_ds,'new') else train_ds\n        return [o for o in (train_ds,valid_ds,fix_ds,test_ds) if o is not None]\n\n    @classmethod\n    def create(cls, train_ds:Dataset, valid_ds:Dataset, test_ds:Optional[Dataset]=None, path:PathOrStr='.', bs:int=64,\n               val_bs:int=None, num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None,\n               device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False, **dl_kwargs)->'DataBunch':\n        \"Create a `DataBunch` from `train_ds`, `valid_ds` and maybe `test_ds` with a batch size of `bs`. Passes `**dl_kwargs` to `DataLoader()`\"\n        datasets = cls._init_ds(train_ds, valid_ds, test_ds)\n        val_bs = ifnone(val_bs, bs)\n        dls = [DataLoader(d, b, shuffle=s, drop_last=s, num_workers=num_workers, **dl_kwargs) for d,b,s in\n               zip(datasets, (bs,val_bs,val_bs,val_bs), (True,False,False,False)) if d is not None]\n        return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)\n\n    def __getattr__(self,k:int)->Any: return getattr(self.train_dl, k)\n    def __setstate__(self,data:Any): self.__dict__.update(data)\n\n    def dl(self, ds_type:DatasetType=DatasetType.Valid)->DeviceDataLoader:\n        \"Returns appropriate `Dataset` for validation, training, or test (`ds_type`).\"\n        #TODO: refactor\n        return (self.train_dl if ds_type == DatasetType.Train else\n                self.test_dl if ds_type == DatasetType.Test else\n                self.valid_dl if ds_type == DatasetType.Valid else\n                self.single_dl if ds_type == DatasetType.Single else\n                self.fix_dl)\n\n    @property\n    def dls(self)->List[DeviceDataLoader]:\n        \"Returns a list of all DeviceDataLoaders. If you need a specific DeviceDataLoader, access via the relevant property (`train_dl`, `valid_dl`, etc) as the index of DLs in this list is not guaranteed to remain constant.\"\n        res = [self.train_dl, self.fix_dl, self.single_dl]\n        # Preserve the original ordering of Train, Valid, Fix, Single, Test Data Loaders\n        # (Unknown/not verified as of 1.0.47 whether there are other methods explicitly using DLs their list index)\n        if self.valid_dl: res.insert(1, self.valid_dl)\n        return res if not self.test_dl else res + [self.test_dl]\n\n    def add_tfm(self,tfm:Callable)->None:\n        for dl in self.dls: dl.add_tfm(tfm)\n\n    def remove_tfm(self,tfm:Callable)->None:\n        for dl in self.dls: dl.remove_tfm(tfm)\n\n    def save(self, file:PathLikeOrBinaryStream= 'data_save.pkl')->None:\n        \"Save the `DataBunch` in `self.path/file`. `file` can be file-like (file or buffer)\"\n        if not getattr(self, 'label_list', False):\n            warn(\"Serializing the `DataBunch` only works when you created it using the data block API.\")\n            return\n        try_save(self.label_list, self.path, file)\n\n    def add_test(self, items:Iterator, label:Any=None, tfms=None, tfm_y=None)->None:\n        \"Add the `items` as a test set. Pass along `label` otherwise label them with `EmptyLabel`.\"\n        self.label_list.add_test(items, label=label, tfms=tfms, tfm_y=tfm_y)\n        vdl = self.valid_dl\n        dl = DataLoader(self.label_list.test, vdl.batch_size, shuffle=False, drop_last=False, num_workers=vdl.num_workers)\n        self.test_dl = DeviceDataLoader(dl, vdl.device, vdl.tfms, vdl.collate_fn)\n\n    def one_batch(self, ds_type:DatasetType=DatasetType.Train, detach:bool=True, denorm:bool=True, cpu:bool=True)->Collection[Tensor]:\n        \"Get one batch from the data loader of `ds_type`. Optionally `detach` and `denorm`.\"\n        dl = self.dl(ds_type)\n        w = self.num_workers\n        self.num_workers = 0\n        try:     x,y = next(iter(dl))\n        finally: self.num_workers = w\n        if detach: x,y = to_detach(x,cpu=cpu),to_detach(y,cpu=cpu)\n        norm = getattr(self,'norm',False)\n        if denorm and norm:\n            x = self.denorm(x)\n            if norm.keywords.get('do_y',False): y = self.denorm(y, do_x=True)\n        return x,y\n\n    def one_item(self, item, detach:bool=False, denorm:bool=False, cpu:bool=False):\n        \"Get `item` into a batch. Optionally `detach` and `denorm`.\"\n        ds = self.single_ds\n        with ds.set_item(item):\n            return self.one_batch(ds_type=DatasetType.Single, detach=detach, denorm=denorm, cpu=cpu)\n\n    def show_batch(self, rows:int=5, ds_type:DatasetType=DatasetType.Train, reverse:bool=False, **kwargs)->None:\n        \"Show a batch of data in `ds_type` on a few `rows`.\"\n        x,y = self.one_batch(ds_type, True, True)\n        if reverse: x,y = x.flip(0),y.flip(0)\n        n_items = rows **2 if self.train_ds.x._square_show else rows\n        if self.dl(ds_type).batch_size < n_items: n_items = self.dl(ds_type).batch_size\n        xs = [self.train_ds.x.reconstruct(grab_idx(x, i)) for i in range(n_items)]\n        #TODO: get rid of has_arg if possible\n        if has_arg(self.train_ds.y.reconstruct, 'x'):\n            ys = [self.train_ds.y.reconstruct(grab_idx(y, i), x=x) for i,x in enumerate(xs)]\n        else : ys = [self.train_ds.y.reconstruct(grab_idx(y, i)) for i in range(n_items)]\n        self.train_ds.x.show_xys(xs, ys, **kwargs)\n \n    def export(self, file:PathLikeOrBinaryStream='export.pkl'):\n        \"Export the minimal state of `self` for inference in `self.path/file`. `file` can be file-like (file or buffer)\"\n        xtra = dict(normalize=self.norm.keywords) if getattr(self, 'norm', False) else {}\n        try_save(self.valid_ds.get_state(**xtra), self.path, file)\n\n    def _grab_dataset(self, dl:DataLoader):\n        ds = dl.dl.dataset\n        while hasattr(ds, 'dataset'): ds = ds.dataset\n        return ds\n\n    @property\n    def train_ds(self)->Dataset: return self._grab_dataset(self.train_dl)\n    @property\n    def valid_ds(self)->Dataset: return self._grab_dataset(self.valid_dl)\n    @property\n    def single_ds(self)->Dataset: return self._grab_dataset(self.single_dl)\n    @property\n    def loss_func(self)->OptLossFunc:\n        return getattr(self.train_ds.y, 'loss_func', F.nll_loss) if hasattr(self.train_ds, 'y') else F.nll_loss\n\n    @property\n    def test_ds(self)->Dataset:\n        return self._grab_dataset(self.test_dl) if self.test_dl is not None else None\n\n    @property\n    def empty_val(self)->bool:\n        if not hasattr(self, 'valid_dl') or self.valid_dl is None:            return True\n        if hasattr(self.valid_ds, 'items') and len(self.valid_ds.items) == 0: return True\n        return (len(self.valid_ds) == 0)\n\n    @property\n    def is_empty(self)->bool:\n        return not ((self.train_dl and len(self.train_ds.items) != 0) or \n                    (self.valid_dl and len(self.valid_ds.items) != 0) or \n                    (self.test_dl  and len(self.test_ds.items)  != 0))\n    \n    @property\n    def batch_size(self):   return self.train_dl.batch_size\n    @batch_size.setter\n    def batch_size(self,v):\n        self.train_dl.batch_size,self.valid_dl.batch_size = v,v\n        if self.test_dl is not None: self.test_dl.batch_size = v\n\n    def sanity_check(self):\n        \"Check the underlying data in the training set can be properly loaded.\"\n        final_message = \"You can deactivate this warning by passing `no_check=True`.\"\n        if not hasattr(self.train_ds, 'items') or len(self.train_ds.items) == 0 or not hasattr(self.train_dl, 'batch_sampler'): return\n        if len(self.train_dl) == 0:\n            warn(f\"\"\"Your training dataloader is empty, you have only {len(self.train_dl.dataset)} items in your training set.\n                 Your batch size is {self.train_dl.batch_size}, you should lower it.\"\"\")\n            print(final_message)\n            return\n        idx = next(iter(self.train_dl.batch_sampler))\n        samples,fails = [],[]\n        for i in idx:\n            try:    samples.append(self.train_dl.dataset[i])\n            except: fails.append(i)\n        if len(fails) > 0:\n            warn_msg = \"There seems to be something wrong with your dataset, for example, in the first batch can't access\"\n            if len(fails) == len(idx):\n                warn_msg += f\" any element of self.train_ds.\\nTried: {show_some(idx)}\"\n            else:\n                warn_msg += f\" these elements in self.train_ds: {show_some(fails)}\"\n            warn(warn_msg)\n            print(final_message)\n            return\n        try: batch = self.collate_fn(samples)\n        except:\n            message = \"It's not possible to collate samples of your dataset together in a batch.\"\n            try:\n                shapes = [[o[i].data.shape for o in samples] for i in range(2)]\n                message += f'\\nShapes of the inputs/targets:\\n{shapes}'\n            except: pass\n            warn(message)\n            print(final_message)\n\ndef load_data(path:PathOrStr, file:PathLikeOrBinaryStream='data_save.pkl', bs:int=64, val_bs:int=None, num_workers:int=defaults.cpus,\n              dl_tfms:Optional[Collection[Callable]]=None, device:torch.device=None, collate_fn:Callable=data_collate,\n              no_check:bool=False, **kwargs)->DataBunch:\n    \"Load a saved `DataBunch` from `path/file`. `file` can be file-like (file or buffer)\"\n    source = Path(path)/file if is_pathlike(file) else file\n    ll = torch.load(source, map_location='cpu') if defaults.device == torch.device('cpu') else torch.load(source)\n    return ll.databunch(path=path, bs=bs, val_bs=val_bs, num_workers=num_workers, dl_tfms=dl_tfms, device=device,\n                        collate_fn=collate_fn, no_check=no_check, **kwargs)\n"
  },
  {
    "path": "fastai/basic_train.py",
    "content": "\"Provides basic training and validation with `Learner`\"\nfrom .torch_core import *\nfrom .basic_data import *\nfrom .callback import *\nfrom .data_block import *\nfrom .utils.ipython import gpu_mem_restore\nimport inspect\nfrom fastprogress.fastprogress import format_time, IN_NOTEBOOK\nfrom time import time\nfrom fastai.sixel import plot_sixel\n\n__all__ = ['Learner', 'LearnerCallback', 'Recorder', 'RecordOnCPU', 'fit', 'loss_batch', 'train_epoch', 'validate',\n           'get_preds', 'load_learner']\n\ndefaults.lr = slice(3e-3)\ndefaults.wd = 1e-2\ndefaults.extra_callbacks    = None\ndefaults.extra_callback_fns = None\n\ndef loss_batch(model:nn.Module, xb:Tensor, yb:Tensor, loss_func:OptLossFunc=None, opt:OptOptimizer=None,\n               cb_handler:Optional[CallbackHandler]=None, count:[int]=[1], batch_multiplier:int=1)->Tuple[Union[Tensor,int,float,str]]:\n    \"Calculate loss and metrics for a batch, call out to callbacks as necessary.\"\n    cb_handler = ifnone(cb_handler, CallbackHandler())\n    if not is_listy(xb): xb = [xb]\n    if not is_listy(yb): yb = [yb]\n    out = model(*xb)\n\n    if not loss_func: return to_detach(out), yb[0].detach()\n    out = cb_handler.on_loss_begin(out)\n    loss = loss_func(out, *yb)/batch_multiplier\n    count[0]-=1\n\n    if opt is not None:\n        loss,skip_bwd = cb_handler.on_backward_begin(loss)\n        if not skip_bwd:                     loss.backward()\n        if count[0] == 0:\n            if not cb_handler.on_backward_end(): opt.step()\n            if not cb_handler.on_step_end():     opt.zero_grad()\n            count[0] = batch_multiplier\n\n    return loss.detach().cpu()\n\ndef get_preds(model:nn.Module, dl:DataLoader, pbar:Optional[PBar]=None, cb_handler:Optional[CallbackHandler]=None,\n              activ:nn.Module=None, loss_func:OptLossFunc=None, n_batch:Optional[int]=None) -> List[Tensor]:\n    \"Tuple of predictions and targets, and optional losses (if `loss_func`) using `dl`, max batches `n_batch`.\"\n    res = [torch.cat(o).cpu() for o in\n           zip(*validate(model, dl, cb_handler=cb_handler, pbar=pbar, average=False, n_batch=n_batch))]\n    if loss_func is not None:\n        with NoneReduceOnCPU(loss_func) as lf: res.append(lf(res[0], res[1]))\n    if activ is not None: res[0] = activ(res[0])\n    return res\n\ndef validate(model:nn.Module, dl:DataLoader, loss_func:OptLossFunc=None, cb_handler:Optional[CallbackHandler]=None,\n             pbar:Optional[PBar]=None, average=True, n_batch:Optional[int]=None)->Iterator[Tuple[Union[Tensor,int],...]]:\n    \"Calculate `loss_func` of `model` on `dl` in evaluation mode.\"\n    model.eval()\n    with torch.no_grad():\n        val_losses,nums = [],[]\n        if cb_handler: cb_handler.set_dl(dl)\n        for xb,yb in progress_bar(dl, parent=pbar, leave=(pbar is not None)):\n            if cb_handler: xb, yb = cb_handler.on_batch_begin(xb, yb, train=False)\n            val_loss = loss_batch(model, xb, yb, loss_func, cb_handler=cb_handler)\n            val_losses.append(val_loss)\n            if not is_listy(yb): yb = [yb]\n            nums.append(first_el(yb).shape[0])\n            if cb_handler and cb_handler.on_batch_end(val_losses[-1]): break\n            if n_batch and (len(nums)>=n_batch): break\n        nums = np.array(nums, dtype=np.float32)\n        if average: return (to_np(torch.stack(val_losses)) * nums).sum() / nums.sum()\n        else:       return val_losses\n\ndef train_epoch(model:nn.Module, dl:DataLoader, opt:optim.Optimizer, loss_func:LossFunction)->None:\n    \"Simple training of `model` for 1 epoch of `dl` using optim `opt` and loss function `loss_func`.\"\n    model.train()\n    for xb,yb in dl:\n        loss = loss_func(model(xb), yb)\n        loss.backward()\n        opt.step()\n        opt.zero_grad()\n\n@dataclass\nclass BasicLearner():\n    model:nn.Module\n    loss_func:LossFunction\n    opt:optim.Optimizer\n    data:DataBunch\n\ndef fit(epochs:int, learn:BasicLearner, callbacks:Optional[CallbackList]=None, metrics:OptMetrics=None, batch_multiplier:int=1)->None:\n    \"Fit the `model` on `data` and learn using `loss_func` and `opt`.\"\n    assert len(learn.data.train_dl) != 0, f\"\"\"Your training dataloader is empty, can't train a model.\n        Use a smaller batch size (batch size={learn.data.train_dl.batch_size} for {len(learn.data.train_dl.dataset)} elements).\"\"\"\n    cb_handler = CallbackHandler(callbacks, metrics)\n    pbar = master_bar(range(epochs))\n    cb_handler.on_train_begin(epochs, pbar=pbar, metrics=metrics)\n\n    exception=False\n    try:\n        for epoch in pbar:\n            learn.model.train()\n            cb_handler.set_dl(learn.data.train_dl)\n            cb_handler.on_epoch_begin()\n            count = [batch_multiplier]\n            for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):\n                xb, yb = cb_handler.on_batch_begin(xb, yb)\n                loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler, count=count, batch_multiplier=batch_multiplier)\n                if cb_handler.on_batch_end(loss): break\n\n            if not cb_handler.skip_validate and not learn.data.empty_val:\n                val_loss = validate(learn.model, learn.data.valid_dl, loss_func=learn.loss_func,\n                                       cb_handler=cb_handler, pbar=pbar)\n            else: val_loss=None\n            if cb_handler.on_epoch_end(val_loss): break\n    except Exception as e:\n        exception = e\n        raise\n    finally: cb_handler.on_train_end(exception)\n\nloss_func_name2activ = {'cross_entropy_loss': F.softmax, 'nll_loss': torch.exp, 'poisson_nll_loss': torch.exp,\n    'kl_div_loss': torch.exp, 'bce_with_logits_loss': torch.sigmoid, 'cross_entropy': F.softmax,\n    'kl_div': torch.exp, 'binary_cross_entropy_with_logits': torch.sigmoid,\n}\n\ndef _loss_func_name2activ(name:str, axis:int=-1):\n    res = loss_func_name2activ[name]\n    if res == F.softmax: res = partial(F.softmax, dim=axis)\n    return res\n\ndef _loss_func2activ(loss_func):\n    if getattr(loss_func,'keywords',None):\n        if not loss_func.keywords.get('log_input', True): return\n    axis = getattr(loss_func, 'axis', -1)\n    # flattened loss\n    loss_func = getattr(loss_func, 'func', loss_func)\n    # could have a partial inside flattened loss! Duplicate on purpose.\n    loss_func = getattr(loss_func, 'func', loss_func)\n    cls_name = camel2snake(loss_func.__class__.__name__)\n    if cls_name == 'mix_up_loss':\n        loss_func = loss_func.crit\n        cls_name = camel2snake(loss_func.__class__.__name__)\n    if cls_name in loss_func_name2activ:\n        if cls_name == 'poisson_nll_loss' and (not getattr(loss_func, 'log_input', True)): return\n        return _loss_func_name2activ(cls_name, axis)\n    if getattr(loss_func,'__name__','') in loss_func_name2activ:\n        return _loss_func_name2activ(loss_func.__name__, axis)\n    return noop\n\n@dataclass\nclass Learner():\n    \"Trainer for `model` using `data` to minimize `loss_func` with optimizer `opt_func`.\"\n    data:DataBunch\n    model:nn.Module\n    opt_func:Callable=AdamW\n    loss_func:Callable=None\n    metrics:Collection[Callable]=None\n    true_wd:bool=True\n    bn_wd:bool=True\n    wd:Floats=defaults.wd\n    train_bn:bool=True\n    path:str = None\n    model_dir:PathOrStr = 'models'\n    callback_fns:Collection[Callable]=None\n    callbacks:Collection[Callback]=field(default_factory=list)\n    layer_groups:Collection[nn.Module]=None\n    add_time:bool=True\n    silent:bool=None\n    def __post_init__(self)->None:\n        \"Setup path,metrics, callbacks and ensure model directory exists.\"\n        self.path = Path(ifnone(self.path, self.data.path))\n        self.model = self.model.to(self.data.device)\n        self.loss_func = self.loss_func or self.data.loss_func\n        self.metrics=listify(self.metrics)\n        if not self.layer_groups: self.layer_groups = [nn.Sequential(*flatten_model(self.model))]\n        self.callbacks = listify(self.callbacks)\n        if self.silent is None: self.silent = defaults.silent\n        self.callback_fns = [partial(Recorder, add_time=self.add_time, silent=self.silent)] + listify(self.callback_fns)\n\n    def init(self, init): apply_init(self.model, init)\n\n    def _test_writeable_path(self):\n        path = self.path/self.model_dir\n        try:\n            path.mkdir(parents=True, exist_ok=True)\n            tmp_file = get_tmp_file(path)\n        except OSError as e:\n            raise Exception(f\"{e}\\nCan't write to '{path}', set `learn.model_dir` attribute in Learner to a full libpath path that is writable\") from None\n        os.remove(tmp_file)\n\n    def lr_range(self, lr:Union[float,slice])->np.ndarray:\n        \"Build differential learning rates from `lr`.\"\n        if not isinstance(lr,slice): return lr\n        if lr.start: res = even_mults(lr.start, lr.stop, len(self.layer_groups))\n        else: res = [lr.stop/10]*(len(self.layer_groups)-1) + [lr.stop]\n        return np.array(res)\n\n    def fit(self, epochs:int, lr:Union[Floats,slice]=defaults.lr,\n            wd:Floats=None, callbacks:Collection[Callback]=None, batch_multiplier:int=1)->None:\n        \"Fit the model on this learner with `lr` learning rate, `wd` weight decay for `epochs` with `callbacks`.\"\n        lr = self.lr_range(lr)\n        if wd is None: wd = self.wd\n        if not getattr(self, 'opt', False): self.create_opt(lr, wd)\n        else: self.opt.lr,self.opt.wd = lr,wd\n        callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)\n        if defaults.extra_callbacks is not None: callbacks += defaults.extra_callbacks\n        fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks, batch_multiplier=batch_multiplier)\n\n    def create_opt(self, lr:Floats, wd:Floats=0.)->None:\n        \"Create optimizer with `lr` learning rate and `wd` weight decay.\"\n        self.opt = OptimWrapper.create(self.opt_func, lr, self.layer_groups, wd=wd, true_wd=self.true_wd, bn_wd=self.bn_wd)\n\n    def split(self, split_on:SplitFuncOrIdxList)->None:\n        \"Split the model at `split_on`.\"\n        if isinstance(split_on,Callable): split_on = split_on(self.model)\n        self.layer_groups = split_model(self.model, split_on)\n        return self\n\n    def freeze_to(self, n:int)->None:\n        \"Freeze layers up to layer group `n`.\"\n        for g in self.layer_groups[:n]:\n            for l in g:\n                if not self.train_bn or not isinstance(l, bn_types): requires_grad(l, False)\n        for g in self.layer_groups[n:]: requires_grad(g, True)\n        self.create_opt(defaults.lr)\n\n    def freeze(self)->None:\n        \"Freeze up to last layer group.\"\n        assert(len(self.layer_groups)>1)\n        self.freeze_to(-1)\n\n    def unfreeze(self):\n        \"Unfreeze entire model.\"\n        self.freeze_to(0)\n\n    def export(self, file:PathLikeOrBinaryStream='export.pkl', destroy=False):\n        \"Export the state of the `Learner` in `self.path/file`. `file` can be file-like (file or buffer)\"\n        if rank_distrib(): return # don't save if slave proc\n        args = ['opt_func', 'loss_func', 'metrics', 'true_wd', 'bn_wd', 'wd', 'train_bn', 'model_dir', 'callback_fns']\n        state = {a:getattr(self,a) for a in args}\n        state['cb_state'] = {cb.__class__:cb.get_state() for cb in self.callbacks}\n        #layer_groups -> need to find a way\n        #TO SEE: do we save model structure and weights separately?\n        with ModelOnCPU(self.model) as m:\n            state['model'] = m\n            xtra = dict(normalize=self.data.norm.keywords) if getattr(self.data, 'norm', False) else {}\n            state['data'] = self.data.valid_ds.get_state(**xtra)\n            state['cls'] = self.__class__\n            try_save(state, self.path, file)\n        if destroy: self.destroy()\n\n    def save(self, file:PathLikeOrBinaryStream=None, return_path:bool=False, with_opt:bool=True):\n        \"Save model and optimizer state (if `with_opt`) with `file` to `self.model_dir`. `file` can be file-like (file or buffer)\"\n        if is_pathlike(file): self._test_writeable_path()\n        if rank_distrib(): return # don't save if slave proc\n        target = self.path/self.model_dir/f'{file}.pth' if is_pathlike(file) else file\n        if not hasattr(self, 'opt'): with_opt=False\n        if not with_opt: state = get_model(self.model).state_dict()\n        else: state = {'model': get_model(self.model).state_dict(), 'opt':self.opt.state_dict()}\n        torch.save(state, target)\n        if return_path: return target\n\n    def dl(self, ds_type:DatasetType=DatasetType.Valid):\n        \"Return DataLoader for DatasetType `ds_type`.\"\n        return self.data.dl(ds_type)\n\n    def load(self, file:PathLikeOrBinaryStream=None, device:torch.device=None, strict:bool=True,\n             with_opt:bool=None, purge:bool=True, remove_module:bool=False):\n        \"Load model and optimizer state (if `with_opt`) `file` from `self.model_dir` using `device`. `file` can be file-like (file or buffer)\"\n        if purge: self.purge(clear_opt=ifnone(with_opt, False))\n        if device is None: device = self.data.device\n        elif isinstance(device, int): device = torch.device('cuda', device)\n        source = self.path/self.model_dir/f'{file}.pth' if is_pathlike(file) else file\n        state = torch.load(source, map_location=device)\n        if set(state.keys()) == {'model', 'opt'}:\n            model_state = state['model']\n            if remove_module: model_state = remove_module_load(model_state)\n            get_model(self.model).load_state_dict(model_state, strict=strict)\n            if ifnone(with_opt,True):\n                if not hasattr(self, 'opt'): self.create_opt(defaults.lr, self.wd)\n                try:    self.opt.load_state_dict(state['opt'])\n                except: pass\n        else:\n            if with_opt: warn(\"Saved filed doesn't contain an optimizer state.\")\n            if remove_module: state = remove_module_load(state)\n            get_model(self.model).load_state_dict(state, strict=strict)\n        del state\n        gc.collect()\n        return self\n\n    def destroy(self):\n        \"Free the Learner internals, leaving just an empty shell that consumes no memory\"\n\n        class ZombieLearner(Learner):\n            msg = \"this object has been destroyed\"\n            def __getattr__(self, item):    print(ZombieLearner.msg); return None\n            def destroyed(*args, **kwargs): print(ZombieLearner.msg)\n\n        attrs = [k for k in self.__dict__.keys() if not k.startswith(\"__\")]\n        for a in attrs: delattr(self, a)\n        # the instance methods can still be called, but will just give a message\n        methods = [k for k in dir(self) if not k.startswith(\"__\") and inspect.isroutine(getattr(self, k))]\n        for m in methods: setattr(self, m, ZombieLearner.destroyed)\n        self.__class__ = ZombieLearner\n        gc.collect()\n        print(\"this Learner object self-destroyed - it still exists, but no longer usable\")\n\n    def purge(self, clear_opt:bool=True):\n        \"Purge the `Learner` of all cached attributes to release some GPU memory.\"\n        self._test_writeable_path()\n        attrs_all = [k for k in self.__dict__.keys() if not k.startswith(\"__\")]\n        attrs_pkl = ['bn_wd', 'callback_fns', 'layer_groups', 'loss_func', 'metrics', 'model',\n                     'model_dir', 'opt_func', 'path', 'train_bn', 'true_wd', 'wd']\n        # +callbacks: get pickled too, but not directly\n        attrs_keep = ['data', 'recorder']\n        attrs_del = list(set(attrs_all) - set(attrs_keep))\n        state = {a:getattr(self, a) for a in attrs_pkl}\n        state['cb_state'] = {cb.__class__:cb.get_state() for cb in self.callbacks}\n        if hasattr(self, 'opt'): state['opt'] = self.opt.get_state()\n\n        tmp_file = get_tmp_file(self.path/self.model_dir)\n        torch.save(state, open(tmp_file, 'wb'))\n        for a in attrs_del: delattr(self, a)\n        gc.collect()\n        state = torch.load(tmp_file)\n        os.remove(tmp_file)\n\n        for a in attrs_pkl: setattr(self, a, state[a])\n        cb_state = state.pop('cb_state')\n        self.callbacks = [load_callback(c,s, self) for c,s in cb_state.items()]\n        if not clear_opt and 'opt' in state:\n            try: self.opt = OptimWrapper.load_with_state_and_layer_group(state['opt'], self.layer_groups)\n            except: warn(\"Wasn't able to properly load the optimizer state again.\")\n        del state\n        gc.collect()\n        return self\n\n    def get_preds(self, ds_type:DatasetType=DatasetType.Valid, with_loss:bool=False, n_batch:Optional[int]=None,\n                  pbar:Optional[PBar]=None) -> List[Tensor]:\n        \"Return predictions and targets on `ds_type` dataset.\"\n        lf = self.loss_func if with_loss else None\n        return get_preds(self.model, self.dl(ds_type), cb_handler=CallbackHandler(self.callbacks),\n                         activ=_loss_func2activ(self.loss_func), loss_func=lf, n_batch=n_batch, pbar=pbar)\n\n    def pred_batch(self, ds_type:DatasetType=DatasetType.Valid, batch:Tuple=None, reconstruct:bool=False, with_dropout:bool=False) -> List[Tensor]:\n        with torch.no_grad():\n            training = self.model.training\n            self.model.train(False)\n            \"Return output of the model on one batch from `ds_type` dataset.\"\n            if batch is not None: xb,yb = batch\n            else: xb,yb = self.data.one_batch(ds_type, detach=False, denorm=False)\n            cb_handler = CallbackHandler(self.callbacks)\n            xb,yb = cb_handler.on_batch_begin(xb,yb, train=False)\n            if not with_dropout: \n                preds = loss_batch(self.model.eval(), xb, yb, cb_handler=cb_handler)\n            else: \n                preds = loss_batch(self.model.eval().apply(self.apply_dropout), xb, yb, cb_handler=cb_handler)\n            res = _loss_func2activ(self.loss_func)(preds[0])\n            self.model.train(training)\n            if not reconstruct: return res\n            res = res.detach().cpu()\n            ds = self.dl(ds_type).dataset\n            norm = getattr(self.data, 'norm', False)\n            if norm and norm.keywords.get('do_y',False):\n                res = self.data.denorm(res, do_x=True)\n            return [ds.reconstruct(o) for o in res]\n\n    def backward(self, item):\n        \"Pass `item` through the model and computes the gradient. Useful if `backward_hooks` are attached.\"\n        xb,yb = self.data.one_item(item)\n        loss = loss_batch(self.model.eval(), xb, yb, self.loss_func, opt=FakeOptimizer(),\n                          cb_handler=CallbackHandler(self.callbacks))\n        return loss\n\n    def predict(self, item:ItemBase, return_x:bool=False, batch_first:bool=True, with_dropout:bool=False, **kwargs):\n        \"Return predicted class, label and probabilities for `item`.\"\n        batch = self.data.one_item(item)\n        res = self.pred_batch(batch=batch, with_dropout=with_dropout)\n        raw_pred,x = grab_idx(res,0,batch_first=batch_first),batch[0]\n        norm = getattr(self.data,'norm',False)\n        if norm:\n            x = self.data.denorm(x)\n            if norm.keywords.get('do_y',False): raw_pred = self.data.denorm(raw_pred)\n        ds = self.data.single_ds\n        pred = ds.y.analyze_pred(raw_pred, **kwargs)\n        x = ds.x.reconstruct(grab_idx(x, 0))\n        y = ds.y.reconstruct(pred, x) if has_arg(ds.y.reconstruct, 'x') else ds.y.reconstruct(pred)\n        return (x, y, pred, raw_pred) if return_x else (y, pred, raw_pred)\n\n    def validate(self, dl=None, callbacks=None, metrics=None):\n        \"Validate on `dl` with potential `callbacks` and `metrics`.\"\n        dl = ifnone(dl, self.data.valid_dl)\n        metrics = ifnone(metrics, self.metrics)\n        cb_handler = CallbackHandler(self.callbacks + ifnone(callbacks, []), metrics)\n        cb_handler.on_epoch_begin()\n        val_metrics = validate(self.model, dl, self.loss_func, cb_handler)\n        cb_handler.on_epoch_end(val_metrics)\n        return cb_handler.state_dict['last_metrics']\n\n    def show_results(self, ds_type=DatasetType.Valid, rows:int=5, **kwargs):\n        \"Show `rows` result of predictions on `ds_type` dataset.\"\n        #TODO: get read of has_arg x and split_kwargs_by_func if possible\n        #TODO: simplify this and refactor with pred_batch(...reconstruct=True)\n        n_items = rows ** 2 if self.data.train_ds.x._square_show_res else rows\n        if self.dl(ds_type).batch_size < n_items: n_items = self.dl(ds_type).batch_size\n        ds = self.dl(ds_type).dataset\n        self.callbacks.append(RecordOnCPU())\n        preds = self.pred_batch(ds_type)\n        *self.callbacks,rec_cpu = self.callbacks\n        x,y = rec_cpu.input,rec_cpu.target\n        norm = getattr(self.data,'norm',False)\n        if norm:\n            x = self.data.denorm(x)\n            if norm.keywords.get('do_y',False):\n                y     = self.data.denorm(y, do_x=True)\n                preds = self.data.denorm(preds, do_x=True)\n        analyze_kwargs,kwargs = split_kwargs_by_func(kwargs, ds.y.analyze_pred)\n        preds = [ds.y.analyze_pred(grab_idx(preds, i), **analyze_kwargs) for i in range(n_items)]\n        xs = [ds.x.reconstruct(grab_idx(x, i)) for i in range(n_items)]\n        if has_arg(ds.y.reconstruct, 'x'):\n            ys = [ds.y.reconstruct(grab_idx(y, i), x=x) for i,x in enumerate(xs)]\n            zs = [ds.y.reconstruct(z, x=x) for z,x in zip(preds,xs)]\n        else :\n            ys = [ds.y.reconstruct(grab_idx(y, i)) for i in range(n_items)]\n            zs = [ds.y.reconstruct(z) for z in preds]\n        ds.x.show_xyzs(xs, ys, zs, **kwargs)\n\n    def apply_dropout(self, m):\n        \"If a module contains 'dropout' in it's name, it will be switched to .train() mode.\"\n        if 'dropout' in m.__class__.__name__.lower(): m.train()\n\n    def predict_with_mc_dropout(self, item:ItemBase, with_dropout:bool=True, n_times=10, **kwargs):\n        \"Make predictions with dropout turned on for n_times (default 10).\"\n        return [self.predict(item, with_dropout=with_dropout) for _ in range(n_times)]\n\nclass RecordOnCPU(Callback):\n    \"Store the `input` and `target` going through the model on the CPU.\"\n    def on_batch_begin(self, last_input,last_target,**kwargs):\n        self.input,self.target = to_cpu(last_input),to_cpu(last_target)\n\nclass LearnerCallback(Callback):\n    \"Base class for creating callbacks for a `Learner`.\"\n    def __init__(self, learn):\n        self._learn = weakref.ref(learn)\n        self.exclude,self.not_min = ['_learn'],[]\n        setattr(self.learn, self.cb_name, self)\n\n    def __getattr__(self,k): return getattr(self.learn, k)\n    def __setstate__(self,data:Any): self.__dict__.update(data)\n\n    @property\n    def learn(self) -> Learner: return self._learn()\n    @learn.setter\n    def learn(self, learn: Learner) -> None: self._learn = weakref.ref(learn)\n\n    @property\n    def cb_name(self): return camel2snake(self.__class__.__name__)\n\nclass Recorder(LearnerCallback):\n    \"A `LearnerCallback` that records epoch, loss, opt and metric data during training.\"\n    _order=-10\n    def __init__(self, learn:Learner, add_time:bool=True, silent:bool=False):\n        super().__init__(learn)\n        self.opt = self.learn.opt\n        self.train_dl = self.learn.data.train_dl\n        self.no_val,self.silent,self.add_time = False,silent,add_time\n\n    def on_train_begin(self, pbar:PBar, metrics_names:Collection[str], **kwargs:Any)->None:\n        \"Initialize recording status at beginning of training.\"\n        self.pbar = pbar\n        self.names = ['epoch', 'train_loss'] if self.no_val else ['epoch', 'train_loss', 'valid_loss']\n        self.metrics_names = metrics_names\n        if hasattr(self, '_added_met_names'): self.metrics_names += self._added_met_names\n        self.names += self.metrics_names\n        if self.add_time: self.names.append('time')\n        if not self.silent: self.pbar.write(self.names, table=True)\n        self.losses,self.val_losses,self.lrs,self.moms,self.metrics,self.nb_batches = [],[],[],[],[],[]\n\n    def on_epoch_begin(self, **kwargs:Any)->None:\n        if self.add_time: self.start_epoch = time()\n\n    def on_batch_begin(self, train, **kwargs:Any)->None:\n        \"Record learning rate and momentum at beginning of batch.\"\n        if train:\n            self.lrs.append(self.opt.lr)\n            self.moms.append(self.opt.mom)\n\n    def on_backward_begin(self, smooth_loss:Tensor, **kwargs:Any)->None:\n        \"Record the loss before any other callback has a chance to modify it.\"\n        self.losses.append(smooth_loss)\n        if self.pbar is not None and hasattr(self.pbar,'child'):\n            self.pbar.child.comment = f'{smooth_loss:.4f}'\n\n    def on_epoch_end(self, epoch:int, num_batch:int, smooth_loss:Tensor,\n                     last_metrics=MetricsList, **kwargs:Any)->bool:\n        \"Save epoch info: num_batch, smooth_loss, metrics.\"\n        self.nb_batches.append(num_batch)\n        if last_metrics is not None: self.val_losses.append(last_metrics[0])\n        else: last_metrics = [] if self.no_val else [None]\n        if len(last_metrics) > 1: self.metrics.append(last_metrics[1:])\n        self.format_stats([epoch, smooth_loss] + last_metrics)\n\n    def format_stats(self, stats:TensorOrNumList)->None:\n        \"Format stats before printing.\"\n        str_stats = []\n        for name,stat in zip(self.names,stats):\n            str_stats.append('#na#' if stat is None else str(stat) if isinstance(stat, int) else f'{stat:.6f}')\n        if self.add_time: str_stats.append(format_time(time() - self.start_epoch))\n        if not self.silent: self.pbar.write(str_stats, table=True)\n\n    def add_metric_names(self, names):\n        \"Add `names` to the inner metric names.\"\n        if hasattr(self, '_added_met_names'): self._added_met_names += names\n        else:                                 self._added_met_names  = names\n\n    def plot_lr(self, show_moms=False, skip_start:int=0, skip_end:int=0, return_fig:bool=None)->Optional[plt.Figure]:\n        \"Plot learning rate, `show_moms` to include momentum.\"\n        lrs = self._split_list(self.lrs, skip_start, skip_end)\n        iterations = self._split_list(range_of(self.lrs), skip_start, skip_end)\n        if show_moms:\n            moms = self._split_list(self.moms, skip_start, skip_end)\n            fig, axs = plt.subplots(1,2, figsize=(12,4))\n            axs[0].plot(iterations, lrs)\n            axs[0].set_xlabel('Iterations')\n            axs[0].set_ylabel('Learning Rate')\n            axs[1].plot(iterations, moms)\n            axs[1].set_xlabel('Iterations')\n            axs[1].set_ylabel('Momentum')\n        else:\n            fig, ax = plt.subplots()\n            ax.plot(iterations, lrs)\n            ax.set_xlabel('Iterations')\n            ax.set_ylabel('Learning Rate')\n        if ifnone(return_fig, defaults.return_fig): return fig\n        if not IN_NOTEBOOK: plot_sixel(fig)\n\n    @staticmethod\n    def smoothen_by_spline(xs, ys, **kwargs):\n        xs = np.arange(len(ys))\n        spl = scipy.interpolate.UnivariateSpline(xs, ys, **kwargs)\n        ys = spl(xs)\n        return ys\n\n    def plot(self, skip_start:int=10, skip_end:int=5, suggestion:bool=False, return_fig:bool=None,\n             **kwargs)->Optional[plt.Figure]:\n        \"Plot learning rate and losses, trimmed between `skip_start` and `skip_end`. Optionally plot and return min gradient\"\n        lrs = self._split_list(self.lrs, skip_start, skip_end)\n        losses = self._split_list(self.losses, skip_start, skip_end)\n        losses = [x.item() for x in losses]\n        if 'k' in kwargs: losses = self.smoothen_by_spline(lrs, losses, **kwargs)\n        fig, ax = plt.subplots(1,1)\n        ax.plot(lrs, losses)\n        ax.set_ylabel(\"Loss\")\n        ax.set_xlabel(\"Learning Rate\")\n        ax.set_xscale('log')\n        ax.xaxis.set_major_formatter(plt.FormatStrFormatter('%.0e'))\n        if suggestion:\n            try: mg = (np.gradient(np.array(losses))).argmin()\n            except:\n                print(\"Failed to compute the gradients, there might not be enough points.\")\n                return\n            print(f\"Min numerical gradient: {lrs[mg]:.2E}\")\n            ax.plot(lrs[mg],losses[mg],markersize=10,marker='o',color='red')\n            self.min_grad_lr = lrs[mg]\n            ml = np.argmin(losses)\n            print(f\"Min loss divided by 10: {lrs[ml]/10:.2E}\")\n        if ifnone(return_fig, defaults.return_fig): return fig\n        if not IN_NOTEBOOK: plot_sixel(fig)\n\n    def plot_losses(self, skip_start:int=0, skip_end:int=0, return_fig:bool=None)->Optional[plt.Figure]:\n        \"Plot training and validation losses.\"\n        fig, ax = plt.subplots(1,1)\n        losses = self._split_list(self.losses, skip_start, skip_end)\n        iterations = self._split_list(range_of(self.losses), skip_start, skip_end)\n        ax.plot(iterations, losses, label='Train')\n        val_iter = self._split_list_val(np.cumsum(self.nb_batches), skip_start, skip_end)\n        val_losses = self._split_list_val(self.val_losses, skip_start, skip_end)\n        ax.plot(val_iter, val_losses, label='Validation')\n        ax.set_ylabel('Loss')\n        ax.set_xlabel('Batches processed')\n        ax.legend()\n        if ifnone(return_fig, defaults.return_fig): return fig\n        if not IN_NOTEBOOK: plot_sixel(fig)\n\n    def plot_metrics(self, skip_start:int=0, skip_end:int=0, return_fig:bool=None)->Optional[plt.Figure]:\n        \"Plot metrics collected during training.\"\n        assert len(self.metrics) != 0, \"There are no metrics to plot.\"\n        fig, axes = plt.subplots(len(self.metrics[0]),1,figsize=(6, 4*len(self.metrics[0])))\n        val_iter = self._split_list_val(np.cumsum(self.nb_batches), skip_start, skip_end)\n        axes = axes.flatten() if len(self.metrics[0]) != 1 else [axes]\n        for i, ax in enumerate(axes):\n            values = [met[i] for met in self.metrics]\n            values = self._split_list_val(values, skip_start, skip_end)\n            ax.plot(val_iter, values)\n            ax.set_ylabel(str(self.metrics_names[i]))\n            ax.set_xlabel('Batches processed')             \n        if ifnone(return_fig, defaults.return_fig): return fig\n        if not IN_NOTEBOOK: plot_sixel(fig)\n\n    def _split_list(self, vals:Collection[float], skip_start:int, skip_end:int):\n        return vals[skip_start:-skip_end] if skip_end > 0 else vals[skip_start:]\n\n    def _split_list_val(self, vals:Collection[float], skip_start:int, skip_end:int):\n        val_iter = np.cumsum(self.nb_batches)\n        start_val = (val_iter - skip_start >= 0).nonzero()[0].min()\n        end_val = (val_iter[-1] - val_iter - skip_end >= 0).nonzero()[0].max()+1\n        return vals[start_val:end_val] if skip_end > 0 else vals[start_val:]\n\nclass FakeOptimizer():\n    def step(self): pass\n    def zero_grad(self): pass\n\ndef load_callback(class_func, state, learn:Learner):\n    init_kwargs, others = split_kwargs_by_func(state, class_func.__init__)\n    res = class_func(learn, **init_kwargs) if issubclass(class_func, LearnerCallback) else class_func(**init_kwargs)\n    for k,v in others.items(): setattr(res, k, v)\n    return res\n\ndef load_learner(path:PathOrStr, file:PathLikeOrBinaryStream='export.pkl', test:ItemList=None, **db_kwargs):\n    \"Load a `Learner` object saved with `export_state` in `path/file` with empty data, optionally add `test` and load on `cpu`. `file` can be file-like (file or buffer)\"\n    source = Path(path)/file if is_pathlike(file) else file\n    state = torch.load(source, map_location='cpu') if defaults.device == torch.device('cpu') else torch.load(source)\n    model = state.pop('model')\n    src = LabelLists.load_state(path, state.pop('data'))\n    if test is not None: src.add_test(test)\n    data = src.databunch(**db_kwargs)\n    cb_state = state.pop('cb_state')\n    clas_func = state.pop('cls')\n    res = clas_func(data, model, **state)\n    res.callback_fns = state['callback_fns'] #to avoid duplicates\n    res.callbacks = [load_callback(c,s, res) for c,s in cb_state.items()]\n    return res\n"
  },
  {
    "path": "fastai/basics.py",
    "content": "from .basic_train import *\nfrom .callback import *\nfrom .core import *\nfrom .basic_data import *\nfrom .data_block import *\nfrom .layers import *\nfrom .metrics import *\nfrom .torch_core import *\nfrom .train import *\nfrom .datasets import *\nfrom .version import *\nfrom . import callbacks\n\n\"\"\"\nfrom . import core,torch_core,basic_data,basic_train,callback,data_block,layers,metrics,train,datasets,callbacks\n\n__all__  = [o for o in dir(core) if not o.startswith('_')]\n__all__ += [o for o in dir(torch_core) if not o.startswith('_')]\n__all__ += [*basic_train.__all__, *callback.__all__, 'core', 'torch_core', 'callbacks',\n           *basic_data.__all__, *data_block.__all__, *layers.__all__, *metrics.__all__,\n           *train.__all__, *datasets.__all__, '__version__']\n\"\"\"\n\ntry: from .gen_doc.nbdoc import doc\nexcept: pass  # Optional if jupyter is present\n    #__all__.append('doc')\n\n__all__ = [o for o in dir(sys.modules[__name__]) if not o.startswith('_')] + ['__version__']\n\n"
  },
  {
    "path": "fastai/callback.py",
    "content": "\"Callbacks provides extensibility to the `basic_train` loop. See `train` for examples of custom callbacks.\"\nfrom .basic_data import *\nfrom .torch_core import *\nimport torch.distributed as dist\n\n__all__ = ['AverageMetric', 'Callback', 'CallbackHandler', 'OptimWrapper', 'SmoothenValue', 'Scheduler', 'annealing_cos', 'CallbackList',\n           'annealing_exp', 'annealing_linear', 'annealing_no', 'annealing_poly']\n\nclass OptimWrapper():\n    \"Basic wrapper around `opt` to simplify hyper-parameters changes.\"\n    def __init__(self, opt:optim.Optimizer, wd:Floats=0., true_wd:bool=False, bn_wd:bool=True):\n        assert not isinstance(opt, OptimWrapper)\n        self.opt,self.true_wd,self.bn_wd = opt,true_wd,bn_wd\n        self.opt_keys = list(self.opt.param_groups[0].keys())\n        self.opt_keys.remove('params')\n        self.read_defaults()\n        self.wd = wd\n\n    @classmethod\n    def create(cls, opt_func:Union[type,Callable], lr:Union[float,Tuple,List], layer_groups:ModuleList, wd:Floats=0., \n               true_wd:bool=False, bn_wd:bool=True)->optim.Optimizer:\n        \"Create an `optim.Optimizer` from `opt_func` with `lr`. Set lr on `layer_groups`.\"\n        split_params = split_no_wd_params(layer_groups)\n        opt = opt_func([{'params': p, 'lr':0} for p in split_params])\n        opt = cls(opt, wd=wd, true_wd=true_wd, bn_wd=bn_wd)\n        opt.lr,opt.opt_func = listify(lr, layer_groups),opt_func\n        return opt\n\n    def new(self, layer_groups:Collection[nn.Module], split_no_wd:bool=True):\n        \"Create a new `OptimWrapper` from `self` with another `layer_groups` but the same hyper-parameters.\"\n        opt_func = getattr(self, 'opt_func', self.opt.__class__)\n        res = self.create(opt_func, self.lr, layer_groups, wd=self.wd, true_wd=self.true_wd, bn_wd=self.bn_wd)\n        res.mom,res.beta = self.mom,self.beta\n        return res\n\n    def new_with_params(self, param_groups:Collection[Collection[nn.Parameter]]):\n        \"Create a new `OptimWrapper` from `self` with another `layer_groups` but the same hyper-parameters.\"\n        opt_func = getattr(self, 'opt_func', self.opt.__class__)\n        opt = opt_func([{'params': p, 'lr':0} for p in param_groups])\n        opt = self.__class__(opt, wd=self.wd, true_wd=self.true_wd, bn_wd=self.bn_wd)\n        opt.lr,opt.opt_func,opt.mom,opt.beta = self.lr,opt_func,self.mom,self.beta\n        return opt\n\n    def __repr__(self)->str:\n        return f'OptimWrapper over {repr(self.opt)}.\\nTrue weight decay: {self.true_wd}'\n\n    #Pytorch optimizer methods\n    def step(self)->None:\n        \"Set weight decay and step optimizer.\"\n        # weight decay outside of optimizer step (AdamW)\n        if self.true_wd:\n            for lr,wd,pg1,pg2 in zip(self._lr,self._wd,self.opt.param_groups[::2],self.opt.param_groups[1::2]):\n                for p in pg1['params']: p.data.mul_(1 - wd*lr)\n                if self.bn_wd:\n                    for p in pg2['params']: p.data.mul_(1 - wd*lr)\n            self.set_val('weight_decay', listify(0, self._wd))\n        self.opt.step()\n\n    def zero_grad(self)->None:\n        \"Clear optimizer gradients.\"\n        self.opt.zero_grad()\n\n    #Passthrough to the inner opt.\n    def __getattr__(self, k:str)->Any: return getattr(self.opt, k, None)\n    def __setstate__(self,data:Any): self.__dict__.update(data)\n\n    def clear(self):\n        \"Reset the state of the inner optimizer.\"\n        sd = self.state_dict()\n        sd['state'] = {}\n        self.load_state_dict(sd)\n\n    @property\n    def n_params(self): return sum([len(pg['params']) for pg in self.opt.param_groups])\n\n    #Hyperparameters as properties\n    @property\n    def lr(self)->float: return self._lr[-1]\n    @lr.setter\n    def lr(self, val:float)->None:\n        self._lr = self.set_val('lr', listify(val, self._lr))\n\n    @property\n    def mom(self)->float:return self._mom[-1]\n    @mom.setter\n    def mom(self, val:float)->None:\n        if 'momentum' in self.opt_keys: self.set_val('momentum', listify(val, self._mom))\n        elif 'betas' in self.opt_keys:  self.set_val('betas', (listify(val, self._mom), self._beta))\n        self._mom = listify(val, self._mom)\n\n    @property\n    def beta(self)->float: return None if self._beta is None else self._beta[-1]\n    @beta.setter\n    def beta(self, val:float)->None:\n        \"Set beta (or alpha as makes sense for given optimizer).\"\n        if val is None: return\n        if 'betas' in self.opt_keys:    self.set_val('betas', (self._mom, listify(val, self._beta)))\n        elif 'alpha' in self.opt_keys:  self.set_val('alpha', listify(val, self._beta))\n        self._beta = listify(val, self._beta)\n\n    @property\n    def wd(self)->float: return self._wd[-1]\n    @wd.setter\n    def wd(self, val:float)->None:\n        \"Set weight decay.\"\n        if not self.true_wd: self.set_val('weight_decay', listify(val, self._wd), bn_groups=self.bn_wd)\n        self._wd = listify(val, self._wd)\n\n    #Helper functions\n    def read_defaults(self)->None:\n        \"Read the values inside the optimizer for the hyper-parameters.\"\n        self._beta = None\n        if 'lr' in self.opt_keys: self._lr = self.read_val('lr')\n        if 'momentum' in self.opt_keys: self._mom = self.read_val('momentum')\n        if 'alpha' in self.opt_keys: self._beta = self.read_val('alpha')\n        if 'betas' in self.opt_keys: self._mom,self._beta = self.read_val('betas')\n        if 'weight_decay' in self.opt_keys: self._wd = self.read_val('weight_decay')\n        reserved_names = ['params', 'lr', 'momentum', 'alpha', 'betas', 'weight_decay']\n        stat_names = [n for n in self.opt_keys if n not in reserved_names]\n        self._stats = {n:self.read_val(n) for n in stat_names}\n\n    def get_stat(self, name:str)->float: \n        if name in ['lr', 'mom', 'beta', 'wd']: return getattr(self, name)\n        else: return self._stats[name][-1]\n    def set_stat(self, name:str, value:Union[float, Collection[float]])->None:\n        if name in ['lr', 'mom', 'beta', 'wd']: setattr(self, name, value)\n        else:\n            val = listify(value, self._stats[name])\n            self.set_val(name, val)\n            self._stats[name] = val\n\n    def set_val(self, key:str, val:Any, bn_groups:bool=True)->Any:\n        \"Set `val` inside the optimizer dictionary at `key`.\"\n        if is_tuple(val): val = [(v1,v2) for v1,v2 in zip(*val)]\n        for v,pg1,pg2 in zip(val,self.opt.param_groups[::2],self.opt.param_groups[1::2]):\n            pg1[key] = v\n            if bn_groups: pg2[key] = v\n        return val\n\n    def read_val(self, key:str) -> Union[List[float],Tuple[List[float],List[float]]]:\n        \"Read a hyperparameter `key` in the optimizer dictionary.\"\n        val = [pg[key] for pg in self.opt.param_groups[::2]]\n        if is_tuple(val[0]): val = [o[0] for o in val], [o[1] for o in val]\n        return val\n    \n    def get_state(self):\n        \"Return the inner state minus the layer groups.\"\n        return {'opt_state':self.opt.state_dict(), 'lr':self._lr, 'wd':self._wd, 'beta':self._beta, 'mom':self._mom,\n                'opt_func':self.opt_func, 'true_wd':self.true_wd, 'bn_wd':self.bn_wd}\n\n    @classmethod\n    def load_with_state_and_layer_group(cls, state:dict, layer_groups:Collection[nn.Module]):\n        res = cls.create(state['opt_func'], state['lr'], layer_groups, wd=state['wd'], true_wd=state['true_wd'], \n                     bn_wd=state['bn_wd'])\n        res._mom,res._beta = state['mom'],state['beta']\n        res.load_state_dict(state['opt_state'])\n        return res\n\nclass Callback():\n    \"Base class for callbacks that want to record values, dynamically change learner params, etc.\"\n    _order=0\n    def on_train_begin(self, **kwargs:Any)->None:\n        \"To initialize constants in the callback.\"\n        pass\n    def on_epoch_begin(self, **kwargs:Any)->None:\n        \"At the beginning of each epoch.\"\n        pass\n    def on_batch_begin(self, **kwargs:Any)->None:\n        \"Set HP before the output and loss are computed.\"\n        pass\n    def on_loss_begin(self, **kwargs:Any)->None:\n        \"Called after forward pass but before loss has been computed.\"\n        pass\n    def on_backward_begin(self, **kwargs:Any)->None:\n        \"Called after the forward pass and the loss has been computed, but before backprop.\"\n        pass\n    def on_backward_end(self, **kwargs:Any)->None:\n        \"Called after backprop but before optimizer step. Useful for true weight decay in AdamW.\"\n        pass\n    def on_step_end(self, **kwargs:Any)->None:\n        \"Called after the step of the optimizer but before the gradients are zeroed.\"\n        pass\n    def on_batch_end(self, **kwargs:Any)->None:\n        \"Called at the end of the batch.\"\n        pass\n    def on_epoch_end(self, **kwargs:Any)->None:\n        \"Called at the end of an epoch.\"\n        pass\n    def on_train_end(self, **kwargs:Any)->None:\n        \"Useful for cleaning up things and saving files/models.\"\n        pass\n    def jump_to_epoch(self, epoch)->None:\n        \"To resume training at `epoch` directly.\"\n        pass\n\n    def get_state(self, minimal:bool=True):\n        \"Return the inner state of the `Callback`, `minimal` or not.\"\n        to_remove = ['exclude', 'not_min'] + getattr(self, 'exclude', []).copy()\n        if minimal: to_remove += getattr(self, 'not_min', []).copy()\n        return {k:v for k,v in self.__dict__.items() if k not in to_remove}\n\n    def  __repr__(self):\n        attrs = func_args(self.__init__)\n        to_remove = getattr(self, 'exclude', [])\n        list_repr = [self.__class__.__name__] + [f'{k}: {getattr(self, k)}' for k in attrs if k != 'self' and k not in to_remove]\n        return '\\n'.join(list_repr)\n\nclass SmoothenValue():\n    \"Create a smooth moving average for a value (loss, etc) using `beta`.\"\n    def __init__(self, beta:float):\n        self.beta,self.n,self.mov_avg = beta,0,0\n\n    def add_value(self, val:float)->None:\n        \"Add `val` to calculate updated smoothed value.\"\n        self.n += 1\n        self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val\n        self.smooth = self.mov_avg / (1 - self.beta ** self.n)\n\nCallbackList = Collection[Callback]\n\ndef _get_init_state(): return {'epoch':0, 'iteration':0, 'num_batch':0, 'skip_validate': False}\n\n@dataclass\nclass CallbackHandler():\n    \"Manage all of the registered `callbacks` and `metrics`, smoothing loss by momentum `beta`.\"\n    callbacks:CallbackList=None\n    metrics:CallbackList=None\n    beta:float=0.98\n\n    def __post_init__(self)->None:\n        \"Initialize smoother and learning stats.\"\n        self.callbacks = ifnone(self.callbacks, [])\n        self.metrics = ifnone(self.metrics, [])\n        self.metrics = [(met if isinstance(met, Callback) else AverageMetric(met)) for met in self.metrics]\n        self.callbacks = sorted(self.callbacks, key=lambda o: getattr(o, '_order', 0))\n        self.smoothener = SmoothenValue(self.beta)\n        self.state_dict:Dict[str,Union[int,float,Tensor]]=_get_init_state()\n\n    def _call_and_update(self, cb, cb_name, **kwargs)->None:\n        \"Call `cb_name` on `cb` and update the inner state.\"\n        new = ifnone(getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs), dict())\n        for k,v in new.items():\n            if k not in self.state_dict:\n                raise Exception(f\"{k} isn't a valid key in the state of the callbacks.\")\n            else: self.state_dict[k] = v\n    \n    def __call__(self, cb_name, call_mets=True, **kwargs)->None:\n        \"Call through to all of the `CallbakHandler` functions.\"\n        if call_mets: \n            for met in self.metrics: self._call_and_update(met, cb_name, **kwargs)\n        for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs)\n\n    def set_dl(self, dl:DataLoader):\n        \"Set the current `dl` used.\"\n        if hasattr(self, 'cb_dl'): self.callbacks.remove(self.cb_dl)\n        if isinstance(dl.dataset, Callback):\n            self.callbacks.append(dl.dataset)\n            self.cb_dl = dl.dataset\n\n    def on_train_begin(self, epochs:int, pbar:PBar, metrics:MetricFuncList)->None:\n        \"About to start learning.\"\n        self.state_dict = _get_init_state()\n        self.state_dict.update(dict(n_epochs=epochs, pbar=pbar, metrics=metrics))\n        names = [(met.name if hasattr(met, 'name') else camel2snake(met.__class__.__name__)) for met in self.metrics]\n        self('train_begin', metrics_names=names)\n        if self.state_dict['epoch'] != 0:\n            self.state_dict['pbar'].first_bar.total -= self.state_dict['epoch']\n            for cb in self.callbacks: cb.jump_to_epoch(self.state_dict['epoch'])\n\n    def on_epoch_begin(self)->None:\n        \"Handle new epoch.\"\n        self.state_dict['num_batch'],self.state_dict['stop_training'] = 0,False\n        self('epoch_begin')\n\n    def on_batch_begin(self, xb:Tensor, yb:Tensor, train:bool=True)->Tuple[Any,Any]:\n        \"Handle new batch `xb`,`yb` in `train` or validation.\"\n        self.state_dict.update(dict(last_input=xb, last_target=yb, train=train, \n            stop_epoch=False, skip_step=False, skip_zero=False, skip_bwd=False))\n        self('batch_begin', mets = not self.state_dict['train'])\n        return self.state_dict['last_input'], self.state_dict['last_target']\n\n    def on_loss_begin(self, out:Tensor)->Any:\n        \"Handle start of loss calculation with model output `out`.\"\n        self.state_dict['last_output'] = out\n        self('loss_begin', call_mets=False)\n        return self.state_dict['last_output']\n\n    def on_backward_begin(self, loss:Tensor)->Tuple[Any,Any]:\n        \"Handle gradient calculation on `loss`.\"\n        self.smoothener.add_value(loss.detach().cpu())\n        self.state_dict['last_loss'], self.state_dict['smooth_loss'] = loss, self.smoothener.smooth\n        self('backward_begin', call_mets=False)\n        return self.state_dict['last_loss'], self.state_dict['skip_bwd']\n\n    def on_backward_end(self)->Any:\n        \"Handle end of gradient calculation.\"\n        self('backward_end', call_mets=False)\n        return self.state_dict['skip_step']\n\n    def on_step_end(self)->Any:\n        \"Handle end of optimization step.\"\n        self('step_end', call_mets=False)\n        return self.state_dict['skip_zero']\n\n    def on_batch_end(self, loss:Tensor)->Any:\n        \"Handle end of processing one batch with `loss`.\"\n        self.state_dict['last_loss'] = loss\n        self('batch_end', call_mets = not self.state_dict['train'])\n        if self.state_dict['train']:\n            self.state_dict['iteration'] += 1\n            self.state_dict['num_batch'] += 1\n        return self.state_dict['stop_epoch']\n\n    def on_epoch_end(self, val_loss:Tensor)->bool:\n        \"Epoch is done, process `val_loss`.\"\n        self.state_dict['last_metrics'] = [val_loss] if val_loss is not None else [None]\n        self('epoch_end', call_mets = val_loss is not None)\n        self.state_dict['epoch'] += 1\n        return self.state_dict['stop_training']\n\n    def on_train_end(self, exception:Union[bool,Exception])->None:\n        \"Handle end of training, `exception` is an `Exception` or False if no exceptions during training.\"\n        self('train_end', exception=exception)\n        \n    @property\n    def skip_validate(self): return self.state_dict['skip_validate']\n\nclass AverageMetric(Callback):\n    \"Wrap a `func` in a callback for metrics computation.\"\n    def __init__(self, func):\n        # If func has a __name__ use this one else it should be a partial\n        name = func.__name__ if hasattr(func, '__name__') else func.func.__name__\n        self.func, self.name = func, name\n        self.world = num_distrib()\n\n    def on_epoch_begin(self, **kwargs):\n        \"Set the inner value to 0.\"\n        self.val, self.count = 0.,0\n\n    def on_batch_end(self, last_output, last_target, **kwargs):\n        \"Update metric computation with `last_output` and `last_target`.\"\n        if not is_listy(last_target): last_target=[last_target]\n        self.count += first_el(last_target).size(0)\n        val = self.func(last_output, *last_target)\n        if self.world:\n            val = val.clone()\n            dist.all_reduce(val, op=dist.ReduceOp.SUM)\n            val /= self.world\n        self.val += first_el(last_target).size(0) * val.detach().cpu()\n\n    def on_epoch_end(self, last_metrics, **kwargs):\n        \"Set the final result in `last_metrics`.\"\n        return add_metrics(last_metrics, self.val/self.count)\n\ndef annealing_no(start:Number, end:Number, pct:float)->Number:\n    \"No annealing, always return `start`.\"\n    return start\ndef annealing_linear(start:Number, end:Number, pct:float)->Number:\n    \"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0.\"\n    return start + pct * (end-start)\ndef annealing_exp(start:Number, end:Number, pct:float)->Number:\n    \"Exponentially anneal from `start` to `end` as pct goes from 0.0 to 1.0.\"\n    return start * (end/start) ** pct\ndef annealing_cos(start:Number, end:Number, pct:float)->Number:\n    \"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0.\"\n    cos_out = np.cos(np.pi * pct) + 1\n    return end + (start-end)/2 * cos_out\n\ndef do_annealing_poly(start:Number, end:Number, pct:float, degree:Number)->Number:\n    \"Helper function for `anneal_poly`.\"\n    return end + (start-end) * (1-pct)**degree\ndef annealing_poly(degree:Number)->Number:\n    \"Anneal polynomically from `start` to `end` as pct goes from 0.0 to 1.0.\"\n    return functools.partial(do_annealing_poly, degree=degree)\n\nclass Scheduler():\n    \"Used to \\\"step\\\" from start,end (`vals`) over `n_iter` iterations on a schedule defined by `func`\"\n    def __init__(self, vals:StartOptEnd, n_iter:int, func:Optional[AnnealFunc]=None):\n        self.start,self.end = (vals[0],vals[1]) if is_tuple(vals) else (vals,0)\n        self.n_iter = max(1,n_iter)\n        if func is None: self.func = annealing_linear if is_tuple(vals) else annealing_no\n        else:          self.func = func\n        self.n = 0\n        \n    def restart(self): self.n = 0\n\n    def step(self)->Number:\n        \"Return next value along annealed schedule.\"\n        self.n += 1\n        return self.func(self.start, self.end, self.n/self.n_iter)\n\n    @property\n    def is_done(self)->bool:\n        \"Return `True` if schedule completed.\"\n        return self.n >= self.n_iter\n\n"
  },
  {
    "path": "fastai/callbacks/__init__.py",
    "content": "from .lr_finder import *\nfrom .one_cycle import *\nfrom .fp16 import *\nfrom .general_sched import *\nfrom .hooks import *\nfrom .mixup import *\nfrom .rnn import *\nfrom .tracker import *\nfrom .csv_logger import *\nfrom .loss_metrics import *\nfrom .oversampling import *\n"
  },
  {
    "path": "fastai/callbacks/csv_logger.py",
    "content": "\"A `Callback` that saves tracked metrics into a persistent file.\"\n#Contribution from devforfu: https://nbviewer.jupyter.org/gist/devforfu/ea0b3fcfe194dad323c3762492b05cae\nfrom ..torch_core import *\nfrom ..basic_data import DataBunch\nfrom ..callback import *\nfrom ..basic_train import Learner, LearnerCallback\nfrom time import time\nfrom fastprogress.fastprogress import format_time\n\n__all__ = ['CSVLogger']\n\nclass CSVLogger(LearnerCallback):\n    \"A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`.\"\n    def __init__(self, learn:Learner, filename: str = 'history', append: bool = False): \n        super().__init__(learn)\n        self.filename,self.path,self.append = filename,self.learn.path/f'{filename}.csv',append\n        self.add_time = True\n\n    def read_logged_file(self):  \n        \"Read the content of saved file\"\n        return pd.read_csv(self.path)\n\n    def on_train_begin(self, **kwargs: Any) -> None:\n        \"Prepare file with metric names.\"\n        self.path.parent.mkdir(parents=True, exist_ok=True)      \n        self.file = self.path.open('a') if self.append else self.path.open('w')\n        self.file.write(','.join(self.learn.recorder.names[:(None if self.add_time else -1)]) + '\\n')\n    \n    def on_epoch_begin(self, **kwargs:Any)->None:\n        if self.add_time: self.start_epoch = time()\n        \n    def on_epoch_end(self, epoch: int, smooth_loss: Tensor, last_metrics: MetricsList, **kwargs: Any) -> bool:\n        \"Add a line with `epoch` number, `smooth_loss` and `last_metrics`.\"\n        last_metrics = ifnone(last_metrics, [])\n        stats = [str(stat) if isinstance(stat, int) else '#na#' if stat is None else f'{stat:.6f}'\n                 for name, stat in zip(self.learn.recorder.names, [epoch, smooth_loss] + last_metrics)]\n        if self.add_time: stats.append(format_time(time() - self.start_epoch))\n        str_stats = ','.join(stats)\n        self.file.write(str_stats + '\\n')\n\n    def on_train_end(self, **kwargs: Any) -> None:  \n        \"Close the file.\"\n        self.file.close()\n"
  },
  {
    "path": "fastai/callbacks/fp16.py",
    "content": "\"Callback support for half precision (fp16) training. Increases training speed.\"\nfrom ..torch_core import *\nfrom ..callback import *\nfrom ..basic_train import *\nfrom torch._utils import _unflatten_dense_tensors\nfrom torch.nn.utils import parameters_to_vector\n\n__all__ = ['MixedPrecision']\n\ndef get_master(layer_groups:ModuleList, flat_master:bool=False) -> Tuple[List[List[Tensor]], List[List[Tensor]]]:\n    \"Return two lists, one for the model parameters in FP16 and one for the master parameters in FP32.\"\n    split_params = split_no_wd_params(layer_groups)\n    model_params = [[param for param in pg if param.requires_grad] for pg in split_params]\n    if flat_master:\n        master_params = []\n        for lg in model_params:\n            if len(lg) !=0 :\n                mp = parameters_to_vector([param.data.float() for param in lg])\n                mp = torch.nn.Parameter(mp, requires_grad=True)\n                if mp.grad is None: mp.grad = mp.new(*mp.size())\n                master_params.append([mp])\n            else: master_params.append([])\n        return model_params, master_params\n    else:\n        master_params = [[param.clone().float().detach() for param in lg] for lg in model_params]\n        for mp in master_params:\n            for param in mp: param.requires_grad = True\n        return model_params, master_params\n\ndef model_g2master_g(model_params:Sequence[Tensor], master_params:Sequence[Tensor], flat_master:bool=False)->None:\n    \"Copy the `model_params` gradients to `master_params` for the optimizer step.\"\n    if flat_master:\n        for model_group,master_group in zip(model_params,master_params):\n            if len(master_group) != 0:\n                if master_group[0].grad is None: master_group[0].grad = master_group[0].data.new(*master_group[0].data.size())\n                master_group[0].grad.data.copy_(parameters_to_vector([p.grad.data.float() for p in model_group]))\n    else:\n        for model_group,master_group in zip(model_params,master_params):\n            for model, master in zip(model_group, master_group):\n                if model.grad is not None:\n                    if master.grad is None: master.grad = master.data.new(*master.data.size())\n                    master.grad.data.copy_(model.grad.data)\n                else: master.grad = None\n\ndef master2model(model_params:Sequence[Tensor], master_params:Sequence[Tensor], flat_master:bool=False)->None:\n    \"Copy `master_params` to `model_params`.\"\n    if flat_master:\n        for model_group,master_group in zip(model_params,master_params):\n            if len(model_group) != 0:\n                for model, master in zip(model_group, _unflatten_dense_tensors(master_group[0].data, model_group)):\n                    model.data.copy_(master)\n    else:\n        for model_group,master_group in zip(model_params,master_params):\n            for model, master in zip(model_group, master_group): model.data.copy_(master.data)\n\ndef grad_overflow(param_group):\n    for group in param_group:\n        for p in group:\n            if p.grad is not None:\n                s = float(p.grad.data.float().sum())\n                if s == float('inf') or s == float('-inf') or s != s: return True\n    return False\n\nclass MixedPrecision(LearnerCallback):\n    _order = 999 #Need to run after things that could call on_backward_begin and change the loss\n    \"Callback that handles mixed-precision training.\"\n    def __init__(self, learn:Learner, loss_scale:float=None, max_noskip:int=1000, dynamic:bool=True, clip:float=None,\n                 flat_master:bool=False, max_scale:float=2**24):\n        super().__init__(learn)\n        self.flat_master,self.dynamic,self.max_noskip,self.clip,self.max_scale = flat_master,dynamic,max_noskip,clip,max_scale\n        self.loss_scale = ifnone(loss_scale, 2**16 if dynamic else 512)\n        self.not_min += ['model_params', 'master_params']\n        assert torch.backends.cudnn.enabled, \"Mixed precision training requires cudnn.\"\n        self.opt = None\n\n    def on_train_begin(self, **kwargs:Any)->None:\n        \"Prepare the master model.\"\n        #Get a copy of the model params in FP32\n        self.model_params, self.master_params = get_master(self.learn.layer_groups, self.flat_master)\n        #Changes the optimizer so that the optimization step is done in FP32.\n        new_opt = self.learn.opt.new_with_params(self.master_params)\n        if self.opt is not None:\n            self.opt.lr,self.opt.wd = self.learn.opt.lr,self.learn.opt.wd\n            new_opt.load_state_dict(self.opt)\n        self.learn.opt.opt = new_opt.opt\n        self.noskip = 0\n\n    def on_loss_begin(self, last_output:Tensor, **kwargs:Any) -> Tensor:\n        \"Convert half precision output to FP32 to avoid reduction overflow.\"\n        return {'last_output': to_float(last_output)}\n\n    def on_backward_begin(self, last_loss:Rank0Tensor, **kwargs:Any) -> Rank0Tensor:\n        \"Scale gradients up by `self.loss_scale` to prevent underflow.\"\n        #To avoid gradient underflow, we scale the gradients\n        ret_loss = last_loss * self.loss_scale\n        return {'last_loss': ret_loss}\n\n    def on_backward_end(self, **kwargs:Any)->None:\n        \"Convert the gradients back to FP32 and divide them by the scale.\"\n        if self.dynamic and grad_overflow(self.model_params) and self.loss_scale > 1:\n            self.loss_scale /= 2\n            self.noskip = 0\n            #The step will be skipped since we don't update the master grads so they are all None or zero\n        else:\n            model_g2master_g(self.model_params, self.master_params, self.flat_master)\n            for group in self.master_params:\n                for param in group:\n                    if param.grad is not None: param.grad.div_(self.loss_scale)\n            if self.clip is not None:\n                for group in self.master_params: nn.utils.clip_grad_norm_(group, self.clip)\n            if not self.dynamic: return\n            self.noskip += 1\n            if self.noskip >= self.max_noskip and self.loss_scale < self.max_scale:\n                self.loss_scale *= 2\n                self.noskip = 0\n\n    def on_step_end(self, **kwargs:Any)->None:\n        \"Update the params from master to model and zero grad.\"\n        #Zeros the gradients of the model since the optimizer is disconnected.\n        self.learn.model.zero_grad()\n        #Update the params from master to model.\n        master2model(self.model_params, self.master_params, self.flat_master)\n"
  },
  {
    "path": "fastai/callbacks/general_sched.py",
    "content": "from ..core import *\nfrom ..callback import *\nfrom ..basic_train import Learner, LearnerCallback\n\n__all__ = ['GeneralScheduler', 'TrainingPhase']\n\n@dataclass\nclass TrainingPhase():\n    \"Schedule hyper-parameters for a phase of `length` iterations.\"\n    length:int\n    \n    def __post_init__(self): self.scheds = dict()\n    def schedule_hp(self, name, vals, anneal=None):\n        \"Adds a schedule for `name` between `vals` using `anneal`.\"\n        self.scheds[name] = Scheduler(vals, self.length, anneal)\n        return self\n\nclass GeneralScheduler(LearnerCallback):\n    \"Schedule multiple `TrainingPhase` for a `Learner`.\"\n    def __init__(self, learn:Learner, phases:Collection[TrainingPhase], start_epoch:int=None):\n        super().__init__(learn)\n        self.phases,self.start_epoch = phases,start_epoch\n\n    def on_train_begin(self, epoch:int, **kwargs:Any)->None:\n        \"Initialize the schedulers for training.\"\n        res = {'epoch':self.start_epoch} if self.start_epoch is not None else None\n        self.start_epoch = ifnone(self.start_epoch, epoch)\n        self.scheds = [p.scheds for p in self.phases]\n        self.opt = self.learn.opt\n        for k,v in self.scheds[0].items(): \n            v.restart()\n            self.opt.set_stat(k, v.start)\n        self.idx_s = 0\n        return res\n    \n    def jump_to_epoch(self, epoch:int)->None:\n        for _ in range(len(self.learn.data.train_dl) * epoch):\n            self.on_batch_end(True)\n\n    def on_batch_end(self, train, **kwargs:Any)->None:\n        \"Take a step in lr,mom sched, start next stepper when the current one is complete.\"\n        if train:\n            if self.idx_s >= len(self.scheds): return {'stop_training': True, 'stop_epoch': True}\n            sched = self.scheds[self.idx_s]\n            for k,v in sched.items(): self.opt.set_stat(k, v.step())\n            if list(sched.values())[0].is_done: self.idx_s += 1"
  },
  {
    "path": "fastai/callbacks/hooks.py",
    "content": "\"Hooks provide extensibility at the model level.\"\nfrom ..torch_core import *\nfrom ..callback import *\nfrom ..basic_train import *\nfrom ..basic_data import *\n\n__all__ = ['ActivationStats', 'Hook', 'HookCallback', 'Hooks', 'hook_output', 'hook_outputs',\n           'model_sizes', 'num_features_model', 'model_summary', 'dummy_eval', 'dummy_batch']\n\nclass Hook():\n    \"Create a hook on `m` with `hook_func`.\"\n    def __init__(self, m:nn.Module, hook_func:HookFunc, is_forward:bool=True, detach:bool=True):\n        self.hook_func,self.detach,self.stored = hook_func,detach,None\n        f = m.register_forward_hook if is_forward else m.register_backward_hook\n        self.hook = f(self.hook_fn)\n        self.removed = False\n\n    def hook_fn(self, module:nn.Module, input:Tensors, output:Tensors):\n        \"Applies `hook_func` to `module`, `input`, `output`.\"\n        if self.detach:\n            input  = (o.detach() for o in input ) if is_listy(input ) else input.detach()\n            output = (o.detach() for o in output) if is_listy(output) else output.detach()\n        self.stored = self.hook_func(module, input, output)\n\n    def remove(self):\n        \"Remove the hook from the model.\"\n        if not self.removed:\n            self.hook.remove()\n            self.removed=True\n\n    def __enter__(self, *args): return self\n    def __exit__(self, *args): self.remove()\n\nclass Hooks():\n    \"Create several hooks on the modules in `ms` with `hook_func`.\"\n    def __init__(self, ms:Collection[nn.Module], hook_func:HookFunc, is_forward:bool=True, detach:bool=True):\n        self.hooks = [Hook(m, hook_func, is_forward, detach) for m in ms]\n\n    def __getitem__(self,i:int)->Hook: return self.hooks[i]\n    def __len__(self)->int: return len(self.hooks)\n    def __iter__(self): return iter(self.hooks)\n    @property\n    def stored(self): return [o.stored for o in self]\n\n    def remove(self):\n        \"Remove the hooks from the model.\"\n        for h in self.hooks: h.remove()\n\n    def __enter__(self, *args): return self\n    def __exit__ (self, *args): self.remove()\n\ndef _hook_inner(m,i,o): return o if isinstance(o,Tensor) else o if is_listy(o) else list(o)\n\ndef hook_output (module:nn.Module, detach:bool=True, grad:bool=False)->Hook:\n    \"Return a `Hook` that stores activations of `module` in `self.stored`\"\n    return Hook(module, _hook_inner, detach=detach, is_forward=not grad)\n\ndef hook_outputs(modules:Collection[nn.Module], detach:bool=True, grad:bool=False)->Hooks:\n    \"Return `Hooks` that store activations of all `modules` in `self.stored`\"\n    return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad)\n\nclass HookCallback(LearnerCallback):\n    \"Callback that can be used to register hooks on `modules`. Implement the corresponding function in `self.hook`.\"\n    def __init__(self, learn:Learner, modules:Sequence[nn.Module]=None, do_remove:bool=True):\n        super().__init__(learn)\n        self.modules,self.do_remove = modules,do_remove\n\n    def on_train_begin(self, **kwargs):\n        \"Register the `Hooks` on `self.modules`.\"\n        if not self.modules:\n            self.modules = [m for m in flatten_model(self.learn.model)\n                            if hasattr(m, 'weight')]\n        self.hooks = Hooks(self.modules, self.hook)\n\n    def on_train_end(self, **kwargs):\n        \"Remove the `Hooks`.\"\n        if self.do_remove: self.remove()\n\n    def remove(self): \n        if getattr(self, 'hooks', None): self.hooks.remove()\n    def __del__(self): self.remove()\n\nclass ActivationStats(HookCallback):\n    \"Callback that record the mean and std of activations.\"\n\n    def on_train_begin(self, **kwargs):\n        \"Initialize stats.\"\n        super().on_train_begin(**kwargs)\n        self.stats = []\n\n    def hook(self, m:nn.Module, i:Tensors, o:Tensors)->Tuple[Rank0Tensor,Rank0Tensor]:\n        \"Take the mean and std of `o`.\"\n        return o.mean().item(),o.std().item()\n    def on_batch_end(self, train, **kwargs):\n        \"Take the stored results and puts it in `self.stats`\"\n        if train: self.stats.append(self.hooks.stored)\n    def on_train_end(self, **kwargs):\n        \"Polish the final result.\"\n        super().on_train_end(**kwargs)\n        self.stats = tensor(self.stats).permute(2,1,0)\n\ndef dummy_batch(m: nn.Module, size:tuple=(64,64))->Tensor:\n    \"Create a dummy batch to go through `m` with `size`.\"\n    ch_in = in_channels(m)\n    return one_param(m).new(1, ch_in, *size).requires_grad_(False).uniform_(-1.,1.)\n\ndef dummy_eval(m:nn.Module, size:tuple=(64,64)):\n    \"Pass a `dummy_batch` in evaluation mode in `m` with `size`.\"\n    m.eval()\n    return m(dummy_batch(m, size))\n    #return m.eval()(dummy_batch(m, size))\n\ndef model_sizes(m:nn.Module, size:tuple=(64,64))->Tuple[Sizes,Tensor,Hooks]:\n    \"Pass a dummy input through the model `m` to get the various sizes of activations.\"\n    with hook_outputs(m) as hooks:\n        x = dummy_eval(m, size)\n        return [o.stored.shape for o in hooks]\n\ndef num_features_model(m:nn.Module)->int:\n    \"Return the number of output features for `model`.\"\n    sz = 64\n    while True:\n        try: return model_sizes(m, size=(sz,sz))[-1][1]\n        except Exception as e:\n            sz *= 2\n            if sz > 2048: raise\n\ndef total_params(m:nn.Module)->int:\n    params, trainable = 0, False\n    if hasattr(m, \"weight\") and hasattr(m.weight, \"size\"):\n         params += m.weight.numel()\n         trainable = m.weight.requires_grad\n    if hasattr(m, \"bias\") and hasattr(m.bias, \"size\"): params += m.bias.numel()\n    return params, trainable\n\ndef hook_params(modules:Collection[nn.Module])->Hooks:\n    return Hooks(modules, lambda m, i, o: total_params(m))\n\ndef params_size(m: Union[nn.Module,Learner], size: tuple = (3, 64, 64))->Tuple[Sizes, Tensor, Hooks]:\n    \"Pass a dummy input through the model to get the various sizes. Returns (res,x,hooks) if `full`\"\n    if isinstance(m, Learner):\n        if m.data.is_empty:\n            raise Exception(\"This is an empty `Learner` and `Learner.summary` requires some data to pass through the model.\")\n        ds_type = DatasetType.Train if m.data.train_dl else (DatasetType.Valid if m.data.valid_dl else DatasetType.Test)\n        x = m.data.one_batch(ds_type=ds_type, detach=False, denorm=False)[0]\n        x = [o[:1] for o in x]  if is_listy(x) else x[:1]\n        m = m.model\n    elif isinstance(m, nn.Module): x = next(m.parameters()).new(1, *size)\n    else: raise TypeError('You should either pass in a Learner or nn.Module')\n    with hook_outputs(flatten_model(m)) as hook_o:\n        with hook_params(flatten_model(m))as hook_p:\n            x = m.eval()(*x) if is_listy(x) else m.eval()(x)\n            output_size = [((o.stored.shape[1:]) if o.stored is not None else None) for o in hook_o]\n            params = [(o.stored if o.stored is not None else (None,None)) for o in hook_p]\n    params, trainables = map(list,zip(*params))\n    return output_size, params, trainables\n\ndef get_layer_name(layer:nn.Module)->str:\n    return str(layer.__class__).split(\".\")[-1].split(\"'\")[0]\n\ndef layers_info(m:Collection[nn.Module]) -> Collection[namedtuple]:\n    func = lambda m:list(map(get_layer_name, flatten_model(m)))\n    layers_names = func(m.model) if isinstance(m, Learner) else func(m)\n    layers_sizes, layers_params, layers_trainable = params_size(m)\n    layer_info = namedtuple('Layer_Information', ['Layer', 'OutputSize', 'Params', 'Trainable'])\n    return list(map(layer_info, layers_names, layers_sizes, layers_params, layers_trainable))\n\ndef model_summary(m:Learner, n:int=70):\n    \"Print a summary of `m` using a output text width of `n` chars\"\n    info = layers_info(m)\n    header = [\"Layer (type)\", \"Output Shape\", \"Param #\", \"Trainable\"]\n    res = m.model.__class__.__name__ + \"\\n\"\n    res += \"=\" * n + \"\\n\"\n    res += f\"{header[0]:<20} {header[1]:<20} {header[2]:<10} {header[3]:<10}\\n\"\n    res += \"=\" * n + \"\\n\"\n    total_params = 0\n    total_trainable_params = 0\n    for layer, size, params, trainable in info:\n        if size is None: continue\n        total_params += int(params)\n        total_trainable_params += int(params) * trainable\n        size, trainable = str(list(size)), str(trainable)\n        res += f\"{layer:<20} {size:<20} {int(params):<10,} {trainable:<10}\\n\"\n        res += \"_\" * n + \"\\n\"\n    res += f\"\\nTotal params: {total_params:,}\\n\"\n    res += f\"Total trainable params: {total_trainable_params:,}\\n\"\n    res += f\"Total non-trainable params: {total_params - total_trainable_params:,}\\n\"\n           \n    res += f\"Optimized with {str(m.opt_func)[25:-1].replace('>', '')}\\n\"\n    if m.true_wd: res += f\"Using true weight decay as discussed in https://www.fast.ai/2018/07/02/adam-weight-decay/ \\n\"\n    if \"wd\" in str(m.opt_func) or \"weight_decay\" in str(m.opt_func): res += f\"\\x1b[1;31m Specifying weight decay in the optimizer has no effect, Learner will overwrite \\x1b[0m \\n\"\n    if \"lr\" in str(m.opt_func) or \"learning_rate\" in str(m.opt_func): res += f\"\\x1b[1;31m Specifying lr in the optimizer has no effect, pass it to fit or the defaults.lr will apply \\x1b[0m \\n\" \n    res += f\"Loss function : {m.loss_func.__class__.__name__}\\n\"\n    res += \"=\" * n + \"\\n\"\n    res += \"Callbacks functions applied \\n\"\n    res += \"\\n\".join([f\"    {cbs.__class__.__name__}\" for cbs in m.callbacks])\n\n    return PrettyString(res)\n\nLearner.summary = model_summary\n"
  },
  {
    "path": "fastai/callbacks/loss_metrics.py",
    "content": "from ..torch_core import *\nfrom ..callback import *\nfrom ..basic_train import Learner, LearnerCallback\n\n__all__ = ['LossMetrics']\n\nclass LossMetrics(LearnerCallback):\n    \"Add `loss_func.metrics` to metrics named by `loss_func.metric_names`\"\n    _order = -20 #Needs to run before the recorder\n\n    def on_train_begin(self, **kwargs):\n        \"Add the metrics names to the `Recorder`.\"\n        self.names = ifnone(self.learn.loss_func.metric_names, [])\n        if not self.names: warn('LossMetrics requested but no loss_func.metric_names provided')\n        self.learn.recorder.add_metric_names(self.names)\n\n    def on_epoch_begin(self, **kwargs):\n        \"Initialize the metrics for this epoch.\"\n        self.metrics = {name:0. for name in self.names}\n        self.nums = 0\n\n    def on_batch_end(self, last_target, train, **kwargs):\n        \"Update the metrics if not `train`\"\n        if train: return\n        bs = last_target.size(0)\n        for name in self.names:\n            self.metrics[name] += bs * self.learn.loss_func.metrics[name].detach().cpu()\n        self.nums += bs\n\n    def on_epoch_end(self, last_metrics, **kwargs):\n        \"Finish the computation and sends the result to the Recorder.\"\n        if not self.nums: return\n        metrics = [self.metrics[name]/self.nums for name in self.names]\n        return {'last_metrics': last_metrics+metrics}\n"
  },
  {
    "path": "fastai/callbacks/lr_finder.py",
    "content": "\"Tools to help find the optimal learning rate for training\"\nfrom ..torch_core import *\nfrom ..basic_data import DataBunch\nfrom ..callback import *\nfrom ..basic_train import Learner, LearnerCallback\n\n__all__ = ['LRFinder']\n\nclass LRFinder(LearnerCallback):\n    \"Causes `learn` to go on a mock training from `start_lr` to `end_lr` for `num_it` iterations.\"\n    def __init__(self, learn:Learner, start_lr:float=1e-7, end_lr:float=10, num_it:int=100, stop_div:bool=True):\n        super().__init__(learn)\n        self.data,self.stop_div = learn.data,stop_div\n        self.sched = Scheduler((start_lr, end_lr), num_it, annealing_exp)\n\n    def on_train_begin(self, pbar, **kwargs:Any)->None:\n        \"Initialize optimizer and learner hyperparameters.\"\n        setattr(pbar, 'clean_on_interrupt', True)\n        self.learn.save('tmp')\n        self.opt = self.learn.opt\n        self.opt.lr = self.sched.start\n        self.stop,self.best_loss = False,0.\n        return {'skip_validate': True}\n\n    def on_batch_end(self, iteration:int, smooth_loss:TensorOrNumber, **kwargs:Any)->None:\n        \"Determine if loss has runaway and we should stop.\"\n        if iteration==0 or smooth_loss < self.best_loss: self.best_loss = smooth_loss\n        self.opt.lr = self.sched.step()\n        if self.sched.is_done or (self.stop_div and (smooth_loss > 4*self.best_loss or torch.isnan(smooth_loss))):\n            #We use the smoothed loss to decide on the stopping since it's less shaky.\n            return {'stop_epoch': True, 'stop_training': True}\n\n    def on_train_end(self, **kwargs:Any)->None:\n        \"Cleanup learn model weights disturbed during LRFinder exploration.\"\n        self.learn.load('tmp', purge=False)\n        if hasattr(self.learn.model, 'reset'): self.learn.model.reset()\n        for cb in self.callbacks:\n            if hasattr(cb, 'reset'): cb.reset()\n        print('LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.')\n"
  },
  {
    "path": "fastai/callbacks/mem.py",
    "content": "\" Memory profiling callbacks \"\n\nimport tracemalloc, threading, torch, time\nfrom ..utils.mem import *\nfrom ..basic_train import *\nfrom ..torch_core import *\nfrom ..utils.pynvml_gate import *\n\nif use_gpu: pynvml = load_pynvml_env()\n\nclass PeakMemMetric(LearnerCallback):\n    \"Callback that measures used and peaked general and GPU memory.\"\n\n    _order=-20 # Needs to run before the recorder\n\n    def __init__(self, learn:Learner):\n        super().__init__(learn)\n        assert torch.cuda.is_available(), \"pytorch CUDA is required\"\n        preload_pytorch()\n\n    def peak_monitor_start(self):\n        self.peak_monitoring = True\n\n        # start RAM tracing\n        tracemalloc.start()\n\n        # this thread samples RAM usage as long as the current epoch of the fit loop is running\n        peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)\n        peak_monitor_thread.daemon = True\n        peak_monitor_thread.start()\n\n    def peak_monitor_stop(self):\n        tracemalloc.stop()\n        self.peak_monitoring = False\n\n    def peak_monitor_func(self):\n        self.gpu_mem_used_peak = -1\n\n        gpu_id = torch.cuda.current_device()\n        gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)\n\n        while True:\n            gpu_mem_used = gpu_mem_get_used_fast(gpu_handle)\n            self.gpu_mem_used_peak = max(gpu_mem_used, self.gpu_mem_used_peak)\n            if not self.peak_monitoring: break\n            time.sleep(0.001) # 1msec\n\n    def on_train_begin(self, **kwargs):\n        self.learn.recorder.add_metric_names(['cpu used',  'peak', 'gpu used',  'peak'])\n\n    def on_epoch_begin(self, **kwargs):\n        self.peak_monitor_start()\n        self.gpu_before = gpu_mem_get_used_no_cache()\n\n    def on_epoch_end(self, last_metrics, **kwargs):\n        cpu_used, cpu_peak =  list(map(lambda x: int(x/2**20), tracemalloc.get_traced_memory()))\n        self.peak_monitor_stop()\n        gpu_used = gpu_mem_get_used_no_cache() - self.gpu_before\n        gpu_peak = self.gpu_mem_used_peak      - self.gpu_before\n        # can be negative, due to unreliable peak monitor thread\n        if gpu_peak < 0:   gpu_peak = 0\n        # since we want the overhead only, subtract delta used if it's positive\n        elif gpu_used > 0: gpu_peak -= gpu_used\n        # The numbers are deltas in MBs (beginning of the epoch and the end)\n        return add_metrics(last_metrics, [cpu_used, cpu_peak, gpu_used, gpu_peak])\n"
  },
  {
    "path": "fastai/callbacks/misc.py",
    "content": "\" Miscellaneous callbacks \"\n\nfrom fastai.callback import Callback\n\nclass StopAfterNBatches(Callback):\n    \"Stop training after n batches of the first epoch.\"\n    def __init__(self, n_batches:int=2):\n        self.stop,self.n_batches = False,n_batches-1 # iteration starts from 0\n\n    def on_batch_end(self, iteration, **kwargs):\n        if iteration == self.n_batches:\n            return {'stop_epoch': True, 'stop_training': True, 'skip_validate': True}\n"
  },
  {
    "path": "fastai/callbacks/mixup.py",
    "content": "\"Implements [mixup](https://arxiv.org/abs/1710.09412) training method\"\nfrom ..torch_core import *\nfrom ..callback import *\nfrom ..basic_train import Learner, LearnerCallback\n\nclass MixUpCallback(LearnerCallback):\n    \"Callback that creates the mixed-up input and target.\"\n    def __init__(self, learn:Learner, alpha:float=0.4, stack_x:bool=False, stack_y:bool=True):\n        super().__init__(learn)\n        self.alpha,self.stack_x,self.stack_y = alpha,stack_x,stack_y\n    \n    def on_train_begin(self, **kwargs):\n        if self.stack_y: self.learn.loss_func = MixUpLoss(self.learn.loss_func)\n        \n    def on_batch_begin(self, last_input, last_target, train, **kwargs):\n        \"Applies mixup to `last_input` and `last_target` if `train`.\"\n        if not train: return\n        lambd = np.random.beta(self.alpha, self.alpha, last_target.size(0))\n        lambd = np.concatenate([lambd[:,None], 1-lambd[:,None]], 1).max(1)\n        lambd = last_input.new(lambd)\n        shuffle = torch.randperm(last_target.size(0)).to(last_input.device)\n        x1, y1 = last_input[shuffle], last_target[shuffle]\n        if self.stack_x:\n            new_input = [last_input, last_input[shuffle], lambd]\n        else: \n            out_shape = [lambd.size(0)] + [1 for _ in range(len(x1.shape) - 1)]\n            new_input = (last_input * lambd.view(out_shape) + x1 * (1-lambd).view(out_shape))\n        if self.stack_y:\n            new_target = torch.cat([last_target[:,None].float(), y1[:,None].float(), lambd[:,None].float()], 1)\n        else:\n            if len(last_target.shape) == 2:\n                lambd = lambd.unsqueeze(1).float()\n            new_target = last_target.float() * lambd + y1.float() * (1-lambd)\n        return {'last_input': new_input, 'last_target': new_target}  \n    \n    def on_train_end(self, **kwargs):\n        if self.stack_y: self.learn.loss_func = self.learn.loss_func.get_old()\n        \n\nclass MixUpLoss(Module):\n    \"Adapt the loss function `crit` to go with mixup.\"\n    \n    def __init__(self, crit, reduction='mean'):\n        super().__init__()\n        if hasattr(crit, 'reduction'): \n            self.crit = crit\n            self.old_red = crit.reduction\n            setattr(self.crit, 'reduction', 'none')\n        else: \n            self.crit = partial(crit, reduction='none')\n            self.old_crit = crit\n        self.reduction = reduction\n        \n    def forward(self, output, target):\n        if len(target.size()) == 2:\n            loss1, loss2 = self.crit(output,target[:,0].long()), self.crit(output,target[:,1].long())\n            d = (loss1 * target[:,2] + loss2 * (1-target[:,2])).mean()\n        else:  d = self.crit(output, target)\n        if self.reduction == 'mean': return d.mean()\n        elif self.reduction == 'sum':            return d.sum()\n        return d\n    \n    def get_old(self):\n        if hasattr(self, 'old_crit'):  return self.old_crit\n        elif hasattr(self, 'old_red'): \n            setattr(self.crit, 'reduction', self.old_red)\n            return self.crit\n"
  },
  {
    "path": "fastai/callbacks/mlflow.py",
    "content": "\"A `Callback` that saves tracked metrics and notebook file into MLflow server.\"\nfrom ..torch_core import *\nfrom ..callback import *\nfrom ..basic_train import Learner, LearnerCallback\n#This is an optional dependency in fastai.  Must install separately.\ntry: import mlflow\nexcept: print(\"To use this tracker, please run 'pip install mlflow'\")\n\nclass MLFlowTracker(LearnerCallback):\n    \"A `TrackerCallback` that tracks the loss and metrics into MLFlow\"\n    def __init__(self, learn:Learner, exp_name: str, params: dict, nb_path: str, uri: str = \"http://localhost:5000\"):\n        super().__init__(learn)\n        self.learn,self.exp_name,self.params,self.nb_path,self.uri = learn,exp_name,params,nb_path,uri\n        self.metrics_names = ['train_loss', 'valid_loss'] + [o.__name__ for o in learn.metrics]\n\n    def on_train_begin(self, **kwargs: Any) -> None:\n        \"Prepare MLflow experiment and log params\"\n        self.client = mlflow.tracking.MlflowClient(self.uri)\n        exp = self.client.get_experiment_by_name(self.exp_name)\n        self.exp_id = self.client.create_experiment(self.exp_name) if exp is None else exp.experiment_id\n        run = self.client.create_run(experiment_id=self.exp_id)\n        self.run = run.info.run_uuid\n        for k,v in self.params.items():\n            self.client.log_param(run_id=self.run, key=k, value=v)\n\n    def on_epoch_end(self, epoch, **kwargs:Any)->None:\n        \"Send loss and metrics values to MLFlow after each epoch\"\n        if kwargs['smooth_loss'] is None or kwargs[\"last_metrics\"] is None: return\n        metrics = [kwargs['smooth_loss']] + kwargs[\"last_metrics\"]\n        for name, val in zip(self.metrics_names, metrics):\n            self.client.log_metric(self.run, name, np.float(val), step=epoch)\n        \n    def on_train_end(self, **kwargs: Any) -> None:  \n        \"Store the notebook and stop run\"\n        self.client.log_artifact(run_id=self.run, local_path=self.nb_path)\n        self.client.set_terminated(run_id=self.run)\n"
  },
  {
    "path": "fastai/callbacks/one_cycle.py",
    "content": "\"Supports 1-Cycle style training\"\nfrom ..core import *\nfrom ..callback import *\nfrom ..basic_train import Learner,LearnerCallback\n\n__all__ = ['OneCycleScheduler']\n\nclass OneCycleScheduler(LearnerCallback):\n    \"Manage 1-Cycle style training as outlined in Leslie Smith's [paper](https://arxiv.org/pdf/1803.09820.pdf).\"\n    def __init__(self, learn:Learner, lr_max:float, moms:Floats=(0.95,0.85), div_factor:float=25., pct_start:float=0.3,\n                 final_div:float=None, tot_epochs:int=None, start_epoch:int=None):\n        super().__init__(learn)\n        self.lr_max,self.div_factor,self.pct_start,self.final_div = lr_max,div_factor,pct_start,final_div\n        if self.final_div is None: self.final_div = div_factor*1e4\n        self.moms=tuple(listify(moms,2))\n        if is_listy(self.lr_max): self.lr_max = np.array(self.lr_max)\n        self.start_epoch, self.tot_epochs = start_epoch, tot_epochs\n\n    def steps(self, *steps_cfg:StartOptEnd):\n        \"Build anneal schedule for all of the parameters.\"\n        return [Scheduler(step, n_iter, func=func)\n                for (step,(n_iter,func)) in zip(steps_cfg, self.phases)]\n\n    def on_train_begin(self, n_epochs:int, epoch:int, **kwargs:Any)->None:\n        \"Initialize our optimization params based on our annealing schedule.\"\n        res = {'epoch':self.start_epoch} if self.start_epoch is not None else None\n        self.start_epoch = ifnone(self.start_epoch, epoch)\n        self.tot_epochs = ifnone(self.tot_epochs, n_epochs)\n        n = len(self.learn.data.train_dl) * self.tot_epochs\n        a1 = int(n * self.pct_start)\n        a2 = n-a1\n        self.phases = ((a1, annealing_cos), (a2, annealing_cos))\n        low_lr = self.lr_max/self.div_factor\n        self.lr_scheds = self.steps((low_lr, self.lr_max), (self.lr_max, self.lr_max/self.final_div))\n        self.mom_scheds = self.steps(self.moms, (self.moms[1], self.moms[0]))\n        self.opt = self.learn.opt\n        self.opt.lr,self.opt.mom = self.lr_scheds[0].start,self.mom_scheds[0].start\n        self.idx_s = 0\n        return res\n    \n    def jump_to_epoch(self, epoch:int)->None:\n        for _ in range(len(self.learn.data.train_dl) * epoch):\n            self.on_batch_end(True)\n\n    def on_batch_end(self, train, **kwargs:Any)->None:\n        \"Take one step forward on the annealing schedule for the optim params.\"\n        if train:\n            if self.idx_s >= len(self.lr_scheds): return {'stop_training': True, 'stop_epoch': True}\n            self.opt.lr = self.lr_scheds[self.idx_s].step()\n            self.opt.mom = self.mom_scheds[self.idx_s].step()\n            # when the current schedule is complete we move onto the next\n            # schedule. (in 1-cycle there are two schedules)\n            if self.lr_scheds[self.idx_s].is_done:\n                self.idx_s += 1\n\n    def on_epoch_end(self, epoch, **kwargs:Any)->None:\n        \"Tell Learner to stop if the cycle is finished.\"\n        if epoch > self.tot_epochs: return {'stop_training': True}\n"
  },
  {
    "path": "fastai/callbacks/oversampling.py",
    "content": "from ..torch_core import *\nfrom ..basic_data import DataBunch\nfrom ..callback import *\nfrom ..basic_train import Learner,LearnerCallback\nfrom torch.utils.data.sampler import WeightedRandomSampler\n\n__all__ = ['OverSamplingCallback']\n\n\n\nclass OverSamplingCallback(LearnerCallback):\n    def __init__(self,learn:Learner,weights:torch.Tensor=None):\n        super().__init__(learn)\n        self.labels = self.learn.data.train_dl.dataset.y.items\n        _, counts = np.unique(self.labels,return_counts=True)\n        self.weights = (weights if weights is not None else\n                        torch.DoubleTensor((1/counts)[self.labels]))\n        self.label_counts = np.bincount([self.learn.data.train_dl.dataset.y[i].data for i in range(len(self.learn.data.train_dl.dataset))])\n        self.total_len_oversample = int(self.learn.data.c*np.max(self.label_counts))\n        \n    def on_train_begin(self, **kwargs):\n        self.learn.data.train_dl.dl.batch_sampler = BatchSampler(WeightedRandomSampler(self.weights,self.total_len_oversample), self.learn.data.train_dl.batch_size,False)"
  },
  {
    "path": "fastai/callbacks/rnn.py",
    "content": "\"Regroups lr adjustment to seq_len, AR and TAR\"\nfrom ..torch_core import *\nfrom ..callback import *\nfrom ..basic_train import Learner, LearnerCallback\n\n__all__ = ['RNNTrainer']\n\nclass RNNTrainer(LearnerCallback):\n    \"`Callback` that regroups lr adjustment to seq_len, AR and TAR.\"\n    def __init__(self, learn:Learner, alpha:float=0., beta:float=0.):\n        super().__init__(learn)\n        self.not_min += ['raw_out', 'out']\n        self.alpha,self.beta = alpha,beta\n        \n    def on_epoch_begin(self, **kwargs):\n        \"Reset the hidden state of the model.\"\n        self.learn.model.reset()\n\n    def on_loss_begin(self, last_output:Tuple[Tensor,Tensor,Tensor], **kwargs):\n        \"Save the extra outputs for later and only returns the true output.\"\n        self.raw_out,self.out = last_output[1],last_output[2]\n        return {'last_output': last_output[0]}\n\n    def on_backward_begin(self, last_loss:Rank0Tensor, last_input:Tensor, **kwargs):\n        \"Apply AR and TAR to `last_loss`.\"\n        #AR and TAR\n        if self.alpha != 0.:  last_loss += self.alpha * self.out[-1].float().pow(2).mean()\n        if self.beta != 0.:\n            h = self.raw_out[-1]\n            if len(h)>1: last_loss += self.beta * (h[:,1:] - h[:,:-1]).float().pow(2).mean()\n        return {'last_loss': last_loss}\n"
  },
  {
    "path": "fastai/callbacks/tensorboard.py",
    "content": "\"Provides convenient callbacks for Learners that write model images, metrics/losses, stats and histograms to Tensorboard\"\nfrom ..basic_train import Learner\nfrom ..basic_data import DatasetType, DataBunch\nfrom ..vision import Image\nfrom ..vision.gan import GANLearner\nfrom ..callbacks import LearnerCallback\nfrom ..core import *\nfrom ..torch_core import *\nfrom threading import Thread, Event\nfrom time import sleep\nfrom queue import Queue\nimport statistics\nimport torchvision.utils as vutils\nfrom abc import ABC\n#This is an optional dependency in fastai.  Must install separately.\ntry: from tensorboardX import SummaryWriter\nexcept: print(\"To use this tracker, please run 'pip install tensorboardx'. Also you must have Tensorboard running to see results\")\n\n__all__=['LearnerTensorboardWriter', 'GANTensorboardWriter', 'ImageGenTensorboardWriter']\n\n#---Example usage (applies to any of the callbacks)--- \n# proj_id = 'Colorize'\n# tboard_path = Path('data/tensorboard/' + proj_id)\n# learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=tboard_path, name='GanLearner'))\n\nclass LearnerTensorboardWriter(LearnerCallback):\n    \"Broadly useful callback for Learners that writes to Tensorboard.  Writes model histograms, losses/metrics, and gradient stats.\"\n    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, stats_iters:int=100):\n        super().__init__(learn=learn)\n        self.base_dir,self.name,self.loss_iters,self.hist_iters,self.stats_iters  = base_dir,name,loss_iters,hist_iters,stats_iters\n        log_dir = base_dir/name\n        self.tbwriter = SummaryWriter(str(log_dir))\n        self.hist_writer = HistogramTBWriter()\n        self.stats_writer = ModelStatsTBWriter()\n        #self.graph_writer = GraphTBWriter()\n        self.data = None\n        self.metrics_root = '/metrics/'\n        self._update_batches_if_needed()\n\n    def _get_new_batch(self, ds_type:DatasetType)->Collection[Tensor]:\n        \"Retrieves new batch of DatasetType, and detaches it.\"\n        return self.learn.data.one_batch(ds_type=ds_type, detach=True, denorm=False, cpu=False)\n\n    def _update_batches_if_needed(self)->None:\n        \"one_batch function is extremely slow with large datasets.  This is caching the result as an optimization.\"\n        if self.learn.data.valid_dl is None: return # Running learning rate finder, so return\n        update_batches = self.data is not self.learn.data\n        if not update_batches: return\n        self.data = self.learn.data\n        self.trn_batch = self._get_new_batch(ds_type=DatasetType.Train)\n        self.val_batch = self._get_new_batch(ds_type=DatasetType.Valid)\n\n    def _write_model_stats(self, iteration:int)->None:\n        \"Writes gradient statistics to Tensorboard.\"\n        self.stats_writer.write(model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)\n\n    def _write_training_loss(self, iteration:int, last_loss:Tensor)->None:\n        \"Writes training loss to Tensorboard.\"\n        scalar_value = to_np(last_loss)\n        tag = self.metrics_root + 'train_loss'\n        self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)\n\n    def _write_weight_histograms(self, iteration:int)->None:\n        \"Writes model weight histograms to Tensorboard.\"\n        self.hist_writer.write(model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)\n\n    def _write_scalar(self, name:str, scalar_value, iteration:int)->None:\n        \"Writes single scalar value to Tensorboard.\"\n        tag = self.metrics_root + name\n        self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)\n\n    #TODO:  Relying on a specific hardcoded start_idx here isn't great.  Is there a better solution?\n    def _write_metrics(self, iteration:int, last_metrics:MetricsList, start_idx:int=2)->None:\n        \"Writes training metrics to Tensorboard.\"\n        recorder = self.learn.recorder\n        for i, name in enumerate(recorder.names[start_idx:]):\n            if last_metrics is None or len(last_metrics) < i+1: return\n            scalar_value = last_metrics[i]\n            self._write_scalar(name=name, scalar_value=scalar_value, iteration=iteration)\n\n    def on_train_begin(self, **kwargs: Any) -> None:\n        #self.graph_writer.write(model=self.learn.model, tbwriter=self.tbwriter,\n                                #input_to_model=next(iter(self.learn.data.dl(DatasetType.Single)))[0])\n        return\n\n    def on_batch_end(self, last_loss:Tensor, iteration:int, **kwargs)->None:\n        \"Callback function that writes batch end appropriate data to Tensorboard.\"\n        if iteration == 0: return\n        self._update_batches_if_needed()\n        if iteration % self.loss_iters == 0: self._write_training_loss(iteration=iteration, last_loss=last_loss)\n        if iteration % self.hist_iters == 0: self._write_weight_histograms(iteration=iteration)\n\n    # Doing stuff here that requires gradient info, because they get zeroed out afterwards in training loop\n    def on_backward_end(self, iteration:int, **kwargs)->None:\n        \"Callback function that writes backward end appropriate data to Tensorboard.\"\n        if iteration == 0: return\n        self._update_batches_if_needed()\n        if iteration % self.stats_iters == 0: self._write_model_stats(iteration=iteration)\n\n    def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs)->None:\n        \"Callback function that writes epoch end appropriate data to Tensorboard.\"\n        self._write_metrics(iteration=iteration, last_metrics=last_metrics)\n\n# TODO:  We're overriding almost everything here.  Seems like a good idea to question that (\"is a\" vs \"has a\")\nclass GANTensorboardWriter(LearnerTensorboardWriter):\n    \"Callback for GANLearners that writes to Tensorboard.  Extends LearnerTensorboardWriter and adds output image writes.\"\n    def __init__(self, learn:GANLearner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, \n                stats_iters:int=100, visual_iters:int=100):\n        super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, hist_iters=hist_iters, stats_iters=stats_iters)\n        self.visual_iters = visual_iters\n        self.img_gen_vis = ImageTBWriter()\n        self.gen_stats_updated = True\n        self.crit_stats_updated = True\n\n    def _write_weight_histograms(self, iteration:int)->None:\n        \"Writes model weight histograms to Tensorboard.\"\n        generator, critic = self.learn.gan_trainer.generator, self.learn.gan_trainer.critic\n        self.hist_writer.write(model=generator, iteration=iteration, tbwriter=self.tbwriter, name='generator')\n        self.hist_writer.write(model=critic,    iteration=iteration, tbwriter=self.tbwriter, name='critic')\n\n    def _write_gen_model_stats(self, iteration:int)->None:\n        \"Writes gradient statistics for generator to Tensorboard.\"\n        generator = self.learn.gan_trainer.generator\n        self.stats_writer.write(model=generator, iteration=iteration, tbwriter=self.tbwriter, name='gen_model_stats')\n        self.gen_stats_updated = True\n\n    def _write_critic_model_stats(self, iteration:int)->None:\n        \"Writes gradient statistics for critic to Tensorboard.\"\n        critic = self.learn.gan_trainer.critic\n        self.stats_writer.write(model=critic, iteration=iteration, tbwriter=self.tbwriter, name='crit_model_stats')\n        self.crit_stats_updated = True\n\n    def _write_model_stats(self, iteration:int)->None:\n        \"Writes gradient statistics to Tensorboard.\"\n        # We don't want to write stats when model is not iterated on and hence has zeroed out gradients\n        gen_mode = self.learn.gan_trainer.gen_mode\n        if gen_mode and not self.gen_stats_updated: self._write_gen_model_stats(iteration=iteration)\n        if not gen_mode and not self.crit_stats_updated: self._write_critic_model_stats(iteration=iteration)\n\n    def _write_training_loss(self, iteration:int, last_loss:Tensor)->None:\n        \"Writes training loss to Tensorboard.\"\n        recorder = self.learn.gan_trainer.recorder\n        if len(recorder.losses) == 0: return\n        scalar_value = to_np((recorder.losses[-1:])[0])\n        tag = self.metrics_root + 'train_loss'\n        self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)\n\n    def _write_images(self, iteration:int)->None:\n        \"Writes model generated, original and real images to Tensorboard.\"\n        trainer = self.learn.gan_trainer\n        #TODO:  Switching gen_mode temporarily seems a bit hacky here.  Certainly not a good side-effect.  Is there a better way?\n        gen_mode = trainer.gen_mode\n        try:\n            trainer.switch(gen_mode=True)\n            self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch, \n                                    iteration=iteration, tbwriter=self.tbwriter)\n        finally: trainer.switch(gen_mode=gen_mode)\n\n    def on_batch_end(self, iteration:int, **kwargs)->None:\n        \"Callback function that writes batch end appropriate data to Tensorboard.\"\n        super().on_batch_end(iteration=iteration, **kwargs)\n        if iteration == 0: return\n        if iteration % self.visual_iters == 0: self._write_images(iteration=iteration)\n\n    def on_backward_end(self, iteration:int, **kwargs)->None:\n        \"Callback function that writes backward end appropriate data to Tensorboard.\"\n        if iteration == 0: return\n        self._update_batches_if_needed()\n        #TODO:  This could perhaps be implemented as queues of requests instead but that seemed like overkill. \n        # But I'm not the biggest fan of maintaining these boolean flags either... Review pls.\n        if iteration % self.stats_iters == 0: self.gen_stats_updated, self.crit_stats_updated = False, False\n        if not (self.gen_stats_updated and self.crit_stats_updated): self._write_model_stats(iteration=iteration)\n\nclass ImageGenTensorboardWriter(LearnerTensorboardWriter):\n    \"Callback for non-GAN image generating Learners that writes to Tensorboard.  Extends LearnerTensorboardWriter and adds output image writes.\"\n    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, stats_iters:int=100, \n                 visual_iters:int=100):\n        super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, hist_iters=hist_iters, \n                         stats_iters=stats_iters)\n        self.visual_iters = visual_iters\n        self.img_gen_vis = ImageTBWriter()\n\n    def _write_images(self, iteration:int)->None:\n        \"Writes model generated, original and real images to Tensorboard\"\n        self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch, iteration=iteration, \n                               tbwriter=self.tbwriter)\n\n    def on_batch_end(self, iteration:int, **kwargs)->None:\n        \"Callback function that writes batch end appropriate data to Tensorboard.\"\n        super().on_batch_end(iteration=iteration, **kwargs)\n        if iteration == 0: return\n        if iteration % self.visual_iters == 0: \n            self._write_images(iteration=iteration)\n\nclass TBWriteRequest(ABC):\n    \"A request object for Tensorboard writes.  Useful for queuing up and executing asynchronous writes.\"\n    def __init__(self, tbwriter: SummaryWriter, iteration:int):\n        super().__init__()\n        self.tbwriter = tbwriter\n        self.iteration = iteration\n\n    @abstractmethod\n    def write(self)->None: pass   \n\n# SummaryWriter writes tend to block quite a bit.  This gets around that and greatly boosts performance.\n# Not all tensorboard writes are using this- just the ones that take a long time.  Note that the \n# SummaryWriter does actually use a threadsafe consumer/producer design ultimately to write to Tensorboard, \n# so writes done outside of this async loop should be fine.\nclass AsyncTBWriter():\n    \"Callback for GANLearners that writes to Tensorboard.  Extends LearnerTensorboardWriter and adds output image writes.\"\n    def __init__(self):\n        super().__init__()\n        self.stop_request = Event()\n        self.queue = Queue()\n        self.thread = Thread(target=self._queue_processor, daemon=True)\n        self.thread.start()\n\n    def request_write(self, request: TBWriteRequest)->None:\n        \"Queues up an asynchronous write request to Tensorboard.\"\n        if self.stop_request.isSet(): return\n        self.queue.put(request)\n\n    def _queue_processor(self)->None:\n        \"Processes queued up write requests asynchronously to Tensorboard.\"\n        while not self.stop_request.isSet():\n            while not self.queue.empty():\n                if self.stop_request.isSet(): return\n                request = self.queue.get()\n                request.write()\n            sleep(0.2)\n\n    #Provided this to stop thread explicitly or by context management (with statement) but thread should end on its own \n    # upon program exit, due to being a daemon.  So using this is probably unecessary.\n    def close(self)->None:\n        \"Stops asynchronous request queue processing thread.\"\n        self.stop_request.set()\n        self.thread.join()\n\n    # Nothing to do, thread already started.  Could start thread here to enforce use of context manager \n    # (but that sounds like a pain and a bit unweildy and unecessary for actual usage)\n    def __enter__(self): pass\n\n    def __exit__(self, exc_type, exc_value, traceback): self.close()\n\nasyncTBWriter = AsyncTBWriter() \n\nclass ModelImageSet():\n    \"Convenience object that holds the original, real(target) and generated versions of a single image fed to a model.\"\n    @staticmethod\n    def get_list_from_model(learn:Learner, ds_type:DatasetType, batch:Tuple)->[]:\n        \"Factory method to convert a batch of model images to a list of ModelImageSet.\"\n        image_sets = []\n        x,y = batch[0],batch[1]\n        preds=[]\n        preds = learn.pred_batch(ds_type=ds_type, batch=(x,y), reconstruct=True)  \n        for orig_px, real_px, gen in zip(x,y,preds):\n            orig, real = Image(px=orig_px), Image(px=real_px)\n            image_set = ModelImageSet(orig=orig, real=real, gen=gen)\n            image_sets.append(image_set)\n        return image_sets  \n\n    def __init__(self, orig:Image, real:Image, gen:Image): self.orig, self.real, self.gen = orig, real, gen\n\nclass HistogramTBRequest(TBWriteRequest):\n    \"Request object for model histogram writes to Tensorboard.\"\n    def __init__(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str):\n        super().__init__(tbwriter=tbwriter, iteration=iteration)\n        self.params = [(name, values.clone().detach().cpu()) for (name, values) in model.named_parameters()]\n        self.name = name\n\n    def _write_histogram(self, param_name:str, values)->None:\n        \"Writes single model histogram to Tensorboard.\"\n        tag = self.name + '/weights/' + param_name\n        self.tbwriter.add_histogram(tag=tag, values=values, global_step=self.iteration)\n\n    def write(self)->None:\n        \"Writes model histograms to Tensorboard.\"\n        for param_name, values in self.params: self._write_histogram(param_name=param_name, values=values)\n\n#If this isn't done async then this is sloooooow\nclass HistogramTBWriter():\n    \"Writes model histograms to Tensorboard.\"\n    def __init__(self): super().__init__()\n\n    def write(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model')->None:\n        \"Writes model histograms to Tensorboard.\"\n        request = HistogramTBRequest(model=model, iteration=iteration, tbwriter=tbwriter, name=name)\n        asyncTBWriter.request_write(request)\n\nclass ModelStatsTBRequest(TBWriteRequest):\n    \"Request object for model gradient statistics writes to Tensorboard.\"\n    def __init__(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str):\n        super().__init__(tbwriter=tbwriter, iteration=iteration)\n        self.gradients = [x.grad.clone().detach().cpu() for x in model.parameters() if x.grad is not None]\n        self.name = name\n\n    def _add_gradient_scalar(self, name:str, scalar_value)->None:\n        \"Writes a single scalar value for a gradient statistic to Tensorboard.\"\n        tag = self.name + '/gradients/' + name\n        self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=self.iteration)\n\n    def _write_avg_norm(self, norms:[])->None:\n        \"Writes the average norm of the gradients to Tensorboard.\"\n        avg_norm = sum(norms)/len(self.gradients)\n        self._add_gradient_scalar('avg_norm', scalar_value=avg_norm)\n\n    def _write_median_norm(self, norms:[])->None:\n        \"Writes the median norm of the gradients to Tensorboard.\"\n        median_norm = statistics.median(norms)\n        self._add_gradient_scalar('median_norm', scalar_value=median_norm)\n\n    def _write_max_norm(self, norms:[])->None:\n        \"Writes the maximum norm of the gradients to Tensorboard.\"\n        max_norm = max(norms)\n        self._add_gradient_scalar('max_norm', scalar_value=max_norm)\n\n    def _write_min_norm(self, norms:[])->None:\n        \"Writes the minimum norm of the gradients to Tensorboard.\"\n        min_norm = min(norms)\n        self._add_gradient_scalar('min_norm', scalar_value=min_norm)\n\n    def _write_num_zeros(self)->None:\n        \"Writes the number of zeroes in the gradients to Tensorboard.\"\n        gradient_nps = [to_np(x.data) for x in self.gradients]\n        num_zeros = sum((np.asarray(x) == 0.0).sum() for x in gradient_nps)\n        self._add_gradient_scalar('num_zeros', scalar_value=num_zeros)\n\n    def _write_avg_gradient(self)->None:\n        \"Writes the average of the gradients to Tensorboard.\"\n        avg_gradient = sum(x.data.mean() for x in self.gradients)/len(self.gradients)\n        self._add_gradient_scalar('avg_gradient', scalar_value=avg_gradient)\n\n    def _write_median_gradient(self)->None:\n        \"Writes the median of the gradients to Tensorboard.\"\n        median_gradient = statistics.median(x.data.median() for x in self.gradients)\n        self._add_gradient_scalar('median_gradient', scalar_value=median_gradient)\n\n    def _write_max_gradient(self)->None:\n        \"Writes the maximum of the gradients to Tensorboard.\"\n        max_gradient = max(x.data.max() for x in self.gradients)\n        self._add_gradient_scalar('max_gradient', scalar_value=max_gradient)\n\n    def _write_min_gradient(self)->None:\n        \"Writes the minimum of the gradients to Tensorboard.\"\n        min_gradient = min(x.data.min() for x in self.gradients)\n        self._add_gradient_scalar('min_gradient', scalar_value=min_gradient)\n\n    def write(self)->None:\n        \"Writes model gradient statistics to Tensorboard.\"\n        if len(self.gradients) == 0: return\n        norms = [x.data.norm() for x in self.gradients]\n        self._write_avg_norm(norms=norms)\n        self._write_median_norm(norms=norms)\n        self._write_max_norm(norms=norms)\n        self._write_min_norm(norms=norms)\n        self._write_num_zeros()\n        self._write_avg_gradient()\n        self._write_median_gradient()\n        self._write_max_gradient()\n        self._write_min_gradient()\n\nclass ModelStatsTBWriter():\n    \"Writes model gradient statistics to Tensorboard.\"\n    def write(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model_stats')->None:\n        \"Writes model gradient statistics to Tensorboard.\"\n        request = ModelStatsTBRequest(model=model, iteration=iteration, tbwriter=tbwriter, name=name)\n        asyncTBWriter.request_write(request)\n\nclass ImageTBRequest(TBWriteRequest):\n    \"Request object for model image output writes to Tensorboard.\"\n    def __init__(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType):\n        super().__init__(tbwriter=tbwriter, iteration=iteration)\n        self.image_sets = ModelImageSet.get_list_from_model(learn=learn, batch=batch, ds_type=ds_type)\n        self.ds_type = ds_type\n\n    def _write_images(self, name:str, images:[Tensor])->None:\n        \"Writes list of images as tensors to Tensorboard.\"\n        tag = self.ds_type.name + ' ' + name\n        self.tbwriter.add_image(tag=tag, img_tensor=vutils.make_grid(images, normalize=True), global_step=self.iteration)\n\n    def _get_image_tensors(self)->([Tensor], [Tensor], [Tensor]):\n        \"Gets list of image tensors from lists of Image objects, as a tuple of original, generated and real(target) images.\"\n        orig_images, gen_images, real_images = [], [], []\n        for image_set in self.image_sets:\n            orig_images.append(image_set.orig.px)\n            gen_images.append(image_set.gen.px)\n            real_images.append(image_set.real.px) \n        return orig_images, gen_images, real_images  \n\n    def write(self)->None:\n        \"Writes original, generated and real(target) images to Tensorboard.\"\n        orig_images, gen_images, real_images = self._get_image_tensors()\n        self._write_images(name='orig images', images=orig_images)\n        self._write_images(name='gen images',  images=gen_images)\n        self._write_images(name='real images', images=real_images)\n\n#If this isn't done async then this is noticeably slower\nclass ImageTBWriter():\n    \"Writes model image output to Tensorboard.\"\n    def __init__(self): super().__init__()\n\n    def write(self, learn:Learner, trn_batch:Tuple, val_batch:Tuple, iteration:int, tbwriter:SummaryWriter)->None:\n        \"Writes training and validation batch images to Tensorboard.\"\n        self._write_for_dstype(learn=learn, batch=val_batch, iteration=iteration, tbwriter=tbwriter, ds_type=DatasetType.Valid)\n        self._write_for_dstype(learn=learn, batch=trn_batch, iteration=iteration, tbwriter=tbwriter, ds_type=DatasetType.Train)\n\n    def _write_for_dstype(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType)->None:\n        \"Writes batch images of specified DatasetType to Tensorboard.\"\n        request = ImageTBRequest(learn=learn, batch=batch, iteration=iteration, tbwriter=tbwriter, ds_type=ds_type)\n        asyncTBWriter.request_write(request)\n\nclass GraphTBRequest(TBWriteRequest):\n    \"Request object for model histogram writes to Tensorboard.\"\n    def __init__(self, model:nn.Module, tbwriter:SummaryWriter, input_to_model:torch.Tensor):\n        super().__init__(tbwriter=tbwriter, iteration=0)\n        self.model,self.input_to_model = model,input_to_model\n\n    def write(self)->None:\n        \"Writes single model graph to Tensorboard.\"\n        self.tbwriter.add_graph(model=self.model, input_to_model=self.input_to_model)\n\nclass GraphTBWriter():\n    \"Writes model network graph to Tensorboard.\"\n    def write(self, model:nn.Module, tbwriter:SummaryWriter, input_to_model:torch.Tensor)->None:\n        \"Writes model graph to Tensorboard.\"\n        request = GraphTBRequest(model=model, tbwriter=tbwriter, input_to_model=input_to_model)\n        asyncTBWriter.request_write(request)\n"
  },
  {
    "path": "fastai/callbacks/tracker.py",
    "content": "# Contribution from @fredguth, https://github.com/fredguth/fastai_playground.\n\nfrom fastai.torch_core import *\nfrom fastai.callback import *\nfrom fastai.basic_train import *\n\n__all__ = ['TerminateOnNaNCallback', 'EarlyStoppingCallback', 'SaveModelCallback', 'TrackerCallback',\n        'ReduceLROnPlateauCallback', 'TrackEpochCallback' ]\n\nclass TerminateOnNaNCallback(Callback):\n    \"A `Callback` that terminates training if loss is NaN.\"\n\n    def __init__(self):\n        self.stop = False\n\n    def on_batch_end(self, last_loss, epoch, num_batch, **kwargs:Any)->None:\n        \"Test if `last_loss` is NaN and interrupts training.\"\n        if self.stop: return True #to skip validation after stopping during training\n        if torch.isnan(last_loss):\n            print (f'Epoch/Batch ({epoch}/{num_batch}): Invalid loss, terminating training.')\n            return {'stop_epoch': True, 'stop_training': True, 'skip_validate': True}\n\nclass TrackerCallback(LearnerCallback):\n    \"A `LearnerCallback` that keeps track of the best value in `monitor`.\"\n    def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto'):\n        super().__init__(learn)\n        self.monitor,self.mode = monitor,mode\n        if self.mode not in ['auto', 'min', 'max']:\n            warn(f'{self.__class__} mode {self.mode} is invalid, falling back to \"auto\" mode.')\n            self.mode = 'auto'\n        mode_dict = {'min': np.less, 'max':np.greater}\n        mode_dict['auto'] = np.less if 'loss' in self.monitor else np.greater\n        self.operator = mode_dict[self.mode]\n\n    def on_train_begin(self, **kwargs:Any)->None:\n        \"Initializes the best value.\"\n        self.best = float('inf') if self.operator == np.less else -float('inf')\n\n    def get_monitor_value(self):\n        \"Pick the monitored value.\"\n        if self.monitor=='trn_loss' and len(self.learn.recorder.losses) == 0: return None\n        elif len(self.learn.recorder.val_losses) == 0: return None\n        values = {'train_loss':self.learn.recorder.losses[-1].cpu().numpy(),\n                  'valid_loss':self.learn.recorder.val_losses[-1]}\n        if values['valid_loss'] is None: return\n        if self.learn.recorder.metrics:\n            for m, n in zip(self.learn.recorder.metrics[-1],self.learn.recorder.names[3:-1]):\n                values[n] = m\n        if values.get(self.monitor) is None:\n            warn(f'{self.__class__} conditioned on metric `{self.monitor}` which is not available. Available metrics are: {\", \".join(map(str, self.learn.recorder.names[1:-1]))}')\n        return values.get(self.monitor)\n\nclass EarlyStoppingCallback(TrackerCallback):\n    \"A `TrackerCallback` that terminates training when monitored quantity stops improving.\"\n    def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto', min_delta:int=0, patience:int=0):\n        super().__init__(learn, monitor=monitor, mode=mode)\n        self.min_delta,self.patience = min_delta,patience\n        if self.operator == np.less:  self.min_delta *= -1\n\n    def on_train_begin(self, **kwargs:Any)->None:\n        \"Initialize inner arguments.\"\n        self.wait = 0\n        super().on_train_begin(**kwargs)\n\n    def on_epoch_end(self, epoch, **kwargs:Any)->None:\n        \"Compare the value monitored to its best score and maybe stop training.\"\n        current = self.get_monitor_value()\n        if current is None: return\n        if self.operator(current - self.min_delta, self.best):\n            self.best,self.wait = current,0\n        else:\n            self.wait += 1\n            if self.wait > self.patience:\n                print(f'Epoch {epoch}: early stopping')\n                return {\"stop_training\":True}\n\nclass SaveModelCallback(TrackerCallback):\n    \"A `TrackerCallback` that saves the model when monitored quantity is best.\"\n    def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto', every:str='improvement', name:str='bestmodel'):\n        super().__init__(learn, monitor=monitor, mode=mode)\n        self.every,self.name = every,name\n        if self.every not in ['improvement', 'epoch']:\n            warn(f'SaveModel every {self.every} is invalid, falling back to \"improvement\".')\n            self.every = 'improvement'\n                 \n    def jump_to_epoch(self, epoch:int)->None:\n        try: \n            self.learn.load(f'{self.name}_{epoch-1}', purge=False)\n            print(f\"Loaded {self.name}_{epoch-1}\")\n        except: print(f'Model {self.name}_{epoch-1} not found.')\n\n    def on_epoch_end(self, epoch:int, **kwargs:Any)->None:\n        \"Compare the value monitored to its best score and maybe save the model.\"\n        if self.every==\"epoch\": self.learn.save(f'{self.name}_{epoch}')\n        else: #every=\"improvement\"\n            current = self.get_monitor_value()\n            if current is not None and self.operator(current, self.best):\n                print(f'Better model found at epoch {epoch} with {self.monitor} value: {current}.')\n                self.best = current\n                self.learn.save(f'{self.name}')\n\n    def on_train_end(self, **kwargs):\n        \"Load the best model.\"\n        if self.every==\"improvement\" and (self.learn.path/f'{self.learn.model_dir}/{self.name}.pth').is_file():\n            self.learn.load(f'{self.name}', purge=False)\n\nclass ReduceLROnPlateauCallback(TrackerCallback):\n    \"A `TrackerCallback` that reduces learning rate when a metric has stopped improving.\"\n    def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto', patience:int=0, factor:float=0.2,\n                 min_delta:int=0):\n        super().__init__(learn, monitor=monitor, mode=mode)\n        self.patience,self.factor,self.min_delta = patience,factor,min_delta\n        if self.operator == np.less:  self.min_delta *= -1\n\n    def on_train_begin(self, **kwargs:Any)->None:\n        \"Initialize inner arguments.\"\n        self.wait, self.opt = 0, self.learn.opt\n        super().on_train_begin(**kwargs)\n\n    def on_epoch_end(self, epoch, **kwargs:Any)->None:\n        \"Compare the value monitored to its best and maybe reduce lr.\"\n        current = self.get_monitor_value()\n        if current is None: return\n        if self.operator(current - self.min_delta, self.best): self.best,self.wait = current,0\n        else:\n            self.wait += 1\n            if self.wait > self.patience:\n                self.opt.lr *= self.factor\n                self.wait = 0\n                print(f'Epoch {epoch}: reducing lr to {self.opt.lr}')\n\n\nclass TrackEpochCallback(LearnerCallback):\n    _order = -20 #Need to run before fit_one_cycle\n    def __init__(self, learn:Learner, name:str='epoch', epoch_offset:int=None):\n        \"Store completed epoch number in `learn.model_dir/name`.\"\n        super().__init__(learn)\n        learn._test_writeable_path()\n        self.path = learn.path/learn.model_dir/name\n        if epoch_offset is None:\n            if os.path.isfile(self.path):\n                 with self.path.open('r') as f:\n                     try:    self.start_epoch = int(f.read())+1\n                     except: self.start_epoch = 0\n            else: self.start_epoch = 0\n                \n    def on_train_begin(self, **kwargs:Any):\n        return {'epoch': self.start_epoch}\n\n    def on_epoch_end(self, epoch, **kwargs:Any)->None:\n        with self.path.open('w') as f: f.write(f'{epoch}')\n\n    def restart(self): os.remove(self.path)\n"
  },
  {
    "path": "fastai/collab.py",
    "content": "\"Module support for Collaborative Filtering\"\nfrom .tabular import *\nfrom . import tabular\n\n__all__ = [*tabular.__all__, 'EmbeddingDotBias', 'EmbeddingNN', 'collab_learner', 'CollabDataBunch', 'CollabLine',\n           'CollabList', 'CollabLearner']\n\nclass CollabProcessor(TabularProcessor):\n    \"Subclass `TabularProcessor for `process_one`.\"\n    def process_one(self, item):\n        res = super().process_one(item)\n        return CollabLine(res.cats,res.conts,res.classes,res.names)\n\nclass CollabLine(TabularLine):\n    \"Base item for collaborative filtering, subclasses `TabularLine`.\"\n    def __init__(self, cats, conts, classes, names):\n        super().__init__(cats, conts, classes, names)\n        self.data = [self.data[0][0],self.data[0][1]]\n\nclass CollabList(TabularList):\n    \"Base `ItemList` for collaborative filtering, subclasses `TabularList`.\"\n    _item_cls,_label_cls,_processor = CollabLine,FloatList,CollabProcessor\n\n    def reconstruct(self, t:Tensor): return CollabLine(tensor(t), tensor([]), self.classes, self.col_names)\n\nclass EmbeddingNN(TabularModel):\n    \"Subclass `TabularModel` to create a NN suitable for collaborative filtering.\"\n    def __init__(self, emb_szs:ListSizes, layers:Collection[int]=None, ps:Collection[float]=None,\n                 emb_drop:float=0., y_range:OptRange=None, use_bn:bool=True, bn_final:bool=False):\n        super().__init__(emb_szs=emb_szs, n_cont=0, out_sz=1, layers=layers, ps=ps, emb_drop=emb_drop, y_range=y_range,\n                         use_bn=use_bn, bn_final=bn_final)\n\n    def forward(self, users:LongTensor, items:LongTensor) -> Tensor:\n        return super().forward(torch.stack([users,items], dim=1), None)\n\nclass EmbeddingDotBias(Module):\n    \"Base dot model for collaborative filtering.\"\n    def __init__(self, n_factors:int, n_users:int, n_items:int, y_range:Tuple[float,float]=None):\n        self.y_range = y_range\n        (self.u_weight, self.i_weight, self.u_bias, self.i_bias) = [embedding(*o) for o in [\n            (n_users, n_factors), (n_items, n_factors), (n_users,1), (n_items,1)\n        ]]\n\n    def forward(self, users:LongTensor, items:LongTensor) -> Tensor:\n        dot = self.u_weight(users)* self.i_weight(items)\n        res = dot.sum(1) + self.u_bias(users).squeeze() + self.i_bias(items).squeeze()\n        if self.y_range is None: return res\n        return torch.sigmoid(res) * (self.y_range[1]-self.y_range[0]) + self.y_range[0]\n\nclass CollabDataBunch(DataBunch):\n    \"Base `DataBunch` for collaborative filtering.\"\n    @classmethod\n    def from_df(cls, ratings:DataFrame, valid_pct:float=0.2, user_name:Optional[str]=None, item_name:Optional[str]=None,\n                rating_name:Optional[str]=None, test:DataFrame=None, seed:int=None, path:PathOrStr='.', bs:int=64, \n                val_bs:int=None, num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None, \n                device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False) -> 'CollabDataBunch':\n        \"Create a `DataBunch` suitable for collaborative filtering from `ratings`.\"\n        user_name   = ifnone(user_name,  ratings.columns[0])\n        item_name   = ifnone(item_name,  ratings.columns[1])\n        rating_name = ifnone(rating_name,ratings.columns[2])\n        cat_names = [user_name,item_name]\n        src = (CollabList.from_df(ratings, cat_names=cat_names, procs=Categorify)\n               .split_by_rand_pct(valid_pct=valid_pct, seed=seed).label_from_df(cols=rating_name))\n        if test is not None: src.add_test(CollabList.from_df(test, cat_names=cat_names))\n        return src.databunch(path=path, bs=bs, val_bs=val_bs, num_workers=num_workers, device=device, \n                             collate_fn=collate_fn, no_check=no_check)\n\nclass CollabLearner(Learner):\n    \"`Learner` suitable for collaborative filtering.\"\n    def get_idx(self, arr:Collection, is_item:bool=True):\n        \"Fetch item or user (based on `is_item`) for all in `arr`. (Set model to `cpu` and no grad.)\"\n        m = self.model.eval().cpu()\n        requires_grad(m,False)\n        u_class,i_class = self.data.train_ds.x.classes.values()\n        classes = i_class if is_item else u_class\n        c2i = {v:k for k,v in enumerate(classes)}\n        try: return tensor([c2i[o] for o in arr])\n        except Exception as e: \n            print(f\"\"\"You're trying to access {'an item' if is_item else 'a user'} that isn't in the training data.\n                  If it was in your original data, it may have been split such that it's only in the validation set now.\"\"\")\n\n    def bias(self, arr:Collection, is_item:bool=True):\n        \"Bias for item or user (based on `is_item`) for all in `arr`. (Set model to `cpu` and no grad.)\"\n        idx = self.get_idx(arr, is_item)\n        m = self.model\n        layer = m.i_bias if is_item else m.u_bias\n        return layer(idx).squeeze()\n\n    def weight(self, arr:Collection, is_item:bool=True):\n        \"Bias for item or user (based on `is_item`) for all in `arr`. (Set model to `cpu` and no grad.)\"\n        idx = self.get_idx(arr, is_item)\n        m = self.model\n        layer = m.i_weight if is_item else m.u_weight\n        return layer(idx)\n\ndef collab_learner(data, n_factors:int=None, use_nn:bool=False, emb_szs:Dict[str,int]=None, layers:Collection[int]=None, \n                   ps:Collection[float]=None, emb_drop:float=0., y_range:OptRange=None, use_bn:bool=True, \n                   bn_final:bool=False, **learn_kwargs)->Learner:\n    \"Create a Learner for collaborative filtering on `data`.\"\n    emb_szs = data.get_emb_szs(ifnone(emb_szs, {}))\n    u,m = data.train_ds.x.classes.values()\n    if use_nn: model = EmbeddingNN(emb_szs=emb_szs, layers=layers, ps=ps, emb_drop=emb_drop, y_range=y_range, \n                                   use_bn=use_bn, bn_final=bn_final, **learn_kwargs)\n    else:      model = EmbeddingDotBias(n_factors, len(u), len(m), y_range=y_range)\n    return CollabLearner(data, model, **learn_kwargs)\n\n"
  },
  {
    "path": "fastai/core.py",
    "content": "\"`fastai.core` contains essential util functions to format and split data\"\nfrom .imports.core import *\n\nwarnings.filterwarnings(\"ignore\", message=\"numpy.dtype size changed\")\nwarnings.filterwarnings(\"ignore\", message=\"numpy.ufunc size changed\")\n\nAnnealFunc = Callable[[Number,Number,float], Number]\nArgStar = Collection[Any]\nBatchSamples = Collection[Tuple[Collection[int], int]]\nDataFrameOrChunks = Union[DataFrame, pd.io.parsers.TextFileReader]\nFilePathList = Collection[Path]\nFloats = Union[float, Collection[float]]\nImgLabel = str\nImgLabels = Collection[ImgLabel]\nIntsOrStrs = Union[int, Collection[int], str, Collection[str]]\nKeyFunc = Callable[[int], int]\nKWArgs = Dict[str,Any]\nListOrItem = Union[Collection[Any],int,float,str]\nListRules = Collection[Callable[[str],str]]\nListSizes = Collection[Tuple[int,int]]\nNPArrayableList = Collection[Union[np.ndarray, list]]\nNPArrayList = Collection[np.ndarray]\nNPArrayMask = np.ndarray\nNPImage = np.ndarray\nOptDataFrame = Optional[DataFrame]\nOptListOrItem = Optional[ListOrItem]\nOptRange = Optional[Tuple[float,float]]\nOptStrTuple = Optional[Tuple[str,str]]\nOptStats = Optional[Tuple[np.ndarray, np.ndarray]]\nPathOrStr = Union[Path,str]\nPathLikeOrBinaryStream = Union[PathOrStr, BufferedWriter, BytesIO]\nPBar = Union[MasterBar, ProgressBar]\nPoint=Tuple[float,float]\nPoints=Collection[Point]\nSizes = List[List[int]]\nSplitArrayList = List[Tuple[np.ndarray,np.ndarray]]\nStartOptEnd=Union[float,Tuple[float,float]]\nStrList = Collection[str]\nTokens = Collection[Collection[str]]\nOptStrList = Optional[StrList]\nnp.set_printoptions(precision=6, threshold=50, edgeitems=4, linewidth=120)\n\ndef num_cpus()->int:\n    \"Get number of cpus\"\n    try:                   return len(os.sched_getaffinity(0))\n    except AttributeError: return os.cpu_count()\n\n_default_cpus = min(16, num_cpus())\ndefaults = SimpleNamespace(cpus=_default_cpus, cmap='viridis', return_fig=False, silent=False)\n\ndef is_listy(x:Any)->bool: return isinstance(x, (tuple,list))\ndef is_tuple(x:Any)->bool: return isinstance(x, tuple)\ndef is_dict(x:Any)->bool: return isinstance(x, dict)\ndef is_pathlike(x:Any)->bool: return isinstance(x, (str,Path))\ndef noop(x): return x\n\nclass PrePostInitMeta(type):\n    \"A metaclass that calls optional `__pre_init__` and `__post_init__` methods\"\n    def __new__(cls, name, bases, dct):\n        x = super().__new__(cls, name, bases, dct)\n        old_init = x.__init__\n        def _pass(self): pass\n        @functools.wraps(old_init)\n        def _init(self,*args,**kwargs):\n            self.__pre_init__()\n            old_init(self, *args,**kwargs)\n            self.__post_init__()\n        x.__init__ = _init\n        if not hasattr(x,'__pre_init__'):  x.__pre_init__  = _pass\n        if not hasattr(x,'__post_init__'): x.__post_init__ = _pass\n        return x\n\ndef chunks(l:Collection, n:int)->Iterable:\n    \"Yield successive `n`-sized chunks from `l`.\"\n    for i in range(0, len(l), n): yield l[i:i+n]\n\ndef recurse(func:Callable, x:Any, *args, **kwargs)->Any:\n    if is_listy(x): return [recurse(func, o, *args, **kwargs) for o in x]\n    if is_dict(x):  return {k: recurse(func, v, *args, **kwargs) for k,v in x.items()}\n    return func(x, *args, **kwargs)\n\ndef first_el(x: Any)->Any:\n    \"Recursively get the first element of `x`.\"\n    if is_listy(x): return first_el(x[0])\n    if is_dict(x):  return first_el(x[list(x.keys())[0]])\n    return x\n\ndef to_int(b:Any)->Union[int,List[int]]:\n    \"Recursively convert `b` to an int or list/dict of ints; raises exception if not convertible.\"\n    return recurse(lambda x: int(x), b)\n\ndef ifnone(a:Any,b:Any)->Any:\n    \"`a` if `a` is not None, otherwise `b`.\"\n    return b if a is None else a\n\ndef is1d(a:Collection)->bool:\n    \"Return `True` if `a` is one-dimensional\"\n    return len(a.shape) == 1 if hasattr(a, 'shape') else len(np.array(a).shape) == 1\n\ndef uniqueify(x:Series, sort:bool=False)->List:\n    \"Return sorted unique values of `x`.\"\n    res = list(OrderedDict.fromkeys(x).keys())\n    if sort: res.sort()\n    return res\n\ndef idx_dict(a):\n    \"Create a dictionary value to index from `a`.\"\n    return {v:k for k,v in enumerate(a)}\n\ndef find_classes(folder:Path)->FilePathList:\n    \"List of label subdirectories in imagenet-style `folder`.\"\n    classes = [d for d in folder.iterdir()\n               if d.is_dir() and not d.name.startswith('.')]\n    assert(len(classes)>0)\n    return sorted(classes, key=lambda d: d.name)\n\ndef arrays_split(mask:NPArrayMask, *arrs:NPArrayableList)->SplitArrayList:\n    \"Given `arrs` is [a,b,...] and `mask`index - return[(a[mask],a[~mask]),(b[mask],b[~mask]),...].\"\n    assert all([len(arr)==len(arrs[0]) for arr in arrs]), 'All arrays should have same length'\n    mask = array(mask)\n    return list(zip(*[(a[mask],a[~mask]) for a in map(np.array, arrs)]))\n\ndef random_split(valid_pct:float, *arrs:NPArrayableList)->SplitArrayList:\n    \"Randomly split `arrs` with `valid_pct` ratio. good for creating validation set.\"\n    assert (valid_pct>=0 and valid_pct<=1), 'Validation set percentage should be between 0 and 1'\n    is_train = np.random.uniform(size=(len(arrs[0]),)) > valid_pct\n    return arrays_split(is_train, *arrs)\n\ndef listify(p:OptListOrItem=None, q:OptListOrItem=None):\n    \"Make `p` listy and the same length as `q`.\"\n    if p is None: p=[]\n    elif isinstance(p, str):          p = [p]\n    elif not isinstance(p, Iterable): p = [p]\n    #Rank 0 tensors in PyTorch are Iterable but don't have a length.\n    else:\n        try: a = len(p)\n        except: p = [p]\n    n = q if type(q)==int else len(p) if q is None else len(q)\n    if len(p)==1: p = p * n\n    assert len(p)==n, f'List len mismatch ({len(p)} vs {n})'\n    return list(p)\n\n_camel_re1 = re.compile('(.)([A-Z][a-z]+)')\n_camel_re2 = re.compile('([a-z0-9])([A-Z])')\ndef camel2snake(name:str)->str:\n    \"Change `name` from camel to snake style.\"\n    s1 = re.sub(_camel_re1, r'\\1_\\2', name)\n    return re.sub(_camel_re2, r'\\1_\\2', s1).lower()\n\ndef even_mults(start:float, stop:float, n:int)->np.ndarray:\n    \"Build log-stepped array from `start` to `stop` in `n` steps.\"\n    mult = stop/start\n    step = mult**(1/(n-1))\n    return np.array([start*(step**i) for i in range(n)])\n\ndef extract_kwargs(names:Collection[str], kwargs:KWArgs):\n    \"Extract the keys in `names` from the `kwargs`.\"\n    new_kwargs = {}\n    for arg_name in names:\n        if arg_name in kwargs:\n            arg_val = kwargs.pop(arg_name)\n            new_kwargs[arg_name] = arg_val\n    return new_kwargs, kwargs\n\ndef partition(a:Collection, sz:int)->List[Collection]:\n    \"Split iterables `a` in equal parts of size `sz`\"\n    return [a[i:i+sz] for i in range(0, len(a), sz)]\n\ndef partition_by_cores(a:Collection, n_cpus:int)->List[Collection]:\n    \"Split data in `a` equally among `n_cpus` cores\"\n    return partition(a, len(a)//n_cpus + 1)\n\ndef series2cat(df:DataFrame, *col_names):\n    \"Categorifies the columns `col_names` in `df`.\"\n    for c in listify(col_names): df[c] = df[c].astype('category').cat.as_ordered()\n\nTfmList = Union[Callable, Collection[Callable]]\n\nclass ItemBase():\n    \"Base item type in the fastai library.\"\n    def __init__(self, data:Any): self.data=self.obj=data\n    def __repr__(self)->str: return f'{self.__class__.__name__} {str(self)}'\n    def show(self, ax:plt.Axes, **kwargs):\n        \"Subclass this method if you want to customize the way this `ItemBase` is shown on `ax`.\"\n        ax.set_title(str(self))\n    def apply_tfms(self, tfms:Collection, **kwargs):\n        \"Subclass this method if you want to apply data augmentation with `tfms` to this `ItemBase`.\"\n        if tfms: raise Exception(f\"Not implemented: you can't apply transforms to this type of item ({self.__class__.__name__})\")\n        return self\n    def __eq__(self, other): return recurse_eq(self.data, other.data)\n\ndef recurse_eq(arr1, arr2):\n    if is_listy(arr1): return is_listy(arr2) and len(arr1) == len(arr2) and np.all([recurse_eq(x,y) for x,y in zip(arr1,arr2)])\n    else:              return np.all(np.atleast_1d(arr1 == arr2))\n        \ndef download_url(url:str, dest:str, overwrite:bool=False, pbar:ProgressBar=None,\n                 show_progress=True, chunk_size=1024*1024, timeout=4, retries=5)->None:\n    \"Download `url` to `dest` unless it exists and not `overwrite`.\"\n    if os.path.exists(dest) and not overwrite: return\n\n    s = requests.Session()\n    s.mount('http://',requests.adapters.HTTPAdapter(max_retries=retries))\n    u = s.get(url, stream=True, timeout=timeout)\n    try: file_size = int(u.headers[\"Content-Length\"])\n    except: show_progress = False\n\n    with open(dest, 'wb') as f:\n        nbytes = 0\n        if show_progress: pbar = progress_bar(range(file_size), auto_update=False, leave=False, parent=pbar)\n        try:\n            for chunk in u.iter_content(chunk_size=chunk_size):\n                nbytes += len(chunk)\n                if show_progress: pbar.update(nbytes)\n                f.write(chunk)\n        except requests.exceptions.ConnectionError as e:\n            fname = url.split('/')[-1]\n            from fastai.datasets import Config\n            data_dir = Config().data_path()\n            timeout_txt =(f'\\n Download of {url} has failed after {retries} retries\\n'\n                          f' Fix the download manually:\\n'\n                          f'$ mkdir -p {data_dir}\\n'\n                          f'$ cd {data_dir}\\n'\n                          f'$ wget -c {url}\\n'\n                          f'$ tar -zxvf {fname}\\n\\n'\n                          f'And re-run your code once the download is successful\\n')\n            print(timeout_txt)\n            import sys;sys.exit(1)\n\ndef range_of(x):\n    \"Create a range from 0 to `len(x)`.\"\n    return list(range(len(x)))\ndef arange_of(x):\n    \"Same as `range_of` but returns an array.\"\n    return np.arange(len(x))\n\nPath.ls = lambda x: list(x.iterdir())\n\ndef join_path(fname:PathOrStr, path:PathOrStr='.')->Path:\n    \"Return `Path(path)/Path(fname)`, `path` defaults to current dir.\"\n    return Path(path)/Path(fname)\n\ndef join_paths(fnames:FilePathList, path:PathOrStr='.')->Collection[Path]:\n    \"Join `path` to every file name in `fnames`.\"\n    path = Path(path)\n    return [join_path(o,path) for o in fnames]\n\ndef loadtxt_str(path:PathOrStr)->np.ndarray:\n    \"Return `ndarray` of `str` of lines of text from `path`.\"\n    with open(path, 'r') as f: lines = f.readlines()\n    return np.array([l.strip() for l in lines])\n\ndef save_texts(fname:PathOrStr, texts:Collection[str]):\n    \"Save in `fname` the content of `texts`.\"\n    with open(fname, 'w') as f:\n        for t in texts: f.write(f'{t}\\n')\n\ndef df_names_to_idx(names:IntsOrStrs, df:DataFrame):\n    \"Return the column indexes of `names` in `df`.\"\n    if not is_listy(names): names = [names]\n    if isinstance(names[0], int): return names\n    return [df.columns.get_loc(c) for c in names]\n\ndef one_hot(x:Collection[int], c:int):\n    \"One-hot encode `x` with `c` classes.\"\n    res = np.zeros((c,), np.float32)\n    res[listify(x)] = 1.\n    return res\n\ndef index_row(a:Union[Collection,pd.DataFrame,pd.Series], idxs:Collection[int])->Any:\n    \"Return the slice of `a` corresponding to `idxs`.\"\n    if a is None: return a\n    if isinstance(a,(pd.DataFrame,pd.Series)):\n        res = a.iloc[idxs]\n        if isinstance(res,(pd.DataFrame,pd.Series)): return res.copy()\n        return res\n    return a[idxs]\n\ndef func_args(func)->bool:\n    \"Return the arguments of `func`.\"\n    code = func.__code__\n    return code.co_varnames[:code.co_argcount]\n\ndef has_arg(func, arg)->bool:\n    \"Check if `func` accepts `arg`.\"\n    return arg in func_args(func)\n\ndef split_kwargs_by_func(kwargs, func):\n    \"Split `kwargs` between those expected by `func` and the others.\"\n    args = func_args(func)\n    func_kwargs = {a:kwargs.pop(a) for a in args if a in kwargs}\n    return func_kwargs, kwargs\n\ndef array(a, dtype:type=None, **kwargs)->np.ndarray:\n    \"Same as `np.array` but also handles generators. `kwargs` are passed to `np.array` with `dtype`.\"\n    if not isinstance(a, collections.abc.Sized) and not getattr(a,'__array_interface__',False):\n        a = list(a)\n    if np.int_==np.int32 and dtype is None and is_listy(a) and len(a) and isinstance(a[0],int):\n        dtype=np.int64\n    return np.array(a, dtype=dtype, **kwargs)\n\nclass EmptyLabel(ItemBase):\n    \"Should be used for a dummy label.\"\n    def __init__(self): self.obj,self.data = 0,0\n    def __str__(self):  return ''\n    def __hash__(self): return hash(str(self))\n\nclass Category(ItemBase):\n    \"Basic class for single classification labels.\"\n    def __init__(self,data,obj): self.data,self.obj = data,obj\n    def __int__(self):  return int(self.data)\n    def __str__(self):  return str(self.obj)\n    def __hash__(self): return hash(str(self))\n\nclass MultiCategory(ItemBase):\n    \"Basic class for multi-classification labels.\"\n    def __init__(self,data,obj,raw): self.data,self.obj,self.raw = data,obj,raw\n    def __str__(self):  return ';'.join([str(o) for o in self.obj])\n    def __hash__(self): return hash(str(self))\n\nclass FloatItem(ItemBase):\n    \"Basic class for float items.\"\n    def __init__(self,obj): self.data,self.obj = np.array(obj).astype(np.float32),obj\n    def __str__(self):  return str(self.obj)\n    def __hash__(self): return hash(str(self))\n\ndef _treat_html(o:str)->str:\n    o = str(o)\n    to_replace = {'\\n':'\\\\n', '<':'&lt;', '>':'&gt;', '&':'&amp;'}\n    for k,v in to_replace.items(): o = o.replace(k, v)\n    return o\n\ndef text2html_table(items:Collection[Collection[str]])->str:\n    \"Put the texts in `items` in an HTML table, `widths` are the widths of the columns in %.\"\n    html_code = f\"\"\"<table border=\"1\" class=\"dataframe\">\"\"\"\n    html_code += f\"\"\"  <thead>\\n    <tr style=\"text-align: right;\">\\n\"\"\"\n    for i in items[0]: html_code += f\"      <th>{_treat_html(i)}</th>\"\n    html_code += f\"    </tr>\\n  </thead>\\n  <tbody>\"\n    html_code += \"  <tbody>\"\n    for line in items[1:]:\n        html_code += \"    <tr>\"\n        for i in line: html_code += f\"      <td>{_treat_html(i)}</td>\"\n        html_code += \"    </tr>\"\n    html_code += \"  </tbody>\\n</table>\"\n    return html_code\n\ndef parallel(func, arr:Collection, max_workers:int=None, leave=False):\n    \"Call `func` on every element of `arr` in parallel using `max_workers`.\"\n    max_workers = ifnone(max_workers, defaults.cpus)\n    if max_workers<2: results = [func(o,i) for i,o in progress_bar(enumerate(arr), total=len(arr), leave=leave)]\n    else:\n        with ProcessPoolExecutor(max_workers=max_workers) as ex:\n            futures = [ex.submit(func,o,i) for i,o in enumerate(arr)]\n            results = []\n            for f in progress_bar(concurrent.futures.as_completed(futures), total=len(arr), leave=leave): \n                results.append(f.result())\n    if any([o is not None for o in results]): return results\n\ndef subplots(rows:int, cols:int, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, title=None, **kwargs):\n    \"Like `plt.subplots` but with consistent axs shape, `kwargs` passed to `fig.suptitle` with `title`\"\n    figsize = ifnone(figsize, (imgsize*cols, imgsize*rows))\n    fig, axs = plt.subplots(rows,cols,figsize=figsize)\n    if rows==cols==1: axs = [[axs]] # subplots(1,1) returns Axes, not [Axes]\n    elif (rows==1 and cols!=1) or (cols==1 and rows!=1): axs = [axs]\n    if title is not None: fig.suptitle(title, **kwargs)\n    return array(axs)\n\ndef show_some(items:Collection, n_max:int=5, sep:str=','):\n    \"Return the representation of the first  `n_max` elements in `items`.\"\n    if items is None or len(items) == 0: return ''\n    res = sep.join([f'{o}' for o in items[:n_max]])\n    if len(items) > n_max: res += '...'\n    return res\n\ndef get_tmp_file(dir=None):\n    \"Create and return a tmp filename, optionally at a specific path. `os.remove` when done with it.\"\n    with tempfile.NamedTemporaryFile(delete=False, dir=dir) as f: return f.name\n\ndef compose(funcs:List[Callable])->Callable:\n    \"Compose `funcs`\"\n    def compose_(funcs, x, *args, **kwargs):\n        for f in listify(funcs): x = f(x, *args, **kwargs)\n        return x\n    return partial(compose_, funcs)\n\nclass PrettyString(str):\n    \"Little hack to get strings to show properly in Jupyter.\"\n    def __repr__(self): return self\n\ndef float_or_x(x):\n    \"Tries to convert to float, returns x if it can't\"\n    try:   return float(x)\n    except:return x\n\ndef bunzip(fn:PathOrStr):\n    \"bunzip `fn`, raising exception if output already exists\"\n    fn = Path(fn)\n    assert fn.exists(), f\"{fn} doesn't exist\"\n    out_fn = fn.with_suffix('')\n    assert not out_fn.exists(), f\"{out_fn} already exists\"\n    with bz2.BZ2File(fn, 'rb') as src, out_fn.open('wb') as dst:\n        for d in iter(lambda: src.read(1024*1024), b''): dst.write(d)\n\n@contextmanager\ndef working_directory(path:PathOrStr):\n    \"Change working directory to `path` and return to previous on exit.\"\n    prev_cwd = Path.cwd()\n    os.chdir(path)\n    try: yield\n    finally: os.chdir(prev_cwd)\n\n"
  },
  {
    "path": "fastai/data_block.py",
    "content": "from .torch_core import *\nfrom .basic_data import *\nfrom .layers import *\nfrom numbers import Integral\n\n__all__ = ['ItemList', 'CategoryList', 'MultiCategoryList', 'MultiCategoryProcessor', 'LabelList', 'ItemLists', 'get_files',\n           'PreProcessor', 'LabelLists', 'FloatList', 'CategoryProcessor', 'EmptyLabelList', 'MixedItem', 'MixedProcessor',\n           'MixedItemList']\n\ndef _decode(df):\n    return np.array([[df.columns[i] for i,t in enumerate(x) if t==1] for x in df.values], dtype=np.object)\n\ndef _maybe_squeeze(arr): return (arr if is1d(arr) else np.squeeze(arr))\n\ndef _path_to_same_str(p_fn):\n    \"path -> str, but same on nt+posix, for alpha-sort only\"\n    s_fn = str(p_fn)\n    s_fn = s_fn.replace('\\\\','.')\n    s_fn = s_fn.replace('/','.')\n    return s_fn\n\ndef _get_files(parent, p, f, extensions):\n    p = Path(p)#.relative_to(parent)\n    if isinstance(extensions,str): extensions = [extensions]\n    low_extensions = [e.lower() for e in extensions] if extensions is not None else None\n    res = [p/o for o in f if not o.startswith('.')\n           and (extensions is None or f'.{o.split(\".\")[-1].lower()}' in low_extensions)]\n    return res\n\ndef get_files(path:PathOrStr, extensions:Collection[str]=None, recurse:bool=False,\n              include:Optional[Collection[str]]=None, presort:bool=False)->FilePathList:\n    \"Return list of files in `path` that have a suffix in `extensions`; optionally `recurse`.\"\n    if recurse:\n        res = []\n        for i,(p,d,f) in enumerate(os.walk(path)):\n            # skip hidden dirs\n            if include is not None and i==0:  d[:] = [o for o in d if o in include]\n            else:                             d[:] = [o for o in d if not o.startswith('.')]\n            res += _get_files(path, p, f, extensions)\n        if presort: res = sorted(res, key=lambda p: _path_to_same_str(p), reverse=False)\n        return res\n    else:\n        f = [o.name for o in os.scandir(path) if o.is_file()]\n        res = _get_files(path, path, f, extensions)\n        if presort: res = sorted(res, key=lambda p: _path_to_same_str(p), reverse=False)\n        return res\n\nclass PreProcessor():\n    \"Basic class for a processor that will be applied to items at the end of the data block API.\"\n    def __init__(self, ds:Collection=None):  self.ref_ds = ds\n    def process_one(self, item:Any):         return item\n    def process(self, ds:Collection):        ds.items = array([self.process_one(item) for item in ds.items])\n\nPreProcessors = Union[PreProcessor, Collection[PreProcessor]]\nfastai_types[PreProcessors] = 'PreProcessors'\n\nclass ItemList():\n    \"A collection of items with `__len__` and `__getitem__` with `ndarray` indexing semantics.\"\n    _bunch,_processor,_label_cls,_square_show,_square_show_res = DataBunch,None,None,False,False\n\n    def __init__(self, items:Iterator, path:PathOrStr='.', label_cls:Callable=None, inner_df:Any=None,\n                 processor:PreProcessors=None, x:'ItemList'=None, ignore_empty:bool=False):\n        self.path = Path(path)\n        self.num_parts = len(self.path.parts)\n        self.items,self.x,self.ignore_empty = items,x,ignore_empty\n        if not isinstance(self.items,np.ndarray): self.items = array(self.items, dtype=object)\n        self.label_cls,self.inner_df,self.processor = ifnone(label_cls,self._label_cls),inner_df,processor\n        self._label_list,self._split = LabelList,ItemLists\n        self.copy_new = ['x', 'label_cls', 'path']\n\n    def __len__(self)->int: return len(self.items) or 1\n    def get(self, i)->Any:\n        \"Subclass if you want to customize how to create item `i` from `self.items`.\"\n        return self.items[i]\n    def __repr__(self)->str:\n        items = [self[i] for i in range(min(5,len(self.items)))]\n        return f'{self.__class__.__name__} ({len(self.items)} items)\\n{show_some(items)}\\nPath: {self.path}'\n\n    def process(self, processor:PreProcessors=None):\n        \"Apply `processor` or `self.processor` to `self`.\"\n        if processor is not None: self.processor = processor\n        self.processor = listify(self.processor)\n        for p in self.processor: p.process(self)\n        return self\n\n    def process_one(self, item:ItemBase, processor:PreProcessors=None):\n        \"Apply `processor` or `self.processor` to `item`.\"\n        if processor is not None: self.processor = processor\n        self.processor = listify(self.processor)\n        for p in self.processor: item = p.process_one(item)\n        return item\n\n    def analyze_pred(self, pred:Tensor):\n        \"Called on `pred` before `reconstruct` for additional preprocessing.\"\n        return pred\n\n    def reconstruct(self, t:Tensor, x:Tensor=None):\n        \"Reconstruct one of the underlying item for its data `t`.\"\n        return self[0].reconstruct(t,x) if has_arg(self[0].reconstruct, 'x') else self[0].reconstruct(t)\n\n    def new(self, items:Iterator, processor:PreProcessors=None, **kwargs)->'ItemList':\n        \"Create a new `ItemList` from `items`, keeping the same attributes.\"\n        processor = ifnone(processor, self.processor)\n        copy_d = {o:getattr(self,o) for o in self.copy_new}\n        kwargs = {**copy_d, **kwargs}\n        return self.__class__(items=items, processor=processor, **kwargs)\n\n    def add(self, items:'ItemList'):\n        self.items = np.concatenate([self.items, items.items], 0)\n        if self.inner_df is not None and items.inner_df is not None:\n            self.inner_df = pd.concat([self.inner_df, items.inner_df])\n        else: self.inner_df = self.inner_df or items.inner_df\n        return self\n\n    def __getitem__(self,idxs:int)->Any:\n        \"returns a single item based if `idxs` is an integer or a new `ItemList` object if `idxs` is a range.\"\n        idxs = try_int(idxs)\n        if isinstance(idxs, Integral): return self.get(idxs)\n        else: return self.new(self.items[idxs], inner_df=index_row(self.inner_df, idxs))\n\n    @classmethod\n    def from_folder(cls, path:PathOrStr, extensions:Collection[str]=None, recurse:bool=True,\n                    include:Optional[Collection[str]]=None, processor:PreProcessors=None, presort:Optional[bool]=False, **kwargs)->'ItemList':\n        \"\"\"Create an `ItemList` in `path` from the filenames that have a suffix in `extensions`.\n        `recurse` determines if we search subfolders.\"\"\"\n        path = Path(path)\n        return cls(get_files(path, extensions, recurse=recurse, include=include, presort=presort), path=path, processor=processor, **kwargs)\n\n    @classmethod\n    def from_df(cls, df:DataFrame, path:PathOrStr='.', cols:IntsOrStrs=0, processor:PreProcessors=None, **kwargs)->'ItemList':\n        \"Create an `ItemList` in `path` from the inputs in the `cols` of `df`.\"\n        inputs = df.iloc[:,df_names_to_idx(cols, df)]\n        assert not inputs.isna().any().any(), f\"You have NaN values in column(s) {cols} of your dataframe, please fix it.\"\n        res = cls(items=_maybe_squeeze(inputs.values), path=path, inner_df=df, processor=processor, **kwargs)\n        return res\n\n    @classmethod\n    def from_csv(cls, path:PathOrStr, csv_name:str, cols:IntsOrStrs=0, delimiter:str=None, header:str='infer',\n                 processor:PreProcessors=None, **kwargs)->'ItemList':\n        \"\"\"Create an `ItemList` in `path` from the inputs in the `cols` of `path/csv_name`\"\"\"\n        df = pd.read_csv(Path(path)/csv_name, delimiter=delimiter, header=header)\n        return cls.from_df(df, path=path, cols=cols, processor=processor, **kwargs)\n\n    def _relative_item_path(self, i): return self.items[i].relative_to(self.path)\n    def _relative_item_paths(self):   return [self._relative_item_path(i) for i in range_of(self.items)]\n\n    def use_partial_data(self, sample_pct:float=0.01, seed:int=None)->'ItemList':\n        \"Use only a sample of `sample_pct`of the full dataset and an optional `seed`.\"\n        if seed is not None: np.random.seed(seed)\n        rand_idx = np.random.permutation(range_of(self))\n        cut = int(sample_pct * len(self))\n        return self[rand_idx[:cut]]\n\n    def to_text(self, fn:str):\n        \"Save `self.items` to `fn` in `self.path`.\"\n        with open(self.path/fn, 'w') as f: f.writelines([f'{o}\\n' for o in self._relative_item_paths()])\n\n    def filter_by_func(self, func:Callable)->'ItemList':\n        \"Only keep elements for which `func` returns `True`.\"\n        self.items = array([o for o in self.items if func(o)])\n        return self\n\n    def filter_by_folder(self, include=None, exclude=None):\n        \"Only keep filenames in `include` folder or reject the ones in `exclude`.\"\n        include,exclude = listify(include),listify(exclude)\n        def _inner(o):\n            if isinstance(o, Path): n = o.relative_to(self.path).parts[0]\n            else: n = o.split(os.path.sep)[len(str(self.path).split(os.path.sep))]\n            if include and not n in include: return False\n            if exclude and     n in exclude: return False\n            return True\n        return self.filter_by_func(_inner)\n\n    def filter_by_rand(self, p:float, seed:int=None):\n        \"Keep random sample of `items` with probability `p` and an optional `seed`.\"\n        if seed is not None: set_all_seed(seed)\n        return self.filter_by_func(lambda o: rand_bool(p))\n\n    def no_split(self):\n        warn(\"`no_split` is deprecated, please use `split_none`.\")\n        return self.split_none()\n\n    def split_none(self):\n        \"Don't split the data and create an empty validation set.\"\n        val = self[[]]\n        val.ignore_empty = True\n        return self._split(self.path, self, val)\n\n    def split_by_list(self, train, valid):\n        \"Split the data between `train` and `valid`.\"\n        return self._split(self.path, train, valid)\n\n    def split_by_idxs(self, train_idx, valid_idx):\n        \"Split the data between `train_idx` and `valid_idx`.\"\n        return self.split_by_list(self[train_idx], self[valid_idx])\n\n    def split_by_idx(self, valid_idx:Collection[int])->'ItemLists':\n        \"Split the data according to the indexes in `valid_idx`.\"\n        #train_idx = [i for i in range_of(self.items) if i not in valid_idx]\n        train_idx = np.setdiff1d(arange_of(self.items), valid_idx)\n        return self.split_by_idxs(train_idx, valid_idx)\n\n    def _get_by_folder(self, name):\n        return [i for i in range_of(self) if (self.items[i].parts[self.num_parts] if isinstance(self.items[i], Path)\n                else self.items[i].split(os.path.sep)[0]) == name ]\n\n    def split_by_folder(self, train:str='train', valid:str='valid')->'ItemLists':\n        \"Split the data depending on the folder (`train` or `valid`) in which the filenames are.\"\n        return self.split_by_idxs(self._get_by_folder(train), self._get_by_folder(valid))\n\n    def random_split_by_pct(self, valid_pct:float=0.2, seed:int=None):\n        warn(\"`random_split_by_pct` is deprecated, please use `split_by_rand_pct`.\")\n        return self.split_by_rand_pct(valid_pct=valid_pct, seed=seed)\n\n    def split_by_rand_pct(self, valid_pct:float=0.2, seed:int=None)->'ItemLists':\n        \"Split the items randomly by putting `valid_pct` in the validation set, optional `seed` can be passed.\"\n        if valid_pct==0.: return self.split_none()\n        if seed is not None: np.random.seed(seed)\n        rand_idx = np.random.permutation(range_of(self))\n        cut = int(valid_pct * len(self))\n        return self.split_by_idx(rand_idx[:cut])\n\n    def split_subsets(self, train_size:float, valid_size:float, seed=None) -> 'ItemLists':\n        \"Split the items into train set with size `train_size * n` and valid set with size `valid_size * n`.\"\n        assert 0 < train_size < 1\n        assert 0 < valid_size < 1\n        assert train_size + valid_size <= 1.\n        if seed is not None: np.random.seed(seed)\n        n = len(self.items)\n        rand_idx = np.random.permutation(range(n))\n        train_cut, valid_cut = int(train_size * n), int(valid_size * n)\n        return self.split_by_idxs(rand_idx[:train_cut], rand_idx[-valid_cut:])\n\n    def split_by_valid_func(self, func:Callable)->'ItemLists':\n        \"Split the data by result of `func` (which returns `True` for validation set).\"\n        valid_idx = [i for i,o in enumerate(self.items) if func(o)]\n        return self.split_by_idx(valid_idx)\n\n    def split_by_files(self, valid_names:'ItemList')->'ItemLists':\n        \"Split the data by using the names in `valid_names` for validation.\"\n        if isinstance(self.items[0], Path): return self.split_by_valid_func(lambda o: o.name in valid_names)\n        else: return self.split_by_valid_func(lambda o: os.path.basename(o) in valid_names)\n\n    def split_by_fname_file(self, fname:PathOrStr, path:PathOrStr=None)->'ItemLists':\n        \"Split the data by using the names in `fname` for the validation set. `path` will override `self.path`.\"\n        path = Path(ifnone(path, self.path))\n        valid_names = loadtxt_str(path/fname)\n        return self.split_by_files(valid_names)\n\n    def split_from_df(self, col:IntsOrStrs=2):\n        \"Split the data from the `col` in the dataframe in `self.inner_df`.\"\n        valid_idx = np.where(self.inner_df.iloc[:,df_names_to_idx(col, self.inner_df)])[0]\n        return self.split_by_idx(valid_idx)\n\n    def get_label_cls(self, labels, label_cls:Callable=None, label_delim:str=None, **kwargs):\n        \"Return `label_cls` or guess one from the first element of `labels`.\"\n        if label_cls is not None:               return label_cls\n        if self.label_cls is not None:          return self.label_cls\n        if label_delim is not None:             return MultiCategoryList\n        it = index_row(labels,0)\n        if isinstance(it, (float, np.float32)): return FloatList\n        if isinstance(try_int(it), (str, Integral)):  return CategoryList\n        if isinstance(it, Collection):          return MultiCategoryList\n        return ItemList #self.__class__\n\n    def _label_from_list(self, labels:Iterator, label_cls:Callable=None, from_item_lists:bool=False, **kwargs)->'LabelList':\n        \"Label `self.items` with `labels`.\"\n        if not from_item_lists:\n            raise Exception(\"Your data isn't split, if you don't want a validation set, please use `split_none`.\")\n        labels = array(labels, dtype=object)\n        label_cls = self.get_label_cls(labels, label_cls=label_cls, **kwargs)\n        y = label_cls(labels, path=self.path, **kwargs)\n        res = self._label_list(x=self, y=y)\n        return res\n\n    def label_from_df(self, cols:IntsOrStrs=1, label_cls:Callable=None, **kwargs):\n        \"Label `self.items` from the values in `cols` in `self.inner_df`.\"\n        labels = self.inner_df.iloc[:,df_names_to_idx(cols, self.inner_df)]\n        assert labels.isna().sum().sum() == 0, f\"You have NaN values in column(s) {cols} of your dataframe, please fix it.\"\n        if is_listy(cols) and len(cols) > 1 and (label_cls is None or label_cls == MultiCategoryList):\n            new_kwargs,label_cls = dict(one_hot=True, classes= cols),MultiCategoryList\n            kwargs = {**new_kwargs, **kwargs}\n        return self._label_from_list(_maybe_squeeze(labels), label_cls=label_cls, **kwargs)\n\n    def label_const(self, const:Any=0, label_cls:Callable=None, **kwargs)->'LabelList':\n        \"Label every item with `const`.\"\n        return self.label_from_func(func=lambda o: const, label_cls=label_cls, **kwargs)\n\n    def label_empty(self, **kwargs):\n        \"Label every item with an `EmptyLabel`.\"\n        kwargs['label_cls'] = EmptyLabelList\n        return self.label_from_func(func=lambda o: 0., **kwargs)\n\n    def label_from_func(self, func:Callable, label_cls:Callable=None, **kwargs)->'LabelList':\n        \"Apply `func` to every input to get its label.\"\n        return self._label_from_list([func(o) for o in self.items], label_cls=label_cls, **kwargs)\n\n    def label_from_folder(self, label_cls:Callable=None, **kwargs)->'LabelList':\n        \"Give a label to each filename depending on its folder.\"\n        return self.label_from_func(func=lambda o: (o.parts if isinstance(o, Path) else o.split(os.path.sep))[-2],\n                                    label_cls=label_cls, **kwargs)\n\n    def label_from_re(self, pat:str, full_path:bool=False, label_cls:Callable=None, **kwargs)->'LabelList':\n        \"Apply the re in `pat` to determine the label of every filename.  If `full_path`, search in the full name.\"\n        pat = re.compile(pat)\n        def _inner(o):\n            s = str((os.path.join(self.path,o) if full_path else o).as_posix())\n            res = pat.search(s)\n            assert res,f'Failed to find \"{pat}\" in \"{s}\"'\n            return res.group(1)\n        return self.label_from_func(_inner, label_cls=label_cls, **kwargs)\n\n    def databunch(self, **kwargs):\n        \"To throw a clear error message when the data wasn't split and labeled.\"\n        raise Exception(\"Your data is neither split nor labeled, can't turn it into a `DataBunch` yet.\")\n\nclass EmptyLabelList(ItemList):\n    \"Basic `ItemList` for dummy labels.\"\n    def get(self, i): return EmptyLabel()\n    def reconstruct(self, t:Tensor, x:Tensor=None):\n        if len(t.size()) == 0: return EmptyLabel()\n        return self.x.reconstruct(t,x) if has_arg(self.x.reconstruct, 'x') else self.x.reconstruct(t)\n\nclass CategoryProcessor(PreProcessor):\n    \"`PreProcessor` that create `classes` from `ds.items` and handle the mapping.\"\n    def __init__(self, ds:ItemList):\n        self.create_classes(ds.classes)\n        self.state_attrs,self.warns = ['classes'],[]\n\n    def create_classes(self, classes):\n        self.classes = classes\n        if classes is not None: self.c2i = {v:k for k,v in enumerate(classes)}\n\n    def generate_classes(self, items):\n        \"Generate classes from `items` by taking the sorted unique values.\"\n        return uniqueify(items, sort=True)\n\n    def process_one(self,item):\n        if isinstance(item, EmptyLabel): return item\n        res = self.c2i.get(item,None)\n        if res is None: self.warns.append(str(item))\n        return res\n\n    def process(self, ds):\n        if self.classes is None: self.create_classes(self.generate_classes(ds.items))\n        ds.classes = self.classes\n        ds.c2i = self.c2i\n        super().process(ds)\n\n    def __getstate__(self): return {n:getattr(self,n) for n in self.state_attrs}\n    def __setstate__(self, state:dict):\n        self.create_classes(state['classes'])\n        self.state_attrs = state.keys()\n        for n in state.keys():\n            if n!='classes': setattr(self, n, state[n])\n\nclass CategoryListBase(ItemList):\n    \"Basic `ItemList` for classification.\"\n    def __init__(self, items:Iterator, classes:Collection=None, **kwargs):\n        self.classes=classes\n        self.filter_missing_y = True\n        super().__init__(items, **kwargs)\n        self.copy_new.append('classes')\n\n    @property\n    def c(self): return len(self.classes)\n\nclass CategoryList(CategoryListBase):\n    \"Basic `ItemList` for single classification labels.\"\n    _processor=CategoryProcessor\n    def __init__(self, items:Iterator, classes:Collection=None, label_delim:str=None, **kwargs):\n        super().__init__(items, classes=classes, **kwargs)\n        self.loss_func = CrossEntropyFlat()\n\n    def get(self, i):\n        o = self.items[i]\n        if o is None: return None\n        return Category(o, self.classes[o])\n\n    def analyze_pred(self, pred, thresh:float=0.5): return pred.argmax()\n\n    def reconstruct(self, t):\n        return Category(t, self.classes[t])\n\nclass MultiCategoryProcessor(CategoryProcessor):\n    \"`PreProcessor` that create `classes` from `ds.items` and handle the mapping.\"\n    def __init__(self, ds:ItemList, one_hot:bool=False):\n        super().__init__(ds)\n        self.one_hot = one_hot\n        self.state_attrs.append('one_hot')\n\n    def process_one(self,item):\n        if self.one_hot or isinstance(item, EmptyLabel): return item\n        res = [super(MultiCategoryProcessor, self).process_one(o) for o in item]\n        return [r for r in res if r is not None]\n\n    def generate_classes(self, items):\n        \"Generate classes from `items` by taking the sorted unique values.\"\n        classes = set()\n        for c in items: classes = classes.union(set(c))\n        classes = list(classes)\n        classes.sort()\n        return classes\n\nclass MultiCategoryList(CategoryListBase):\n    \"Basic `ItemList` for multi-classification labels.\"\n    _processor=MultiCategoryProcessor\n    def __init__(self, items:Iterator, classes:Collection=None, label_delim:str=None, one_hot:bool=False, **kwargs):\n        if label_delim is not None: items = array(csv.reader(items.astype(str), delimiter=label_delim))\n        super().__init__(items, classes=classes, **kwargs)\n        if one_hot:\n            assert classes is not None, \"Please provide class names with `classes=...`\"\n            self.processor = [MultiCategoryProcessor(self, one_hot=True)]\n        self.loss_func = BCEWithLogitsFlat()\n        self.one_hot = one_hot\n        self.copy_new += ['one_hot']\n\n    def get(self, i):\n        o = self.items[i]\n        if o is None: return None\n        if self.one_hot: return self.reconstruct(o.astype(np.float32))\n        return MultiCategory(one_hot(o, self.c), [self.classes[p] for p in o], o)\n\n    def analyze_pred(self, pred, thresh:float=0.5):\n        return (pred >= thresh).float()\n\n    def reconstruct(self, t):\n        o = [i for i in range(self.c) if t[i] == 1.]\n        return MultiCategory(t, [self.classes[p] for p in o], o)\n\nclass FloatList(ItemList):\n    \"`ItemList` suitable for storing the floats in items for regression. Will add a `log` if this flag is `True`.\"\n    def __init__(self, items:Iterator, log:bool=False, classes:Collection=None, **kwargs):\n        super().__init__(np.array(items, dtype=np.float32), **kwargs)\n        self.log = log\n        self.copy_new.append('log')\n        self.c = self.items.shape[1] if len(self.items.shape) > 1 else 1\n        self.loss_func = MSELossFlat()\n\n    def get(self, i):\n        o = super().get(i)\n        return FloatItem(np.log(o) if self.log else o)\n\n    def reconstruct(self,t): return FloatItem(t.numpy())\n\nclass ItemLists():\n    \"An `ItemList` for each of `train` and `valid` (optional `test`).\"\n    def __init__(self, path:PathOrStr, train:ItemList, valid:ItemList):\n        self.path,self.train,self.valid,self.test = Path(path),train,valid,None\n        if not self.train.ignore_empty and len(self.train.items) == 0:\n            warn(\"Your training set is empty. If this is by design, pass `ignore_empty=True` to remove this warning.\")\n        if not self.valid.ignore_empty and len(self.valid.items) == 0:\n            warn(\"\"\"Your validation set is empty. If this is by design, use `split_none()`\n                 or pass `ignore_empty=True` when labelling to remove this warning.\"\"\")\n        if isinstance(self.train, LabelList): self.__class__ = LabelLists\n\n    def __dir__(self)->List[str]:\n        default_dir = dir(type(self)) + list(self.__dict__.keys())\n        add_ons = ['label_const', 'label_empty', 'label_from_df', 'label_from_folder', 'label_from_func',\n                   'label_from_list', 'label_from_re']\n        return default_dir + add_ons\n\n    def __repr__(self)->str:\n        return f'{self.__class__.__name__};\\n\\nTrain: {self.train};\\n\\nValid: {self.valid};\\n\\nTest: {self.test}'\n\n    def __getattr__(self, k):\n        ft = getattr(self.train, k)\n        if not isinstance(ft, Callable): return ft\n        fv = getattr(self.valid, k)\n        assert isinstance(fv, Callable)\n        def _inner(*args, **kwargs):\n            self.train = ft(*args, from_item_lists=True, **kwargs)\n            assert isinstance(self.train, LabelList)\n            kwargs['label_cls'] = self.train.y.__class__\n            self.valid = fv(*args, from_item_lists=True, **kwargs)\n            self.__class__ = LabelLists\n            self.process()\n            return self\n        return _inner\n\n    def __setstate__(self,data:Any): self.__dict__.update(data)\n\n    @property\n    def lists(self):\n        res = [self.train,self.valid]\n        if self.test is not None: res.append(self.test)\n        return res\n\n    def label_from_lists(self, train_labels:Iterator, valid_labels:Iterator, label_cls:Callable=None, **kwargs)->'LabelList':\n        \"Use the labels in `train_labels` and `valid_labels` to label the data. `label_cls` will overwrite the default.\"\n        label_cls = self.train.get_label_cls(train_labels, label_cls)\n        self.train = self.train._label_list(x=self.train, y=label_cls(train_labels, **kwargs))\n        self.valid = self.valid._label_list(x=self.valid, y=self.train.y.new(valid_labels, **kwargs))\n        self.__class__ = LabelLists\n        self.process()\n        return self\n\n    def transform(self, tfms:Optional[Tuple[TfmList,TfmList]]=(None,None), **kwargs):\n        \"Set `tfms` to be applied to the xs of the train and validation set.\"\n        if not tfms: tfms=(None,None)\n        assert is_listy(tfms) and len(tfms) == 2, \"Please pass a list of two lists of transforms (train and valid).\"\n        self.train.transform(tfms[0], **kwargs)\n        self.valid.transform(tfms[1], **kwargs)\n        if self.test: self.test.transform(tfms[1], **kwargs)\n        return self\n\n    def transform_y(self, tfms:Optional[Tuple[TfmList,TfmList]]=(None,None), **kwargs):\n        \"Set `tfms` to be applied to the ys of the train and validation set.\"\n        if not tfms: tfms=(None,None)\n        self.train.transform_y(tfms[0], **kwargs)\n        self.valid.transform_y(tfms[1], **kwargs)\n        if self.test: self.test.transform_y(tfms[1], **kwargs)\n        return self\n\n    def databunch(self, **kwargs):\n        \"To throw a clear error message when the data wasn't labeled.\"\n        raise Exception(\"Your data isn't labeled, can't turn it into a `DataBunch` yet!\")\n\nclass LabelLists(ItemLists):\n    \"A `LabelList` for each of `train` and `valid` (optional `test`).\"\n    def get_processors(self):\n        \"Read the default class processors if none have been set.\"\n        procs_x,procs_y = listify(self.train.x._processor),listify(self.train.y._processor)\n        xp = ifnone(self.train.x.processor, [p(ds=self.train.x) for p in procs_x])\n        yp = ifnone(self.train.y.processor, [p(ds=self.train.y) for p in procs_y])\n        return xp,yp\n\n    def process(self):\n        \"Process the inner datasets.\"\n        xp,yp = self.get_processors()\n        for ds,n in zip(self.lists, ['train','valid','test']): ds.process(xp, yp, name=n)\n        #progress_bar clear the outputs so in some case warnings issued during processing disappear.\n        for ds in self.lists:\n            if getattr(ds, 'warn', False): warn(ds.warn)\n        return self\n\n    def filter_by_func(self, func:Callable):\n        for ds in self.lists: ds.filter_by_func(func)\n        return self\n\n    def databunch(self, path:PathOrStr=None, bs:int=64, val_bs:int=None, num_workers:int=defaults.cpus,\n                  dl_tfms:Optional[Collection[Callable]]=None, device:torch.device=None, collate_fn:Callable=data_collate,\n                  no_check:bool=False, **kwargs)->'DataBunch':\n        \"Create an `DataBunch` from self, `path` will override `self.path`, `kwargs` are passed to `DataBunch.create`.\"\n        path = Path(ifnone(path, self.path))\n        data = self.x._bunch.create(self.train, self.valid, test_ds=self.test, path=path, bs=bs, val_bs=val_bs,\n                                    num_workers=num_workers, dl_tfms=dl_tfms, device=device, collate_fn=collate_fn, no_check=no_check, **kwargs)\n        if getattr(self, 'normalize', False):#In case a normalization was serialized\n            norm = self.normalize\n            data.normalize((norm['mean'], norm['std']), do_x=norm['do_x'], do_y=norm['do_y'])\n        data.label_list = self\n        return data\n\n    def add_test(self, items:Iterator, label:Any=None, tfms=None, tfm_y=None):\n        \"Add test set containing `items` with an arbitrary `label`.\"\n        # if no label passed, use label of first training item\n        if label is None: labels = EmptyLabelList([0] * len(items))\n        else: labels = self.valid.y.new([label] * len(items)).process()\n        if isinstance(items, MixedItemList): items = self.valid.x.new(items.item_lists, inner_df=items.inner_df).process()\n        elif isinstance(items, ItemList): items = self.valid.x.new(items.items, inner_df=items.inner_df).process()\n        else: items = self.valid.x.new(items).process()\n        self.test = self.valid.new(items, labels, tfms=tfms, tfm_y=tfm_y)\n        return self\n\n    def add_test_folder(self, test_folder:str='test', label:Any=None, tfms=None, tfm_y=None):\n        \"Add test set containing items from `test_folder` and an arbitrary `label`.\"\n        # note: labels will be ignored if available in the test dataset\n        items = self.x.__class__.from_folder(self.path/test_folder)\n        return self.add_test(items.items, label=label, tfms=tfms, tfm_y=tfm_y)\n\n    @classmethod\n    def load_state(cls, path:PathOrStr, state:dict):\n        \"Create a `LabelLists` with empty sets from the serialized `state`.\"\n        path = Path(path)\n        train_ds = LabelList.load_state(path, state)\n        valid_ds = LabelList.load_state(path, state)\n        return LabelLists(path, train=train_ds, valid=valid_ds)\n\n    @classmethod\n    def load_empty(cls, path:PathOrStr, fn:PathOrStr='export.pkl'):\n        \"Create a `LabelLists` with empty sets from the serialized file in `path/fn`.\"\n        path = Path(path)\n        state = torch.load(open(path/fn, 'rb'))\n        return LabelLists.load_state(path, state)\n\ndef _check_kwargs(ds:ItemList, tfms:TfmList, **kwargs):\n    tfms = listify(tfms)\n    if (tfms is None or len(tfms) == 0) and len(kwargs) == 0: return\n    if len(ds.items) >= 1:\n        x = ds[0]\n        try: x.apply_tfms(tfms, **kwargs)\n        except Exception as e:\n            raise Exception(f\"It's not possible to apply those transforms to your dataset:\\n {e}\")\n\nclass LabelList(Dataset):\n    \"A list of inputs `x` and labels `y` with optional `tfms`.\"\n    def __init__(self, x:ItemList, y:ItemList, tfms:TfmList=None, tfm_y:bool=False, **kwargs):\n        self.x,self.y,self.tfm_y = x,y,tfm_y\n        self.y.x = x\n        self.item=None\n        self.transform(tfms, **kwargs)\n\n    def __len__(self)->int: return len(self.x) if self.item is None else 1\n\n    @contextmanager\n    def set_item(self,item):\n        \"For inference, will briefly replace the dataset with one that only contains `item`.\"\n        self.item = self.x.process_one(item)\n        yield None\n        self.item = None\n\n    def __repr__(self)->str:\n        items = [self[i] for i in range(min(5,len(self.items)))]\n        res = f'{self.__class__.__name__} ({len(self.items)} items)\\n'\n        res += f'x: {self.x.__class__.__name__}\\n{show_some([i[0] for i in items])}\\n'\n        res += f'y: {self.y.__class__.__name__}\\n{show_some([i[1] for i in items])}\\n'\n        return res + f'Path: {self.path}'\n\n    def predict(self, res):\n        \"Delegates predict call on `res` to `self.y`.\"\n        return self.y.predict(res)\n\n    @property\n    def c(self): return self.y.c\n\n    def new(self, x, y, tfms=None, tfm_y=None, **kwargs)->'LabelList':\n        tfms,tfm_y = ifnone(tfms, self.tfms),ifnone(tfm_y, self.tfm_y)\n        if isinstance(x, ItemList):\n            return self.__class__(x, y, tfms=tfms, tfm_y=tfm_y, **self.tfmargs)\n        else:\n            return self.new(self.x.new(x, **kwargs), self.y.new(y, **kwargs), tfms=tfms, tfm_y=tfm_y).process()\n\n    def __getattr__(self,k:str)->Any:\n        x = super().__getattribute__('x')\n        res = getattr(x, k, None)\n        if res is not None and k not in ['classes', 'c']: return res\n        y = super().__getattribute__('y')\n        res = getattr(y, k, None)\n        if res is not None: return res\n        raise AttributeError(k)\n\n    def __setstate__(self,data:Any): self.__dict__.update(data)\n\n    def __getitem__(self,idxs:Union[int,np.ndarray])->'LabelList':\n        \"return a single (x, y) if `idxs` is an integer or a new `LabelList` object if `idxs` is a range.\"\n        idxs = try_int(idxs)\n        if isinstance(idxs, Integral):\n            if self.item is None: x,y = self.x[idxs],self.y[idxs]\n            else:                 x,y = self.item   ,0\n            if self.tfms or self.tfmargs:\n                x = x.apply_tfms(self.tfms, is_x=True, **self.tfmargs)\n            if hasattr(self, 'tfms_y') and self.tfm_y and self.item is None:\n                y = y.apply_tfms(self.tfms_y, is_x=False, **{**self.tfmargs_y, 'do_resolve':False})\n            if y is None: y=0\n            return x,y\n        else: return self.new(self.x[idxs], self.y[idxs])\n\n    def to_df(self)->None:\n        \"Create `pd.DataFrame` containing `items` from `self.x` and `self.y`.\"\n        return pd.DataFrame(dict(x=self.x._relative_item_paths(), y=[str(o) for o in self.y]))\n\n    def to_csv(self, dest:str)->None:\n        \"Save `self.to_df()` to a CSV file in `self.path`/`dest`.\"\n        self.to_df().to_csv(self.path/dest, index=False)\n\n    def get_state(self, **kwargs):\n        \"Return the minimal state for export.\"\n        state = {'x_cls':self.x.__class__, 'x_proc':self.x.processor,\n                 'y_cls':self.y.__class__, 'y_proc':self.y.processor,\n                 'tfms':self.tfms, 'tfm_y':self.tfm_y, 'tfmargs':self.tfmargs}\n        if hasattr(self, 'tfms_y'):    state['tfms_y']    = self.tfms_y\n        if hasattr(self, 'tfmargs_y'): state['tfmargs_y'] = self.tfmargs_y\n        return {**state, **kwargs}\n\n    def export(self, fn:PathOrStr, **kwargs):\n        \"Export the minimal state and save it in `fn` to load an empty version for inference.\"\n        pickle.dump(self.get_state(**kwargs), open(fn, 'wb'))\n\n    @classmethod\n    def load_empty(cls, path:PathOrStr, fn:PathOrStr):\n        \"Load the state in `fn` to create an empty `LabelList` for inference.\"\n        return cls.load_state(path, pickle.load(open(Path(path)/fn, 'rb')))\n\n    @classmethod\n    def load_state(cls, path:PathOrStr, state:dict) -> 'LabelList':\n        \"Create a `LabelList` from `state`.\"\n        x = state['x_cls']([], path=path, processor=state['x_proc'], ignore_empty=True)\n        y = state['y_cls']([], path=path, processor=state['y_proc'], ignore_empty=True)\n        res = cls(x, y, tfms=state['tfms'], tfm_y=state['tfm_y'], **state['tfmargs']).process()\n        if state.get('tfms_y', False):    res.tfms_y    = state['tfms_y']\n        if state.get('tfmargs_y', False): res.tfmargs_y = state['tfmargs_y']\n        if state.get('normalize', False): res.normalize = state['normalize']\n        return res\n\n    def process(self, xp:PreProcessor=None, yp:PreProcessor=None, name:str=None):\n        \"Launch the processing on `self.x` and `self.y` with `xp` and `yp`.\"\n        self.y.process(yp)\n        if getattr(self.y, 'filter_missing_y', False):\n            filt = array([o is None for o in self.y.items])\n            if filt.sum()>0:\n                #Warnings are given later since progress_bar might make them disappear.\n                self.warn = f\"You are labelling your items with {self.y.__class__.__name__}.\\n\"\n                self.warn += f\"Your {name} set contained the following unknown labels, the corresponding items have been discarded.\\n\"\n                for p in self.y.processor:\n                    if len(getattr(p, 'warns', [])) > 0:\n                        warnings = list(set(p.warns))\n                        self.warn += ', '.join(warnings[:5])\n                        if len(warnings) > 5: self.warn += \"...\"\n                    p.warns = []\n                self.x,self.y = self.x[~filt],self.y[~filt]\n        self.x.process(xp)\n        return self\n\n    def filter_by_func(self, func:Callable):\n        filt = array([func(x,y) for x,y in zip(self.x.items, self.y.items)])\n        self.x,self.y = self.x[~filt],self.y[~filt]\n        return self\n\n    def transform(self, tfms:TfmList, tfm_y:bool=None, **kwargs):\n        \"Set the `tfms` and `tfm_y` value to be applied to the inputs and targets.\"\n        _check_kwargs(self.x, tfms, **kwargs)\n        if tfm_y is None: tfm_y = self.tfm_y\n        tfms_y = None if tfms is None else list(filter(lambda t: getattr(t, 'use_on_y', True), listify(tfms)))\n        if tfm_y: _check_kwargs(self.y, tfms_y, **kwargs)\n        self.tfms,self.tfmargs  = tfms,kwargs\n        self.tfm_y,self.tfms_y,self.tfmargs_y = tfm_y,tfms_y,kwargs\n        return self\n\n    def transform_y(self, tfms:TfmList=None, **kwargs):\n        \"Set `tfms` to be applied to the targets only.\"\n        tfms_y = list(filter(lambda t: getattr(t, 'use_on_y', True), listify(self.tfms if tfms is None else tfms)))\n        tfmargs_y = {**self.tfmargs, **kwargs} if tfms is None else kwargs\n        _check_kwargs(self.y, tfms_y, **tfmargs_y)\n        self.tfm_y,self.tfms_y,self.tfmargs_y=True,tfms_y,tfmargs_y\n        return self\n\n    def databunch(self, **kwargs):\n        \"To throw a clear error message when the data wasn't split.\"\n        raise Exception(\"Your data isn't split, if you don't want a validation set, please use `split_none`\")\n\n@classmethod\ndef _databunch_load_empty(cls, path, fname:str='export.pkl'):\n    \"Load an empty `DataBunch` from the exported file in `path/fname` with optional `tfms`.\"\n    sd = LabelLists.load_empty(path, fn=fname)\n    return sd.databunch()\n\nDataBunch.load_empty = _databunch_load_empty\n\nclass MixedProcessor(PreProcessor):\n    def __init__(self, procs:Collection[Union[PreProcessor, Collection[PreProcessor]]]):\n        self.procs = procs\n\n    def process_one(self, item:Any):\n        res = []\n        for procs, i in zip(self.procs, item):\n            for p in procs: i = p.process_one(i)\n            res.append(i)\n        return res\n\n    def process(self, ds:Collection):\n        for procs, il in zip(self.procs, ds.item_lists):\n            for p in procs: p.process(il)\n\nclass MixedItem(ItemBase):\n    def __init__(self, items):\n        self.obj = items\n        self.data = [item.data for item in items]\n\n    def __repr__(self): return '\\n'.join([f'{self.__class__.__name__}'] + [repr(item) for item in self.obj])\n\n    def apply_tfms(self, tfms:Collection, **kwargs):\n        self.obj = [item.apply_tfms(t, **kwargs) for item,t in zip(self.obj, tfms)]\n        self.data = [item.data for item in self.obj]\n        return self\n\nclass MixedItemList(ItemList):\n\n    def __init__(self, item_lists, path:PathOrStr=None, label_cls:Callable=None, inner_df:Any=None,\n                 x:'ItemList'=None, ignore_empty:bool=False, processor=None):\n        self.item_lists = item_lists\n        if processor is None:\n            default_procs = [[p(ds=il) for p in listify(il._processor)] for il in item_lists]\n            processor = MixedProcessor([ifnone(il.processor, dp) for il,dp in zip(item_lists, default_procs)])\n        items = range_of(item_lists[0]) if len(item_lists) >= 1 else []\n        if path is None and len(item_lists) >= 1: path = item_lists[0].path\n        super().__init__(items, processor=processor, path=path,\n                         label_cls=label_cls, inner_df=inner_df, x=x, ignore_empty=ignore_empty)\n\n    def new(self, item_lists, processor:PreProcessor=None, **kwargs)->'ItemList':\n        \"Create a new `ItemList` from `items`, keeping the same attributes.\"\n        processor = ifnone(processor, self.processor)\n        copy_d = {o:getattr(self,o) for o in self.copy_new}\n        kwargs = {**copy_d, **kwargs}\n        return self.__class__(item_lists, processor=processor, **kwargs)\n\n    def get(self, i):\n        return MixedItem([il.get(i) for il in self.item_lists])\n\n    def __getitem__(self,idxs:int)->Any:\n        idxs = try_int(idxs)\n        if isinstance(idxs, Integral): return self.get(idxs)\n        else:\n            item_lists = [il.new(il.items[idxs], inner_df=index_row(il.inner_df, idxs)) for il in self.item_lists]\n            return self.new(item_lists, inner_df=index_row(self.inner_df, idxs))\n"
  },
  {
    "path": "fastai/datasets.py",
    "content": "from .core import *\nimport hashlib\n\n__all__ = ['URLs', 'Config', 'untar_data', 'download_data', 'datapath4file', 'url2name', 'url2path']\n\nMODEL_URL = 'http://files.fast.ai/models/'\nURL = 'http://files.fast.ai/data/examples/'\nclass URLs():\n    \"Global constants for dataset and model URLs.\"\n    LOCAL_PATH = Path.cwd()\n    S3 = 'https://s3.amazonaws.com/fast-ai-'\n\n    S3_IMAGE    = f'{S3}imageclas/'\n    S3_IMAGELOC = f'{S3}imagelocal/'\n    S3_NLP      = f'{S3}nlp/'\n    S3_COCO     = f'{S3}coco/'\n    S3_MODEL    = f'{S3}modelzoo/'\n\n    # main datasets\n    ADULT_SAMPLE        = f'{URL}adult_sample'\n    BIWI_SAMPLE         = f'{URL}biwi_sample'\n    CIFAR               = f'{URL}cifar10'\n    COCO_SAMPLE         = f'{S3_COCO}coco_sample'\n    COCO_TINY           = f'{URL}coco_tiny'\n    HUMAN_NUMBERS       = f'{URL}human_numbers'\n    IMDB                = f'{S3_NLP}imdb'\n    IMDB_SAMPLE         = f'{URL}imdb_sample'\n    ML_SAMPLE           = f'{URL}movie_lens_sample'\n    MNIST_SAMPLE        = f'{URL}mnist_sample'\n    MNIST_TINY          = f'{URL}mnist_tiny'\n    MNIST_VAR_SIZE_TINY = f'{S3_IMAGE}mnist_var_size_tiny'\n    PLANET_SAMPLE       = f'{URL}planet_sample'\n    PLANET_TINY         = f'{URL}planet_tiny'\n    IMAGENETTE          = f'{S3_IMAGE}imagenette'\n    IMAGENETTE_160      = f'{S3_IMAGE}imagenette-160'\n    IMAGENETTE_320      = f'{S3_IMAGE}imagenette-320'\n    IMAGEWOOF           = f'{S3_IMAGE}imagewoof'\n    IMAGEWOOF_160       = f'{S3_IMAGE}imagewoof-160'\n    IMAGEWOOF_320       = f'{S3_IMAGE}imagewoof-320'\n\n    # kaggle competitions download dogs-vs-cats -p {DOGS.absolute()}\n    DOGS = f'{URL}dogscats'\n\n    # image classification datasets\n    CALTECH_101  = f'{S3_IMAGE}caltech_101'\n    CARS         = f'{S3_IMAGE}stanford-cars'\n    CIFAR_100    = f'{S3_IMAGE}cifar100'\n    CUB_200_2011 = f'{S3_IMAGE}CUB_200_2011'\n    FLOWERS      = f'{S3_IMAGE}oxford-102-flowers'\n    FOOD         = f'{S3_IMAGE}food-101'\n    MNIST        = f'{S3_IMAGE}mnist_png'\n    PETS         = f'{S3_IMAGE}oxford-iiit-pet'\n\n    # NLP datasets\n    AG_NEWS                 = f'{S3_NLP}ag_news_csv'\n    AMAZON_REVIEWS          = f'{S3_NLP}amazon_review_full_csv'\n    AMAZON_REVIEWS_POLARITY = f'{S3_NLP}amazon_review_polarity_csv'\n    DBPEDIA                 = f'{S3_NLP}dbpedia_csv'\n    MT_ENG_FRA              = f'{S3_NLP}giga-fren'\n    SOGOU_NEWS              = f'{S3_NLP}sogou_news_csv'\n    WIKITEXT                = f'{S3_NLP}wikitext-103'\n    WIKITEXT_TINY           = f'{S3_NLP}wikitext-2'\n    YAHOO_ANSWERS           = f'{S3_NLP}yahoo_answers_csv'\n    YELP_REVIEWS            = f'{S3_NLP}yelp_review_full_csv'\n    YELP_REVIEWS_POLARITY   = f'{S3_NLP}yelp_review_polarity_csv'\n\n    # Image localization datasets\n    BIWI_HEAD_POSE     = f\"{S3_IMAGELOC}biwi_head_pose\"\n    CAMVID             = f'{S3_IMAGELOC}camvid'\n    CAMVID_TINY        = f'{URL}camvid_tiny'\n    LSUN_BEDROOMS      = f'{S3_IMAGE}bedroom'\n    PASCAL_2007        = f'{S3_IMAGELOC}pascal_2007'\n    PASCAL_2012        = f'{S3_IMAGELOC}pascal_2012'\n\n    #Pretrained models\n    OPENAI_TRANSFORMER = f'{S3_MODEL}transformer'\n    WT103_FWD          = f'{S3_MODEL}wt103-fwd'\n    WT103_BWD          = f'{S3_MODEL}wt103-bwd'\n\n# to create/update a checksum for ./mnist_var_size_tiny.tgz, run:\n# python -c 'import fastai.datasets; print(fastai.datasets._check_file(\"mnist_var_size_tiny.tgz\"))'\n_checks = {\n    URLs.ADULT_SAMPLE:(968212, '64eb9d7e23732de0b138f7372d15492f'),\n    URLs.AG_NEWS:(11784419, 'b86f328f4dbd072486591cb7a5644dcd'),\n    URLs.AMAZON_REVIEWS_POLARITY:(688339454, '676f7e5208ec343c8274b4bb085bc938'),\n    URLs.AMAZON_REVIEWS:(643695014, '4a1196cf0adaea22f4bc3f592cddde90'),\n    URLs.BIWI_HEAD_POSE:(452316199, '00f4ccf66e8cba184bc292fdc08fb237'),\n    URLs.BIWI_SAMPLE:(593774, '9179f4c1435f4b291f0d5b072d60c2c9'),\n    URLs.CALTECH_101:(131740031, 'd673425306e98ee4619fcdeef8a0e876'),\n    URLs.CAMVID:(598913237, '648371e4f3a833682afb39b08a3ce2aa'),\n    URLs.CAMVID_TINY:(2314212, '2cf6daf91b7a2083ecfa3e9968e9d915'),\n    URLs.CARS:(1957803273, '9045d6673c9ced0889f41816f6bf2f9f'),\n    URLs.CIFAR:(168168549, 'a5f8c31371b63a406b23368042812d3c'),\n    URLs.CIFAR_100:(169168619, 'e5e65dcb54b9d3913f7b8a9ad6607e62'),\n    URLs.COCO_SAMPLE:(3245877008, '006cd55d633d94b36ecaf661467830ec'),\n    URLs.COCO_TINY:(801038, '367467451ac4fba79a647753c2c66d3a'),\n    URLs.CUB_200_2011:(1150585339, 'd2acaa99439dff0483c7bbac1bfe2a92'),\n    URLs.DBPEDIA:(68341743, '239c7837b9e79db34486f3de6a00e38e'),\n    URLs.DOGS:(839285364, '3e483c8d6ef2175e9d395a6027eb92b7'),\n    URLs.FLOWERS:(345236087, '5666e01c1311b4c67fcf20d2b3850a88'),\n    URLs.FOOD:(5686607260, '1a540ebf1fb40b2bf3f2294234ba7907'),\n    URLs.HUMAN_NUMBERS:(30252, '8a19c3bfa2bcb08cd787e741261f3ea2'),\n    URLs.IMDB:(144440600, '90f9b1c4ff43a90d67553c9240dc0249'),\n    URLs.IMDB_SAMPLE:(571827, '0842e61a9867caa2e6fbdb14fa703d61'),\n    URLs.LSUN_BEDROOMS:(4579163978, '35d84f38f8a15fe47e66e460c8800d68'),\n    URLs.ML_SAMPLE:(51790, '10961384dfe7c5181460390a460c1f77'),\n    URLs.MNIST:(15683414, '03639f83c4e3d19e0a3a53a8a997c487'),\n    URLs.MNIST_SAMPLE:(3214948, '2dbc7ec6f9259b583af0072c55816a88'),\n    URLs.MNIST_TINY:(342207, '56143e8f24db90d925d82a5a74141875'),\n    URLs.MNIST_VAR_SIZE_TINY:(565372, 'b71a930f4eb744a4a143a6c7ff7ed67f'),\n    URLs.MT_ENG_FRA:(2598183296, '69573f58e2c850b90f2f954077041d8c'),\n    URLs.OPENAI_TRANSFORMER:(432848315, '024b0d2203ebb0cd1fc64b27cf8af18e'),\n    URLs.PASCAL_2007:(1636130334, 'a70574e9bc592bd3b253f5bf46ce12e3'),\n    URLs.PASCAL_2012:(2611715776, '2ae7897038383836f86ce58f66b09e31'),\n    URLs.PETS:(811706944, 'e4db5c768afd933bb91f5f594d7417a4'),\n    URLs.PLANET_SAMPLE:(15523994, '8bfb174b3162f07fbde09b54555bdb00'),\n    URLs.PLANET_TINY:(997569, '490873c5683454d4b2611fb1f00a68a9'),\n    URLs.SOGOU_NEWS:(384269937, '950f1366d33be52f5b944f8a8b680902'),\n    URLs.WIKITEXT:(190200704, '2dd8cf8693b3d27e9c8f0a7df054b2c7'),\n    URLs.WIKITEXT_TINY:(4070055, '2a82d47a7b85c8b6a8e068dc4c1d37e7'),\n    URLs.WT103_FWD:(105067061, '7d1114cd9684bf9d1ca3c9f6a54da6f9'),\n    URLs.WT103_BWD:(105205312, '20b06f5830fd5a891d21044c28d3097f'),\n    URLs.YAHOO_ANSWERS:(319476345, '0632a0d236ef3a529c0fa4429b339f68'),\n    URLs.YELP_REVIEWS_POLARITY:(166373201, '48c8451c1ad30472334d856b5d294807'),\n    URLs.YELP_REVIEWS:(196146755, '1efd84215ea3e30d90e4c33764b889db'),\n}\n\n#TODO: This can probably be coded more shortly and nicely.\nclass Config():\n    \"Creates a default config file 'config.yml' in $FASTAI_HOME (default `~/.fastai/`)\"\n    DEFAULT_CONFIG_LOCATION = os.path.expanduser(os.getenv('FASTAI_HOME', '~/.fastai'))\n    DEFAULT_CONFIG_PATH = DEFAULT_CONFIG_LOCATION + '/config.yml'\n    DEFAULT_CONFIG = {\n        'data_path': DEFAULT_CONFIG_LOCATION + '/data',\n        'data_archive_path': DEFAULT_CONFIG_LOCATION + '/data',\n        'model_path': DEFAULT_CONFIG_LOCATION + '/models'\n    }\n\n    @classmethod\n    def get_key(cls, key):\n        \"Get the path to `key` in the config file.\"\n        return cls.get().get(key, cls.DEFAULT_CONFIG.get(key,None))\n\n    @classmethod\n    def get_path(cls, path):\n        \"Get the `path` in the config file.\"\n        return _expand_path(cls.get_key(path))\n\n    @classmethod\n    def data_path(cls):\n        \"Get the path to data in the config file.\"\n        return cls.get_path('data_path')\n\n    @classmethod\n    def data_archive_path(cls):\n        \"Get the path to data archives in the config file.\"\n        return cls.get_path('data_archive_path')\n\n    @classmethod\n    def model_path(cls):\n        \"Get the path to fastai pretrained models in the config file.\"\n        return cls.get_path('model_path')\n\n    @classmethod\n    def get(cls, fpath=None, create_missing=True):\n        \"Retrieve the `Config` in `fpath`.\"\n        fpath = _expand_path(fpath or cls.DEFAULT_CONFIG_PATH)\n        if not fpath.exists() and create_missing: cls.create(fpath)\n        assert fpath.exists(), f'Could not find config at: {fpath}. Please create'\n        with open(fpath, 'r') as yaml_file: return yaml.safe_load(yaml_file)\n\n    @classmethod\n    def create(cls, fpath):\n        \"Creates a `Config` from `fpath`.\"\n        fpath = _expand_path(fpath)\n        assert(fpath.suffix == '.yml')\n        if fpath.exists(): return\n        fpath.parent.mkdir(parents=True, exist_ok=True)\n        with open(fpath, 'w') as yaml_file:\n            yaml.dump(cls.DEFAULT_CONFIG, yaml_file, default_flow_style=False)\n\ndef _expand_path(fpath): return Path(fpath).expanduser()\ndef url2name(url): return url.split('/')[-1]\n\n#TODO: simplify this mess\ndef url2path(url, data=True, ext:str='.tgz'):\n    \"Change `url` to a path.\"\n    name = url2name(url)\n    return datapath4file(name, ext=ext, archive=False) if data else modelpath4file(name, ext=ext)\ndef _url2tgz(url, data=True, ext:str='.tgz'):\n    return datapath4file(f'{url2name(url)}{ext}', ext=ext) if data else modelpath4file(f'{url2name(url)}{ext}', ext=ext)\n\ndef modelpath4file(filename, ext:str='.tgz'):\n    \"Return model path to `filename`, checking locally first then in the config file.\"\n    local_path = URLs.LOCAL_PATH/'models'/filename\n    if local_path.exists() or local_path.with_suffix(ext).exists(): return local_path\n    else: return Config.model_path()/filename\n\ndef datapath4file(filename, ext:str='.tgz', archive=True):\n    \"Return data path to `filename`, checking locally first then in the config file.\"\n    local_path = URLs.LOCAL_PATH/'data'/filename\n    if local_path.exists() or local_path.with_suffix(ext).exists(): return local_path\n    elif archive: return Config.data_archive_path() / filename\n    else: return Config.data_path() / filename\n\ndef download_data(url:str, fname:PathOrStr=None, data:bool=True, ext:str='.tgz') -> Path:\n    \"Download `url` to destination `fname`.\"\n    fname = Path(ifnone(fname, _url2tgz(url, data, ext=ext)))\n    os.makedirs(fname.parent, exist_ok=True)\n    if not fname.exists():\n        print(f'Downloading {url}')\n        download_url(f'{url}{ext}', fname)\n    return fname\n\ndef _check_file(fname):\n    size = os.path.getsize(fname)\n    with open(fname, \"rb\") as f:\n        hash_nb = hashlib.md5(f.read(2**20)).hexdigest()\n    return size,hash_nb\n\ndef untar_data(url:str, fname:PathOrStr=None, dest:PathOrStr=None, data=True, force_download=False) -> Path:\n    \"Download `url` to `fname` if `dest` doesn't exist, and un-tgz to folder `dest`.\"\n    dest = url2path(url, data) if dest is None else Path(dest)/url2name(url)\n    fname = Path(ifnone(fname, _url2tgz(url, data)))\n    if force_download or (fname.exists() and url in _checks and _check_file(fname) != _checks[url]):\n        print(f\"A new version of the {'dataset' if data else 'model'} is available.\")\n        if fname.exists(): os.remove(fname)\n        if dest.exists(): shutil.rmtree(dest)\n    if not dest.exists():\n        fname = download_data(url, fname=fname, data=data)\n        if url in _checks:\n            assert _check_file(fname) == _checks[url], f\"Downloaded file {fname} does not match checksum expected! Remove that file from {Config().data_archive_path()} and try your code again.\"\n        tarfile.open(fname, 'r:gz').extractall(dest.parent)\n    return dest\n"
  },
  {
    "path": "fastai/distributed.py",
    "content": "from .torch_core import *\nfrom .basic_train import Learner,LearnerCallback\nfrom torch.nn.parallel import DistributedDataParallel, DataParallel\nfrom torch.utils.data.distributed import DistributedSampler\n\nfrom fastai.text import TextLMDataBunch\n\n__all__ = ['DistributedRecorder', 'DistributedTrainer', 'read_metrics', 'setup_distrib']\n\ndef rnn_reset(self):\n    if hasattr(self.module, 'reset'): self.module.reset()\nDistributedDataParallel.reset = rnn_reset\n\nclass ParallelTrainer(LearnerCallback):\n    _order = -20\n    def on_train_begin(self, **kwargs): self.learn.model = DataParallel(self.learn.model)\n    def on_train_end  (self, **kwargs): self.learn.model = self.learn.model.module\n\nclass DistributedTrainer(LearnerCallback):\n    _order = -20 # Needs to run before the recorder\n    def __init__(self, learn:Learner, cuda_id:int=0):\n        super().__init__(learn)\n        self.cuda_id,self.train_sampler = cuda_id,None\n\n    def _change_dl(self, dl, shuffle):\n        old_dl = dl\n        sampler = OurDistributedSampler(dl.dataset, shuffle=shuffle)\n        new_dl = dl.new(shuffle=False, sampler=sampler)\n        return old_dl,new_dl,sampler\n\n    def on_train_begin(self, **kwargs):\n        self.learn.model = DistributedDataParallel(self.model, device_ids=[self.cuda_id], output_device=self.cuda_id)\n        shuffle = self.data.train_dl.init_kwargs['shuffle'] if hasattr(self.data.train_dl, 'init_kwargs') else True\n        self.old_train_dl,self.data.train_dl,self.train_sampler = self._change_dl(self.data.train_dl, shuffle)\n        if hasattr(self.data, 'valid_dl') and self.data.valid_dl is not None:\n            self.old_valid_dl,self.data.valid_dl,self.valid_sampler = self._change_dl(self.data.valid_dl, shuffle)\n        self.rank = rank_distrib()\n        self.recorder.silent = (self.rank != 0)\n\n    def on_epoch_begin(self, epoch, **kwargs): self.train_sampler.set_epoch(epoch)\n\n    def on_train_end(self, **kwargs):\n        self.learn.model = self.learn.model.module\n        self.learn.data.train_dl = self.old_train_dl\n        if hasattr(self.learn.data, 'valid_dl') and self.learn.data.valid_dl is not None:\n            self.learn.data.valid_dl = self.old_valid_dl\n\nclass DistributedRecorder(LearnerCallback):\n    def __init__(self, learn:Learner, cuda_id:int=0, cache_dir:PathOrStr='tmp'):\n        super().__init__(learn)\n        self.cuda_id,self.cache_dir = cuda_id,cache_dir\n\n    def on_train_begin(self, **kwargs):\n        os.makedirs(self.learn.path/self.cache_dir, exist_ok=True)\n\n    def on_epoch_end(self, **kwargs): self.save_stats()\n    def on_train_end(self, **kwargs): self.save_stats()\n\n    def save_stats(self):\n        cache_path,recorder = self.learn.path/self.cache_dir,self.learn.recorder\n        np.save(cache_path/f'losses_{self.cuda_id}', np.array(recorder.losses))\n        stats = np.array([[v] + m for v,m in zip(recorder.val_losses,recorder.metrics)])\n        np.save(cache_path/f'metrics_{self.cuda_id}', stats)\n\ndef _learner_parallel(learn:Learner):\n    \"Use nn.DataParallel when training and remove when done\"\n    if not torch.cuda.is_available(): warnings.warn('CUDA is not available, check your drivers - training will continue on CPU', ResourceWarning) \n    learn.callbacks.append(ParallelTrainer(learn))\n    return learn\n\ndef _learner_distributed(learn:Learner, cuda_id:int, cache_dir:PathOrStr='tmp'):\n    \"Put `learn` on distributed training with `cuda_id`.\"\n    learn.callbacks.append(DistributedTrainer(learn, cuda_id))\n    learn.callbacks.append(DistributedRecorder(learn, cuda_id, cache_dir))\n    return learn\n\nLearner.to_distributed = _learner_distributed\nLearner.to_parallel = _learner_parallel\n\ndef read_metrics(cache_path:PathOrStr, n_gpus:int, reduce:bool=True):\n    losses,metrics = [],[]\n    for i in range(n_gpus):\n        losses.append(np.load(cache_path/f'losses_{i}.npy')[None])\n        metrics.append(np.load(cache_path/f'metrics_{i}.npy')[None])\n    if reduce:\n        losses,metrics = np.concatenate(losses,0),np.concatenate(metrics,0)\n        return losses.mean(0),metrics.mean(0)\n    return losses,metrics\n\ndef setup_distrib(gpu:Any=None):\n    if gpu is None: return gpu\n    gpu = int(gpu)\n    torch.cuda.set_device(int(gpu))\n    if num_distrib() > 1:\n        torch.distributed.init_process_group(backend='nccl', init_method='env://')\n    return gpu\n\nclass OurDistributedSampler(DistributedSampler):\n    \"A sampler for language models with the option to not shuffle.\"\n    def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):\n            super().__init__(dataset, num_replicas=num_replicas, rank=rank)\n            self.shuffle = shuffle\n    \n    def __iter__(self):\n        if self.shuffle:\n            g = torch.Generator()\n            g.manual_seed(self.epoch)\n            indices = torch.randperm(len(self.dataset), generator=g).tolist()\n        else: indices = torch.arange(len(self.dataset)).tolist()\n\n        # add extra samples to make it evenly divisible\n        indices += indices[:(self.total_size - len(indices))]\n        assert len(indices) == self.total_size\n\n        # subsample\n        indices = indices[self.rank:self.total_size:self.num_replicas]\n        assert len(indices) == self.num_samples\n\n        return iter(indices)\n"
  },
  {
    "path": "fastai/gen_doc/__init__.py",
    "content": "from . import gen_notebooks, nbdoc, core, doctest, nbtest\n"
  },
  {
    "path": "fastai/gen_doc/autogen.tpl",
    "content": "<!--\n\n\n#################################################\n### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ###\n#################################################\n# file to edit: {{ resources.nb_path }}\n# instructions: https://docs.fast.ai/gen_doc_main.html\n\n-->\n"
  },
  {
    "path": "fastai/gen_doc/convert2html.py",
    "content": "import os.path, re, nbformat, jupyter_contrib_nbextensions\nfrom nbconvert.preprocessors import Preprocessor\nfrom nbconvert import HTMLExporter\nfrom traitlets.config import Config\nfrom pathlib import Path\n\n__all__ = ['read_nb', 'convert_nb', 'convert_all']\n\nexporter = HTMLExporter(Config())\nexporter.exclude_input_prompt=True\nexporter.exclude_output_prompt=True\n#Loads the template to deal with hidden cells.\nexporter.template_file = 'jekyll.tpl'\npath = Path(__file__).parent\nexporter.template_path.append(str(path))\n\ndef read_nb(fname):\n    \"Read the notebook in `fname`.\"\n    with open(fname,'r') as f: return nbformat.reads(f.read(), as_version=4)\n\ndef convert_nb(fname, dest_path='.'):\n    \"Convert a notebook `fname` to html file in `dest_path`.\"\n    from .gen_notebooks import remove_undoc_cells, remove_code_cell_jupyter_widget_state_elem\n    nb = read_nb(fname)\n    nb['cells'] = remove_undoc_cells(nb['cells'])\n    nb['cells'] = remove_code_cell_jupyter_widget_state_elem(nb['cells'])\n    fname = Path(fname).absolute()\n    dest_name = fname.with_suffix('.html').name\n    meta = nb['metadata']\n    meta_jekyll = meta['jekyll'] if 'jekyll' in meta else {'title': fname.with_suffix('').name}\n    meta_jekyll['nb_path'] = f'{fname.parent.name}/{fname.name}'\n    with open(f'{dest_path}/{dest_name}','w') as f:\n        f.write(exporter.from_notebook_node(nb, resources=meta_jekyll)[0])\n\ndef convert_all(folder, dest_path='.', force_all=False):\n    \"Convert modified notebooks in `folder` to html pages in `dest_path`.\"\n    path = Path(folder)\n\n    changed_cnt = 0\n    for fname in path.glob(\"*.ipynb\"):\n        # only rebuild modified files\n        fname_out = Path(dest_path)/fname.with_suffix('.html').name\n        if not force_all and fname_out.exists():\n            in_mod  = os.path.getmtime(fname)\n            out_mod = os.path.getmtime(fname_out)\n            if in_mod < out_mod: continue\n\n        print(f\"converting: {fname} => {fname_out}\")\n        changed_cnt += 1\n        convert_nb(fname, dest_path=dest_path)\n    if not changed_cnt: print(\"No notebooks were modified\")\n"
  },
  {
    "path": "fastai/gen_doc/core.py",
    "content": "from ..core import *\nimport re\n\ndef strip_fastai(s):  return re.sub(r'^fastai\\.', '', s)\n\n"
  },
  {
    "path": "fastai/gen_doc/docstrings.py",
    "content": "# https://github.com/openstack/rally/blob/master/rally/common/plugin/info.py\n# Copyright 2015: Mirantis Inc.\n# All Rights Reserved.\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\"); you may\n#    not use this file except in compliance with the License. You may obtain\n#    a copy of the License at\n#\n#         http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT\n#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the\n#    License for the specific language governing permissions and limitations\n#    under the License.\n\nimport re\nimport sys\n\n__all__ = ['parse_docstring']\n\n\nFIELDS = 'param|val' # supported fields\nPARAM_OR_RETURN_REGEX = re.compile(f\":(?:{FIELDS}|return)\")\nRETURN_REGEX = re.compile(\":return: (?P<doc>.*)\", re.S)\nNEW_REGEX = re.compile(f\":(?P<field>{FIELDS}) (?P<name>[\\*\\w]+): (?P<doc>.*?)\"\n                         f\"(?:(?=:(?:{FIELDS}|return|raises))|\\Z)\", re.S)\n\ndef trim(docstring):\n    \"\"\"trim function from PEP-257\"\"\"\n    if not docstring:\n        return \"\"\n    # Convert tabs to spaces (following the normal Python rules)\n    # and split into a list of lines:\n    lines = docstring.expandtabs().splitlines()\n    # Determine minimum indentation (first line doesn't count):\n    indent = sys.maxsize\n    for line in lines[1:]:\n        stripped = line.lstrip()\n        if stripped:\n            indent = min(indent, len(line) - len(stripped))\n    # Remove indentation (first line is special):\n    trimmed = [lines[0].strip()]\n    if indent < sys.maxsize:\n        for line in lines[1:]:\n            trimmed.append(line[indent:].rstrip())\n    # Strip off trailing and leading blank lines:\n    while trimmed and not trimmed[-1]:\n        trimmed.pop()\n    while trimmed and not trimmed[0]:\n        trimmed.pop(0)\n\n    # Current code/unittests expects a line return at\n    # end of multiline docstrings\n    # workaround expected behavior from unittests\n    if \"\\n\" in docstring:\n        trimmed.append(\"\")\n\n    # Return a single string:\n    return \"\\n\".join(trimmed)\n\n\ndef reindent(string):\n    return \"\\n\".join(l.strip() for l in string.strip().split(\"\\n\"))\n\n\ndef parse_docstring(docstring):\n    \"\"\"Parse the docstring into its components.\n\n    :return: a dictionary of form\n              {\n                  \"short_description\": ...,\n                  \"long_description\": ...,\n                  \"params\": [{\"name\": ..., \"doc\": ...}, ...],\n                  \"vals\": [{\"name\": ..., \"doc\": ...}, ...],\n                  \"return\": ...\n              }\n    \"\"\"\n\n    short_description = long_description = return_str = \"\"\n    args = []\n\n    if docstring:\n        docstring = trim(docstring.lstrip(\"\\n\"))\n\n        lines = docstring.split(\"\\n\", 1)\n        short_description = lines[0]\n\n        if len(lines) > 1:\n            long_description = lines[1].strip()\n\n            params_return_desc = None\n\n            match = PARAM_OR_RETURN_REGEX.search(long_description)\n            if match:\n                long_desc_end = match.start()\n                params_return_desc = long_description[long_desc_end:].strip()\n                long_description = long_description[:long_desc_end].rstrip()\n\n            if params_return_desc:\n                args = [\n                    {\"name\": name, \"doc\": trim(doc), \"field\": field}\n                    for field, name, doc in NEW_REGEX.findall(params_return_desc)\n                ]\n                match = RETURN_REGEX.search(params_return_desc)\n                if match:\n                    return_str = reindent(match.group(\"doc\"))\n    comments = {p['name']: p['doc'] for p in args}\n    return {\n        \"short_description\": short_description,\n        \"long_description\": long_description,\n        \"args\": args,\n        \"comments\": comments,\n        \"return\": return_str\n    }\n\n\nclass InfoMixin(object):\n\n    @classmethod\n    def _get_doc(cls):\n        \"\"\"Return documentary of class\n\n        By default it returns docstring of class, but it can be overridden\n        for example for cases like merging own docstring with parent\n        \"\"\"\n        return cls.__doc__\n\n    @classmethod\n    def get_info(cls):\n        doc = parse_docstring(cls._get_doc())\n\n        return {\n            \"name\": cls.get_name(),\n            \"platform\": cls.get_platform(),\n            \"module\": cls.__module__,\n            \"title\": doc[\"short_description\"],\n            \"description\": doc[\"long_description\"],\n            \"parameters\": doc[\"params\"],\n            \"schema\": getattr(cls, \"CONFIG_SCHEMA\", None),\n            \"return\": doc[\"return\"]\n        }\n"
  },
  {
    "path": "fastai/gen_doc/doctest.py",
    "content": "import sys, re, json, pprint\nfrom pathlib import Path\nfrom collections import defaultdict\nfrom inspect import currentframe, getframeinfo, ismodule\n\n__all__ = ['this_tests']\n\nDB_NAME = 'test_registry.json'\n\ndef _json_set_default(obj):\n    if isinstance(obj, set): return list(obj)\n    raise TypeError\n\nclass TestRegistry:\n    \"Tests register which API they validate using this class.\"\n    registry = defaultdict(list)\n    this_tests_check = None\n    missing_this_tests = set()\n\n    # logic for checking whether each test calls `this_tests`:\n    # 1. `this_tests_check` is set to True during test's 'setup' stage if it wasn't skipped\n    # 2. if the test is dynamically skipped `this_tests_check` is set to False\n    # 3. `this_tests` sets this flag to False when it's successfully completes\n    # 4. if during the 'teardown' stage `this_tests_check` is still True then we\n    # know that this test needs `this_tests_check`\n\n    @staticmethod\n    def this_tests(*funcs):\n        prev_frame = currentframe().f_back.f_back\n        file_name, lineno, test_name, _, _ = getframeinfo(prev_frame)\n        parent_func_lineno, _ = get_parent_func(lineno, get_lines(file_name))\n        entry = {'file': relative_test_path(file_name), 'test': test_name , 'line': parent_func_lineno}\n        for func in funcs:\n            if func == 'na':\n                # special case when we can't find a function to declare, e.g.\n                # when attributes are tested\n                continue\n            try:\n                func_fq = get_func_fq_name(func)\n            except:\n                raise Exception(f\"'{func}' is not a function\") from None\n            if re.match(r'fastai\\.', func_fq):\n                if entry not in TestRegistry.registry[func_fq]:\n                    TestRegistry.registry[func_fq].append(entry)\n            else:\n                raise Exception(f\"'{func}' is not in the fastai API\") from None\n        TestRegistry.this_tests_check = False\n\n    def this_tests_check_on():\n        TestRegistry.this_tests_check = True\n\n    def this_tests_check_off():\n        TestRegistry.this_tests_check = False\n\n    def this_tests_check_run(file_name, test_name):\n        if TestRegistry.this_tests_check:\n            TestRegistry.missing_this_tests.add(f\"{file_name}::{test_name}\")\n\n    def registry_save():\n        if TestRegistry.registry:\n            path = Path(__file__).parent.parent.resolve()/DB_NAME\n            if path.exists():\n                #print(\"\\n*** Merging with the existing test registry\")\n                with open(path, 'r') as f: old_registry = json.load(f)\n                TestRegistry.registry = merge_registries(old_registry, TestRegistry.registry)\n            #print(f\"\\n*** Saving test registry @ {path}\")\n            with open(path, 'w') as f:\n                json.dump(obj=TestRegistry.registry, fp=f, indent=4, sort_keys=True, default=_json_set_default)\n\n    def missing_this_tests_alert():\n        if TestRegistry.missing_this_tests:\n            tests = '\\n  '.join(sorted(TestRegistry.missing_this_tests))\n            print(f\"\"\"\n*** Attention ***\nPlease include `this_tests` call in each of the following tests:\n  {tests}\nFor details see: https://docs.fast.ai/dev/test.html#test-registry\"\"\")\n\n# merge_registries helpers\n# merge dict of lists of dict\ndef a2k(a): return '::'.join([a['file'], a['test']]), a['line']\ndef k2a(k, v): f,t = k.split('::'); return {\"file\": f, \"line\": v, \"test\": t}\n# merge by key that is a combination of 2 values: test, file\ndef merge_lists(a, b):\n    x = dict(map(a2k, [*a, *b]))            # pack + merge\n    return [k2a(k, v) for k,v in x.items()] # unpack\ndef merge_registries(a, b):\n    for i in b: a[i] = merge_lists(a[i], b[i]) if i in a else b[i]\n    return a\n\ndef this_tests(*funcs): TestRegistry.this_tests(*funcs)\n\ndef str2func(name):\n    \"Converts 'fastai.foo.bar' into an function 'object' if such exists\"\n    if isinstance(name, str): subpaths = name.split('.')\n    else:                     return None\n\n    module = subpaths.pop(0)\n    if module in sys.modules: obj = sys.modules[module]\n    else:                     return None\n\n    for subpath in subpaths:\n        obj = getattr(obj, subpath, None)\n        if obj == None: return None\n    return obj\n\ndef get_func_fq_name(func):\n    if ismodule(func): return func.__name__\n    if isinstance(func, str): func = str2func(func)\n    name = None\n    if   hasattr(func, '__qualname__'): name = func.__qualname__\n    elif hasattr(func, '__name__'):     name = func.__name__\n    elif hasattr(func, '__wrapped__'):  return get_func_fq_name(func.__wrapped__)\n    elif hasattr(func, '__class__'):    name = func.__class__.__name__\n    else: raise Exception(f\"'{func}' is not a func or class\")\n    return f'{func.__module__}.{name}'\n\ndef get_parent_func(lineno, lines, ignore_missing=False):\n    \"Find any lines where `elt` is called and return the parent test function\"\n    for idx,l in enumerate(reversed(lines[:lineno])):\n        if re.match(f'\\s*def test', l):  return (lineno - idx), l # 1 based index for github\n        if re.match(f'\\w+', l):  break # top level indent - out of function scope\n    if ignore_missing: return None\n    raise LookupError('Could not find parent function for line:', lineno, lines[:lineno])\n\ndef relative_test_path(test_file:Path)->str:\n    \"Path relative to the `fastai` parent directory\"\n    test_file = Path(test_file)\n    testdir_idx = list(reversed(test_file.parts)).index('tests')\n    return '/'.join(test_file.parts[-(testdir_idx+1):])\n\ndef get_lines(file):\n    with open(file, 'r') as f: return f.readlines()\n"
  },
  {
    "path": "fastai/gen_doc/gen_notebooks.py",
    "content": "\"`gen_doc.nbdoc` generates notebook documentation from module functions and links to correct places\"\nimport pkgutil, inspect, sys,os, importlib,json,enum,warnings,nbformat,re\nfrom IPython.core.display import display, Markdown\nfrom nbconvert.preprocessors import ExecutePreprocessor\nfrom nbformat.sign import NotebookNotary\nfrom pathlib import Path\nfrom .core import *\nfrom .nbdoc import *\n\n__all__ = ['create_module_page', 'update_module_page', 'import_mod',\n           'link_nb', 'update_notebooks', 'generate_missing_metadata', 'update_nb_metadata']\n\ndef get_empty_notebook():\n    \"Default notbook with the minimum metadata.\"\n    #TODO: check python version and nbformat\n    return {'metadata': {'kernelspec': {'display_name': 'Python 3',\n                                        'language': 'python',\n                                        'name': 'python3'},\n                         'language_info': {'codemirror_mode': {'name': 'ipython', 'version': 3},\n                         'file_extension': '.py',\n                         'mimetype': 'text/x-python',\n                         'name': 'python',\n                         'nbconvert_exporter': 'python',\n                         'pygments_lexer': 'ipython3',\n                         'version': '3.6.6'}},\n            'nbformat': 4,\n            'nbformat_minor': 2}\n\ndef get_md_cell(source, metadata=None):\n    \"Markdown cell containing `source` with `metadata`.\"\n    return {'cell_type': 'markdown',\n            'metadata': {} if metadata is None else metadata,\n            'source': source}\n\ndef get_empty_cell(ctype='markdown'):\n    \"Empty cell of type `ctype`.\"\n    return {'cell_type': ctype, 'metadata': {}, 'source': []}\n\ndef get_code_cell(code, hidden=False):\n    \"Code cell containing `code` that may be `hidden`.\"\n    return {'cell_type' : 'code',\n            'execution_count': 0,\n            'metadata' : {'hide_input': hidden, 'trusted':True},\n            'source' : code,\n            'outputs': []}\n\ndef get_doc_cell(func_name):\n    \"Code cell with the command to show the doc of `func_name`.\"\n    code = f\"show_doc({func_name})\"\n    return get_code_cell(code, True)\n\ndef get_global_vars(mod):\n    \"Return globally assigned variables.\"\n    # https://stackoverflow.com/questions/8820276/docstring-for-variable/31764368#31764368\n    import ast,re\n    with open(mod.__file__, 'r') as f: fstr = f.read()\n    flines = fstr.splitlines()\n    d = {}\n    for node in ast.walk(ast.parse(fstr)):\n        if isinstance(node,ast.Assign) and hasattr(node.targets[0], 'id'):\n            key,lineno = node.targets[0].id,node.targets[0].lineno\n            codestr = flines[lineno]\n            match = re.match(f\"^({key})\\s*=\\s*.*\", codestr)\n            if match and match.group(1) != '__all__': # only top level assignment\n                d[key] = f'`{codestr}` {get_source_link(mod, lineno)}'\n    return d\n\ndef write_nb(nb, nb_path, mode='w'):\n    with open(nb_path, mode) as f: f.write(nbformat.writes(nbformat.from_dict(nb), version=4))\n\nclass ExecuteShowDocPreprocessor(ExecutePreprocessor):\n    \"An ExecutePreprocessor that only executes show_doc cells\"\n    def preprocess_cell(self, cell, resources, index):\n        if 'source' in cell and cell.cell_type == \"code\":\n            if IMPORT_RE.search(cell['source']) or SHOW_DOC_RE.search(cell['source']):\n                return super().preprocess_cell(cell, resources, index)\n        return cell, resources\n\ndef execute_nb(fname, metadata=None, save=True, show_doc_only=False):\n    \"Execute notebook `fname` with `metadata` for preprocessing.\"\n    # Any module used in the notebook that isn't inside must be in the same directory as this script\n    with open(fname) as f: nb = nbformat.read(f, as_version=4)\n    ep_class = ExecuteShowDocPreprocessor if show_doc_only else ExecutePreprocessor\n    ep = ep_class(timeout=600, kernel_name='python3')\n    metadata = metadata or {}\n    ep.preprocess(nb, metadata)\n    if save:\n        with open(fname, 'wt') as f: nbformat.write(nb, f)\n        NotebookNotary().sign(nb)\n\ndef _symbol_skeleton(name): return [get_doc_cell(name), get_md_cell(f\"`{name}`\")]\n\ndef create_module_page(mod, dest_path, force=False):\n    \"Create the documentation notebook for module `mod_name` in path `dest_path`\"\n    nb = get_empty_notebook()\n    mod_name = mod.__name__\n    strip_name = strip_fastai(mod_name)\n    init_cell = [get_md_cell(f'## Title for {strip_name} (use plain english, not module name!)'), get_md_cell('Type an introduction of the package here.')]\n    cells = [get_code_cell(f'from fastai.gen_doc.nbdoc import *\\nfrom {mod_name} import * ', True)]\n\n    gvar_map = get_global_vars(mod)\n    if gvar_map: cells.append(get_md_cell('### Global Variable Definitions:'))\n    for name in get_exports(mod):\n        if name in gvar_map: cells.append(get_md_cell(gvar_map[name]))\n\n    for ft_name in get_ft_names(mod, include_inner=True):\n        if not hasattr(mod, ft_name):\n            warnings.warn(f\"Module {strip_name} doesn't have a function named {ft_name}.\")\n            continue\n        cells += _symbol_skeleton(ft_name)\n        elt = getattr(mod, ft_name)\n    nb['cells'] = init_cell + cells + [get_md_cell(UNDOC_HEADER)]\n\n    doc_path = get_doc_path(mod, dest_path)\n    write_nb(nb, doc_path, 'w' if force else 'x')\n    execute_nb(doc_path)\n    return doc_path\n\n_default_exclude = ['.ipynb_checkpoints', '__pycache__', '__init__.py', 'imports']\n\ndef get_module_names(path_dir, exclude=None):\n    if exclude is None: exclude = _default_exclude\n    \"Search a given `path_dir` and return all the modules contained inside except those in `exclude`\"\n    files = sorted(path_dir.glob('*'), key=lambda x: (x.is_dir(), x.name), reverse=True) # directories first\n    res = [f'{path_dir.name}']\n    for f in files:\n        if f.is_dir() and f.name in exclude: continue # exclude directories\n        if any([f.name.endswith(ex) for ex in exclude]): continue # exclude extensions\n\n        if f.suffix == '.py': res.append(f'{path_dir.name}.{f.stem}')\n        elif f.is_dir(): res += [f'{path_dir.name}.{name}' for name in get_module_names(f)]\n    return res\n\ndef read_nb(fname):\n    \"Read a notebook in `fname` and return its corresponding json\"\n    with open(fname,'r') as f: return nbformat.reads(f.read(), as_version=4)\n\nSHOW_DOC_RE = re.compile(r\"show_doc\\(([\\w\\.]*)\")\ndef read_nb_content(cells, mod_name):\n    \"Build a dictionary containing the position of the `cells`.\"\n    doc_fns = {}\n    for i, cell in enumerate(cells):\n        if cell['cell_type'] == 'code':\n            for match in SHOW_DOC_RE.findall(cell['source']):\n                doc_fns[match] = i\n    return doc_fns\n\ndef read_nb_types(cells):\n    doc_fns = {}\n    for i, cell in enumerate(cells):\n        if cell['cell_type'] == 'markdown':\n            match = re.match(r\"^(?:<code>|`)?(\\w*)\\s*=\\s*\", cell['source'])\n            if match is not None: doc_fns[match.group(1)] = i\n    return doc_fns\n\ndef link_markdown_cells(cells, modules):\n    \"Create documentation links for all cells in markdown with backticks.\"\n    for i, cell in enumerate(cells):\n        if cell['cell_type'] == 'markdown':\n            cell['source'] = link_docstring(modules, cell['source'])\n\ndef get_insert_idx(pos_dict, name):\n    \"Return the position to insert a given function doc in a notebook.\"\n    keys,i = list(pos_dict.keys()),0\n    while i < len(keys) and str.lower(keys[i]) < str.lower(name): i+=1\n    if i == len(keys): return -1\n    else:              return pos_dict[keys[i]]\n\ndef update_pos(pos_dict, start_key, nbr=2):\n    \"Update the `pos_dict` by moving all positions after `start_key` by `nbr`.\"\n    for key,idx in pos_dict.items():\n        if str.lower(key) >= str.lower(start_key): pos_dict[key] += nbr\n    return pos_dict\n\ndef insert_cells(cells, pos_dict, ft_name, append=False):\n    \"Insert the function doc `cells` at their correct position and updates `pos_dict`.\"\n    idx = get_insert_idx(pos_dict, ft_name)\n    if append or idx == -1: cells += [get_doc_cell(ft_name), get_empty_cell()]\n    else:\n        cells.insert(idx, get_doc_cell(ft_name))\n        cells.insert(idx+1, get_empty_cell())\n        pos_dict = update_pos(pos_dict, ft_name, 2)\n    return cells, pos_dict\n\ndef get_doc_path(mod, dest_path):\n    strip_name = strip_fastai(mod.__name__)\n    return os.path.join(dest_path,f'{strip_name}.ipynb')\n\ndef generate_missing_metadata(dest_file):\n    fn = Path(dest_file)\n    meta_fn = fn.parent/'jekyll_metadata.ipynb'\n    if not fn.exists() or not meta_fn.exists(): return print('Could not find notebooks:', fn, meta_fn)\n    metadata_nb = read_nb(meta_fn)\n\n    if has_metadata_cell(metadata_nb['cells'], fn.name): return\n    nb = read_nb(fn)\n    jmd = nb['metadata'].get('jekyll', {})\n    fmt_params = ''\n    for k,v in jmd.items(): fmt_params += f',\\n    {k}={stringify(v)}'\n    metadata_cell = get_code_cell(f\"update_nb_metadata('{Path(fn).name}'{fmt_params})\", hidden=False)\n    metadata_nb['cells'].append(metadata_cell)\n    write_nb(metadata_nb, meta_fn)\n\ndef update_nb_metadata(nb_path=None, title=None, summary=None, keywords='fastai', overwrite=True, **kwargs):\n    \"Creates jekyll metadata for given notebook path.\"\n    nb = read_nb(nb_path)\n    data = {'title': title, 'summary': summary, 'keywords': keywords, **kwargs}\n    data = {k:v for (k,v) in data.items() if v is not None} # remove none values\n    if not data: return\n    nb['metadata']['jekyll'] = data\n    write_nb(nb, nb_path)\n    NotebookNotary().sign(nb)\n\ndef has_metadata_cell(cells, fn):\n    for c in cells:\n        if re.search(f\"update_nb_metadata\\('{fn}'\", c['source']): return c\n\ndef stringify(s): return f'\\'{s}\\'' if isinstance(s, str) else s\n\nIMPORT_RE = re.compile(r\"from (fastai[\\.\\w_]*)\")\ndef get_imported_modules(cells, nb_module_name=''):\n    \"Finds all submodules of notebook - sorted by submodules > top level modules > manual imports. This gives notebook imports priority\"\n    module_names = get_top_level_modules()\n    nb_imports = [match.group(1) for cell in cells for match in IMPORT_RE.finditer(cell['source']) if cell['cell_type'] == 'code']\n    parts = nb_module_name.split('.')\n    parent_modules = ['.'.join(parts[:(x+1)]) for x in range_of(parts)] # Imports parent modules - a.b.c = [a, a.b, a.b.c]\n    all_modules = module_names + nb_imports + parent_modules\n    mods = [import_mod(m, ignore_errors=True) for m in all_modules]\n    return [m for m in mods if m is not None]\n\ndef get_top_level_modules(num_levels=1):\n    mod_dir = Path(import_mod('fastai').__file__).parent\n    filtered_n = filter(lambda x: x.count('.')<=num_levels, get_module_names(mod_dir))\n    return sorted(filtered_n, key=lambda s: s.count('.'), reverse=True) # Submodules first (sorted by periods)\n\nNEW_FT_HEADER = '## New Methods - Please document or move to the undocumented section'\nUNDOC_HEADER = '## Undocumented Methods - Methods moved below this line will intentionally be hidden'\ndef parse_sections(cells):\n    old_cells, undoc_cells, new_cells = [], [], []\n    current_section = old_cells\n    for cell in cells:\n        if cell['cell_type'] == 'markdown':\n            if re.match(UNDOC_HEADER, cell['source']): current_section = undoc_cells\n            if re.match(NEW_FT_HEADER, cell['source']): current_section = new_cells\n        current_section.append(cell)\n    undoc_cells = undoc_cells or [get_md_cell(UNDOC_HEADER)]\n    new_cells = new_cells or [get_md_cell(NEW_FT_HEADER)]\n    return old_cells, undoc_cells, new_cells\n\ndef remove_undoc_cells(cells):\n    old, _, _ = parse_sections(cells)\n    return old\n\n# currently code vbox sub-cells mainly\ndef remove_code_cell_jupyter_widget_state_elem(cells):\n    for c in cells:\n        if c['cell_type'] == 'code':\n            if 'outputs' in c:\n                c['outputs'] = [l for l in c['outputs'] if not ('data' in l and 'application/vnd.jupyter.widget-view+json' in l.data)]\n    return cells\n\ndef update_module_page(mod, dest_path='.'):\n    \"Update the documentation notebook of a given module.\"\n    doc_path = get_doc_path(mod, dest_path)\n    strip_name = strip_fastai(mod.__name__)\n    nb = read_nb(doc_path)\n    cells = nb['cells']\n\n    link_markdown_cells(cells, get_imported_modules(cells, mod.__name__))\n\n    type_dict = read_nb_types(cells)\n    gvar_map = get_global_vars(mod)\n    for name in get_exports(mod):\n        if name not in gvar_map: continue\n        code = gvar_map[name]\n        if name in type_dict: cells[type_dict[name]] = get_md_cell(code)\n        else: cells.append(get_md_cell(code))\n\n    pos_dict = read_nb_content(cells, strip_name)\n    ft_names = get_ft_names(mod, include_inner=True)\n    new_fts = list(set(ft_names) - set(pos_dict.keys()))\n    if new_fts: print(f'Found new fuctions for {mod}. Please document:\\n{new_fts}')\n    existing, undoc_cells, new_cells = parse_sections(cells)\n    for ft_name in new_fts: new_cells.extend([get_doc_cell(ft_name), get_empty_cell()])\n    if len(new_cells) > 1: nb['cells'] = existing + undoc_cells + new_cells\n\n    write_nb(nb, doc_path)\n    return doc_path\n\ndef link_nb(nb_path):\n    nb = read_nb(nb_path)\n    cells = nb['cells']\n    link_markdown_cells(cells, get_imported_modules(cells, Path(nb_path).stem))\n    write_nb(nb, nb_path)\n    NotebookNotary().sign(read_nb(nb_path))\n\ndef get_module_from_notebook(doc_path):\n    \"Find module given a source path. Assume it belongs to fastai directory\"\n    return f'fastai.{Path(doc_path).stem}'\n\ndef check_nbconvert_version():\n    import nbconvert\n    assert nbconvert.version_info >= (5,4,0), \"Please update nbconvert to >=5.4 for consistent .html output\"\n\ndef update_notebooks(source_path, dest_path=None, update_html=True, document_new_fns=False,\n                     update_nb_links=True, html_path=None, force=False):\n    \"`source_path` can be a directory or a file. Assume all modules reside in the fastai directory.\"\n    from .convert2html import convert_nb\n    source_path = Path(source_path)\n\n    if source_path.is_file():\n        dest_path = source_path.parent if dest_path is None else Path(dest_path)\n        html_path = dest_path/'..'/'docs' if html_path is None else Path(html_path)\n        doc_path = source_path\n        assert source_path.suffix == '.ipynb', 'Must update from notebook or module'\n        if document_new_fns:\n            mod = import_mod(get_module_from_notebook(source_path))\n            if not mod: print('Could not find module for path:', source_path)\n            elif mod.__file__.endswith('__init__.py'): pass\n            else: update_module_page(mod, dest_path)\n        generate_missing_metadata(doc_path)\n        if update_nb_links:\n            print(f'Updating notebook {doc_path}. Please wait...')\n            link_nb(doc_path)\n            execute_nb(doc_path, {'metadata': {'path': doc_path.parent}}, show_doc_only=True)\n        if update_html:\n            check_nbconvert_version()\n            html_fn = html_path/doc_path.with_suffix('.html').name\n            if not force and html_fn.is_file():\n                in_mod  = os.path.getmtime(doc_path)\n                out_mod = os.path.getmtime(html_fn)\n                if in_mod < out_mod: return\n            convert_nb(doc_path, html_path)\n\n    elif (source_path.name.startswith('fastai.')):\n        # Do module update\n        assert dest_path is not None, 'To update a module, you must specify a destination folder for where notebook resides'\n        mod = import_mod(source_path.name)\n        if not mod: return print('Could not find module for:', source_path)\n        doc_path = Path(dest_path)/(strip_fastai(mod.__name__)+'.ipynb')\n        if not doc_path.exists():\n            print('Notebook does not exist. Creating:', doc_path)\n            create_module_page(mod, dest_path)\n        update_notebooks(doc_path, dest_path=dest_path, update_html=update_html, document_new_fns=document_new_fns,\n                         update_nb_links=update_nb_links, html_path=html_path)\n    elif source_path.is_dir():\n        for f in sorted(Path(source_path).glob('*.ipynb')):\n            update_notebooks(f, dest_path=dest_path, update_html=update_html, document_new_fns=document_new_fns,\n                             update_nb_links=update_nb_links, html_path=html_path)\n    else: print('Could not resolve source file:', source_path)\n"
  },
  {
    "path": "fastai/gen_doc/hide.tpl",
    "content": "{%- extends 'basic.tpl' -%}\n\n{% block input_group -%}\n{%- if cell.metadata.hide_input or nb.metadata.hide_input -%}\n{%- else -%}\n    {{ super()  }}\n{%- endif -%}\n{% endblock input_group %}\n\n{% block output_group -%}\n{%- if cell.metadata.hide_output -%}\n{%- else -%}\n    {{ super()  }}\n{%- endif -%}\n{% endblock output_group %}\n\n{% block output_area_prompt %}\n{%- if cell.metadata.hide_input or nb.metadata.hide_input -%}\n   <div class=\"prompt\"> </div>\n{%- else -%}\n    {{ super()  }}\n{%- endif -%}\n{% endblock output_area_prompt %}\n"
  },
  {
    "path": "fastai/gen_doc/jekyll.tpl",
    "content": "{%- extends 'hide.tpl' -%}{% block body %}---\n{% if resources.toc != \"\" and resources.toc != nil %}toc: {{resources.toc}}{% endif %}\n{% if resources.title != \"\" and resources.title != nil %}title: {{resources.title}}{% endif %}\nkeywords: {{resources.keywords}}\nsidebar: home_sidebar\n{% if resources.tags != \"\" and resources.tags != nil %}tags: {{resources.tags}}{% endif %}\n{% if resources.summary != \"\" and resources.summary != nil %}summary: \"{{resources.summary}}\"{% endif %}\n---\n{% include 'autogen.tpl' %}\n\n<div class=\"container\" id=\"notebook-container\">\n    {{ super()  }}\n</div>\n{%- endblock body %}\n"
  },
  {
    "path": "fastai/gen_doc/nbdoc.py",
    "content": "\"`gen_doc.nbdoc` generates notebook documentation from module functions and links to correct places\"\n\nimport inspect,importlib,enum,os,re,nbconvert\nfrom IPython.core.display import display, Markdown, HTML\nfrom nbconvert import HTMLExporter\nfrom IPython.core import page\nfrom IPython import get_ipython\nfrom typing import Dict, Any, AnyStr, List, Sequence, TypeVar, Tuple, Optional, Union\nfrom .docstrings import *\nfrom .core import *\nfrom ..torch_core import *\nfrom .nbtest import get_pytest_html\nfrom ..utils.ipython import IS_IN_COLAB\n\n__all__ = ['get_fn_link', 'link_docstring', 'show_doc', 'get_ft_names', 'md2html',\n           'get_exports', 'show_video', 'show_video_from_youtube', 'import_mod', 'get_source_link',\n           'is_enum', 'jekyll_note', 'jekyll_warn', 'jekyll_important', 'doc']\n\nMODULE_NAME = 'fastai'\nSOURCE_URL = 'https://github.com/fastai/fastai/blob/master/'\nPYTORCH_DOCS = 'https://pytorch.org/docs/stable/'\nFASTAI_DOCS = 'https://docs.fast.ai'\nuse_relative_links = True\n\n_typing_names = {t:n for t,n in fastai_types.items() if t.__module__=='typing'}\narg_prefixes = {inspect._VAR_POSITIONAL: '\\*', inspect._VAR_KEYWORD:'\\*\\*'}\n\n\ndef is_enum(cls): return cls == enum.Enum or cls == enum.EnumMeta\n\ndef link_type(arg_type, arg_name=None, include_bt:bool=True):\n    \"Create link to documentation.\"\n    arg_name = arg_name or fn_name(arg_type)\n    if include_bt: arg_name = code_esc(arg_name)\n    if belongs_to_module(arg_type, 'torch') and ('Tensor' not in arg_name): return f'[{arg_name}]({get_pytorch_link(arg_type)})'\n    if is_fastai_class(arg_type): return f'[{arg_name}]({get_fn_link(arg_type)})'\n    return arg_name\n\ndef is_fastai_class(t): return belongs_to_module(t, MODULE_NAME)\n\ndef belongs_to_module(t, module_name):\n    \"Check if `t` belongs to `module_name`.\"\n    if hasattr(t, '__func__'): return belongs_to_module(t.__func__, module_name)\n    if not inspect.getmodule(t): return False\n    return inspect.getmodule(t).__name__.startswith(module_name)\n\ndef code_esc(s): return f'`{s}`'\n\ndef type_repr(t):\n    if t in _typing_names: return link_type(t, _typing_names[t])\n    if isinstance(t, partial): return partial_repr(t)\n    if hasattr(t, '__forward_arg__'): return link_type(t.__forward_arg__)\n    elif getattr(t, '__args__', None):\n        args = t.__args__\n        if len(args)==2 and args[1] == type(None):\n            return f'`Optional`\\[{type_repr(args[0])}\\]'\n        reprs = ', '.join([type_repr(o) for o in args])\n        return f'{link_type(t)}\\[{reprs}\\]'\n    else: return link_type(t)\n\ndef partial_repr(t):\n    args = (t.func,) + t.args + tuple([f'{k}={v}' for k,v in t.keywords.items()])\n    reprs = ', '.join([link_type(o) for o in args])\n    return f'<code>partial(</code>{reprs}<code>)</code>'\n\ndef anno_repr(a): return type_repr(a)\n\ndef format_param(p):\n    \"Formats function param to `param1:Type=val`. Font weights: param1=bold, val=bold+italic\"\n    arg_prefix = arg_prefixes.get(p.kind, '') # asterisk prefix for *args and **kwargs\n    res = f\"**{arg_prefix}{code_esc(p.name)}**\"\n    if hasattr(p, 'annotation') and p.annotation != p.empty: res += f':{anno_repr(p.annotation)}'\n    if p.default != p.empty:\n        default = getattr(p.default, 'func', p.default)\n        default = getattr(default, '__name__', default)\n        res += f'=***`{repr(default)}`***'\n    return res\n\ndef format_ft_def(func, full_name:str=None)->str:\n    \"Format and link `func` definition to show in documentation\"\n    sig = inspect.signature(func)\n    name = f'<code>{full_name or func.__name__}</code>'\n    fmt_params = [format_param(param) for name,param\n                  in sig.parameters.items() if name not in ('self','cls')]\n    arg_str = f\"({', '.join(fmt_params)})\"\n    if sig.return_annotation and (sig.return_annotation != sig.empty): arg_str += f\" → {anno_repr(sig.return_annotation)}\"\n    if is_fastai_class(type(func)):        arg_str += f\" :: {link_type(type(func))}\"\n    f_name = f\"<code>class</code> {name}\" if inspect.isclass(func) else name\n    return f'{f_name}',f'{name}{arg_str}'\n\ndef get_enum_doc(elt, full_name:str)->str:\n    \"Formatted enum documentation.\"\n    vals = ', '.join(elt.__members__.keys())\n    return f'{code_esc(full_name)}',f'<code>Enum</code> = [{vals}]'\n\ndef get_cls_doc(elt, full_name:str)->str:\n    \"Class definition.\"\n    parent_class = inspect.getclasstree([elt])[-1][0][1][0]\n    name,args = format_ft_def(elt, full_name)\n    if parent_class != object: args += f' :: {link_type(parent_class, include_bt=True)}'\n    return name,args\n\ndef show_doc(elt, doc_string:bool=True, full_name:str=None, arg_comments:dict=None, title_level=None, alt_doc_string:str='',\n             ignore_warn:bool=False, markdown=True, show_tests=True):\n    \"Show documentation for element `elt`. Supported types: class, Callable, and enum.\"\n    arg_comments = ifnone(arg_comments, {})\n    anchor_id = get_anchor(elt)\n    elt = getattr(elt, '__func__', elt)\n    full_name = full_name or fn_name(elt)\n    if inspect.isclass(elt):\n        if is_enum(elt.__class__):   name,args = get_enum_doc(elt, full_name)\n        else:                        name,args = get_cls_doc(elt, full_name)\n    elif isinstance(elt, Callable):  name,args = format_ft_def(elt, full_name)\n    else: raise Exception(f'doc definition not supported for {full_name}')\n    source_link = get_function_source(elt) if is_fastai_class(elt) else \"\"\n    test_link, test_modal = get_pytest_html(elt, anchor_id=anchor_id) if show_tests else ('', '')\n    title_level = ifnone(title_level, 2 if inspect.isclass(elt) else 4)\n    doc =  f'<h{title_level} id=\"{anchor_id}\" class=\"doc_header\">{name}{source_link}{test_link}</h{title_level}>'\n    doc += f'\\n\\n> {args}\\n\\n'\n    doc += f'{test_modal}'\n    if doc_string and (inspect.getdoc(elt) or arg_comments):\n        doc += format_docstring(elt, arg_comments, alt_doc_string, ignore_warn) + ' '\n    if markdown: display(Markdown(doc))\n    else: return doc\n\ndef md2html(md):\n    if nbconvert.__version__ < '5.5.0': return HTMLExporter().markdown2html(md)\n    else: return HTMLExporter().markdown2html(defaultdict(lambda: defaultdict(dict)), md)\n    \ndef doc(elt):\n    \"Show `show_doc` info in preview window along with link to full docs.\"\n    global use_relative_links\n    use_relative_links = False\n    elt = getattr(elt, '__func__', elt)\n    md = show_doc(elt, markdown=False)\n    if is_fastai_class(elt):\n        md += f'\\n\\n<a href=\"{get_fn_link(elt)}\" target=\"_blank\" rel=\"noreferrer noopener\">Show in docs</a>'\n    output = md2html(md)\n    use_relative_links = True\n    if IS_IN_COLAB: get_ipython().run_cell_magic(u'html', u'', output)\n    else:\n        try: page.page({'text/html': output})\n        except: display(Markdown(md))\n\ndef format_docstring(elt, arg_comments:dict={}, alt_doc_string:str='', ignore_warn:bool=False)->str:\n    \"Merge and format the docstring definition with `arg_comments` and `alt_doc_string`.\"\n    parsed = \"\"\n    doc = parse_docstring(inspect.getdoc(elt))\n    description = alt_doc_string or f\"{doc['short_description']} {doc['long_description']}\"\n    if description: parsed += f'\\n\\n{link_docstring(inspect.getmodule(elt), description)}'\n\n    resolved_comments = {**doc.get('comments', {}), **arg_comments} # arg_comments takes priority\n    args = inspect.getfullargspec(elt).args if not is_enum(elt.__class__) else elt.__members__.keys()\n    if resolved_comments: parsed += '\\n'\n    for a in resolved_comments:\n        parsed += f'\\n- *{a}*: {resolved_comments[a]}'\n        if a not in args and not ignore_warn: warn(f'Doc arg mismatch: {a}')\n\n    return_comment = arg_comments.get('return') or doc.get('return')\n    if return_comment: parsed += f'\\n\\n*return*: {return_comment}'\n    return parsed\n\n_modvars = {}\n\ndef replace_link(m):\n    keyword = m.group(1) or m.group(2)\n    elt = find_elt(_modvars, keyword)\n    if elt is None: return m.group()\n    return link_type(elt, arg_name=keyword)\n\n# Finds all places with a backtick but only if it hasn't already been linked\nBT_REGEX = re.compile(\"\\[`([^`]*)`\\](?:\\([^)]*\\))|`([^`]*)`\") # matches [`key`](link) or `key`\ndef link_docstring(modules, docstring:str, overwrite:bool=False)->str:\n    \"Search `docstring` for backticks and attempt to link those functions to respective documentation.\"\n    mods = listify(modules)\n    for mod in mods: _modvars.update(mod.__dict__) # concat all module definitions\n    return re.sub(BT_REGEX, replace_link, docstring)\n\ndef find_elt(modvars, keyword, match_last=False):\n    \"Attempt to resolve keywords such as Learner.lr_find. `match_last` starts matching from last component.\"\n    keyword = strip_fastai(keyword)\n    if keyword in modvars: return modvars[keyword]\n    comps = keyword.split('.')\n    comp_elt = modvars.get(comps[0])\n    if hasattr(comp_elt, '__dict__'): return find_elt(comp_elt.__dict__, '.'.join(comps[1:]), match_last=match_last)\n\ndef import_mod(mod_name:str, ignore_errors=False):\n    \"Return module from `mod_name`.\"\n    splits = str.split(mod_name, '.')\n    try:\n        if len(splits) > 1 : mod = importlib.import_module('.' + '.'.join(splits[1:]), splits[0])\n        else: mod = importlib.import_module(mod_name)\n        return mod\n    except:\n        if not ignore_errors: print(f\"Module {mod_name} doesn't exist.\")\n\ndef show_doc_from_name(mod_name, ft_name:str, doc_string:bool=True, arg_comments:dict={}, alt_doc_string:str=''):\n    \"Show documentation for `ft_name`, see `show_doc`.\"\n    mod = import_mod(mod_name)\n    splits = str.split(ft_name, '.')\n    assert hasattr(mod, splits[0]), print(f\"Module {mod_name} doesn't have a function named {splits[0]}.\")\n    elt = getattr(mod, splits[0])\n    for i,split in enumerate(splits[1:]):\n        assert hasattr(elt, split), print(f\"Class {'.'.join(splits[:i+1])} doesn't have a function named {split}.\")\n        elt = getattr(elt, split)\n    show_doc(elt, doc_string, ft_name, arg_comments, alt_doc_string)\n\ndef get_exports(mod):\n    public_names = mod.__all__ if hasattr(mod, '__all__') else dir(mod)\n    #public_names.sort(key=str.lower)\n    return [o for o in public_names if not o.startswith('_')]\n\ndef get_ft_names(mod, include_inner=False)->List[str]:\n    \"Return all the functions of module `mod`.\"\n    # If the module has an attribute __all__, it picks those.\n    # Otherwise, it returns all the functions defined inside a module.\n    fn_names = []\n    for elt_name in get_exports(mod):\n        elt = getattr(mod,elt_name)\n        #This removes the files imported from elsewhere\n        try:    fname = inspect.getfile(elt)\n        except: continue\n        if mod.__file__.endswith('__init__.py'):\n            if inspect.ismodule(elt): fn_names.append(elt_name)\n            else: continue\n        else:\n            if (fname != mod.__file__): continue\n            if inspect.isclass(elt) or inspect.isfunction(elt): fn_names.append(elt_name)\n            else: continue\n        if include_inner and inspect.isclass(elt) and not is_enum(elt.__class__):\n            fn_names.extend(get_inner_fts(elt))\n    return fn_names\n\ndef get_inner_fts(elt)->List[str]:\n    \"List the inner functions of a class.\"\n    fts = []\n    for ft_name in elt.__dict__.keys():\n        if ft_name.startswith('_'): continue\n        ft = getattr(elt, ft_name)\n        if inspect.isfunction(ft): fts.append(f'{elt.__name__}.{ft_name}')\n        if inspect.ismethod(ft): fts.append(f'{elt.__name__}.{ft_name}')\n        if inspect.isclass(ft): fts += [f'{elt.__name__}.{n}' for n in get_inner_fts(ft)]\n    return fts\n\ndef get_module_toc(mod_name):\n    \"Display table of contents for given `mod_name`.\"\n    mod = import_mod(mod_name)\n    ft_names = mod.__all__ if hasattr(mod,'__all__') else get_ft_names(mod)\n    ft_names.sort(key = str.lower)\n    tabmat = ''\n    for ft_name in ft_names:\n        tabmat += f'- [{ft_name}](#{ft_name})\\n'\n        elt = getattr(mod, ft_name)\n        if inspect.isclass(elt) and not is_enum(elt.__class__):\n            in_ft_names = get_inner_fts(elt)\n            for name in in_ft_names:\n                tabmat += f'  - [{name}](#{name})\\n'\n    display(Markdown(tabmat))\n\ndef show_video(url):\n    \"Display video in `url`.\"\n    data = f'<iframe width=\"560\" height=\"315\" src=\"{url}\" frameborder=\"0\" allowfullscreen></iframe>'\n    return display(HTML(data))\n\ndef show_video_from_youtube(code, start=0):\n    \"Display video from Youtube with a `code` and a `start` time.\"\n    url = f'https://www.youtube.com/embed/{code}?start={start}&amp;rel=0&amp;controls=0&amp;showinfo=0'\n    return show_video(url)\n\ndef get_anchor(fn)->str:\n    if hasattr(fn,'__qualname__'): return fn.__qualname__\n    if inspect.ismethod(fn): return fn_name(fn.__self__) + '.' + fn_name(fn)\n    return fn_name(fn)\n\ndef fn_name(ft)->str:\n    if ft.__hash__ and ft in _typing_names: return _typing_names[ft]\n    if hasattr(ft, '__name__'):   return ft.__name__\n    elif hasattr(ft,'_name') and ft._name: return ft._name\n    elif hasattr(ft,'__origin__'): return str(ft.__origin__).split('.')[-1]\n    else:                          return str(ft).split('.')[-1]\n\ndef get_fn_link(ft)->str:\n    \"Return function link to notebook documentation of `ft`. Private functions link to source code\"\n    ft = getattr(ft, '__func__', ft)\n    anchor = strip_fastai(get_anchor(ft))\n    module_name = strip_fastai(get_module_name(ft))\n    base = '' if use_relative_links else FASTAI_DOCS\n    return f'{base}/{module_name}.html#{anchor}'\n\ndef get_module_name(ft)->str: return inspect.getmodule(ft).__name__\n\ndef get_pytorch_link(ft)->str:\n    \"Returns link to pytorch docs of `ft`.\"\n    name = ft.__name__\n    ext = '.html'\n    if name == 'device': return f'{PYTORCH_DOCS}tensor_attributes{ext}#torch-device'\n    if name == 'Tensor': return f'{PYTORCH_DOCS}tensors{ext}#torch-tensor'\n    if name.startswith('torchvision'):\n        doc_path = get_module_name(ft).replace('.', '/')\n        if inspect.ismodule(ft): name = name.replace('.', '-')\n        return f'{PYTORCH_DOCS}{doc_path}{ext}#{name}'\n    if name.startswith('torch.nn') and inspect.ismodule(ft): # nn.functional is special case\n        nn_link = name.replace('.', '-')\n        return f'{PYTORCH_DOCS}nn{ext}#{nn_link}'\n    paths = get_module_name(ft).split('.')\n    if len(paths) == 1: return f'{PYTORCH_DOCS}{paths[0]}{ext}#{paths[0]}.{name}'\n\n    offset = 1 if paths[1] == 'utils' else 0 # utils is a pytorch special case\n    doc_path = paths[1+offset]\n    if inspect.ismodule(ft): return f'{PYTORCH_DOCS}{doc_path}{ext}#module-{name}'\n    fnlink = '.'.join(paths[:(2+offset)]+[name])\n    return f'{PYTORCH_DOCS}{doc_path}{ext}#{fnlink}'\n\ndef get_source_link(file, line, display_text=\"[source]\", **kwargs)->str:\n    \"Returns github link for given file\"\n    link = f\"{SOURCE_URL}{file}#L{line}\"\n    if display_text is None: return link\n    return f'<a href=\"{link}\" class=\"source_link\" style=\"float:right\">{display_text}</a>'\n\ndef get_function_source(ft, **kwargs)->str:\n    \"Returns link to `ft` in source code.\"\n    try: line = inspect.getsourcelines(ft)[1]\n    except Exception: return ''\n    mod_path = get_module_name(ft).replace('.', '/') + '.py'\n    return get_source_link(mod_path, line, **kwargs)\n\ndef title_md(s:str, title_level:int, markdown=True):\n    res = '#' * title_level\n    if title_level: res += ' '\n    return Markdown(res+s) if markdown else (res+s)\n\ndef jekyll_div(s,c,h,icon=None):\n    icon = ifnone(icon,c)\n    res = f'<div markdown=\"span\" class=\"alert alert-{c}\" role=\"alert\"><i class=\"fa fa-{c}-circle\"></i> <b>{h}: </b>{s}</div>'\n    display(Markdown(res))\n\ndef jekyll_note(s): return jekyll_div(s,'info','Note')\ndef jekyll_warn(s): return jekyll_div(s,'danger','Warning', 'exclamation')\ndef jekyll_important(s): return jekyll_div(s,'warning','Important')\n"
  },
  {
    "path": "fastai/gen_doc/nbtest.py",
    "content": "\"`gen_doc.nbtest` shows pytest documentation for module functions\"\n\nimport inspect, os, re\nfrom os.path import abspath, dirname, join\nfrom collections import namedtuple\n\nfrom fastai.gen_doc import nbdoc\nfrom ..imports.core import *\nfrom .core import ifnone\nfrom .doctest import get_parent_func, relative_test_path, get_func_fq_name, DB_NAME\n\nfrom nbconvert import HTMLExporter\nfrom IPython.core import page\nfrom IPython.core.display import display, Markdown, HTML\n\n__all__ = ['show_test', 'doctest', 'find_related_tests', 'lookup_db', 'find_test_matches', 'find_test_files', 'fuzzy_test_match', 'get_pytest_html']\n\nTestFunctionMatch = namedtuple('TestFunctionMatch', ['line_number', 'line'])\n\ndef show_test(elt)->str:\n    \"Show associated tests for a fastai function/class\"\n    md = build_tests_markdown(elt)\n    display(Markdown(md))\n\ndef doctest(elt):\n    \"Inline notebook popup for `show_test`\"\n    md = build_tests_markdown(elt)\n    output = nbdoc.md2html(md)\n    try:    page.page({'text/html': output})\n    except: display(Markdown(md))\n\ndef build_tests_markdown(elt):\n    fn_name = nbdoc.fn_name(elt)\n    md = ''\n    db_matches = [get_links(t) for t in lookup_db(elt)]\n    md += tests2md(db_matches, '')\n    try:\n        related = [get_links(t) for t in find_related_tests(elt)]\n        other_tests = [k for k in OrderedDict.fromkeys(related) if k not in db_matches]\n        md += tests2md(other_tests, f'Some other tests where `{fn_name}` is used:')\n    except OSError as e: pass\n\n    if len(md.strip())==0:\n        return (f'No tests found for `{fn_name}`.'\n                ' To contribute a test please refer to [this guide](/dev/test.html)'\n                ' and [this discussion](https://forums.fast.ai/t/improving-expanding-functional-tests/32929).')\n    return (f'Tests found for `{fn_name}`: {md}'\n            '\\n\\nTo run tests please refer to this [guide](/dev/test.html#quick-guide).')\n\ndef tests2md(tests, type_label:str):\n    if not tests: return ''\n    md = [f'\\n\\n{type_label}'] + [f'* `{cmd}` {link}' for link,cmd in sorted(tests, key=lambda k: k[1])]\n    return '\\n'.join(md)\n\ndef get_pytest_html(elt, anchor_id:str)->Tuple[str,str]:\n    md = build_tests_markdown(elt)\n    html = nbdoc.md2html(md).replace('\\n','') # nbconverter fails to parse markdown if it has both html and '\\n'\n    anchor_id = anchor_id.replace('.', '-') + '-pytest'\n    link, body = get_pytest_card(html, anchor_id)\n    return link, body\n\ndef get_pytest_card(html, anchor_id):\n    \"creates a collapsible bootstrap card for `show_test`\"\n    link = f'<a class=\"source_link\" data-toggle=\"collapse\" data-target=\"#{anchor_id}\" style=\"float:right; padding-right:10px\">[test]</a>'\n    body = (f'<div class=\"collapse\" id=\"{anchor_id}\"><div class=\"card card-body pytest_card\">'\n                f'<a type=\"button\" data-toggle=\"collapse\" data-target=\"#{anchor_id}\" class=\"close\" aria-label=\"Close\"><span aria-hidden=\"true\">&times;</span></a>'\n                f'{html}'\n            '</div></div>')\n    return link, body\n\ndef lookup_db(elt)->List[Dict]:\n    \"Finds `this_test` entries from test_registry.json\"\n    db_file = Path(abspath(join(dirname( __file__ ), '..')))/DB_NAME\n    if not db_file.exists():\n        raise Exception(f'Could not find {db_file}. Please make sure it exists at \"{db_file}\" or run `make test`')\n    with open(db_file, 'r') as f:\n        db = json.load(f)\n    key = get_func_fq_name(elt)\n    return db.get(key, [])\n\ndef find_related_tests(elt)->Tuple[List[Dict],List[Dict]]:\n    \"Searches `fastai/tests` folder for any test functions related to `elt`\"\n    related_matches = []\n    for test_file in find_test_files(elt):\n        fuzzy_matches = find_test_matches(elt, test_file)\n        related_matches.extend(fuzzy_matches)\n    return related_matches\n\ndef get_tests_dir(elt)->Path:\n    \"Absolute path of `fastai/tests` directory\"\n    test_dir = Path(__file__).parent.parent.parent.resolve()/'tests'\n    if not test_dir.exists(): raise OSError('Could not find test directory at this location:', test_dir)\n    return test_dir\n\ndef get_file(elt)->str:\n    if hasattr(elt, '__wrapped__'): elt = elt.__wrapped__\n    if not nbdoc.is_fastai_class(elt): return None\n    return inspect.getfile(elt)\n\ndef find_test_files(elt, exact_match:bool=False)->List[Path]:\n    \"Searches in `fastai/tests` directory for module tests\"\n    test_dir = get_tests_dir(elt)\n    matches = [test_dir/o.name for o in os.scandir(test_dir) if _is_file_match(elt, o.name)]\n    # if len(matches) != 1: raise Error('Could not find exact file match:', matches)\n    return matches\n\ndef _is_file_match(elt, file_name:str, exact_match:bool=False)->bool:\n    fp = get_file(elt)\n    if fp is None: return False\n    subdir = ifnone(_submodule_name(elt), '')\n    exact_re = '' if exact_match else '\\w*'\n    return re.match(f'test_{subdir}\\w*{Path(fp).stem}{exact_re}\\.py', file_name)\n\ndef _submodule_name(elt)->str:\n    \"Returns submodule - utils, text, vision, imports, etc.\"\n    if inspect.ismodule(elt): return None\n    modules = elt.__module__.split('.')\n    if len(modules) > 2:\n        return modules[1]\n    return None\n\ndef find_test_matches(elt, test_file:Path)->Tuple[List[Dict],List[Dict]]:\n    \"Find all functions in `test_file` related to `elt`\"\n    lines = get_lines(test_file)\n    rel_path = relative_test_path(test_file)\n    fn_name = get_qualname(elt) if not inspect.ismodule(elt) else ''\n    return fuzzy_test_match(fn_name, lines, rel_path)\n\ndef get_qualname(elt):\n    return elt.__qualname__ if hasattr(elt, '__qualname__') else fn_name(elt)\n\ndef separate_comp(qualname:str):\n    if not isinstance(qualname, str): qualname = get_qualname(qualname)\n    parts = qualname.split('.')\n    parts[-1] = remove_underscore(parts[-1])\n    if len(parts) == 1: return [], parts[0]\n    return parts[:-1], parts[-1]\n\ndef remove_underscore(fn_name):\n    if fn_name and fn_name[0] == '_': return fn_name[1:] # remove private method underscore prefix\n    return fn_name\n\ndef fuzzy_test_match(fn_name:str, lines:List[Dict], rel_path:str)->List[TestFunctionMatch]:\n    \"Find any lines where `fn_name` is invoked and return the parent test function\"\n    fuzzy_line_matches = _fuzzy_line_match(fn_name, lines)\n    fuzzy_matches = [get_parent_func(lno, lines, ignore_missing=True) for lno,_ in fuzzy_line_matches]\n    fuzzy_matches = list(filter(None.__ne__, fuzzy_matches))\n    return [map_test(rel_path, lno, l) for lno,l in fuzzy_matches]\n\ndef _fuzzy_line_match(fn_name:str, lines)->List[TestFunctionMatch]:\n    \"Find any lines where `fn_name` is called\"\n    result = []\n    _,fn_name = separate_comp(fn_name)\n    for idx,line in enumerate(lines):\n        if re.match(f'.*[\\s\\.\\(]{fn_name}[\\.\\(]', line):\n            result.append((idx,line))\n    return result\n\ndef get_lines(file:Path)->List[str]:\n    with open(file, 'r') as f: return f.readlines()\n\ndef map_test(test_file, line, line_text):\n    \"Creates dictionary test format to match doctest api\"\n    test_name = re.match(f'\\s*def (test_\\w*)', line_text).groups(0)[0]\n    return { 'file': test_file, 'line': line, 'test': test_name }\n\ndef get_links(metadata)->Tuple[str,str]:\n    \"Returns source code link and pytest command\"\n    return nbdoc.get_source_link(**metadata), pytest_command(**metadata)\n\ndef pytest_command(file:str, test:str, **kwargs)->str:\n    \"Returns CLI command to run specific test function\"\n    return f'pytest -sv {file}::{test}'\n"
  },
  {
    "path": "fastai/general_optimizer.py",
    "content": "from .torch_core import *\nfrom torch.optim import Optimizer\nimport types\n\n__all__ = ['StatScope', 'Statistic', 'ConstStatistic', 'AvgStatistic', 'AvgSquare', 'GeneralOptimizer']\n\nStatScope = Enum('StatScope', 'Global Group Layer Channel Weight')\n\n@dataclass\nclass Statistic():\n    name:str\n    param:float=0.9  # e.g. for exp moving average\n    scope:StatScope=StatScope.Weight\n    init:float=0.  # starting value\n\n    @property\n    def buf(self): return f'{self.name}_buffer'\n\n    def new_step(self):\n        \"Set state when computing statistics for Global or Group\"\n        raise NotImplementedError\n\n    def accumulate(self, val):\n        \"Add `val` to statistic\"\n        raise NotImplementedError\n\n    def update(self, state, param, val=None, step=None):\n        \"Update state with accumlated, or `val` (if `Weight` or `Layer` scope)\"\n        raise NotImplementedError\n\nclass ConstStatistic(Statistic):\n    @property\n    def buf(self): return None\n    def new_step(self):   pass\n    def accumulate(self): pass\n    def update(self, state, param, val=None, step=None): return param\n\n@dataclass\nclass CounterStat(Statistic):\n    def __post_init__(self): self.init,self._buf,self.name = 0,self.name,None\n    @property\n    def buf(self): return self._buf\n    def new_step(self): pass\n    def accumulate(self, val): pass\n    def update(self, state, param, val=None, step=None): return state + 1\n\n@dataclass\nclass AvgStatistic(Statistic):\n    decay:bool=False\n    debias:bool=False\n    def new_step(self): self.val,self.count = 0.,0\n\n    def accumulate(self, val):\n        self.count += 1\n        self.val += self._get_val1(val)\n\n    def _get_val1(self, val): return val.mean()\n    def _get_val2(self, state, val, param): return state.add_(1-param, val) if self.decay else state.add_(val)\n    def _get_val3(self, state, val, param): \n        v = val.view(val.size(0), -1).mean(1)\n        return state.add_(1-param, v) if self.decay else state.add_(v)\n\n    def update(self, state, param, val=None, step=None):\n        if self.scope == StatScope.Weight:\n            # `state` is a tensor\n            res = self._get_val2(state.mul_(param), val, param)\n        elif self.scope == StatScope.Channel:\n            # `state` is a tensor of size n_channels\n            res = self._get_val3(state.mul_(param), val, param)\n        # For everything else, `state` is a scalar\n        elif self.scope == StatScope.Layer:  res = state*param + self._get_val1(val) * (1-param if self.decay else 1.)\n        elif self.count != 0:                res = state*param + self.val/self.count * (1-param if self.decay else 1.)\n        else: return state\n        if self.debias and step is not None: res /= (1 - param ** step)\n        return res\n\nclass AvgSquare(AvgStatistic):\n\n    def __init__(self, name:str, param:float=0.9, scope=StatScope.Weight, init:float=0., decay:bool=True, debias:bool=False):\n        super().__init__(name, param=param, scope=scope, init=init, decay=decay, debias=debias)\n\n    def _get_val1(self, val): return torch.norm(val).pow(2)/val.numel()\n    def _get_val2(self, state, val, param): \n        return state.addcmul_(1-param, val, val) if self.decay else state.addcmul_(val, val)\n    def _get_val3(self, state, val, param):\n        v = val.view(val.size(0), -1).mean(1)\n        return state.addcmul_(1-param, v, v) if self.decay else state.addcmul_(v, v)\n\nclass GeneralOptimizer(Optimizer):\n    def __init__(self, params, stats=None, on_step:Callable=None):\n        defaults = {s.name:s.param for s in listify(stats) if s.name is not None}\n        super().__init__(params, defaults)\n        self.global_stats,self.group_stats,self.layer_stats,self.channel_stats,self.weight_stats = self._split_stats(stats)\n        self.init_stats()\n        if on_step is not None: self.on_step = types.MethodType(on_step, self)\n\n    def step(self, closure=None):\n        self.update_stats()\n        for i,pg in enumerate(self.param_groups):\n            for p in pg['params']:\n                if p.grad is not None: self.on_step(p, pg, i)\n\n    def on_step(self, p, group, group_idx): p.data.add_(-group['lr'], p.grad.data)\n\n    def _split_stats(self, stats):\n        splits = [[stat for stat in listify(stats) if stat.scope==scope] for scope in StatScope]\n        for split,s in zip([splits[0], splits[1], splits[2]+splits[3]+splits[4]], StatScope):\n            if np.any([getattr(s, 'debias', False) for s in split]): split.insert(0, CounterStat('step', scope=s))\n        return splits\n\n    def _init_stats(self, stats, data=None):\n        return {stat.buf: stat.init if data is None\n                else torch.zeros_like(data) + stat.init for stat in stats if stat.buf is not None}\n\n    def init_stats(self):\n        self.state['global'] = self._init_stats(self.global_stats)\n        for i,pg in enumerate(self.param_groups):\n            self.state[f'group{i}'] = self._init_stats(self.group_stats)\n            for p in pg['params']:\n                self.state[p] = self._init_stats(self.layer_stats)\n                self.state[p].update(self._init_stats(self.channel_stats, p.data.view(p.data.size(0), -1).mean(1)))\n                self.state[p].update(self._init_stats(self.weight_stats, p.data))\n\n    def _set_bufs(self, p, stats, pg, val=None):\n        d = self.state[p]\n        for stat in stats:\n            if stat.buf is not None: d[stat.buf] = stat.update(d[stat.buf], pg[stat.name], val=val, step=d.get('step', None))\n\n    def update_stats(self):\n        for stat in self.global_stats: stat.new_step()\n        for i,pg in enumerate(self.param_groups):\n            for stat in self.group_stats: stat.new_step()\n            for p in pg['params']:\n                if p.grad is not None:\n                    for stat in self.global_stats + self.group_stats: stat.accumulate(p.grad.data)\n                    self._set_bufs(p, self.layer_stats+self.channel_stats+self.weight_stats, pg, p.grad.data)\n            self._set_bufs(f'group{i}', self.group_stats, pg)\n        self._set_bufs('global', self.global_stats, self.param_groups[0])\n\n"
  },
  {
    "path": "fastai/imports/__init__.py",
    "content": "from .core import *\nfrom .torch import *\n"
  },
  {
    "path": "fastai/imports/core.py",
    "content": "import csv, gc, gzip, os, pickle, shutil, sys, warnings, yaml, io, subprocess\nimport math, matplotlib.pyplot as plt, numpy as np, pandas as pd, random\nimport scipy.stats, scipy.special\nimport abc, collections, hashlib, itertools, json, operator, pathlib\nimport mimetypes, inspect, typing, functools, importlib, weakref\nimport html, re, requests, tarfile, numbers, tempfile, bz2\n\nfrom abc import abstractmethod, abstractproperty\nfrom collections import abc,  Counter, defaultdict, namedtuple, OrderedDict\nfrom collections.abc import Iterable\nimport concurrent\nfrom concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor\nfrom copy import copy, deepcopy\nfrom dataclasses import dataclass, field, InitVar\nfrom enum import Enum, IntEnum\nfrom functools import partial, reduce\nfrom pdb import set_trace\nfrom matplotlib import patches, patheffects\nfrom numpy import array, cos, exp, log, sin, tan, tanh\nfrom operator import attrgetter, itemgetter\nfrom pathlib import Path\nfrom warnings import warn\nfrom contextlib import contextmanager\nfrom fastprogress.fastprogress import MasterBar, ProgressBar\nfrom matplotlib.patches import Patch\nfrom pandas import Series, DataFrame\nfrom io import BufferedWriter, BytesIO\n\nimport pkg_resources\npkg_resources.require(\"fastprogress>=0.1.19\")\nfrom fastprogress.fastprogress import master_bar, progress_bar\n\n#for type annotations\nfrom numbers import Number\nfrom typing import Any, AnyStr, Callable, Collection, Dict, Hashable, Iterator, List, Mapping, NewType, Optional\nfrom typing import Sequence, Tuple, TypeVar, Union\nfrom types import SimpleNamespace\n\ndef try_import(module):\n    \"Try to import `module`. Returns module's object on success, None on failure\"\n    try: return importlib.import_module(module)\n    except: return None\n\ndef have_min_pkg_version(package, version):\n    \"Check whether we have at least `version` of `package`. Returns True on success, False otherwise.\"\n    try:\n        pkg_resources.require(f\"{package}>={version}\")\n        return True\n    except:\n        return False\n"
  },
  {
    "path": "fastai/imports/torch.py",
    "content": "import torch, torch.nn.functional as F\nfrom torch import ByteTensor, DoubleTensor, FloatTensor, HalfTensor, LongTensor, ShortTensor, Tensor\nfrom torch import nn, optim, as_tensor\nfrom torch.utils.data import BatchSampler, DataLoader, Dataset, Sampler, TensorDataset\nfrom torch.nn.utils import weight_norm, spectral_norm\n"
  },
  {
    "path": "fastai/launch.py",
    "content": "import subprocess, torch\nfrom fastai.script import *\n\n@call_parse\ndef main(\n    gpus:Param(\"The GPUs to use for distributed training\", str)='all',\n    script:Param(\"Script to run\", str, opt=False)='',\n    args:Param(\"Args to pass to script\", nargs='...', opt=False)=''\n):\n    \"PyTorch distributed training launch helper that spawns multiple distributed processes\"\n    # Loosely based on torch.distributed.launch\n    current_env = os.environ.copy()\n    gpus = list(range(torch.cuda.device_count())) if gpus=='all' else list(gpus)\n    current_env[\"WORLD_SIZE\"] = str(len(gpus))\n    current_env[\"MASTER_ADDR\"] = '127.0.0.1'\n    current_env[\"MASTER_PORT\"] = '29500'\n\n    processes = []\n    for i,gpu in enumerate(gpus):\n        current_env[\"RANK\"] = str(i)\n        cmd = [sys.executable, \"-u\", script, f\"--gpu={gpu}\"] + args\n        process = subprocess.Popen(cmd, env=current_env)\n        processes.append(process)\n\n    for process in processes: process.wait()\n\n"
  },
  {
    "path": "fastai/layers.py",
    "content": "\"`fastai.layers` provides essential functions to building and modifying `model` architectures\"\nfrom .torch_core import *\n\n__all__ = ['AdaptiveConcatPool2d', 'BCEWithLogitsFlat', 'BCEFlat', 'MSELossFlat', 'CrossEntropyFlat', 'Debugger',\n           'Flatten', 'Lambda', 'PoolFlatten', 'View', 'ResizeBatch', 'bn_drop_lin', 'conv2d', 'conv2d_trans', 'conv_layer',\n           'embedding', 'simple_cnn', 'NormType', 'relu', 'batchnorm_2d', 'trunc_normal_', 'PixelShuffle_ICNR', 'icnr',\n           'NoopLoss', 'WassersteinLoss', 'SelfAttention', 'SequentialEx', 'MergeLayer', 'res_block', 'sigmoid_range',\n           'SigmoidRange', 'PartialLayer', 'FlattenedLoss', 'BatchNorm1dFlat', 'LabelSmoothingCrossEntropy', 'PooledSelfAttention2d']\n\nclass Lambda(Module):\n    \"Create a layer that simply calls `func` with `x`\"\n    def __init__(self, func:LambdaFunc): self.func=func\n    def forward(self, x): return self.func(x)\n\nclass View(Module):\n    \"Reshape `x` to `size`\"\n    def __init__(self, *size:int): self.size = size\n    def forward(self, x): return x.view(self.size)\n\nclass ResizeBatch(Module):\n    \"Reshape `x` to `size`, keeping batch dim the same size\"\n    def __init__(self, *size:int): self.size = size\n    def forward(self, x): return x.view((x.size(0),) + self.size)\n\nclass Flatten(Module):\n    \"Flatten `x` to a single dimension, often used at the end of a model. `full` for rank-1 tensor\"\n    def __init__(self, full:bool=False): self.full = full\n    def forward(self, x): return x.view(-1) if self.full else x.view(x.size(0), -1)\n\ndef PoolFlatten()->nn.Sequential:\n    \"Apply `nn.AdaptiveAvgPool2d` to `x` and then flatten the result.\"\n    return nn.Sequential(nn.AdaptiveAvgPool2d(1), Flatten())\n\nNormType = Enum('NormType', 'Batch BatchZero Weight Spectral Group Instance SpectralGN')\n\ndef batchnorm_2d(nf:int, norm_type:NormType=NormType.Batch):\n    \"A batchnorm2d layer with `nf` features initialized depending on `norm_type`.\"\n    bn = nn.BatchNorm2d(nf)\n    with torch.no_grad():\n        bn.bias.fill_(1e-3)\n        bn.weight.fill_(0. if norm_type==NormType.BatchZero else 1.)\n    return bn\n\ndef bn_drop_lin(n_in:int, n_out:int, bn:bool=True, p:float=0., actn:Optional[nn.Module]=None):\n    \"Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`.\"\n    layers = [nn.BatchNorm1d(n_in)] if bn else []\n    if p != 0: layers.append(nn.Dropout(p))\n    layers.append(nn.Linear(n_in, n_out))\n    if actn is not None: layers.append(actn)\n    return layers\n\ndef conv1d(ni:int, no:int, ks:int=1, stride:int=1, padding:int=0, bias:bool=False):\n    \"Create and initialize a `nn.Conv1d` layer with spectral normalization.\"\n    conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)\n    nn.init.kaiming_normal_(conv.weight)\n    if bias: conv.bias.data.zero_()\n    return spectral_norm(conv)\n\nclass PooledSelfAttention2d(Module):\n    \"Pooled self attention layer for 2d.\"\n    def __init__(self, n_channels:int):\n        self.n_channels = n_channels\n        self.theta = spectral_norm(conv2d(n_channels, n_channels//8, 1)) # query\n        self.phi   = spectral_norm(conv2d(n_channels, n_channels//8, 1)) # key\n        self.g     = spectral_norm(conv2d(n_channels, n_channels//2, 1)) # value\n        self.o     = spectral_norm(conv2d(n_channels//2, n_channels, 1))\n        self.gamma = nn.Parameter(tensor([0.]))\n\n    def forward(self, x):\n        # code borrowed from https://github.com/ajbrock/BigGAN-PyTorch/blob/7b65e82d058bfe035fc4e299f322a1f83993e04c/layers.py#L156\n        theta = self.theta(x)\n        phi = F.max_pool2d(self.phi(x), [2,2])\n        g = F.max_pool2d(self.g(x), [2,2])    \n        theta = theta.view(-1, self.n_channels // 8, x.shape[2] * x.shape[3])\n        phi = phi.view(-1, self.n_channels // 8, x.shape[2] * x.shape[3] // 4)\n        g = g.view(-1, self.n_channels // 2, x.shape[2] * x.shape[3] // 4)\n        beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)\n        o = self.o(torch.bmm(g, beta.transpose(1,2)).view(-1, self.n_channels // 2, x.shape[2], x.shape[3]))\n        return self.gamma * o + x\n\nclass SelfAttention(Module):\n    \"Self attention layer for nd.\"\n    def __init__(self, n_channels:int):\n        self.query = conv1d(n_channels, n_channels//8)\n        self.key   = conv1d(n_channels, n_channels//8)\n        self.value = conv1d(n_channels, n_channels)\n        self.gamma = nn.Parameter(tensor([0.]))\n\n    def forward(self, x):\n        #Notation from https://arxiv.org/pdf/1805.08318.pdf\n        size = x.size()\n        x = x.view(*size[:2],-1)\n        f,g,h = self.query(x),self.key(x),self.value(x)\n        beta = F.softmax(torch.bmm(f.permute(0,2,1).contiguous(), g), dim=1)\n        o = self.gamma * torch.bmm(h, beta) + x\n        return o.view(*size).contiguous()\n\ndef conv2d(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias=False, init:LayerFunc=nn.init.kaiming_normal_) -> nn.Conv2d:\n    \"Create and initialize `nn.Conv2d` layer. `padding` defaults to `ks//2`.\"\n    if padding is None: padding = ks//2\n    return init_default(nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=padding, bias=bias), init)\n\ndef conv2d_trans(ni:int, nf:int, ks:int=2, stride:int=2, padding:int=0, bias=False) -> nn.ConvTranspose2d:\n    \"Create `nn.ConvTranspose2d` layer.\"\n    return nn.ConvTranspose2d(ni, nf, kernel_size=ks, stride=stride, padding=padding, bias=bias)\n\ndef relu(inplace:bool=False, leaky:float=None):\n    \"Return a relu activation, maybe `leaky` and `inplace`.\"\n    return nn.LeakyReLU(inplace=inplace, negative_slope=leaky) if leaky is not None else nn.ReLU(inplace=inplace)\n\ndef conv_layer(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias:bool=None, is_1d:bool=False,\n               norm_type:Optional[NormType]=NormType.Batch,  use_activ:bool=True, leaky:float=None,\n               transpose:bool=False, init:Callable=nn.init.kaiming_normal_, self_attention:bool=False):\n    \"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers.\"\n    if padding is None: padding = (ks-1)//2 if not transpose else 0\n    bn = norm_type in (NormType.Batch, NormType.BatchZero)\n    if bias is None: bias = not bn\n    conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d\n    conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding), init)\n    if   norm_type==NormType.Weight:   conv = weight_norm(conv)\n    elif norm_type==NormType.Spectral: conv = spectral_norm(conv)\n    layers = [conv]\n    if use_activ: layers.append(relu(True, leaky=leaky))\n    if bn: layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))\n    if self_attention: layers.append(SelfAttention(nf))\n    return nn.Sequential(*layers)\n\nclass SequentialEx(Module):\n    \"Like `nn.Sequential`, but with ModuleList semantics, and can access module input\"\n    def __init__(self, *layers): self.layers = nn.ModuleList(layers)\n\n    def forward(self, x):\n        res = x\n        for l in self.layers:\n            res.orig = x\n            nres = l(res)\n            #print(l. + ' mean: ' + str(nres.abs().mean()))\n            #print(' max: ' + str(nres.abs().max()))\n            # We have to remove res.orig to avoid hanging refs and therefore memory leaks\n            res.orig = None\n            res = nres\n        return res\n\n    def __getitem__(self,i): return self.layers[i]\n    def append(self,l): return self.layers.append(l)\n    def extend(self,l): return self.layers.extend(l)\n    def insert(self,i,l): return self.layers.insert(i,l)\n\nclass MergeLayer(Module):\n    \"Merge a shortcut with the result of the module by adding them or concatenating thme if `dense=True`.\"\n    def __init__(self, dense:bool=False): self.dense=dense\n    def forward(self, x): return torch.cat([x,x.orig], dim=1) if self.dense else (x+x.orig)\n\ndef res_block(nf, dense:bool=False, norm_type:Optional[NormType]=NormType.Batch, bottle:bool=False, **conv_kwargs):\n    \"Resnet block of `nf` features. `conv_kwargs` are passed to `conv_layer`.\"\n    norm2 = norm_type\n    if not dense and (norm_type==NormType.Batch): norm2 = NormType.BatchZero\n    nf_inner = nf//2 if bottle else nf\n    return SequentialEx(conv_layer(nf, nf_inner, norm_type=norm_type, **conv_kwargs),\n                      conv_layer(nf_inner, nf, norm_type=norm2, **conv_kwargs),\n                      MergeLayer(dense))\n\ndef sigmoid_range(x:Tensor, low:int, high:int):\n    \"Sigmoid function with range `(low, high)`\"\n    return torch.sigmoid(x) * (high - low) + low\n\nclass SigmoidRange(Module):\n    \"Sigmoid module with range `(low,x_max)`\"\n    def __init__(self, low:int, high:int): self.low,self.high = low,high\n    def forward(self, x): return sigmoid_range(x, self.low, self.high)\n\nclass PartialLayer(Module):\n    \"Layer that applies `partial(func, **kwargs)`.\"\n    def __init__(self, func, **kwargs): self.repr,self.func = f'{func}({kwargs})', partial(func, **kwargs)\n    def forward(self, x): return self.func(x)\n    def __repr__(self): return self.repr\n\nclass AdaptiveConcatPool2d(Module):\n    \"Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`.\"\n    def __init__(self, sz:Optional[int]=None):\n        \"Output will be 2*sz or 2 if sz is None\"\n        self.output_size = sz or 1\n        self.ap = nn.AdaptiveAvgPool2d(self.output_size)\n        self.mp = nn.AdaptiveMaxPool2d(self.output_size)\n\n    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)\n\nclass Debugger(Module):\n    \"A module to debug inside a model.\"\n    def forward(self,x:Tensor) -> Tensor:\n        set_trace()\n        return x\n\ndef icnr(x, scale=2, init=nn.init.kaiming_normal_):\n    \"ICNR init of `x`, with `scale` and `init` function.\"\n    ni,nf,h,w = x.shape\n    ni2 = int(ni/(scale**2))\n    k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1)\n    k = k.contiguous().view(ni2, nf, -1)\n    k = k.repeat(1, 1, scale**2)\n    k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1)\n    x.data.copy_(k)\n\nclass PixelShuffle_ICNR(Module):\n    \"Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`.\"\n    def __init__(self, ni:int, nf:int=None, scale:int=2, blur:bool=False, norm_type=NormType.Weight, leaky:float=None):\n        nf = ifnone(nf, ni)\n        self.conv = conv_layer(ni, nf*(scale**2), ks=1, norm_type=norm_type, use_activ=False)\n        icnr(self.conv[0].weight)\n        self.shuf = nn.PixelShuffle(scale)\n        # Blurring over (h*w) kernel\n        # \"Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts\"\n        # - https://arxiv.org/abs/1806.02658\n        self.pad = nn.ReplicationPad2d((1,0,1,0))\n        self.blur = nn.AvgPool2d(2, stride=1)\n        self.relu = relu(True, leaky=leaky)\n\n    def forward(self,x):\n        x = self.shuf(self.relu(self.conv(x)))\n        return self.blur(self.pad(x)) if self.blur else x\n\nclass FlattenedLoss():\n    \"Same as `func`, but flattens input and target.\"\n    def __init__(self, func, *args, axis:int=-1, floatify:bool=False, is_2d:bool=True, **kwargs):\n        self.func,self.axis,self.floatify,self.is_2d = func(*args,**kwargs),axis,floatify,is_2d\n        functools.update_wrapper(self, self.func)\n\n    def __repr__(self): return f\"FlattenedLoss of {self.func}\"\n    @property\n    def reduction(self): return self.func.reduction\n    @reduction.setter\n    def reduction(self, v): self.func.reduction = v\n\n    def __call__(self, input:Tensor, target:Tensor, **kwargs)->Rank0Tensor:\n        input = input.transpose(self.axis,-1).contiguous()\n        target = target.transpose(self.axis,-1).contiguous()\n        if self.floatify: target = target.float()\n        input = input.view(-1,input.shape[-1]) if self.is_2d else input.view(-1)\n        return self.func.__call__(input, target.view(-1), **kwargs)\n\ndef CrossEntropyFlat(*args, axis:int=-1, **kwargs):\n    \"Same as `nn.CrossEntropyLoss`, but flattens input and target.\"\n    return FlattenedLoss(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)\n\ndef BCEWithLogitsFlat(*args, axis:int=-1, floatify:bool=True, **kwargs):\n    \"Same as `nn.BCEWithLogitsLoss`, but flattens input and target.\"\n    return FlattenedLoss(nn.BCEWithLogitsLoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)\n\ndef BCEFlat(*args, axis:int=-1, floatify:bool=True, **kwargs):\n    \"Same as `nn.BCELoss`, but flattens input and target.\"\n    return FlattenedLoss(nn.BCELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)\n\ndef MSELossFlat(*args, axis:int=-1, floatify:bool=True, **kwargs):\n    \"Same as `nn.MSELoss`, but flattens input and target.\"\n    return FlattenedLoss(nn.MSELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)\n\nclass NoopLoss(Module):\n    \"Just returns the mean of the `output`.\"\n    def forward(self, output, *args): return output.mean()\n\nclass WassersteinLoss(Module):\n    \"For WGAN.\"\n    def forward(self, real, fake): return real.mean() - fake.mean()\n\ndef simple_cnn(actns:Collection[int], kernel_szs:Collection[int]=None,\n               strides:Collection[int]=None, bn=False) -> nn.Sequential:\n    \"CNN with `conv_layer` defined by `actns`, `kernel_szs` and `strides`, plus batchnorm if `bn`.\"\n    nl = len(actns)-1\n    kernel_szs = ifnone(kernel_szs, [3]*nl)\n    strides    = ifnone(strides   , [2]*nl)\n    layers = [conv_layer(actns[i], actns[i+1], kernel_szs[i], stride=strides[i],\n              norm_type=(NormType.Batch if bn and i<(len(strides)-1) else None)) for i in range_of(strides)]\n    layers.append(PoolFlatten())\n    return nn.Sequential(*layers)\n\ndef trunc_normal_(x:Tensor, mean:float=0., std:float=1.) -> Tensor:\n    \"Truncated normal initialization.\"\n    # From https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/12\n    return x.normal_().fmod_(2).mul_(std).add_(mean)\n\ndef embedding(ni:int,nf:int) -> nn.Module:\n    \"Create an embedding layer.\"\n    emb = nn.Embedding(ni, nf)\n    # See https://arxiv.org/abs/1711.09160\n    with torch.no_grad(): trunc_normal_(emb.weight, std=0.01)\n    return emb\n\nclass BatchNorm1dFlat(nn.BatchNorm1d):\n    \"`nn.BatchNorm1d`, but first flattens leading dimensions\"\n    def forward(self, x):\n        if x.dim()==2: return super().forward(x)\n        *f,l = x.shape\n        x = x.contiguous().view(-1,l)\n        return super().forward(x).view(*f,l)\n\nclass LabelSmoothingCrossEntropy(Module):\n    def __init__(self, eps:float=0.1, reduction='mean'): self.eps,self.reduction = eps,reduction\n\n    def forward(self, output, target):\n        c = output.size()[-1]\n        log_preds = F.log_softmax(output, dim=-1)\n        if self.reduction=='sum': loss = -log_preds.sum()\n        else:\n            loss = -log_preds.sum(dim=-1)\n            if self.reduction=='mean':  loss = loss.mean()\n        return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction)\n"
  },
  {
    "path": "fastai/metrics.py",
    "content": "\"Implements various metrics to measure training accuracy\"\nfrom .torch_core import *\nfrom .callback import *\nfrom .layers import *\nfrom .basic_train import LearnerCallback\n\n__all__ = ['error_rate', 'accuracy', 'accuracy_thresh', 'dice', 'exp_rmspe', 'fbeta','FBeta', 'mse', 'mean_squared_error',\n            'mae', 'mean_absolute_error', 'rmse', 'root_mean_squared_error', 'msle', 'mean_squared_logarithmic_error',\n            'explained_variance', 'r2_score', 'top_k_accuracy', 'KappaScore', 'ConfusionMatrix', 'MatthewsCorreff',\n            'Precision', 'Recall', 'R2Score', 'ExplainedVariance', 'ExpRMSPE', 'RMSE', 'Perplexity', 'AUROC', 'auc_roc_score', \n            'roc_curve', 'MultiLabelFbeta', 'foreground_acc']\n\ndef fbeta(y_pred:Tensor, y_true:Tensor, thresh:float=0.2, beta:float=2, eps:float=1e-9, sigmoid:bool=True)->Rank0Tensor:\n    \"Computes the f_beta between `preds` and `targets`\"\n    beta2 = beta ** 2\n    if sigmoid: y_pred = y_pred.sigmoid()\n    y_pred = (y_pred>thresh).float()\n    y_true = y_true.float()\n    TP = (y_pred*y_true).sum(dim=1)\n    prec = TP/(y_pred.sum(dim=1)+eps)\n    rec = TP/(y_true.sum(dim=1)+eps)\n    res = (prec*rec)/(prec*beta2+rec+eps)*(1+beta2)\n    return res.mean()\n\ndef accuracy(input:Tensor, targs:Tensor)->Rank0Tensor:\n    \"Computes accuracy with `targs` when `input` is bs * n_classes.\"\n    n = targs.shape[0]\n    input = input.argmax(dim=-1).view(n,-1)\n    targs = targs.view(n,-1)\n    return (input==targs).float().mean()\n\ndef accuracy_thresh(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True)->Rank0Tensor:\n    \"Computes accuracy when `y_pred` and `y_true` are the same size.\"\n    if sigmoid: y_pred = y_pred.sigmoid()\n    return ((y_pred>thresh)==y_true.byte()).float().mean()\n\ndef top_k_accuracy(input:Tensor, targs:Tensor, k:int=5)->Rank0Tensor:\n    \"Computes the Top-k accuracy (target is in the top k predictions).\"\n    input = input.topk(k=k, dim=-1)[1]\n    targs = targs.unsqueeze(dim=-1).expand_as(input)\n    return (input == targs).max(dim=-1)[0].float().mean()\n\ndef foreground_acc(input, target, void_code):\n    \"Computes non-background accuracy, e.g. camvid for multiclass segmentation\"\n    target = target.squeeze(1)\n    mask = target != void_code\n    return (input.argmax(dim=1)[mask]==target[mask]).float().mean()\n\ndef error_rate(input:Tensor, targs:Tensor)->Rank0Tensor:\n    \"1 - `accuracy`\"\n    return 1 - accuracy(input, targs)\n\ndef dice(input:Tensor, targs:Tensor, iou:bool=False, eps:float=1e-8)->Rank0Tensor:\n    \"Dice coefficient metric for binary target. If iou=True, returns iou metric, classic for segmentation problems.\"\n    n = targs.shape[0]\n    input = input.argmax(dim=1).view(n,-1)\n    targs = targs.view(n,-1)\n    intersect = (input * targs).sum().float()\n    union = (input+targs).sum().float()\n    if not iou: return (2. * intersect / union if union > 0 else union.new([1.]).squeeze())\n    else: return (intersect / (union-intersect+eps) if union > 0 else union.new([1.]).squeeze())\n\ndef psnr(input:Tensor, targs:Tensor)->Rank0Tensor:\n    return 10 * (1. / mean_squared_error(input, targs)).log10()\n\ndef exp_rmspe(pred:Tensor, targ:Tensor)->Rank0Tensor:\n    \"Exp RMSE between `pred` and `targ`.\"\n    pred,targ = flatten_check(pred,targ)\n    pred, targ = torch.exp(pred), torch.exp(targ)\n    pct_var = (targ - pred)/targ\n    return torch.sqrt((pct_var**2).mean())\n\ndef mean_absolute_error(pred:Tensor, targ:Tensor)->Rank0Tensor:\n    \"Mean absolute error between `pred` and `targ`.\"\n    pred,targ = flatten_check(pred,targ)\n    return torch.abs(targ - pred).mean()\n\ndef mean_squared_error(pred:Tensor, targ:Tensor)->Rank0Tensor:\n    \"Mean squared error between `pred` and `targ`.\"\n    pred,targ = flatten_check(pred,targ)\n    return F.mse_loss(pred, targ)\n\ndef root_mean_squared_error(pred:Tensor, targ:Tensor)->Rank0Tensor:\n    \"Root mean squared error between `pred` and `targ`.\"\n    pred,targ = flatten_check(pred,targ)\n    return torch.sqrt(F.mse_loss(pred, targ))\n\ndef mean_squared_logarithmic_error(pred:Tensor, targ:Tensor)->Rank0Tensor:\n    \"Mean squared logarithmic error between `pred` and `targ`.\"\n    pred,targ = flatten_check(pred,targ)\n    return F.mse_loss(torch.log(1 + pred), torch.log(1 + targ))\n\ndef explained_variance(pred:Tensor, targ:Tensor)->Rank0Tensor:\n    \"Explained variance between `pred` and `targ`.\"\n    pred,targ = flatten_check(pred,targ)\n    var_pct = torch.var(targ - pred) / torch.var(targ)\n    return 1 - var_pct\n\ndef r2_score(pred:Tensor, targ:Tensor)->Rank0Tensor:\n    \"R2 score (coefficient of determination) between `pred` and `targ`.\"\n    pred,targ = flatten_check(pred,targ)\n    u = torch.sum((targ - pred) ** 2)\n    d = torch.sum((targ - targ.mean()) ** 2)\n    return 1 - u / d\n\nclass RegMetrics(Callback):\n    \"Stores predictions and targets to perform calculations on epoch end.\"\n    def on_epoch_begin(self, **kwargs):\n        self.targs, self.preds = Tensor([]), Tensor([])\n\n    def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):\n        assert last_output.numel() == last_target.numel(), \"Expected same numbers of elements in pred & targ\"\n        self.preds = torch.cat((self.preds, last_output.cpu()))\n        self.targs = torch.cat((self.targs, last_target.cpu()))\n\nclass R2Score(RegMetrics):\n    \"Computes the R2 score (coefficient of determination).\"\n    def on_epoch_end(self, last_metrics, **kwargs):\n        return add_metrics(last_metrics, r2_score(self.preds, self.targs))\n\nclass ExplainedVariance(RegMetrics):\n    \"Computes the explained variance.\"\n    def on_epoch_end(self, last_metrics, **kwargs):\n        return add_metrics(last_metrics, explained_variance(self.preds, self.targs))\n\nclass RMSE(RegMetrics):\n    \"Computes the root mean squared error.\"\n    def on_epoch_end(self, last_metrics, **kwargs):\n        return add_metrics(last_metrics, root_mean_squared_error(self.preds, self.targs))\n\nclass ExpRMSPE(RegMetrics):\n    \"Computes the exponential of the root mean square error.\"\n    def on_epoch_end(self, last_metrics, **kwargs):\n        return add_metrics(last_metrics, exp_rmspe(self.preds, self.targs))\n\n# Aliases\nmse = mean_squared_error\nmae = mean_absolute_error\nmsle = mean_squared_logarithmic_error\nrmse = root_mean_squared_error\n\nclass ConfusionMatrix(Callback):\n    \"Computes the confusion matrix.\"\n\n    def on_train_begin(self, **kwargs):\n        self.n_classes = 0\n\n    def on_epoch_begin(self, **kwargs):\n        self.cm = None\n\n    def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):\n        preds = last_output.argmax(-1).view(-1).cpu()\n        targs = last_target.cpu()\n        if self.n_classes == 0:\n            self.n_classes = last_output.shape[-1]\n            self.x = torch.arange(0, self.n_classes)\n        cm = ((preds==self.x[:, None]) & (targs==self.x[:, None, None])).sum(dim=2, dtype=torch.float32)\n        if self.cm is None: self.cm =  cm\n        else:               self.cm += cm\n\n    def on_epoch_end(self, **kwargs):\n        self.metric = self.cm\n\n@dataclass\nclass CMScores(ConfusionMatrix):\n    \"Base class for metrics which rely on the calculation of the precision and/or recall score.\"\n    average:Optional[str]=\"binary\"      # `binary`, `micro`, `macro`, `weigthed` or None\n    pos_label:int=1                     # 0 or 1\n    eps:float=1e-9\n\n    def _recall(self):\n        rec = torch.diag(self.cm) / self.cm.sum(dim=1)\n        if self.average is None: return rec\n        else:\n            if self.average == \"micro\": weights = self._weights(avg=\"weighted\")\n            else: weights = self._weights(avg=self.average)\n            return (rec * weights).sum()\n\n    def _precision(self):\n        prec = torch.diag(self.cm) / self.cm.sum(dim=0)\n        if self.average is None: return prec\n        else:\n            weights = self._weights(avg=self.average)\n            return (prec * weights).sum()\n\n    def _weights(self, avg:str):\n        if self.n_classes != 2 and avg == \"binary\":\n            avg = self.average = \"macro\"\n            warn(\"average=`binary` was selected for a non binary case. Value for average has now been set to `macro` instead.\")\n        if avg == \"binary\":\n            if self.pos_label not in (0, 1):\n                self.pos_label = 1\n                warn(\"Invalid value for pos_label. It has now been set to 1.\")\n            if self.pos_label == 1: return Tensor([0,1])\n            else: return Tensor([1,0])\n        elif avg == \"micro\": return self.cm.sum(dim=0) / self.cm.sum()\n        elif avg == \"macro\": return torch.ones((self.n_classes,)) / self.n_classes\n        elif avg == \"weighted\": return self.cm.sum(dim=1) / self.cm.sum()\n\n\nclass Recall(CMScores):\n    \"Computes the Recall.\"\n    def on_epoch_end(self, last_metrics, **kwargs): \n        return add_metrics(last_metrics, self._recall())\n\nclass Precision(CMScores):\n    \"Computes the Precision.\"\n    def on_epoch_end(self, last_metrics, **kwargs): \n        return add_metrics(last_metrics, self._precision())\n\n@dataclass\nclass FBeta(CMScores):\n    \"Computes the F`beta` score.\"\n    beta:float=2\n\n    def on_train_begin(self, **kwargs):\n        self.n_classes = 0\n        self.beta2 = self.beta ** 2\n        self.avg = self.average\n        if self.average != \"micro\": self.average = None\n\n    def on_epoch_end(self, last_metrics, **kwargs):\n        prec = self._precision()\n        rec = self._recall()\n        metric = (1 + self.beta2) * prec * rec / (prec * self.beta2 + rec + self.eps)\n        metric[metric != metric] = 0  # removing potential \"nan\"s\n        if self.avg: metric = (self._weights(avg=self.avg) * metric).sum()\n        return add_metrics(last_metrics, metric)\n\n    def on_train_end(self, **kwargs): self.average = self.avg\n\n@dataclass\nclass KappaScore(ConfusionMatrix):\n    \"Computes the rate of agreement (Cohens Kappa).\"\n    weights:Optional[str]=None      # None, `linear`, or `quadratic`\n\n    def on_epoch_end(self, last_metrics, **kwargs):\n        sum0 = self.cm.sum(dim=0)\n        sum1 = self.cm.sum(dim=1)\n        expected = torch.einsum('i,j->ij', (sum0, sum1)) / sum0.sum()\n        if self.weights is None:\n            w = torch.ones((self.n_classes, self.n_classes))\n            w[self.x, self.x] = 0\n        elif self.weights == \"linear\" or self.weights == \"quadratic\":\n            w = torch.zeros((self.n_classes, self.n_classes))\n            w += torch.arange(self.n_classes, dtype=torch.float)\n            w = torch.abs(w - torch.t(w)) if self.weights == \"linear\" else (w - torch.t(w)) ** 2\n        else: raise ValueError('Unknown weights. Expected None, \"linear\", or \"quadratic\".')\n        k = torch.sum(w * self.cm) / torch.sum(w * expected)\n        return add_metrics(last_metrics, 1-k)\n\n@dataclass\nclass MatthewsCorreff(ConfusionMatrix):\n    \"Computes the Matthews correlation coefficient.\"\n    def on_epoch_end(self, last_metrics, **kwargs):\n        t_sum = self.cm.sum(dim=1)\n        p_sum = self.cm.sum(dim=0)\n        n_correct = torch.trace(self.cm)\n        n_samples = p_sum.sum()\n        cov_ytyp = n_correct * n_samples - torch.dot(t_sum, p_sum)\n        cov_ypyp = n_samples ** 2 - torch.dot(p_sum, p_sum)\n        cov_ytyt = n_samples ** 2 - torch.dot(t_sum, t_sum)\n        return add_metrics(last_metrics, cov_ytyp / torch.sqrt(cov_ytyt * cov_ypyp))\n\nclass Perplexity(Callback):\n    \"Perplexity metric for language models.\"\n    def on_epoch_begin(self, **kwargs): self.loss,self.len = 0.,0\n\n    def on_batch_end(self, last_output, last_target, **kwargs):\n        self.loss += last_target.size(1) * CrossEntropyFlat()(last_output, last_target)\n        self.len += last_target.size(1)\n\n    def on_epoch_end(self, last_metrics, **kwargs): \n        return add_metrics(last_metrics, torch.exp(self.loss / self.len))\n\ndef auc_roc_score(input:Tensor, targ:Tensor):\n    \"Computes the area under the receiver operator characteristic (ROC) curve using the trapezoid method. Restricted binary classification tasks.\"\n    fpr, tpr = roc_curve(input, targ)\n    d = fpr[1:] - fpr[:-1]\n    sl1, sl2 = [slice(None)], [slice(None)]\n    sl1[-1], sl2[-1] = slice(1, None), slice(None, -1)\n    return (d * (tpr[tuple(sl1)] + tpr[tuple(sl2)]) / 2.).sum(-1)\n\ndef roc_curve(input:Tensor, targ:Tensor):\n    \"Computes the receiver operator characteristic (ROC) curve by determining the true positive ratio (TPR) and false positive ratio (FPR) for various classification thresholds. Restricted binary classification tasks.\"\n    targ = (targ == 1)\n    desc_score_indices = torch.flip(input.argsort(-1), [-1])\n    input = input[desc_score_indices]\n    targ = targ[desc_score_indices]\n    d = input[1:] - input[:-1]\n    distinct_value_indices = torch.nonzero(d).transpose(0,1)[0]\n    threshold_idxs = torch.cat((distinct_value_indices, LongTensor([len(targ) - 1]).to(targ.device)))\n    tps = torch.cumsum(targ * 1, dim=-1)[threshold_idxs]\n    fps = (1 + threshold_idxs - tps)\n    if tps[0] != 0 or fps[0] != 0:\n        fps = torch.cat((LongTensor([0]), fps))\n        tps = torch.cat((LongTensor([0]), tps))\n    fpr, tpr = fps.float() / fps[-1], tps.float() / tps[-1]\n    return fpr, tpr\n\n@dataclass\nclass AUROC(Callback):\n    \"Computes the area under the curve (AUC) score based on the receiver operator characteristic (ROC) curve. Restricted to binary classification tasks.\"\n    def on_epoch_begin(self, **kwargs):\n        self.targs, self.preds = LongTensor([]), Tensor([])\n        \n    def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):\n        last_output = F.softmax(last_output, dim=1)[:,-1]\n        self.preds = torch.cat((self.preds, last_output.cpu()))\n        self.targs = torch.cat((self.targs, last_target.cpu().long()))\n    \n    def on_epoch_end(self, last_metrics, **kwargs):\n        return add_metrics(last_metrics, auc_roc_score(self.preds, self.targs))\n\nclass MultiLabelFbeta(LearnerCallback):\n    \"Computes the fbeta score for multilabel classification\"\n    # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html\n    _order = -20 \n    def __init__(self, learn, beta=2, eps=1e-15, thresh=0.3, sigmoid=True, average=\"micro\"):\n        super().__init__(learn)\n        self.eps, self.thresh, self.sigmoid, self.average, self.beta2 = \\\n            eps, thresh, sigmoid, average, beta**2\n\n    def on_train_begin(self, **kwargs):\n        self.c = self.learn.data.c\n        if self.average != \"none\": self.learn.recorder.add_metric_names([f'{self.average}_fbeta'])\n        else: self.learn.recorder.add_metric_names([f\"fbeta_{c}\" for c in self.learn.data.classes])\n\n    def on_epoch_begin(self, **kwargs):\n        dvc = self.learn.data.device\n        self.tp = torch.zeros(self.c).to(dvc)\n        self.total_pred = torch.zeros(self.c).to(dvc)\n        self.total_targ = torch.zeros(self.c).to(dvc)\n    \n    def on_batch_end(self, last_output, last_target, **kwargs):\n        pred, targ = (last_output.sigmoid() if self.sigmoid else last_output) > self.thresh, last_target.byte()\n        m = pred*targ\n        self.tp += m.sum(0).float()\n        self.total_pred += pred.sum(0).float()\n        self.total_targ += targ.sum(0).float()\n    \n    def fbeta_score(self, precision, recall):\n        return (1 + self.beta2)*(precision*recall)/((self.beta2*precision + recall) + self.eps)\n\n    def on_epoch_end(self, last_metrics, **kwargs):\n        self.total_pred += self.eps\n        self.total_targ += self.eps\n        if self.average == \"micro\":\n            precision, recall = self.tp.sum() / self.total_pred.sum(), self.tp.sum() / self.total_targ.sum()\n            res = self.fbeta_score(precision, recall)\n        elif self.average == \"macro\":\n            res = self.fbeta_score((self.tp / self.total_pred), (self.tp / self.total_targ)).mean()\n        elif self.average == \"weighted\":\n            scores = self.fbeta_score((self.tp / self.total_pred), (self.tp / self.total_targ))\n            res = (scores*self.total_targ).sum() / self.total_targ.sum()\n        elif self.average == \"none\":\n            res = listify(self.fbeta_score((self.tp / self.total_pred), (self.tp / self.total_targ)))\n        else:\n            raise Exception(\"Choose one of the average types: [micro, macro, weighted, none]\")\n        \n        return add_metrics(last_metrics, res)\n"
  },
  {
    "path": "fastai/script.py",
    "content": "import os, sys, subprocess, inspect\nfrom dataclasses import dataclass\nfrom typing import Any\nfrom argparse import ArgumentParser\n\n\n@dataclass\nclass Param():\n    \"A parameter in a function used in `anno_parser` or `call_parse`\"\n    help:str=None\n    type:type=None\n    opt:bool=True\n    action:str=None\n    nargs:str=None\n    const:str=None\n    choices:str=None\n    required:bool=None\n\n    @property\n    def pre(self): return '--' if self.opt else ''\n    @property\n    def kwargs(self): return {k:v for k,v in self.__dict__.items()\n                              if v is not None and k!='opt'}\n\ndef anno_parser(func):\n    \"Look at params (annotated with `Param`) in func and return an `ArgumentParser`\"\n    p = ArgumentParser(description=func.__doc__)\n    for k,v in inspect.signature(func).parameters.items():\n        param = func.__annotations__.get(k, Param())\n        kwargs = param.kwargs\n        if v.default != inspect.Parameter.empty: kwargs['default'] = v.default\n        p.add_argument(f\"{param.pre}{k}\", **kwargs)\n    return p\n\ndef call_parse(func):\n    \"Decorator to create a simple CLI from `func` using `anno_parser`\"\n    name = inspect.currentframe().f_back.f_globals['__name__']\n    if name == \"__main__\":\n        args = anno_parser(func).parse_args()\n        func(**args.__dict__)\n    else: return func\n\ndef call_plac(f):\n    \"Decorator to create a simple CLI from `func` using `plac`\"\n    name = inspect.currentframe().f_back.f_globals['__name__']\n    if name == '__main__':\n        import plac\n        res = plac.call(f)\n        if callable(res): res()\n    else: return f\n\n"
  },
  {
    "path": "fastai/sixel.py",
    "content": "from .core import *\n\nlibsixel = try_import('libsixel')\n\ndef _sixel_encode(data, width, height):\n    s = io.BytesIO()\n    output = libsixel.sixel_output_new(lambda data, s: s.write(data), s)\n    dither = libsixel.sixel_dither_new(256)\n    w,h = int(width),int(height)\n    libsixel.sixel_dither_initialize(dither, data, w, h, libsixel.SIXEL_PIXELFORMAT_RGBA8888)\n    libsixel.sixel_encode(data, w, h, 1, dither, output)\n    return s.getvalue().decode('ascii')\n\ndef plot_sixel(fig=None):\n    if not libsixel:\n        warn(\"You could see this plot with `libsixel`. See https://github.com/saitoha/libsixel\")\n        return\n    if fig is None: fig = plt.gcf()\n    fig.canvas.draw()\n    dpi = fig.get_dpi()\n    res = _sixel_encode(fig.canvas.buffer_rgba(), fig.get_figwidth()* dpi, fig.get_figheight() * dpi)\n    print(res)\n\n"
  },
  {
    "path": "fastai/tabular/__init__.py",
    "content": "from .. import basics\nfrom ..basics import *\nfrom .data import *\nfrom .transform import *\nfrom .models import *\nfrom .. import tabular\n\n__all__ = [*basics.__all__, *data.__all__, *transform.__all__, *models.__all__, 'tabular']\n\n"
  },
  {
    "path": "fastai/tabular/data.py",
    "content": "\"Data loading pipeline for structured data support. Loads from pandas DataFrame\"\nfrom ..torch_core import *\nfrom .transform import *\nfrom ..basic_data import *\nfrom ..data_block import *\nfrom ..basic_train import *\nfrom .models import *\nfrom pandas.api.types import is_numeric_dtype, is_categorical_dtype\n\n__all__ = ['TabularDataBunch', 'TabularLine', 'TabularList', 'TabularProcessor', 'tabular_learner']\n\nOptTabTfms = Optional[Collection[TabularProc]]\n\n#def emb_sz_rule(n_cat:int)->int: return min(50, (n_cat//2)+1)\ndef emb_sz_rule(n_cat:int)->int: return min(600, round(1.6 * n_cat**0.56))\n\ndef def_emb_sz(classes, n, sz_dict=None):\n    \"Pick an embedding size for `n` depending on `classes` if not given in `sz_dict`.\"\n    sz_dict = ifnone(sz_dict, {})\n    n_cat = len(classes[n])\n    sz = sz_dict.get(n, int(emb_sz_rule(n_cat)))  # rule of thumb\n    return n_cat,sz\n\nclass TabularLine(ItemBase):\n    \"Basic item for tabular data.\"\n    def __init__(self, cats, conts, classes, names):\n        self.cats,self.conts,self.classes,self.names = cats,conts,classes,names\n        self.data = [tensor(cats), tensor(conts)]\n\n    def __str__(self):\n        res = ''\n        for c, n in zip(self.cats, self.names[:len(self.cats)]):\n            res += f\"{n} {(self.classes[n][c])}; \"\n        for c,n in zip(self.conts, self.names[len(self.cats):]):\n            res += f'{n} {c:.4f}; '\n        return res\n\nclass TabularProcessor(PreProcessor):\n    \"Regroup the `procs` in one `PreProcessor`.\"\n    def __init__(self, ds:ItemBase=None, procs=None):\n        procs = ifnone(procs, ds.procs if ds is not None else None)\n        self.procs = listify(procs)\n\n    def process_one(self, item):\n        df = pd.DataFrame([item,item])\n        for proc in self.procs: proc(df, test=True)\n        if len(self.cat_names) != 0:\n            codes = np.stack([c.cat.codes.values for n,c in df[self.cat_names].items()], 1).astype(np.int64) + 1\n        else: codes = [[]]\n        if len(self.cont_names) != 0:\n            conts = np.stack([c.astype('float32').values for n,c in df[self.cont_names].items()], 1)\n        else: conts = [[]]\n        classes = None\n        col_names = list(df[self.cat_names].columns.values) + list(df[self.cont_names].columns.values)\n        return TabularLine(codes[0], conts[0], classes, col_names)\n\n    def process(self, ds):\n        if ds.inner_df is None:\n            ds.classes,ds.cat_names,ds.cont_names = self.classes,self.cat_names,self.cont_names\n            ds.col_names = self.cat_names + self.cont_names\n            ds.preprocessed = True\n            return\n        for i,proc in enumerate(self.procs):\n            if isinstance(proc, TabularProc): proc(ds.inner_df, test=True)\n            else:\n                #cat and cont names may have been changed by transform (like Fill_NA)\n                proc = proc(ds.cat_names, ds.cont_names)\n                proc(ds.inner_df)\n                ds.cat_names,ds.cont_names = proc.cat_names,proc.cont_names\n                self.procs[i] = proc\n        self.cat_names,self.cont_names = ds.cat_names,ds.cont_names\n        if len(ds.cat_names) != 0:\n            ds.codes = np.stack([c.cat.codes.values for n,c in ds.inner_df[ds.cat_names].items()], 1).astype(np.int64) + 1\n            self.classes = ds.classes = OrderedDict({n:np.concatenate([['#na#'],c.cat.categories.values])\n                                      for n,c in ds.inner_df[ds.cat_names].items()})\n            cat_cols = list(ds.inner_df[ds.cat_names].columns.values)\n        else: ds.codes,ds.classes,self.classes,cat_cols = None,None,None,[]\n        if len(ds.cont_names) != 0:\n            ds.conts = np.stack([c.astype('float32').values for n,c in ds.inner_df[ds.cont_names].items()], 1)\n            cont_cols = list(ds.inner_df[ds.cont_names].columns.values)\n        else: ds.conts,cont_cols = None,[]\n        ds.col_names = cat_cols + cont_cols\n        ds.preprocessed = True\n\nclass TabularDataBunch(DataBunch):\n    \"Create a `DataBunch` suitable for tabular data.\"\n    @classmethod\n    def from_df(cls, path, df:DataFrame, dep_var:str, valid_idx:Collection[int], procs:OptTabTfms=None,\n                cat_names:OptStrList=None, cont_names:OptStrList=None, classes:Collection=None, \n                test_df=None, bs:int=64, val_bs:int=None, num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None, \n                device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False)->DataBunch:\n        \"Create a `DataBunch` from `df` and `valid_idx` with `dep_var`. `kwargs` are passed to `DataBunch.create`.\"\n        cat_names = ifnone(cat_names, []).copy()\n        cont_names = ifnone(cont_names, list(set(df)-set(cat_names)-{dep_var}))\n        procs = listify(procs)\n        src = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)\n                           .split_by_idx(valid_idx))\n        src = src.label_from_df(cols=dep_var) if classes is None else src.label_from_df(cols=dep_var, classes=classes)\n        if test_df is not None: src.add_test(TabularList.from_df(test_df, cat_names=cat_names, cont_names=cont_names,\n                                                                 processor = src.train.x.processor))\n        return src.databunch(path=path, bs=bs, val_bs=val_bs, num_workers=num_workers, device=device, \n                             collate_fn=collate_fn, no_check=no_check)\n\nclass TabularList(ItemList):\n    \"Basic `ItemList` for tabular data.\"\n    _item_cls=TabularLine\n    _processor=TabularProcessor\n    _bunch=TabularDataBunch\n    def __init__(self, items:Iterator, cat_names:OptStrList=None, cont_names:OptStrList=None,\n                 procs=None, **kwargs)->'TabularList':\n        super().__init__(range_of(items), **kwargs)\n        #dataframe is in inner_df, items is just a range of index\n        if cat_names is None:  cat_names = []\n        if cont_names is None: cont_names = []\n        self.cat_names,self.cont_names,self.procs = cat_names,cont_names,procs\n        self.copy_new += ['cat_names', 'cont_names', 'procs']\n        self.preprocessed = False\n\n    @classmethod\n    def from_df(cls, df:DataFrame, cat_names:OptStrList=None, cont_names:OptStrList=None, procs=None, **kwargs)->'ItemList':\n        \"Get the list of inputs in the `col` of `path/csv_name`.\"\n        return cls(items=range(len(df)), cat_names=cat_names, cont_names=cont_names, procs=procs, inner_df=df.copy(), **kwargs)\n\n    def get(self, o):\n        if not self.preprocessed: return self.inner_df.iloc[o] if hasattr(self, 'inner_df') else self.items[o]\n        codes = [] if self.codes is None else self.codes[o]\n        conts = [] if self.conts is None else self.conts[o]\n        return self._item_cls(codes, conts, self.classes, self.col_names)\n\n    def get_emb_szs(self, sz_dict=None):\n        \"Return the default embedding sizes suitable for this data or takes the ones in `sz_dict`.\"\n        return [def_emb_sz(self.classes, n, sz_dict) for n in self.cat_names]\n\n    def reconstruct(self, t:Tensor):\n        return self._item_cls(t[0], t[1], self.classes, self.col_names)\n\n    def show_xys(self, xs, ys)->None:\n        \"Show the `xs` (inputs) and `ys` (targets).\"\n        from IPython.display import display, HTML\n        items,names = [], xs[0].names + ['target']\n        for i, (x,y) in enumerate(zip(xs,ys)):\n            res = []\n            cats = x.cats if len(x.cats.size()) > 0 else []\n            conts = x.conts if len(x.conts.size()) > 0 else []\n            for c, n in zip(cats, x.names[:len(cats)]):\n                res.append(x.classes[n][c])\n            res += [f'{c:.4f}' for c in conts] + [y]\n            items.append(res)\n        items = np.array(items)\n        df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)\n        with pd.option_context('display.max_colwidth', -1):\n            display(HTML(df.to_html(index=False)))\n\n    def show_xyzs(self, xs, ys, zs):\n        \"Show `xs` (inputs), `ys` (targets) and `zs` (predictions).\"\n        from IPython.display import display, HTML\n        items,names = [], xs[0].names + ['target', 'prediction']\n        for i, (x,y,z) in enumerate(zip(xs,ys,zs)):\n            res = []\n            cats = x.cats if len(x.cats.size()) > 0 else []\n            conts = x.conts if len(x.conts.size()) > 0 else []\n            for c, n in zip(cats, x.names[:len(cats)]):\n                res.append(str(x.classes[n][c]))\n            res += [f'{c:.4f}' for c in conts] + [y, z]\n            items.append(res)\n        items = np.array(items)\n        df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)\n        with pd.option_context('display.max_colwidth', -1):\n            display(HTML(df.to_html(index=False)))\n\ndef tabular_learner(data:DataBunch, layers:Collection[int], emb_szs:Dict[str,int]=None, metrics=None,\n        ps:Collection[float]=None, emb_drop:float=0., y_range:OptRange=None, use_bn:bool=True, **learn_kwargs):\n    \"Get a `Learner` using `data`, with `metrics`, including a `TabularModel` created using the remaining params.\"\n    emb_szs = data.get_emb_szs(ifnone(emb_szs, {}))\n    model = TabularModel(emb_szs, len(data.cont_names), out_sz=data.c, layers=layers, ps=ps, emb_drop=emb_drop,\n                         y_range=y_range, use_bn=use_bn)\n    return Learner(data, model, metrics=metrics, **learn_kwargs)\n\n"
  },
  {
    "path": "fastai/tabular/models.py",
    "content": "from ..torch_core import *\nfrom ..layers import *\nfrom ..basic_data import *\nfrom ..basic_train import *\nfrom ..train import ClassificationInterpretation\n\n__all__ = ['TabularModel']\n\nclass TabularModel(Module):\n    \"Basic model for tabular data.\"\n    def __init__(self, emb_szs:ListSizes, n_cont:int, out_sz:int, layers:Collection[int], ps:Collection[float]=None,\n                 emb_drop:float=0., y_range:OptRange=None, use_bn:bool=True, bn_final:bool=False):\n        super().__init__()\n        ps = ifnone(ps, [0]*len(layers))\n        ps = listify(ps, layers)\n        self.embeds = nn.ModuleList([embedding(ni, nf) for ni,nf in emb_szs])\n        self.emb_drop = nn.Dropout(emb_drop)\n        self.bn_cont = nn.BatchNorm1d(n_cont)\n        n_emb = sum(e.embedding_dim for e in self.embeds)\n        self.n_emb,self.n_cont,self.y_range = n_emb,n_cont,y_range\n        sizes = self.get_sizes(layers, out_sz)\n        actns = [nn.ReLU(inplace=True) for _ in range(len(sizes)-2)] + [None]\n        layers = []\n        for i,(n_in,n_out,dp,act) in enumerate(zip(sizes[:-1],sizes[1:],[0.]+ps,actns)):\n            layers += bn_drop_lin(n_in, n_out, bn=use_bn and i!=0, p=dp, actn=act)\n        if bn_final: layers.append(nn.BatchNorm1d(sizes[-1]))\n        self.layers = nn.Sequential(*layers)\n\n    def get_sizes(self, layers, out_sz):\n        return [self.n_emb + self.n_cont] + layers + [out_sz]\n\n    def forward(self, x_cat:Tensor, x_cont:Tensor) -> Tensor:\n        if self.n_emb != 0:\n            x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]\n            x = torch.cat(x, 1)\n            x = self.emb_drop(x)\n        if self.n_cont != 0:\n            x_cont = self.bn_cont(x_cont)\n            x = torch.cat([x, x_cont], 1) if self.n_emb != 0 else x_cont\n        x = self.layers(x)\n        if self.y_range is not None:\n            x = (self.y_range[1]-self.y_range[0]) * torch.sigmoid(x) + self.y_range[0]\n        return x\n\n@classmethod\ndef _cl_int_from_learner(cls, learn:Learner, ds_type=DatasetType.Valid, activ:nn.Module=None):\n    \"Creates an instance of 'ClassificationInterpretation\"\n    preds = learn.get_preds(ds_type=ds_type, activ=activ, with_loss=True)\n    return cls(learn, *preds, ds_type=ds_type)\n\ndef _cl_int_plot_top_losses(self, k, largest:bool=True, return_table:bool=False)->Optional[plt.Figure]:\n    \"Generates a dataframe of 'top_losses' along with their prediction, actual, loss, and probability of the actual class.\"\n    tl_val, tl_idx = self.top_losses(k, largest)\n    classes = self.data.classes\n    cat_names = self.data.x.cat_names\n    cont_names = self.data.x.cont_names\n    df = pd.DataFrame(columns=[['Prediction', 'Actual', 'Loss', 'Probability'] + cat_names + cont_names])\n    for i, idx in enumerate(tl_idx):\n        da, cl = self.data.dl(self.ds_type).dataset[idx]\n        cl = int(cl)\n        t1 = str(da)\n        t1 = t1.split(';')\n        arr = []\n        arr.extend([classes[self.pred_class[idx]], classes[cl], f'{self.losses[idx]:.2f}',\n                    f'{self.preds[idx][cl]:.2f}'])\n        for x in range(len(t1)-1):\n            _, value = t1[x].rsplit(' ', 1)\n            arr.append(value)\n        df.loc[i] = arr\n    display(df)\n    return_fig = return_table\n    if ifnone(return_fig, defaults.return_fig): return df\n\n\nClassificationInterpretation.from_learner = _cl_int_from_learner\nClassificationInterpretation.plot_top_losses = _cl_int_plot_top_losses\n\ndef _learner_interpret(learn:Learner, ds_type:DatasetType = DatasetType.Valid):\n    \"Create a 'ClassificationInterpretation' object from 'learner' on 'ds_type'.\"\n    return ClassificationInterpretation.from_learner(learn, ds_type=ds_type)\n\nLearner.interpret = _learner_interpret\n"
  },
  {
    "path": "fastai/tabular/transform.py",
    "content": "\"Cleaning and feature engineering functions for structured data\"\nfrom ..torch_core import *\nfrom pandas.api.types import is_numeric_dtype\nfrom datetime import date, datetime\nimport calendar\n\n__all__ = ['add_datepart', 'cont_cat_split', 'Categorify', 'FillMissing', 'FillStrategy', 'Normalize', 'TabularProc',\n           'add_elapsed_times', 'make_date', 'add_cyclic_datepart']\n\ndef make_date(df:DataFrame, date_field:str):\n    \"Make sure `df[field_name]` is of the right date type.\"\n    field_dtype = df[date_field].dtype\n    if isinstance(field_dtype, pd.core.dtypes.dtypes.DatetimeTZDtype):\n        field_dtype = np.datetime64\n    if not np.issubdtype(field_dtype, np.datetime64):\n        df[date_field] = pd.to_datetime(df[date_field], infer_datetime_format=True)\n\ndef cyclic_dt_feat_names(time:bool=True, add_linear:bool=False)->List[str]:\n    \"Return feature names of date/time cycles as produced by `cyclic_dt_features`.\"\n    fs = ['cos','sin']\n    attr = [f'{r}_{f}' for r in 'weekday day_month month_year day_year'.split() for f in fs]\n    if time: attr += [f'{r}_{f}' for r in 'hour clock min sec'.split() for f in fs]\n    if add_linear: attr.append('year_lin')\n    return attr\n\ndef cyclic_dt_features(d:Union[date,datetime], time:bool=True, add_linear:bool=False)->List[float]:\n    \"Calculate the cos and sin of date/time cycles.\"\n    tt,fs = d.timetuple(), [np.cos, np.sin]\n    day_year,days_month = tt.tm_yday, calendar.monthrange(d.year, d.month)[1]\n    days_year = 366 if calendar.isleap(d.year) else 365\n    rs = d.weekday()/7, (d.day-1)/days_month, (d.month-1)/12, (day_year-1)/days_year\n    feats = [f(r * 2 * np.pi) for r in rs for f in fs]\n    if time and isinstance(d, datetime) and type(d) != date:\n        rs = tt.tm_hour/24, tt.tm_hour%12/12, tt.tm_min/60, tt.tm_sec/60\n        feats += [f(r * 2 * np.pi) for r in rs for f in fs]\n    if add_linear:\n        if type(d) == date: feats.append(d.year + rs[-1])\n        else:\n            secs_in_year = (datetime(d.year+1, 1, 1) - datetime(d.year, 1, 1)).total_seconds()\n            feats.append(d.year + ((d - datetime(d.year, 1, 1)).total_seconds() / secs_in_year))\n    return feats\n\ndef add_cyclic_datepart(df:DataFrame, field_name:str, prefix:str=None, drop:bool=True, time:bool=False, add_linear:bool=False):\n    \"Helper function that adds trigonometric date/time features to a date in the column `field_name` of `df`.\"\n    make_date(df, field_name)\n    field = df[field_name]\n    prefix = ifnone(prefix, re.sub('[Dd]ate$', '', field_name))\n    series = field.apply(partial(cyclic_dt_features, time=time, add_linear=add_linear))\n    columns = [prefix + c for c in cyclic_dt_feat_names(time, add_linear)]\n    df_feats = pd.DataFrame([item for item in series], columns=columns, index=series.index)\n    for column in columns: df[column] = df_feats[column]\n    if drop: df.drop(field_name, axis=1, inplace=True)\n    return df\n\ndef add_datepart(df:DataFrame, field_name:str, prefix:str=None, drop:bool=True, time:bool=False):\n    \"Helper function that adds columns relevant to a date in the column `field_name` of `df`.\"\n    make_date(df, field_name)\n    field = df[field_name]\n    prefix = ifnone(prefix, re.sub('[Dd]ate$', '', field_name))\n    attr = ['Year', 'Month', 'Week', 'Day', 'Dayofweek', 'Dayofyear', 'Is_month_end', 'Is_month_start', \n            'Is_quarter_end', 'Is_quarter_start', 'Is_year_end', 'Is_year_start']\n    if time: attr = attr + ['Hour', 'Minute', 'Second']\n    for n in attr: df[prefix + n] = getattr(field.dt, n.lower())\n    df[prefix + 'Elapsed'] = field.astype(np.int64) // 10 ** 9\n    if drop: df.drop(field_name, axis=1, inplace=True)\n    return df\n\ndef _get_elapsed(df:DataFrame,field_names:Collection[str], date_field:str, base_field:str, prefix:str):\n    for f in field_names:\n        day1 = np.timedelta64(1, 'D')\n        last_date,last_base,res = np.datetime64(),None,[]\n        for b,v,d in zip(df[base_field].values, df[f].values, df[date_field].values):\n            if last_base is None or b != last_base:\n                last_date,last_base = np.datetime64(),b\n            if v: last_date = d\n            res.append(((d-last_date).astype('timedelta64[D]') / day1))\n        df[prefix + f] = res\n    return df\n\ndef add_elapsed_times(df:DataFrame, field_names:Collection[str], date_field:str, base_field:str):\n    field_names = listify(field_names)\n    #Make sure date_field is a date and base_field a bool\n    df[field_names] = df[field_names].astype('bool')\n    make_date(df, date_field)\n    \n    work_df = df[field_names + [date_field, base_field]]\n    work_df = work_df.sort_values([base_field, date_field])\n    work_df = _get_elapsed(work_df, field_names, date_field, base_field, 'After')\n    work_df = work_df.sort_values([base_field, date_field], ascending=[True, False])\n    work_df = _get_elapsed(work_df, field_names, date_field, base_field, 'Before')\n    \n    for a in ['After' + f for f in field_names] + ['Before' + f for f in field_names]:\n        work_df[a] = work_df[a].fillna(0).astype(int)  \n    \n    for a,s in zip([True, False], ['_bw', '_fw']):\n        work_df = work_df.set_index(date_field)\n        tmp = (work_df[[base_field] + field_names].sort_index(ascending=a)\n                      .groupby(base_field).rolling(7, min_periods=1).sum())\n        tmp.drop(base_field,1,inplace=True)\n        tmp.reset_index(inplace=True)\n        work_df.reset_index(inplace=True)\n        work_df = work_df.merge(tmp, 'left', [date_field, base_field], suffixes=['', s])\n    work_df.drop(field_names,1,inplace=True)\n    return df.merge(work_df, 'left', [date_field, base_field])\n\ndef cont_cat_split(df, max_card=20, dep_var=None)->Tuple[List,List]:\n    \"Helper function that returns column names of cont and cat variables from given df.\"\n    cont_names, cat_names = [], []\n    for label in df:\n        if label == dep_var: continue\n        if df[label].dtype == int and df[label].unique().shape[0] > max_card or df[label].dtype == float: cont_names.append(label)\n        else: cat_names.append(label)\n    return cont_names, cat_names\n        \n@dataclass\nclass TabularProc():\n    \"A processor for tabular dataframes.\"\n    cat_names:StrList\n    cont_names:StrList\n\n    def __call__(self, df:DataFrame, test:bool=False):\n        \"Apply the correct function to `df` depending on `test`.\"\n        func = self.apply_test if test else self.apply_train\n        func(df)\n\n    def apply_train(self, df:DataFrame):\n        \"Function applied to `df` if it's the train set.\"\n        raise NotImplementedError\n    def apply_test(self, df:DataFrame):\n        \"Function applied to `df` if it's the test set.\"\n        self.apply_train(df)\n\nclass Categorify(TabularProc):\n    \"Transform the categorical variables to that type.\"\n    def apply_train(self, df:DataFrame):\n        \"Transform `self.cat_names` columns in categorical.\"\n        self.categories = {}\n        for n in self.cat_names:\n            df.loc[:,n] = df.loc[:,n].astype('category').cat.as_ordered()\n            self.categories[n] = df[n].cat.categories\n\n    def apply_test(self, df:DataFrame):\n        \"Transform `self.cat_names` columns in categorical using the codes decided in `apply_train`.\"\n        for n in self.cat_names:\n            df.loc[:,n] = pd.Categorical(df[n], categories=self.categories[n], ordered=True)\n\nFillStrategy = IntEnum('FillStrategy', 'MEDIAN COMMON CONSTANT')\n\n@dataclass\nclass FillMissing(TabularProc):\n    \"Fill the missing values in continuous columns.\"\n    fill_strategy:FillStrategy=FillStrategy.MEDIAN\n    add_col:bool=True\n    fill_val:float=0.\n    def apply_train(self, df:DataFrame):\n        \"Fill missing values in `self.cont_names` according to `self.fill_strategy`.\"\n        self.na_dict = {}\n        for name in self.cont_names:\n            if pd.isnull(df[name]).sum():\n                if self.add_col:\n                    df[name+'_na'] = pd.isnull(df[name])\n                    if name+'_na' not in self.cat_names: self.cat_names.append(name+'_na')\n                if self.fill_strategy == FillStrategy.MEDIAN: filler = df[name].median()\n                elif self.fill_strategy == FillStrategy.CONSTANT: filler = self.fill_val\n                else: filler = df[name].dropna().value_counts().idxmax()\n                df[name] = df[name].fillna(filler)\n                self.na_dict[name] = filler\n\n    def apply_test(self, df:DataFrame):\n        \"Fill missing values in `self.cont_names` like in `apply_train`.\"\n        for name in self.cont_names:\n            if name in self.na_dict:\n                if self.add_col:\n                    df[name+'_na'] = pd.isnull(df[name])\n                    if name+'_na' not in self.cat_names: self.cat_names.append(name+'_na')\n                df[name] = df[name].fillna(self.na_dict[name])\n            elif pd.isnull(df[name]).sum() != 0:\n                raise Exception(f\"\"\"There are nan values in field {name} but there were none in the training set. \n                Please fix those manually.\"\"\")\n\nclass Normalize(TabularProc):\n    \"Normalize the continuous variables.\"\n    def apply_train(self, df:DataFrame):\n        \"Compute the means and stds of `self.cont_names` columns to normalize them.\"\n        self.means,self.stds = {},{}\n        for n in self.cont_names:\n            assert is_numeric_dtype(df[n]), (f\"\"\"Cannot normalize '{n}' column as it isn't numerical.\n                Are you sure it doesn't belong in the categorical set of columns?\"\"\")\n            self.means[n],self.stds[n] = df[n].mean(),df[n].std()\n            df[n] = (df[n]-self.means[n]) / (1e-7 + self.stds[n])\n\n    def apply_test(self, df:DataFrame):\n        \"Normalize `self.cont_names` with the same statistics as in `apply_train`.\"\n        for n in self.cont_names:\n            df[n] = (df[n]-self.means[n]) / (1e-7 + self.stds[n])\n"
  },
  {
    "path": "fastai/test_registry.json",
    "content": "{\n    \"fastai.basic_data.DataBunch\": [\n        {\n            \"file\": \"tests/test_data_block.py\",\n            \"line\": 152,\n            \"test\": \"test_custom_dataset\"\n        }\n    ],\n    \"fastai.basic_data.DataBunch.create\": [\n        {\n            \"file\": \"tests/test_basic_data.py\",\n            \"line\": 30,\n            \"test\": \"test_DataBunch_Create\"\n        },\n        {\n            \"file\": \"tests/test_basic_data.py\",\n            \"line\": 44,\n            \"test\": \"test_DataBunch_no_valid_dl\"\n        }\n    ],\n    \"fastai.basic_data.DataBunch.one_batch\": [\n        {\n            \"file\": \"tests/test_basic_data.py\",\n            \"line\": 58,\n            \"test\": \"test_DataBunch_onebatch\"\n        },\n        {\n            \"file\": \"tests/test_text_data.py\",\n            \"line\": 83,\n            \"test\": \"test_should_load_backwards_lm_1\"\n        },\n        {\n            \"file\": \"tests/test_text_data.py\",\n            \"line\": 99,\n            \"test\": \"test_should_load_backwards_lm_2\"\n        },\n        {\n            \"file\": \"tests/test_text_data.py\",\n            \"line\": 110,\n            \"test\": \"test_backwards_cls_databunch\"\n        },\n        {\n            \"file\": \"tests/test_basic_data.py\",\n            \"line\": 83,\n            \"test\": \"test_DataBunch_save_load\"\n        }\n    ],\n    \"fastai.basic_data.DataBunch.one_item\": [\n        {\n            \"file\": \"tests/test_basic_data.py\",\n            \"line\": 67,\n            \"test\": \"test_DataBunch_oneitem\"\n        }\n    ],\n    \"fastai.basic_data.DataBunch.save\": [\n        {\n            \"file\": \"tests/test_basic_data.py\",\n            \"line\": 83,\n            \"test\": \"test_DataBunch_save_load\"\n        }\n    ],\n    \"fastai.basic_data.DataBunch.show_batch\": [\n        {\n            \"file\": \"tests/test_basic_data.py\",\n            \"line\": 75,\n            \"test\": \"test_DataBunch_show_batch\"\n        }\n    ],\n    \"fastai.basic_data.intercept_args\": [\n        {\n            \"file\": \"tests/test_basic_data.py\",\n            \"line\": 18,\n            \"test\": \"test_intercept_args\"\n        }\n    ],\n    \"fastai.basic_data.load_data\": [\n        {\n            \"file\": \"tests/test_text_data.py\",\n            \"line\": 129,\n            \"test\": \"test_load_and_save_test\"\n        },\n        {\n            \"file\": \"tests/test_basic_data.py\",\n            \"line\": 83,\n            \"test\": \"test_DataBunch_save_load\"\n        }\n    ],\n    \"fastai.basic_train.Learner.destroy\": [\n        {\n            \"file\": \"tests/test_basic_train.py\",\n            \"line\": 170,\n            \"test\": \"test_destroy\"\n        },\n        {\n            \"file\": \"tests/test_basic_train.py\",\n            \"line\": 213,\n            \"test\": \"test_memory\"\n        }\n    ],\n    \"fastai.basic_train.Learner.export\": [\n        {\n            \"file\": \"tests/test_basic_train.py\",\n            \"line\": 230,\n            \"test\": \"test_export_load_learner\"\n        }\n    ],\n    \"fastai.basic_train.Learner.fit\": [\n        {\n            \"file\": \"tests/test_train.py\",\n            \"line\": 28,\n            \"test\": \"test_fit\"\n        }\n    ],\n    \"fastai.basic_train.Learner.freeze\": [\n        {\n            \"file\": \"tests/test_basic_train.py\",\n            \"line\": 49,\n            \"test\": \"test_freeze\"\n        }\n    ],\n    \"fastai.basic_train.Learner.freeze_to\": [\n        {\n            \"file\": \"tests/test_basic_train.py\",\n            \"line\": 39,\n            \"test\": \"test_freeze_to\"\n        }\n    ],\n    \"fastai.basic_train.Learner.get_preds\": [\n        {\n            \"file\": \"tests/test_basic_train.py\",\n            \"line\": 32,\n            \"test\": \"test_get_preds\"\n        }\n    ],\n    \"fastai.basic_train.Learner.load\": [\n        {\n            \"file\": \"tests/test_basic_train.py\",\n            \"line\": 104,\n            \"test\": \"test_save_load\"\n        },\n        {\n            \"file\": \"tests/test_basic_train.py\",\n            \"line\": 213,\n            \"test\": \"test_memory\"\n        }\n    ],\n    \"fastai.basic_train.Learner.predict\": [\n        {\n            \"file\": \"tests/test_vision_train.py\",\n            \"line\": 63,\n            \"test\": \"test_preds\"\n        },\n        {\n            \"file\": \"tests/test_vision_train.py\",\n            \"line\": 89,\n            \"test\": \"test_models_meta\"\n        }\n    ],\n    \"fastai.basic_train.Learner.purge\": [\n        {\n            \"file\": \"tests/test_basic_train.py\",\n            \"line\": 76,\n            \"test\": \"test_purge\"\n        },\n        {\n            \"file\": \"tests/test_basic_train.py\",\n            \"line\": 104,\n            \"test\": \"test_save_load\"\n        },\n        {\n            \"file\": \"tests/test_basic_train.py\",\n            \"line\": 213,\n            \"test\": \"test_memory\"\n        }\n    ],\n    \"fastai.basic_train.Learner.save\": [\n        {\n            \"file\": \"tests/test_basic_train.py\",\n            \"line\": 104,\n            \"test\": \"test_save_load\"\n        },\n        {\n            \"file\": \"tests/test_basic_train.py\",\n            \"line\": 213,\n            \"test\": \"test_memory\"\n        }\n    ],\n    \"fastai.basic_train.Learner.unfreeze\": [\n        {\n            \"file\": \"tests/test_basic_train.py\",\n            \"line\": 58,\n            \"test\": \"test_unfreeze\"\n        }\n    ],\n    \"fastai.basic_train.Learner.validate\": [\n        {\n            \"file\": \"tests/test_collab_train.py\",\n            \"line\": 16,\n            \"test\": \"test_val_loss\"\n        },\n        {\n            \"file\": \"tests/test_text_train.py\",\n            \"line\": 56,\n            \"test\": \"test_val_loss\"\n        }\n    ],\n    \"fastai.basic_train.Recorder\": [\n        {\n            \"file\": \"tests/test_vision_train.py\",\n            \"line\": 49,\n            \"test\": \"test_1cycle_lrs\"\n        },\n        {\n            \"file\": \"tests/test_vision_train.py\",\n            \"line\": 56,\n            \"test\": \"test_1cycle_moms\"\n        }\n    ],\n    \"fastai.basic_train.load_learner\": [\n        {\n            \"file\": \"tests/test_basic_train.py\",\n            \"line\": 230,\n            \"test\": \"test_export_load_learner\"\n        }\n    ],\n    \"fastai.basic_train.validate\": [\n        {\n            \"file\": \"tests/test_tabular_train.py\",\n            \"line\": 26,\n            \"test\": \"test_accuracy\"\n        }\n    ],\n    \"fastai.callback.AverageMetric\": [\n        {\n            \"file\": \"tests/test_metrics.py\",\n            \"line\": 213,\n            \"test\": \"test_average_metric_naming\"\n        }\n    ],\n    \"fastai.callback.Callback\": [\n        {\n            \"file\": \"tests/test_callback.py\",\n            \"line\": 33,\n            \"test\": \"test_callbacks_learner\"\n        },\n        {\n            \"file\": \"tests/test_callback.py\",\n            \"line\": 64,\n            \"test\": \"test_callbacks_fit\"\n        }\n    ],\n    \"fastai.callbacks.csv_logger.CSVLogger\": [\n        {\n            \"file\": \"tests/test_callbacks_csv_logger.py\",\n            \"line\": 37,\n            \"test\": \"test_logger\"\n        }\n    ],\n    \"fastai.callbacks.hooks.hook_output\": [\n        {\n            \"file\": \"tests/test_callbacks_hooks.py\",\n            \"line\": 74,\n            \"test\": \"test_hook_output_basics\"\n        }\n    ],\n    \"fastai.callbacks.hooks.model_summary\": [\n        {\n            \"file\": \"tests/test_callbacks_hooks.py\",\n            \"line\": 18,\n            \"test\": \"test_model_summary_vision\"\n        },\n        {\n            \"file\": \"tests/test_callbacks_hooks.py\",\n            \"line\": 26,\n            \"test\": \"test_model_summary_text\"\n        },\n        {\n            \"file\": \"tests/test_callbacks_hooks.py\",\n            \"line\": 33,\n            \"test\": \"test_model_summary_tabular\"\n        },\n        {\n            \"file\": \"tests/test_callbacks_hooks.py\",\n            \"line\": 48,\n            \"test\": \"test_model_summary_collab\"\n        },\n        {\n            \"file\": \"tests/test_basic_train.py\",\n            \"line\": 230,\n            \"test\": \"test_export_load_learner\"\n        }\n    ],\n    \"fastai.callbacks.mem.PeakMemMetric\": [\n        {\n            \"file\": \"tests/test_callbacks_mem.py\",\n            \"line\": 8,\n            \"test\": \"test_peak_mem_metric\"\n        }\n    ],\n    \"fastai.callbacks.misc.StopAfterNBatches\": [\n        {\n            \"file\": \"tests/test_callbacks_misc.py\",\n            \"line\": 22,\n            \"test\": \"test_stop_after_n_batches\"\n        }\n    ],\n    \"fastai.core.Category\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 242,\n            \"test\": \"test_itembase_eq\"\n        }\n    ],\n    \"fastai.core.Category.__hash__\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 304,\n            \"test\": \"test_itembase_hash\"\n        }\n    ],\n    \"fastai.core.FloatItem\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 242,\n            \"test\": \"test_itembase_eq\"\n        }\n    ],\n    \"fastai.core.FloatItem.__hash__\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 304,\n            \"test\": \"test_itembase_hash\"\n        }\n    ],\n    \"fastai.core.ItemBase.__eq__\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 242,\n            \"test\": \"test_itembase_eq\"\n        },\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 304,\n            \"test\": \"test_itembase_hash\"\n        }\n    ],\n    \"fastai.core.MultiCategory\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 242,\n            \"test\": \"test_itembase_eq\"\n        }\n    ],\n    \"fastai.core.MultiCategory.__hash__\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 304,\n            \"test\": \"test_itembase_hash\"\n        }\n    ],\n    \"fastai.core.arrays_split\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 141,\n            \"test\": \"test_arrays_split\"\n        }\n    ],\n    \"fastai.core.camel2snake\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 164,\n            \"test\": \"test_camel2snake\"\n        }\n    ],\n    \"fastai.core.chunks\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 46,\n            \"test\": \"test_chunks\"\n        }\n    ],\n    \"fastai.core.df_names_to_idx\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 213,\n            \"test\": \"test_df_names_to_idx\"\n        }\n    ],\n    \"fastai.core.download_url\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 193,\n            \"test\": \"test_download_url\"\n        }\n    ],\n    \"fastai.core.even_mults\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 178,\n            \"test\": \"test_even_mults\"\n        }\n    ],\n    \"fastai.core.find_classes\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 131,\n            \"test\": \"test_find_classes\"\n        }\n    ],\n    \"fastai.core.idx_dict\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 125,\n            \"test\": \"test_idx_dict\"\n        }\n    ],\n    \"fastai.core.ifnone\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 39,\n            \"test\": \"test_ifnone\"\n        }\n    ],\n    \"fastai.core.is1d\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 235,\n            \"test\": \"test_is1d\"\n        }\n    ],\n    \"fastai.core.is_dict\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 76,\n            \"test\": \"test_dict\"\n        }\n    ],\n    \"fastai.core.is_listy\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 59,\n            \"test\": \"test_listy\"\n        }\n    ],\n    \"fastai.core.is_tuple\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 70,\n            \"test\": \"test_tuple\"\n        }\n    ],\n    \"fastai.core.join_path\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 206,\n            \"test\": \"test_join_paths\"\n        }\n    ],\n    \"fastai.core.listify\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 25,\n            \"test\": \"test_listify\"\n        }\n    ],\n    \"fastai.core.noop\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 82,\n            \"test\": \"test_noop\"\n        }\n    ],\n    \"fastai.core.num_cpus\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 8,\n            \"test\": \"test_cpus\"\n        }\n    ],\n    \"fastai.core.one_hot\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 218,\n            \"test\": \"test_one_hot\"\n        }\n    ],\n    \"fastai.core.partition\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 94,\n            \"test\": \"test_partition_functionality\"\n        }\n    ],\n    \"fastai.core.random_split\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 154,\n            \"test\": \"test_random_split\"\n        }\n    ],\n    \"fastai.core.recurse\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 29,\n            \"test\": \"test_recurse\"\n        }\n    ],\n    \"fastai.core.series2cat\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 184,\n            \"test\": \"test_series2cat\"\n        }\n    ],\n    \"fastai.core.subplots\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 222,\n            \"test\": \"test_subplots_multi_row_cols\"\n        },\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 229,\n            \"test\": \"test_subplots_single\"\n        }\n    ],\n    \"fastai.core.to_int\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 86,\n            \"test\": \"test_to_int\"\n        }\n    ],\n    \"fastai.core.uniqueify\": [\n        {\n            \"file\": \"tests/test_core.py\",\n            \"line\": 53,\n            \"test\": \"test_uniqueify\"\n        }\n    ],\n    \"fastai.data_block.CategoryProcessor.process_one\": [\n        {\n            \"file\": \"tests/test_data_block.py\",\n            \"line\": 80,\n            \"test\": \"test_category_processor_existing_class\"\n        },\n        {\n            \"file\": \"tests/test_data_block.py\",\n            \"line\": 91,\n            \"test\": \"test_category_processor_non_existing_class\"\n        }\n    ],\n    \"fastai.data_block.ItemList.filter_by_folder\": [\n        {\n            \"file\": \"tests/test_data_block.py\",\n            \"line\": 161,\n            \"test\": \"test_filter_by_folder\"\n        }\n    ],\n    \"fastai.data_block.ItemList.filter_by_rand\": [\n        {\n            \"file\": \"tests/test_data_block.py\",\n            \"line\": 112,\n            \"test\": \"test_filter_by_rand\"\n        }\n    ],\n    \"fastai.data_block.ItemList.label_from_folder\": [\n        {\n            \"file\": \"tests/test_text_data.py\",\n            \"line\": 30,\n            \"test\": \"test_from_folder\"\n        },\n        {\n            \"file\": \"tests/test_text_data.py\",\n            \"line\": 42,\n            \"test\": \"test_filter_classes\"\n        }\n    ],\n    \"fastai.data_block.ItemList.split_by_rand_pct\": [\n        {\n            \"file\": \"tests/test_data_block.py\",\n            \"line\": 103,\n            \"test\": \"test_splitdata_datasets\"\n        }\n    ],\n    \"fastai.data_block.ItemList.split_subsets\": [\n        {\n            \"file\": \"tests/test_data_block.py\",\n            \"line\": 121,\n            \"test\": \"test_split_subsets\"\n        }\n    ],\n    \"fastai.data_block.LabelLists.databunch\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 217,\n            \"test\": \"test_vision_datasets\"\n        }\n    ],\n    \"fastai.datasets.Config\": [\n        {\n            \"file\": \"tests/test_datasets.py\",\n            \"line\": 15,\n            \"test\": \"test_creates_config\"\n        },\n        {\n            \"file\": \"tests/test_datasets.py\",\n            \"line\": 26,\n            \"test\": \"test_load_config\"\n        },\n        {\n            \"file\": \"tests/test_datasets.py\",\n            \"line\": 29,\n            \"test\": \"test_default_config\"\n        },\n        {\n            \"file\": \"tests/test_datasets.py\",\n            \"line\": 42,\n            \"test\": \"test_user_config\"\n        }\n    ],\n    \"fastai.datasets.datapath4file\": [\n        {\n            \"file\": \"tests/test_datasets.py\",\n            \"line\": 26,\n            \"test\": \"test_load_config\"\n        },\n        {\n            \"file\": \"tests/test_datasets.py\",\n            \"line\": 42,\n            \"test\": \"test_user_config\"\n        }\n    ],\n    \"fastai.datasets.download_data\": [\n        {\n            \"file\": \"tests/test_datasets.py\",\n            \"line\": 26,\n            \"test\": \"test_load_config\"\n        },\n        {\n            \"file\": \"tests/test_datasets.py\",\n            \"line\": 42,\n            \"test\": \"test_user_config\"\n        }\n    ],\n    \"fastai.datasets.untar_data\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 165,\n            \"test\": \"test_trunc_download\"\n        },\n        {\n            \"file\": \"tests/test_datasets.py\",\n            \"line\": 26,\n            \"test\": \"test_load_config\"\n        },\n        {\n            \"file\": \"tests/test_datasets.py\",\n            \"line\": 42,\n            \"test\": \"test_user_config\"\n        }\n    ],\n    \"fastai.datasets.url2path\": [\n        {\n            \"file\": \"tests/test_datasets.py\",\n            \"line\": 26,\n            \"test\": \"test_load_config\"\n        },\n        {\n            \"file\": \"tests/test_datasets.py\",\n            \"line\": 42,\n            \"test\": \"test_user_config\"\n        }\n    ],\n    \"fastai.gen_doc.doctest.merge_registries\": [\n        {\n            \"file\": \"tests/test_gen_doc_nbtest.py\",\n            \"line\": 199,\n            \"test\": \"test_merge_registries\"\n        }\n    ],\n    \"fastai.gen_doc.doctest.this_tests\": [\n        {\n            \"file\": \"tests/test_gen_doc_nbtest.py\",\n            \"line\": 75,\n            \"test\": \"test_this_tests\"\n        }\n    ],\n    \"fastai.gen_doc.nbtest._fuzzy_line_match\": [\n        {\n            \"file\": \"tests/test_gen_doc_nbtest.py\",\n            \"line\": 61,\n            \"test\": \"test_fuzzy_line_match\"\n        }\n    ],\n    \"fastai.gen_doc.nbtest._is_file_match\": [\n        {\n            \"file\": \"tests/test_gen_doc_nbtest.py\",\n            \"line\": 16,\n            \"test\": \"test_is_file_match\"\n        }\n    ],\n    \"fastai.gen_doc.nbtest._submodule_name\": [\n        {\n            \"file\": \"tests/test_gen_doc_nbtest.py\",\n            \"line\": 7,\n            \"test\": \"test_submodule_name\"\n        }\n    ],\n    \"fastai.gen_doc.nbtest.direct_test_match\": [\n        {\n            \"file\": \"tests/test_gen_doc_nbtest.py\",\n            \"line\": 38,\n            \"test\": \"test_direct_test_match\"\n        },\n        {\n            \"file\": \"tests/test_gen_doc_nbtest.py\",\n            \"line\": 46,\n            \"test\": \"test_direct_test_match_class_methods\"\n        }\n    ],\n    \"fastai.gen_doc.nbtest.fuzzy_test_match\": [\n        {\n            \"file\": \"tests/test_gen_doc_nbtest.py\",\n            \"line\": 38,\n            \"test\": \"test_fuzzy_test_match\"\n        }\n    ],\n    \"fastai.gen_doc.nbtest.get_file\": [\n        {\n            \"file\": \"tests/test_gen_doc_nbtest.py\",\n            \"line\": 26,\n            \"test\": \"test_wrapped_functions\"\n        }\n    ],\n    \"fastai.gen_doc.nbtest.get_tests_dir\": [\n        {\n            \"file\": \"tests/test_gen_doc_nbtest.py\",\n            \"line\": 70,\n            \"test\": \"test_get_tests_dir\"\n        }\n    ],\n    \"fastai.layers.SelfAttention\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 269,\n            \"test\": \"test_keep_parameter\"\n        }\n    ],\n    \"fastai.metrics.accuracy\": [\n        {\n            \"file\": \"tests/test_metrics.py\",\n            \"line\": 44,\n            \"test\": \"test_accuracy\"\n        },\n        {\n            \"file\": \"tests/test_vision_train.py\",\n            \"line\": 41,\n            \"test\": \"test_accuracy\"\n        }\n    ],\n    \"fastai.metrics.accuracy_thresh\": [\n        {\n            \"file\": \"tests/test_metrics.py\",\n            \"line\": 99,\n            \"test\": \"test_accuracy_thresh\"\n        }\n    ],\n    \"fastai.metrics.dice\": [\n        {\n            \"file\": \"tests/test_metrics.py\",\n            \"line\": 108,\n            \"test\": \"test_dice\"\n        },\n        {\n            \"file\": \"tests/test_metrics.py\",\n            \"line\": 118,\n            \"test\": \"test_dice_iou\"\n        }\n    ],\n    \"fastai.metrics.error_rate\": [\n        {\n            \"file\": \"tests/test_metrics.py\",\n            \"line\": 86,\n            \"test\": \"test_error_rate\"\n        },\n        {\n            \"file\": \"tests/test_vision_train.py\",\n            \"line\": 45,\n            \"test\": \"test_error_rate\"\n        }\n    ],\n    \"fastai.metrics.exp_rmspe\": [\n        {\n            \"file\": \"tests/test_metrics.py\",\n            \"line\": 90,\n            \"test\": \"test_exp_rmspe\"\n        },\n        {\n            \"file\": \"tests/test_metrics.py\",\n            \"line\": 94,\n            \"test\": \"test_exp_rmspe_num_of_ele\"\n        }\n    ],\n    \"fastai.metrics.explained_variance\": [\n        {\n            \"file\": \"tests/test_metrics.py\",\n            \"line\": 174,\n            \"test\": \"test_explained_variance\"\n        }\n    ],\n    \"fastai.metrics.fbeta\": [\n        {\n            \"file\": \"tests/test_metrics.py\",\n            \"line\": 126,\n            \"test\": \"test_fbeta\"\n        }\n    ],\n    \"fastai.metrics.foreground_acc\": [\n        {\n            \"file\": \"tests/test_metrics.py\",\n            \"line\": 78,\n            \"test\": \"test_foreground_acc\"\n        }\n    ],\n    \"fastai.metrics.mean_absolute_error\": [\n        {\n            \"file\": \"tests/test_metrics.py\",\n            \"line\": 135,\n            \"test\": \"test_mae\"\n        }\n    ],\n    \"fastai.metrics.mean_squared_error\": [\n        {\n            \"file\": \"tests/test_metrics.py\",\n            \"line\": 144,\n            \"test\": \"test_mse\"\n        }\n    ],\n    \"fastai.metrics.mean_squared_logarithmic_error\": [\n        {\n            \"file\": \"tests/test_metrics.py\",\n            \"line\": 163,\n            \"test\": \"test_msle\"\n        }\n    ],\n    \"fastai.metrics.r2_score\": [\n        {\n            \"file\": \"tests/test_metrics.py\",\n            \"line\": 185,\n            \"test\": \"test_r2_score\"\n        }\n    ],\n    \"fastai.metrics.root_mean_squared_error\": [\n        {\n            \"file\": \"tests/test_metrics.py\",\n            \"line\": 153,\n            \"test\": \"test_rmse\"\n        }\n    ],\n    \"fastai.metrics.top_k_accuracy\": [\n        {\n            \"file\": \"tests/test_metrics.py\",\n            \"line\": 69,\n            \"test\": \"test_top_k_accuracy\"\n        }\n    ],\n    \"fastai.tabular.data.TabularList.from_df\": [\n        {\n            \"file\": \"tests/test_tabular_data.py\",\n            \"line\": 5,\n            \"test\": \"test_from_df\"\n        }\n    ],\n    \"fastai.tabular.models._cl_int_from_learner\": [\n        {\n            \"file\": \"tests/test_vision_train.py\",\n            \"line\": 72,\n            \"test\": \"test_interp\"\n        }\n    ],\n    \"fastai.tabular.models._learner_interpret\": [\n        {\n            \"file\": \"tests/test_vision_train.py\",\n            \"line\": 78,\n            \"test\": \"test_interp_shortcut\"\n        }\n    ],\n    \"fastai.tabular.transform.Categorify\": [\n        {\n            \"file\": \"tests/test_tabular_transform.py\",\n            \"line\": 6,\n            \"test\": \"test_categorify\"\n        }\n    ],\n    \"fastai.tabular.transform.FillMissing\": [\n        {\n            \"file\": \"tests/test_tabular_transform.py\",\n            \"line\": 30,\n            \"test\": \"test_default_fill_strategy_is_median\"\n        }\n    ],\n    \"fastai.tabular.transform.FillMissing.apply_test\": [\n        {\n            \"file\": \"tests/test_tabular_transform.py\",\n            \"line\": 36,\n            \"test\": \"test_fill_missing_leaves_no_na_values\"\n        },\n        {\n            \"file\": \"tests/test_tabular_transform.py\",\n            \"line\": 49,\n            \"test\": \"test_fill_missing_returns_correct_medians\"\n        }\n    ],\n    \"fastai.tabular.transform.FillMissing.apply_train\": [\n        {\n            \"file\": \"tests/test_tabular_transform.py\",\n            \"line\": 36,\n            \"test\": \"test_fill_missing_leaves_no_na_values\"\n        },\n        {\n            \"file\": \"tests/test_tabular_transform.py\",\n            \"line\": 49,\n            \"test\": \"test_fill_missing_returns_correct_medians\"\n        }\n    ],\n    \"fastai.tabular.transform.cont_cat_split\": [\n        {\n            \"file\": \"tests/test_tabular_transform.py\",\n            \"line\": 64,\n            \"test\": \"test_cont_cat_split\"\n        }\n    ],\n    \"fastai.text.data.SortSampler\": [\n        {\n            \"file\": \"tests/test_text_data.py\",\n            \"line\": 158,\n            \"test\": \"test_sampler\"\n        },\n        {\n            \"file\": \"tests/test_text_data.py\",\n            \"line\": 158,\n            \"test\": \"test_sort_sampler\"\n        }\n    ],\n    \"fastai.text.data.SortishSampler\": [\n        {\n            \"file\": \"tests/test_text_data.py\",\n            \"line\": 143,\n            \"test\": \"test_sortish_sampler\"\n        }\n    ],\n    \"fastai.text.data.TextDataBunch.from_csv\": [\n        {\n            \"file\": \"tests/test_text_data.py\",\n            \"line\": 57,\n            \"test\": \"test_from_csv_and_from_df\"\n        }\n    ],\n    \"fastai.text.data.TextDataBunch.from_df\": [\n        {\n            \"file\": \"tests/test_text_data.py\",\n            \"line\": 57,\n            \"test\": \"test_from_csv_and_from_df\"\n        }\n    ],\n    \"fastai.text.data.TextDataBunch.from_ids\": [\n        {\n            \"file\": \"tests/test_text_data.py\",\n            \"line\": 173,\n            \"test\": \"test_from_ids_works_for_equally_length_sentences\"\n        },\n        {\n            \"file\": \"tests/test_text_data.py\",\n            \"line\": 181,\n            \"test\": \"test_from_ids_works_for_variable_length_sentences\"\n        },\n        {\n            \"file\": \"tests/test_text_data.py\",\n            \"line\": 189,\n            \"test\": \"test_from_ids_exports_classes\"\n        }\n    ],\n    \"fastai.text.learner.language_model_learner\": [\n        {\n            \"file\": \"tests/test_text_train.py\",\n            \"line\": 61,\n            \"test\": \"test_qrnn_works_with_no_split\"\n        },\n        {\n            \"file\": \"tests/test_text_train.py\",\n            \"line\": 73,\n            \"test\": \"test_qrnn_works_if_split_fn_provided\"\n        }\n    ],\n    \"fastai.text.learner.text_classifier_learner\": [\n        {\n            \"file\": \"tests/test_text_train.py\",\n            \"line\": 100,\n            \"test\": \"test_classifier\"\n        },\n        {\n            \"file\": \"tests/test_text_train.py\",\n            \"line\": 139,\n            \"test\": \"test_order_preds\"\n        }\n    ],\n    \"fastai.text.models.qrnn.BwdForgetMultGPU\": [\n        {\n            \"file\": \"tests/test_text_qrnn.py\",\n            \"line\": 28,\n            \"test\": \"test_forget_mult_cuda\"\n        }\n    ],\n    \"fastai.text.models.qrnn.ForgetMultGPU\": [\n        {\n            \"file\": \"tests/test_text_qrnn.py\",\n            \"line\": 7,\n            \"test\": \"test_forget_mult_forward_gpu\"\n        },\n        {\n            \"file\": \"tests/test_text_qrnn.py\",\n            \"line\": 27,\n            \"test\": \"test_compare_forget_mult_forward_implementations\"\n        },\n        {\n            \"file\": \"tests/test_text_qrnn.py\",\n            \"line\": 28,\n            \"test\": \"test_forget_mult_cuda\"\n        }\n    ],\n    \"fastai.text.models.qrnn.QRNN\": [\n        {\n            \"file\": \"tests/test_text_qrnn.py\",\n            \"line\": 105,\n            \"test\": \"test_qrnn_bidir\"\n        }\n    ],\n    \"fastai.text.models.qrnn.QRNNLayer\": [\n        {\n            \"file\": \"tests/test_text_qrnn.py\",\n            \"line\": 89,\n            \"test\": \"test_qrnn_layer\"\n        }\n    ],\n    \"fastai.text.models.qrnn.forget_mult_CPU\": [\n        {\n            \"file\": \"tests/test_text_qrnn.py\",\n            \"line\": 75,\n            \"test\": \"test_forget_mult\"\n        }\n    ],\n    \"fastai.text.transform.Tokenizer\": [\n        {\n            \"file\": \"tests/test_text_transform.py\",\n            \"line\": 15,\n            \"test\": \"test_tokenize\"\n        },\n        {\n            \"file\": \"tests/test_text_transform.py\",\n            \"line\": 24,\n            \"test\": \"test_tokenize_handles_empty_lines\"\n        },\n        {\n            \"file\": \"tests/test_text_transform.py\",\n            \"line\": 32,\n            \"test\": \"test_tokenize_ignores_extraneous_space\"\n        }\n    ],\n    \"fastai.text.transform.Vocab.numericalize\": [\n        {\n            \"file\": \"tests/test_text_transform.py\",\n            \"line\": 39,\n            \"test\": \"test_numericalize_and_textify\"\n        }\n    ],\n    \"fastai.text.transform.Vocab.textify\": [\n        {\n            \"file\": \"tests/test_text_transform.py\",\n            \"line\": 39,\n            \"test\": \"test_numericalize_and_textify\"\n        }\n    ],\n    \"fastai.text.transform.deal_caps\": [\n        {\n            \"file\": \"tests/test_text_transform.py\",\n            \"line\": 5,\n            \"test\": \"test_rules\"\n        }\n    ],\n    \"fastai.text.transform.fix_html\": [\n        {\n            \"file\": \"tests/test_text_transform.py\",\n            \"line\": 5,\n            \"test\": \"test_rules\"\n        }\n    ],\n    \"fastai.text.transform.replace_all_caps\": [\n        {\n            \"file\": \"tests/test_text_transform.py\",\n            \"line\": 5,\n            \"test\": \"test_rules\"\n        }\n    ],\n    \"fastai.text.transform.replace_rep\": [\n        {\n            \"file\": \"tests/test_text_transform.py\",\n            \"line\": 5,\n            \"test\": \"test_rules\"\n        }\n    ],\n    \"fastai.text.transform.replace_wrep\": [\n        {\n            \"file\": \"tests/test_text_transform.py\",\n            \"line\": 5,\n            \"test\": \"test_rules\"\n        }\n    ],\n    \"fastai.text.transform.rm_useless_spaces\": [\n        {\n            \"file\": \"tests/test_text_transform.py\",\n            \"line\": 5,\n            \"test\": \"test_rules\"\n        }\n    ],\n    \"fastai.text.transform.spec_add_spaces\": [\n        {\n            \"file\": \"tests/test_text_transform.py\",\n            \"line\": 5,\n            \"test\": \"test_rules\"\n        }\n    ],\n    \"fastai.torch_core.NoneReduceOnCPU\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 249,\n            \"test\": \"test_none_reduce_on_cpu\"\n        }\n    ],\n    \"fastai.torch_core.apply_init\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 47,\n            \"test\": \"test_apply_init\"\n        }\n    ],\n    \"fastai.torch_core.apply_leaf\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 47,\n            \"test\": \"test_apply_init\"\n        }\n    ],\n    \"fastai.torch_core.batch_to_half\": [\n        {\n            \"file\": \"tests/test_fp16.py\",\n            \"line\": 32,\n            \"test\": \"test_batch_to_half\"\n        }\n    ],\n    \"fastai.torch_core.children\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 197,\n            \"test\": \"test_children\"\n        }\n    ],\n    \"fastai.torch_core.first_layer\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 213,\n            \"test\": \"test_first_layer\"\n        }\n    ],\n    \"fastai.torch_core.in_channels\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 59,\n            \"test\": \"test_in_channels\"\n        },\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 64,\n            \"test\": \"test_in_channels_no_weights\"\n        }\n    ],\n    \"fastai.torch_core.last_layer\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 220,\n            \"test\": \"test_last_layer\"\n        }\n    ],\n    \"fastai.torch_core.model2half\": [\n        {\n            \"file\": \"tests/test_fp16.py\",\n            \"line\": 6,\n            \"test\": \"test_model2half\"\n        },\n        {\n            \"file\": \"tests/test_fp16.py\",\n            \"line\": 16,\n            \"test\": \"test_model2half_forward\"\n        }\n    ],\n    \"fastai.torch_core.model_type\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 227,\n            \"test\": \"test_model_type\"\n        }\n    ],\n    \"fastai.torch_core.np2model_tensor\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 94,\n            \"test\": \"test_np2model_tensor\"\n        }\n    ],\n    \"fastai.torch_core.np_address\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 100,\n            \"test\": \"test_np_address\"\n        }\n    ],\n    \"fastai.torch_core.num_children\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 206,\n            \"test\": \"test_num_children\"\n        }\n    ],\n    \"fastai.torch_core.range_children\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 70,\n            \"test\": \"test_range_children\"\n        }\n    ],\n    \"fastai.torch_core.requires_grad\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 32,\n            \"test\": \"test_requires_grad\"\n        },\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 37,\n            \"test\": \"test_requires_grad_set\"\n        }\n    ],\n    \"fastai.torch_core.set_bn_eval\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 87,\n            \"test\": \"test_set_bn_eval\"\n        }\n    ],\n    \"fastai.torch_core.split_model\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 75,\n            \"test\": \"test_split_model\"\n        }\n    ],\n    \"fastai.torch_core.split_no_wd_params\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 81,\n            \"test\": \"test_split_no_wd_params\"\n        }\n    ],\n    \"fastai.torch_core.tensor\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 13,\n            \"test\": \"test_tensor_with_list\"\n        },\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 18,\n            \"test\": \"test_tensor_with_ndarray\"\n        },\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 25,\n            \"test\": \"test_tensor_with_tensor\"\n        }\n    ],\n    \"fastai.torch_core.to_cpu\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 154,\n            \"test\": \"test_to_cpu\"\n        }\n    ],\n    \"fastai.torch_core.to_data\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 106,\n            \"test\": \"test_to_data\"\n        }\n    ],\n    \"fastai.torch_core.to_detach\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 131,\n            \"test\": \"test_to_detach\"\n        }\n    ],\n    \"fastai.torch_core.to_float\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 184,\n            \"test\": \"test_to_float\"\n        }\n    ],\n    \"fastai.torch_core.to_half\": [\n        {\n            \"file\": \"tests/test_fp16.py\",\n            \"line\": 25,\n            \"test\": \"test_to_half\"\n        },\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 171,\n            \"test\": \"test_to_half\"\n        }\n    ],\n    \"fastai.torch_core.to_np\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 244,\n            \"test\": \"test_to_np\"\n        }\n    ],\n    \"fastai.torch_core.trange_of\": [\n        {\n            \"file\": \"tests/test_torch_core.py\",\n            \"line\": 236,\n            \"test\": \"test_trange_of\"\n        }\n    ],\n    \"fastai.train.ClassificationInterpretation\": [\n        {\n            \"file\": \"tests/test_vision_train.py\",\n            \"line\": 95,\n            \"test\": \"test_ClassificationInterpretation\"\n        }\n    ],\n    \"fastai.train.ClassificationInterpretation.confusion_matrix\": [\n        {\n            \"file\": \"tests/test_tabular_train.py\",\n            \"line\": 84,\n            \"test\": \"test_confusion_tabular\"\n        }\n    ],\n    \"fastai.train.fit_one_cycle\": [\n        {\n            \"file\": \"tests/test_train.py\",\n            \"line\": 36,\n            \"test\": \"test_fit_one_cycle\"\n        }\n    ],\n    \"fastai.train.lr_find\": [\n        {\n            \"file\": \"tests/test_train.py\",\n            \"line\": 16,\n            \"test\": \"test_lr_find\"\n        },\n        {\n            \"file\": \"tests/test_vision_train.py\",\n            \"line\": 84,\n            \"test\": \"test_lrfind\"\n        }\n    ],\n    \"fastai.utils.collect_env.check_perf\": [\n        {\n            \"file\": \"tests/test_utils.py\",\n            \"line\": 18,\n            \"test\": \"test_check_perf\"\n        }\n    ],\n    \"fastai.utils.collect_env.show_install\": [\n        {\n            \"file\": \"tests/test_utils.py\",\n            \"line\": 8,\n            \"test\": \"test_show_install\"\n        }\n    ],\n    \"fastai.utils.mem.GPUMemTrace\": [\n        {\n            \"file\": \"tests/test_utils_mem.py\",\n            \"line\": 76,\n            \"test\": \"test_gpu_mem_trace\"\n        },\n        {\n            \"file\": \"tests/test_utils_mem.py\",\n            \"line\": 137,\n            \"test\": \"test_gpu_mem_trace_ctx\"\n        }\n    ],\n    \"fastai.utils.mem.gpu_mem_get\": [\n        {\n            \"file\": \"tests/test_utils_mem.py\",\n            \"line\": 25,\n            \"test\": \"test_gpu_mem_by_id\"\n        }\n    ],\n    \"fastai.utils.mem.gpu_mem_get_all\": [\n        {\n            \"file\": \"tests/test_utils_mem.py\",\n            \"line\": 35,\n            \"test\": \"test_gpu_mem_all\"\n        }\n    ],\n    \"fastai.utils.mem.gpu_mem_get_used\": [\n        {\n            \"file\": \"tests/test_utils_mem.py\",\n            \"line\": 56,\n            \"test\": \"test_gpu_mem_measure_consumed_reclaimed\"\n        }\n    ],\n    \"fastai.utils.mem.gpu_mem_trace\": [\n        {\n            \"file\": \"tests/test_utils_mem.py\",\n            \"line\": 178,\n            \"test\": \"test_gpu_mem_trace_decorator\"\n        }\n    ],\n    \"fastai.utils.mem.gpu_with_max_free_mem\": [\n        {\n            \"file\": \"tests/test_utils_mem.py\",\n            \"line\": 44,\n            \"test\": \"test_gpu_with_max_free_mem\"\n        }\n    ],\n    \"fastai.utils.mod_display.progress_disabled_ctx\": [\n        {\n            \"file\": \"tests/test_mod_display.py\",\n            \"line\": 16,\n            \"test\": \"test_progress_disabled_ctx\"\n        }\n    ],\n    \"fastai.vision.data.ImageDataBunch.from_csv\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 22,\n            \"test\": \"test_path_can_be_str_type\"\n        },\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 54,\n            \"test\": \"test_from_csv_and_from_df\"\n        }\n    ],\n    \"fastai.vision.data.ImageDataBunch.from_df\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 54,\n            \"test\": \"test_from_csv_and_from_df\"\n        }\n    ],\n    \"fastai.vision.data.ImageDataBunch.from_folder\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 26,\n            \"test\": \"test_from_folder\"\n        }\n    ],\n    \"fastai.vision.data.ImageDataBunch.from_lists\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 39,\n            \"test\": \"test_from_lists\"\n        }\n    ],\n    \"fastai.vision.data.ImageDataBunch.from_name_re\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 32,\n            \"test\": \"test_from_name_re\"\n        },\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 70,\n            \"test\": \"test_image_resize\"\n        }\n    ],\n    \"fastai.vision.data.ImageDataBunch.normalize\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 120,\n            \"test\": \"test_normalize\"\n        }\n    ],\n    \"fastai.vision.data.ImageList.from_csv\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 227,\n            \"test\": \"test_multi\"\n        }\n    ],\n    \"fastai.vision.data.ImageList.from_folder\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 217,\n            \"test\": \"test_vision_datasets\"\n        },\n        {\n            \"file\": \"tests/test_vision_gan.py\",\n            \"line\": 30,\n            \"test\": \"test_gan_datasets\"\n        }\n    ],\n    \"fastai.vision.data.ObjectItemList\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 267,\n            \"test\": \"test_coco\"\n        },\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 280,\n            \"test\": \"test_coco_same_size\"\n        },\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 297,\n            \"test\": \"test_coco_pickle\"\n        }\n    ],\n    \"fastai.vision.data.PointsItemList\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 254,\n            \"test\": \"test_points\"\n        }\n    ],\n    \"fastai.vision.data.SegmentationItemList\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 238,\n            \"test\": \"test_camvid\"\n        }\n    ],\n    \"fastai.vision.data.denormalize\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 134,\n            \"test\": \"test_denormalize\"\n        }\n    ],\n    \"fastai.vision.data.download_images\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 144,\n            \"test\": \"test_download_images\"\n        }\n    ],\n    \"fastai.vision.data.verify_image\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 201,\n            \"test\": \"test_verify_image\"\n        }\n    ],\n    \"fastai.vision.data.verify_images\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 190,\n            \"test\": \"test_verify_images\"\n        }\n    ],\n    \"fastai.vision.gan.GANModule\": [\n        {\n            \"file\": \"tests/test_vision_gan.py\",\n            \"line\": 67,\n            \"test\": \"test_gan_module\"\n        }\n    ],\n    \"fastai.vision.gan.GANTrainer\": [\n        {\n            \"file\": \"tests/test_vision_gan.py\",\n            \"line\": 80,\n            \"test\": \"test_gan_trainer\"\n        }\n    ],\n    \"fastai.vision.gan.NoisyItem\": [\n        {\n            \"file\": \"tests/test_vision_gan.py\",\n            \"line\": 37,\n            \"test\": \"test_noisy_item\"\n        }\n    ],\n    \"fastai.vision.gan.basic_critic\": [\n        {\n            \"file\": \"tests/test_vision_gan.py\",\n            \"line\": 56,\n            \"test\": \"test_basic_critic\"\n        }\n    ],\n    \"fastai.vision.gan.basic_generator\": [\n        {\n            \"file\": \"tests/test_vision_gan.py\",\n            \"line\": 46,\n            \"test\": \"test_basic_generator\"\n        }\n    ],\n    \"fastai.vision.image.Image\": [\n        {\n            \"file\": \"tests/test_vision_transform.py\",\n            \"line\": 58,\n            \"test\": \"test_mask_data_aug\"\n        }\n    ],\n    \"fastai.vision.image.Image.resize\": [\n        {\n            \"file\": \"tests/test_vision_image.py\",\n            \"line\": 49,\n            \"test\": \"test_image_resize_same_size_shortcut\"\n        }\n    ],\n    \"fastai.vision.image.ImageBBox\": [\n        {\n            \"file\": \"tests/test_vision_transform.py\",\n            \"line\": 37,\n            \"test\": \"test_bbox_data_aug\"\n        }\n    ],\n    \"fastai.vision.image.ImagePoints\": [\n        {\n            \"file\": \"tests/test_vision_transform.py\",\n            \"line\": 22,\n            \"test\": \"test_points_data_aug\"\n        }\n    ],\n    \"fastai.vision.image.ImageSegment\": [\n        {\n            \"file\": \"tests/test_vision_transform.py\",\n            \"line\": 58,\n            \"test\": \"test_mask_data_aug\"\n        }\n    ],\n    \"fastai.vision.image.pil2tensor\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 348,\n            \"test\": \"test_vision_pil2tensor\"\n        },\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 379,\n            \"test\": \"test_vision_pil2tensor_16bit\"\n        },\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 386,\n            \"test\": \"test_vision_pil2tensor_numpy\"\n        }\n    ],\n    \"fastai.vision.image.rle_decode\": [\n        {\n            \"file\": \"tests/test_vision_image.py\",\n            \"line\": 17,\n            \"test\": \"test_rle_decode_with_str\"\n        },\n        {\n            \"file\": \"tests/test_vision_image.py\",\n            \"line\": 23,\n            \"test\": \"test_rle_decode_empty_str\"\n        }\n    ],\n    \"fastai.vision.image.rle_encode\": [\n        {\n            \"file\": \"tests/test_vision_image.py\",\n            \"line\": 5,\n            \"test\": \"test_rle_encode_with_array\"\n        },\n        {\n            \"file\": \"tests/test_vision_image.py\",\n            \"line\": 11,\n            \"test\": \"test_rle_encode_all_zero_array\"\n        }\n    ],\n    \"fastai.vision.image.tis2hw\": [\n        {\n            \"file\": \"tests/test_vision_image.py\",\n            \"line\": 29,\n            \"test\": \"test_tis2hw_int\"\n        },\n        {\n            \"file\": \"tests/test_vision_image.py\",\n            \"line\": 34,\n            \"test\": \"test_tis2hw_3dims\"\n        },\n        {\n            \"file\": \"tests/test_vision_image.py\",\n            \"line\": 39,\n            \"test\": \"test_tis2hw_2dims\"\n        },\n        {\n            \"file\": \"tests/test_vision_image.py\",\n            \"line\": 44,\n            \"test\": \"test_tis2hw_str_raises_an_error\"\n        }\n    ],\n    \"fastai.vision.learner._cl_int_from_learner\": [\n        {\n            \"file\": \"tests/test_vision_train.py\",\n            \"line\": 72,\n            \"test\": \"test_interp\"\n        }\n    ],\n    \"fastai.vision.learner._learner_interpret\": [\n        {\n            \"file\": \"tests/test_vision_train.py\",\n            \"line\": 78,\n            \"test\": \"test_interp_shortcut\"\n        }\n    ],\n    \"fastai.vision.learner.create_body\": [\n        {\n            \"file\": \"tests/test_vision_learner.py\",\n            \"line\": 16,\n            \"test\": \"test_create_body\"\n        }\n    ],\n    \"fastai.vision.learner.create_head\": [\n        {\n            \"file\": \"tests/test_vision_learner.py\",\n            \"line\": 39,\n            \"test\": \"test_create_head\"\n        }\n    ],\n    \"fastai.vision.learner.has_pool_type\": [\n        {\n            \"file\": \"tests/test_vision_learner.py\",\n            \"line\": 45,\n            \"test\": \"test_has_pool_type\"\n        }\n    ],\n    \"fastai.vision.models.unet.DynamicUnet\": [\n        {\n            \"file\": \"tests/test_vision_models_unet.py\",\n            \"line\": 39,\n            \"test\": \"test_dynamic_unet_shape\"\n        },\n        {\n            \"file\": \"tests/test_vision_models_unet.py\",\n            \"line\": 45,\n            \"test\": \"test_unet_block_shapes\"\n        }\n    ],\n    \"fastai.vision.transform._crop\": [\n        {\n            \"file\": \"tests/test_vision_transform.py\",\n            \"line\": 111,\n            \"test\": \"test_deterministic_transforms\"\n        },\n        {\n            \"file\": \"tests/test_vision_transform.py\",\n            \"line\": 123,\n            \"test\": \"test_crop_without_size\"\n        },\n        {\n            \"file\": \"tests/test_vision_transform.py\",\n            \"line\": 131,\n            \"test\": \"test_crops_with_tensor_image_sizes\"\n        }\n    ],\n    \"fastai.vision.transform._dihedral\": [\n        {\n            \"file\": \"tests/test_vision_transform.py\",\n            \"line\": 102,\n            \"test\": \"test_all_dihedral\"\n        }\n    ],\n    \"fastai.vision.transform._flip_affine\": [\n        {\n            \"file\": \"tests/test_vision_transform.py\",\n            \"line\": 111,\n            \"test\": \"test_deterministic_transforms\"\n        }\n    ],\n    \"fastai.vision.transform._flip_lr\": [\n        {\n            \"file\": \"tests/test_vision_transform.py\",\n            \"line\": 111,\n            \"test\": \"test_deterministic_transforms\"\n        }\n    ],\n    \"fastai.vision.transform._pad\": [\n        {\n            \"file\": \"tests/test_vision_transform.py\",\n            \"line\": 111,\n            \"test\": \"test_deterministic_transforms\"\n        }\n    ],\n    \"fastai.vision.transform._perspective_warp\": [\n        {\n            \"file\": \"tests/test_vision_transform.py\",\n            \"line\": 83,\n            \"test\": \"test_all_warps\"\n        }\n    ],\n    \"fastai.vision.transform._rotate\": [\n        {\n            \"file\": \"tests/test_vision_transform.py\",\n            \"line\": 111,\n            \"test\": \"test_deterministic_transforms\"\n        }\n    ],\n    \"fastai.vision.transform._skew\": [\n        {\n            \"file\": \"tests/test_vision_transform.py\",\n            \"line\": 83,\n            \"test\": \"test_all_warps\"\n        }\n    ],\n    \"fastai.vision.transform._squish\": [\n        {\n            \"file\": \"tests/test_vision_transform.py\",\n            \"line\": 111,\n            \"test\": \"test_deterministic_transforms\"\n        }\n    ],\n    \"fastai.vision.transform._tilt\": [\n        {\n            \"file\": \"tests/test_vision_transform.py\",\n            \"line\": 83,\n            \"test\": \"test_all_warps\"\n        }\n    ],\n    \"fastai.vision.transform._zoom\": [\n        {\n            \"file\": \"tests/test_vision_transform.py\",\n            \"line\": 111,\n            \"test\": \"test_deterministic_transforms\"\n        }\n    ],\n    \"fastai.vision.transform.get_transforms\": [\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 313,\n            \"test\": \"test_image_to_image_different_y_size\"\n        },\n        {\n            \"file\": \"tests/test_vision_data.py\",\n            \"line\": 328,\n            \"test\": \"test_image_to_image_different_tfms\"\n        }\n    ],\n    \"fastai.widgets.image_cleaner.ImageCleaner\": [\n        {\n            \"file\": \"tests/test_widgets_image_cleaner.py\",\n            \"line\": 16,\n            \"test\": \"test_image_cleaner_index_length_mismatch\"\n        },\n        {\n            \"file\": \"tests/test_widgets_image_cleaner.py\",\n            \"line\": 23,\n            \"test\": \"test_image_cleaner_length_correct\"\n        },\n        {\n            \"file\": \"tests/test_widgets_image_cleaner.py\",\n            \"line\": 30,\n            \"test\": \"test_image_cleaner_wrong_input_type\"\n        }\n    ],\n    \"fastai.widgets.image_downloader.ImageDownloader\": [\n        {\n            \"file\": \"tests/test_widgets_image_cleaner.py\",\n            \"line\": 36,\n            \"test\": \"test_image_downloader_with_path\"\n        }\n    ]\n}"
  },
  {
    "path": "fastai/text/__init__.py",
    "content": "from .. import basics\nfrom ..basics import *\nfrom .learner import *\nfrom .data import *\nfrom .transform import *\nfrom .models import *\nfrom .. import text\n\n__all__ =  [*basics.__all__, *learner.__all__, *data.__all__, *transform.__all__, *models.__all__, 'text']\n\n"
  },
  {
    "path": "fastai/text/data.py",
    "content": "\"NLP data loading pipeline. Supports csv, folders, and preprocessed data.\"\nfrom ..torch_core import *\nfrom .transform import *\nfrom ..basic_data import *\nfrom ..data_block import *\nfrom ..layers import *\nfrom ..callback import Callback\n\n__all__ = ['LanguageModelPreLoader', 'SortSampler', 'SortishSampler', 'TextList', 'pad_collate', 'TextDataBunch',\n           'TextLMDataBunch', 'TextClasDataBunch', 'Text', 'open_text', 'TokenizeProcessor', 'NumericalizeProcessor',\n           'OpenFileProcessor', 'LMLabelList', 'LMTextList', 'SPProcessor']\n\nTextMtd = IntEnum('TextMtd', 'DF TOK IDS')\ntext_extensions = {'.txt'}\n\nclass LanguageModelPreLoader(Callback):\n    \"Transforms the tokens in `dataset` to a stream of contiguous batches for language modelling.\"\n\n    class CircularIndex():\n        \"Handles shuffle, direction of indexing, wraps around to head tail in the ragged array as needed\"\n        def __init__(self, length:int, forward:bool): self.idx, self.forward = np.arange(length), forward\n        def __getitem__(self, i):\n            return self.idx[ i%len(self.idx) if self.forward else len(self.idx)-1-i%len(self.idx)]\n        def __len__(self) -> int: return len(self.idx)\n        def shuffle(self): np.random.shuffle(self.idx)\n\n    def __init__(self, dataset:LabelList, lengths:Collection[int]=None, bs:int=32, bptt:int=70, backwards:bool=False,\n                 shuffle:bool=False):\n        self.dataset,self.bs,self.bptt,self.shuffle,self.backwards,self.lengths = dataset,bs,bptt,shuffle,backwards,lengths\n        self.bs *= num_distrib() or 1\n        self.totalToks,self.ite_len,self.idx = int(0),None,None\n\n    def __len__(self):\n        if self.ite_len is None:\n            if self.lengths is None: self.lengths = np.array([len(item) for item in self.dataset.x.items])\n            self.totalToks = self.lengths.sum()\n            self.ite_len   = self.bs*int( math.ceil( self.totalToks/(self.bptt*self.bs) )) if self.item is None else 1\n        return self.ite_len\n\n    def __getattr__(self,k:str)->Any: return getattr(self.dataset, k)\n\n    def allocate_buffers(self):\n        \"Create the ragged array that will be filled when we ask for items.\"\n        if self.ite_len is None: len(self)\n        self.idx   = LanguageModelPreLoader.CircularIndex(len(self.dataset.x.items), not self.backwards)\n        self.batch = np.zeros((self.bs, self.bptt+1), dtype=np.int64)\n        self.batch_x, self.batch_y = self.batch[:,0:self.bptt], self.batch[:,1:self.bptt+1]\n        #ro: index of the text we're at inside our datasets for the various batches\n        self.ro    = np.zeros(self.bs, dtype=np.int64)\n        #ri: index of the token we're at inside our current text for the various batches\n        self.ri    = np.zeros(self.bs, dtype=np.int)\n\n    def on_epoch_begin(self, **kwargs):\n        if self.idx is None or len(self.idx) != len(self.dataset.x.items): self.allocate_buffers()\n        elif self.shuffle:   self.idx.shuffle()\n        self.idx.forward = not self.backwards\n\n        step = self.totalToks / self.bs\n        ln_rag, countTokens, i_rag = 0, 0, -1\n        for i in range(0,self.bs):\n            #Compute the initial values for ro and ri\n            while ln_rag + countTokens <= int(step * i):\n                countTokens += ln_rag\n                i_rag       += 1\n                ln_rag       = self.lengths[self.idx[i_rag]]\n            self.ro[i] = i_rag\n            self.ri[i] = ( ln_rag - int(step * i - countTokens) ) if self.backwards else int(step * i - countTokens)\n\n    #Training dl gets on_epoch_begin called, val_dl, on_epoch_end\n    def on_epoch_end(self, **kwargs): self.on_epoch_begin()\n\n    def __getitem__(self, k:int):\n        j = k % self.bs\n        if self.item is not None: return self.dataset[0]\n        if self.idx is None: self.on_epoch_begin()\n        self.ro[j],self.ri[j] = self.fill_row(not self.backwards, self.dataset.x.items, self.idx, self.batch[j],\n                                              self.ro[j], self.ri[j], overlap=1, lengths=self.lengths)\n        return self.batch_x[j], self.batch_y[j]\n\n    def fill_row(self, forward, items, idx, row, ro, ri, overlap,lengths):\n        \"Fill the row with tokens from the ragged array. --OBS-- overlap != 1 has not been implemented\"\n        ibuf = n = 0\n        ro  -= 1\n        while ibuf < row.size:\n            ro   += 1\n            ix    = idx[ro]\n            rag   = items[ix]\n            if forward:\n                ri = 0 if ibuf else ri\n                n  = min(lengths[ix] - ri, row.size - ibuf)\n                row[ibuf:ibuf+n] = rag[ri:ri+n]\n            else:\n                ri = lengths[ix] if ibuf else ri\n                n  = min(ri, row.size - ibuf)\n                row[ibuf:ibuf+n] = rag[ri-n:ri][::-1]\n            ibuf += n\n        return ro, ri + ((n-overlap) if forward else -(n-overlap))\n\nclass SortSampler(Sampler):\n    \"Go through the text data by order of length.\"\n\n    def __init__(self, data_source:NPArrayList, key:KeyFunc): self.data_source,self.key = data_source,key\n    def __len__(self) -> int: return len(self.data_source)\n    def __iter__(self):\n        return iter(sorted(range_of(self.data_source), key=self.key, reverse=True))\n\nclass SortishSampler(Sampler):\n    \"Go through the text data by order of length with a bit of randomness.\"\n\n    def __init__(self, data_source:NPArrayList, key:KeyFunc, bs:int):\n        self.data_source,self.key,self.bs = data_source,key,bs\n\n    def __len__(self) -> int: return len(self.data_source)\n\n    def __iter__(self):\n        idxs = np.random.permutation(len(self.data_source))\n        sz = self.bs*50\n        ck_idx = [idxs[i:i+sz] for i in range(0, len(idxs), sz)]\n        sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])\n        sz = self.bs\n        ck_idx = [sort_idx[i:i+sz] for i in range(0, len(sort_idx), sz)]\n        max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx])  # find the chunk with the largest key,\n        ck_idx[0],ck_idx[max_ck] = ck_idx[max_ck],ck_idx[0]     # then make sure it goes first.\n        sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([],dtype=np.int)\n        sort_idx = np.concatenate((ck_idx[0], sort_idx))\n        return iter(sort_idx)\n\ndef pad_collate(samples:BatchSamples, pad_idx:int=1, pad_first:bool=True, backwards:bool=False) -> Tuple[LongTensor, LongTensor]:\n    \"Function that collect samples and adds padding. Flips token order if needed\"\n    samples = to_data(samples)\n    max_len = max([len(s[0]) for s in samples])\n    res = torch.zeros(len(samples), max_len).long() + pad_idx\n    if backwards: pad_first = not pad_first\n    for i,s in enumerate(samples):\n        if pad_first: res[i,-len(s[0]):] = LongTensor(s[0])\n        else:         res[i,:len(s[0]):] = LongTensor(s[0])\n    if backwards: res = res.flip(1)\n    return res, tensor(np.array([s[1] for s in samples]))\n\ndef _get_processor(tokenizer:Tokenizer=None, vocab:Vocab=None, chunksize:int=10000, max_vocab:int=60000,\n                   min_freq:int=2, mark_fields:bool=False, include_bos:bool=True, include_eos:bool=False):\n    return [TokenizeProcessor(tokenizer=tokenizer, chunksize=chunksize, \n                              mark_fields=mark_fields, include_bos=include_bos, include_eos=include_eos),\n            NumericalizeProcessor(vocab=vocab, max_vocab=max_vocab, min_freq=min_freq)]\n\nclass TextDataBunch(DataBunch):\n    \"General class to get a `DataBunch` for NLP. Subclassed by `TextLMDataBunch` and `TextClasDataBunch`.\"\n\n    @classmethod\n    def from_ids(cls, path:PathOrStr, vocab:Vocab, train_ids:Collection[Collection[int]], valid_ids:Collection[Collection[int]],\n                 test_ids:Collection[Collection[int]]=None, train_lbls:Collection[Union[int,float]]=None,\n                 valid_lbls:Collection[Union[int,float]]=None, classes:Collection[Any]=None,\n                 processor:PreProcessor=None, **kwargs) -> DataBunch:\n        \"Create a `TextDataBunch` from ids, labels and a `vocab`. `kwargs` are passed to the dataloader creation.\"\n        src = ItemLists(path, TextList(train_ids, vocab, path=path, processor=[]),\n                        TextList(valid_ids, vocab, path=path, processor=[]))\n        src = src.label_for_lm() if cls==TextLMDataBunch else src.label_from_lists(train_lbls, valid_lbls, classes=classes, processor=[])\n        if not is1d(train_lbls): src.train.y.one_hot,src.valid.y.one_hot = True,True\n        if test_ids is not None: src.add_test(TextList(test_ids, vocab, path=path), label=train_lbls[0])\n        src.valid.x.processor = ifnone(processor, [TokenizeProcessor(), NumericalizeProcessor(vocab=vocab)])\n        if classes is not None: src.valid.y.processor = ifnone(processor, [CategoryProcessor(src.valid.y)])\n        return src.databunch(**kwargs)\n\n    @classmethod\n    def load(cls, path:PathOrStr, cache_name:PathOrStr='tmp', processor:PreProcessor=None, **kwargs):\n        \"Load a `TextDataBunch` from `path/cache_name`. `kwargs` are passed to the dataloader creation.\"\n        warn(\"\"\"This method is deprecated and only kept to load data serialized in v1.0.43 or earlier.\n                Use `load_data` for data saved with v1.0.44 or later.\"\"\", DeprecationWarning)\n        cache_path = Path(path)/cache_name\n        vocab = Vocab(pickle.load(open(cache_path/'itos.pkl','rb')))\n        train_ids,train_lbls = np.load(cache_path/f'train_ids.npy'), np.load(cache_path/f'train_lbl.npy')\n        valid_ids,valid_lbls = np.load(cache_path/f'valid_ids.npy'), np.load(cache_path/f'valid_lbl.npy')\n        test_ids = np.load(cache_path/f'test_ids.npy') if os.path.isfile(cache_path/f'test_ids.npy') else None\n        classes = loadtxt_str(cache_path/'classes.txt') if os.path.isfile(cache_path/'classes.txt') else None\n        return cls.from_ids(path, vocab, train_ids, valid_ids, test_ids, train_lbls, valid_lbls, classes, processor, **kwargs)\n\n    @classmethod#TODO: test\n    def from_tokens(cls, path:PathOrStr, trn_tok:Collection[Collection[str]], trn_lbls:Collection[Union[int,float]],\n                 val_tok:Collection[Collection[str]], val_lbls:Collection[Union[int,float]], vocab:Vocab=None,\n                 tst_tok:Collection[Collection[str]]=None, classes:Collection[Any]=None, max_vocab:int=60000, min_freq:int=3,\n                 **kwargs) -> DataBunch:\n        \"Create a `TextDataBunch` from tokens and labels. `kwargs` are passed to the dataloader creation.\"\n        processor = NumericalizeProcessor(vocab=vocab, max_vocab=max_vocab, min_freq=min_freq)\n        src = ItemLists(path, TextList(trn_tok, path=path, processor=processor),\n                        TextList(val_tok, path=path, processor=processor))\n        src = src.label_for_lm() if cls==TextLMDataBunch else src.label_from_lists(trn_lbls, val_lbls, classes=classes)\n        if tst_tok is not None: src.add_test(TextList(tst_tok, path=path))\n        return src.databunch(**kwargs)\n\n    @classmethod\n    def from_df(cls, path:PathOrStr, train_df:DataFrame, valid_df:DataFrame, test_df:Optional[DataFrame]=None,\n                tokenizer:Tokenizer=None, vocab:Vocab=None, classes:Collection[str]=None, text_cols:IntsOrStrs=1,\n                label_cols:IntsOrStrs=0, label_delim:str=None, chunksize:int=10000, max_vocab:int=60000,\n                min_freq:int=2, mark_fields:bool=False, include_bos:bool=True, include_eos:bool=False, **kwargs) -> DataBunch:\n        \"Create a `TextDataBunch` from DataFrames. `kwargs` are passed to the dataloader creation.\"\n        processor = _get_processor(tokenizer=tokenizer, vocab=vocab, chunksize=chunksize, max_vocab=max_vocab,\n                                   min_freq=min_freq, mark_fields=mark_fields, \n                                   include_bos=include_bos, include_eos=include_eos)\n        if classes is None and is_listy(label_cols) and len(label_cols) > 1: classes = label_cols\n        src = ItemLists(path, TextList.from_df(train_df, path, cols=text_cols, processor=processor),\n                        TextList.from_df(valid_df, path, cols=text_cols, processor=processor))\n        if cls==TextLMDataBunch: src = src.label_for_lm()\n        else: \n            if label_delim is not None: src = src.label_from_df(cols=label_cols, classes=classes, label_delim=label_delim)\n            else: src = src.label_from_df(cols=label_cols, classes=classes)\n        if test_df is not None: src.add_test(TextList.from_df(test_df, path, cols=text_cols))\n        return src.databunch(**kwargs)\n\n    @classmethod\n    def from_csv(cls, path:PathOrStr, csv_name, valid_pct:float=0.2, test:Optional[str]=None,\n                 tokenizer:Tokenizer=None, vocab:Vocab=None, classes:Collection[str]=None, delimiter:str=None, header='infer',\n                 text_cols:IntsOrStrs=1, label_cols:IntsOrStrs=0, label_delim:str=None,\n                 chunksize:int=10000, max_vocab:int=60000, min_freq:int=2, \n                 mark_fields:bool=False, include_bos:bool=True, include_eos:bool=False, **kwargs) -> DataBunch:\n        \"Create a `TextDataBunch` from texts in csv files. `kwargs` are passed to the dataloader creation.\"\n        df = pd.read_csv(Path(path)/csv_name, header=header, delimiter=delimiter)\n        df = df.iloc[np.random.permutation(len(df))]\n        cut = int(valid_pct * len(df)) + 1\n        train_df, valid_df = df[cut:], df[:cut]\n        test_df = None if test is None else pd.read_csv(Path(path)/test, header=header, delimiter=delimiter)\n        return cls.from_df(path, train_df, valid_df, test_df, tokenizer=tokenizer, vocab=vocab, classes=classes, text_cols=text_cols,\n                           label_cols=label_cols, label_delim=label_delim, chunksize=chunksize, max_vocab=max_vocab,\n                           min_freq=min_freq, mark_fields=mark_fields, \n                           include_bos=include_bos, include_eos=include_eos, **kwargs)\n\n    @classmethod\n    def from_folder(cls, path:PathOrStr, train:str='train', valid:str='valid', test:Optional[str]=None,\n                    classes:Collection[Any]=None, tokenizer:Tokenizer=None, vocab:Vocab=None, chunksize:int=10000, max_vocab:int=60000,\n                    min_freq:int=2, mark_fields:bool=False, include_bos:bool=True, include_eos:bool=False, **kwargs):\n        \"Create a `TextDataBunch` from text files in folders.\"\n        path = Path(path).absolute()\n        processor = [OpenFileProcessor()] + _get_processor(tokenizer=tokenizer, vocab=vocab, chunksize=chunksize, max_vocab=max_vocab,\n                                   min_freq=min_freq, mark_fields=mark_fields, include_bos=include_bos, include_eos=include_eos)\n        src = (TextList.from_folder(path, processor=processor)\n                       .split_by_folder(train=train, valid=valid))\n        src = src.label_for_lm() if cls==TextLMDataBunch else src.label_from_folder(classes=classes)\n        if test is not None: src.add_test_folder(path/test)\n        return src.databunch(**kwargs)\n\nclass TextLMDataBunch(TextDataBunch):\n    \"Create a `TextDataBunch` suitable for training a language model.\"\n    @classmethod\n    def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', no_check:bool=False, bs=64, val_bs:int=None,\n               num_workers:int=0, device:torch.device=None, collate_fn:Callable=data_collate,\n               dl_tfms:Optional[Collection[Callable]]=None, bptt:int=70, backwards:bool=False, **dl_kwargs) -> DataBunch:\n        \"Create a `TextDataBunch` in `path` from the `datasets` for language modelling. Passes `**dl_kwargs` on to `DataLoader()`\"\n        datasets = cls._init_ds(train_ds, valid_ds, test_ds)\n        val_bs = ifnone(val_bs, bs)\n        datasets = [LanguageModelPreLoader(ds, shuffle=(i==0), bs=(bs if i==0 else val_bs), bptt=bptt, backwards=backwards)\n                    for i,ds in enumerate(datasets)]\n        val_bs = bs\n        dls = [DataLoader(d, b, shuffle=False, **dl_kwargs) for d,b in zip(datasets, (bs,val_bs,val_bs,val_bs)) if d is not None]\n        return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)\n\nclass TextClasDataBunch(TextDataBunch):\n    \"Create a `TextDataBunch` suitable for training an RNN classifier.\"\n    @classmethod\n    def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', bs:int=32, val_bs:int=None, pad_idx=1,\n               pad_first=True, device:torch.device=None, no_check:bool=False, backwards:bool=False, \n               dl_tfms:Optional[Collection[Callable]]=None, **dl_kwargs) -> DataBunch:\n        \"Function that transform the `datasets` in a `DataBunch` for classification. Passes `**dl_kwargs` on to `DataLoader()`\"\n        datasets = cls._init_ds(train_ds, valid_ds, test_ds)\n        val_bs = ifnone(val_bs, bs)\n        collate_fn = partial(pad_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards)\n        train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0][t][0].data), bs=bs)\n        train_dl = DataLoader(datasets[0], batch_size=bs, sampler=train_sampler, drop_last=True, **dl_kwargs)\n        dataloaders = [train_dl]\n        for ds in datasets[1:]:\n            lengths = [len(t) for t in ds.x.items]\n            sampler = SortSampler(ds.x, key=lengths.__getitem__)\n            dataloaders.append(DataLoader(ds, batch_size=val_bs, sampler=sampler, **dl_kwargs))\n        return cls(*dataloaders, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)\n\ndef open_text(fn:PathOrStr, enc='utf-8'):\n    \"Read the text in `fn`.\"\n    with open(fn,'r', encoding = enc) as f: return ''.join(f.readlines())\n\nclass Text(ItemBase):\n    \"Basic item for <code>text</code> data in numericalized `ids`.\"\n    def __init__(self, ids, text): self.data,self.text = np.array(ids, dtype=np.int64),text\n    def __str__(self):  return str(self.text)\n\nclass TokenizeProcessor(PreProcessor):\n    \"`PreProcessor` that tokenizes the texts in `ds`.\"\n    def __init__(self, ds:ItemList=None, tokenizer:Tokenizer=None, chunksize:int=10000, \n                 mark_fields:bool=False, include_bos:bool=True, include_eos:bool=False):\n        self.tokenizer,self.chunksize,self.mark_fields = ifnone(tokenizer, Tokenizer()),chunksize,mark_fields\n        self.include_bos, self.include_eos = include_bos, include_eos\n\n    def process_one(self, item):\n        return self.tokenizer._process_all_1(_join_texts([item], self.mark_fields, self.include_bos, self.include_eos))[0]\n\n    def process(self, ds):\n        ds.items = _join_texts(ds.items, self.mark_fields, self.include_bos, self.include_eos)\n        tokens = []\n        for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):\n            tokens += self.tokenizer.process_all(ds.items[i:i+self.chunksize])\n        ds.items = tokens\n\nclass NumericalizeProcessor(PreProcessor):\n    \"`PreProcessor` that numericalizes the tokens in `ds`.\"\n    def __init__(self, ds:ItemList=None, vocab:Vocab=None, max_vocab:int=60000, min_freq:int=3):\n        vocab = ifnone(vocab, ds.vocab if ds is not None else None)\n        self.vocab,self.max_vocab,self.min_freq = vocab,max_vocab,min_freq\n\n    def process_one(self,item): return np.array(self.vocab.numericalize(item), dtype=np.int64)\n    def process(self, ds):\n        if self.vocab is None: self.vocab = Vocab.create(ds.items, self.max_vocab, self.min_freq)\n        ds.vocab = self.vocab\n        super().process(ds)\n\nclass OpenFileProcessor(PreProcessor):\n    \"`PreProcessor` that opens the filenames and read the texts.\"\n    def process(self, ds:Collection): ds.items = array([self.process_one(item) for item in ds.items], dtype=np.object)\n    def process_one(self,item): return open_text(item) if isinstance(item, Path) else item\n\nclass TextList(ItemList):\n    \"Basic `ItemList` for text data.\"\n    _bunch = TextClasDataBunch\n    _processor = [TokenizeProcessor, NumericalizeProcessor]\n    _is_lm = False\n\n    def __init__(self, items:Iterator, vocab:Vocab=None, pad_idx:int=1, sep=' ', **kwargs):\n        super().__init__(items, **kwargs)\n        self.vocab,self.pad_idx,self.sep = vocab,pad_idx,sep\n        self.copy_new += ['vocab', 'pad_idx', 'sep']\n\n    def get(self, i):\n        o = super().get(i)\n        return o if self.vocab is None else Text(o, self.vocab.textify(o, self.sep))\n\n    def label_for_lm(self, **kwargs):\n        \"A special labelling method for language models.\"\n        self.__class__ = LMTextList\n        kwargs['label_cls'] = LMLabelList\n        return self.label_const(0, **kwargs)\n\n    def reconstruct(self, t:Tensor):\n        idx_min = (t != self.pad_idx).nonzero().min()\n        idx_max = (t != self.pad_idx).nonzero().max()\n        return Text(t[idx_min:idx_max+1], self.vocab.textify(t[idx_min:idx_max+1]))\n\n    @classmethod\n    def from_folder(cls, path:PathOrStr='.', extensions:Collection[str]=text_extensions, vocab:Vocab=None,\n                    processor:PreProcessor=None, **kwargs)->'TextList':\n        \"Get the list of files in `path` that have a text suffix. `recurse` determines if we search subfolders.\"\n        processor = ifnone(processor, [OpenFileProcessor(), TokenizeProcessor(), NumericalizeProcessor(vocab=vocab)])\n        return super().from_folder(path=path, extensions=extensions, processor=processor, **kwargs)\n\n    def show_xys(self, xs, ys, max_len:int=70)->None:\n        \"Show the `xs` (inputs) and `ys` (targets). `max_len` is the maximum number of tokens displayed.\"\n        from IPython.display import display, HTML\n        names = ['idx','text'] if self._is_lm else ['text','target']\n        items = []\n        for i, (x,y) in enumerate(zip(xs,ys)):\n            txt_x = ' '.join(x.text.split(' ')[:max_len]) if max_len is not None else x.text\n            items.append([i, txt_x] if self._is_lm else [txt_x, y])\n        items = np.array(items)\n        df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)\n        with pd.option_context('display.max_colwidth', -1):\n            display(HTML(df.to_html(index=False)))\n\n    def show_xyzs(self, xs, ys, zs, max_len:int=70):\n        \"Show `xs` (inputs), `ys` (targets) and `zs` (predictions). `max_len` is the maximum number of tokens displayed.\"\n        from IPython.display import display, HTML\n        items,names = [],['text','target','prediction']\n        for i, (x,y,z) in enumerate(zip(xs,ys,zs)):\n            txt_x = ' '.join(x.text.split(' ')[:max_len]) if max_len is not None else x.text\n            items.append([txt_x, y, z])\n        items = np.array(items)\n        df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)\n        with pd.option_context('display.max_colwidth', -1):\n            display(HTML(df.to_html(index=False)))\n\nclass LMLabelList(EmptyLabelList):\n    \"Basic `ItemList` for dummy labels.\"\n    def __init__(self, items:Iterator, **kwargs):\n        super().__init__(items, **kwargs)\n        self.loss_func = CrossEntropyFlat()\n\nclass LMTextList(TextList):\n    \"Special `TextList` for a language model.\"\n    _bunch = TextLMDataBunch\n    _is_lm = True\n\ndef _join_texts(texts:Collection[str], mark_fields:bool=False, include_bos:bool=True, include_eos:bool=False):\n    if not isinstance(texts, np.ndarray): texts = np.array(texts)\n    if is1d(texts): texts = texts[:,None]\n    df = pd.DataFrame({i:texts[:,i] for i in range(texts.shape[1])})\n    bos_tok = f'{BOS} ' if include_bos else ''\n    text_col = f'{bos_tok}{FLD} {1} ' + df[0].astype(str) if mark_fields else f'{bos_tok}' + df[0].astype(str)\n    for i in range(1,len(df.columns)):\n        text_col += (f' {FLD} {i+1} ' if mark_fields else ' ') + df[i].astype(str)\n    if include_eos: text_col = text_col + f' {EOS}'\n    return text_col.values\n\ndef apply_rules(text, pre_rules=None, post_rules=None):\n    \"Apply `pre_rules` and `post_rules` to `text`\"\n    text = text.strip(' ')\n    for r in ifnone(pre_rules, defaults.text_pre_rules): text = r(text)\n    toks = text.split()\n    for r in ifnone(post_rules, defaults.text_post_rules): toks = r(toks)\n    return ' '.join(toks) \n\ndef get_default_size(texts, max_vocab_sz):\n    \"Either max_vocab_sz or one quarter of the number of unique words in `texts`\"\n    cnt = Counter()\n    for t in texts: \n        cnt.update(t.split())\n        if len(cnt)//4 > max_vocab_sz: return max_vocab_sz\n    res = len(cnt)//4\n    while res%8 != 0: res+=1\n    return res\n\nfull_char_coverage_langs = [\"bg\", \"cs\", \"da\", \"de\", \"el\", \"en\", \"es\", \"et\", \"fi\", \"fr\", \"ga\", \"hr\", \"hu\",\n                       \"it\",\"lt\",\"lv\",\"mt\",\"nl\",\"pl\",\"pt\",\"ro\",\"sk\",\"sl\",\"sv\"] # all European langs\n\ndef train_sentencepiece(texts:Collection[str], path:PathOrStr, pre_rules: ListRules=None, post_rules:ListRules=None, \n    vocab_sz:int=None, max_vocab_sz:int=30000, model_type:str='unigram', max_sentence_len:int=20480, lang='en',\n    char_coverage=None, tmp_dir='tmp'):\n    \"Train a sentencepiece tokenizer on `texts` and save it in `path/tmp_dir`\"\n    from sentencepiece import SentencePieceTrainer\n    cache_dir = Path(path)/tmp_dir\n    os.makedirs(cache_dir, exist_ok=True)\n    if vocab_sz is None: vocab_sz=get_default_size(texts, max_vocab_sz)\n    raw_text_path = cache_dir / 'all_text.out'\n    with open(raw_text_path, 'w') as f: f.write(\"\\n\".join(texts))\n    spec_tokens = ['\\u2581'+s for s in defaults.text_spec_tok]\n    SentencePieceTrainer.Train(\" \".join([\n        f\"--input={raw_text_path} --max_sentence_length={max_sentence_len}\",\n        f\"--character_coverage={ifnone(char_coverage, 0.99999 if lang in full_char_coverage_langs else 0.9998)}\",\n        f\"--unk_id={len(defaults.text_spec_tok)} --pad_id=-1 --bos_id=-1 --eos_id=-1\",\n        f\"--user_defined_symbols={','.join(spec_tokens)}\",\n        f\"--model_prefix={cache_dir/'spm'} --vocab_size={vocab_sz} --model_type={model_type}\"]))\n    raw_text_path.unlink()\n    return cache_dir\n\nclass SPProcessor(PreProcessor):\n    \"`PreProcessor` that tokenizes and numericalizes with `sentencepiece`\"\n    def __init__(self, ds:ItemList=None, pre_rules: ListRules=None, post_rules:ListRules=None, vocab_sz:int=None,\n                 max_vocab_sz:int=30000, model_type:str='unigram', max_sentence_len:int=20480, lang='en',\n                 char_coverage=None, tmp_dir='tmp', mark_fields:bool=False, include_bos:bool=True, \n                 include_eos:bool=False, sp_model=None, sp_vocab=None, n_cpus:int=None):\n        try: from sentencepiece import SentencePieceTrainer,SentencePieceProcessor\n        except ImportError:\n            raise Exception('sentencepiece module is missing: run `pip install sentencepiece`')\n        self.pre_rules,self.post_rules = pre_rules,post_rules\n        self.mark_fields,self.include_bos,self.include_eos = mark_fields,include_bos,include_eos\n        self.sp_model,self.sp_vocab,self.n_cpus = sp_model,sp_vocab,ifnone(n_cpus,defaults.cpus)\n        self.train_func = partial(train_sentencepiece, pre_rules=pre_rules, post_rules=post_rules, vocab_sz=vocab_sz,\n                max_vocab_sz=max_vocab_sz, model_type=model_type, max_sentence_len=max_sentence_len, lang=lang,\n                char_coverage=char_coverage, tmp_dir=tmp_dir)\n\n    def process_one(self, item, join=True):\n        if join: text = _join_texts([item], self.mark_fields, self.include_bos, self.include_eos)[0]\n        text = apply_rules(text, pre_rules=self.pre_rules, post_rules=self.post_rules)\n        return self._encode_batch([text])[0]\n\n    def process(self, ds):\n        ds.items = _join_texts(ds.items, self.mark_fields, self.include_bos, self.include_eos)\n        ds.items = [apply_rules(t, pre_rules=self.pre_rules, post_rules=self.post_rules) \n                    for t in progress_bar(ds.items, leave=False)]\n        if self.sp_model is None or self.sp_vocab is None:\n            cache_dir = self.train_func(ds.items, ds.path)\n            self.sp_model,self.sp_vocab = cache_dir/'spm.model',cache_dir/'spm.vocab'\n        if not getattr(self, 'vocab', False): \n            with open(self.sp_vocab, 'r') as f: self.vocab = Vocab([line.split('\\t')[0] for line in f.readlines()])\n        if self.n_cpus <= 1: ds.items = self._encode_batch(ds.items)\n        else:\n            with ProcessPoolExecutor(self.n_cpus) as e:\n                ds.items = np.array(sum(e.map(self._encode_batch, partition_by_cores(ds.items, self.n_cpus)), []))\n        ds.vocab = self.vocab\n\n    def _encode_batch(self, texts):\n        from sentencepiece import SentencePieceProcessor\n        tok = SentencePieceProcessor()\n        tok.Load(str(self.sp_model))\n        return [np.array(tok.EncodeAsIds(t)) for t in texts]\n\n    @classmethod\n    def load(cls, path:PathOrStr, tmp_dir:PathOrStr='tmp', name:str='spm'):\n        cache_dir = Path(path)/tmp_dir\n        return cls(sp_model=cache_dir/f'{name}.model', sp_vocab=cache_dir/f'{name}.vocab')\n"
  },
  {
    "path": "fastai/text/interpret.py",
    "content": "from ..torch_core import *\nfrom ..basic_data import *\nfrom ..basic_train import *\nfrom ..train import ClassificationInterpretation\nimport matplotlib.cm as cm\n\n__all__ = ['TextClassificationInterpretation']\n\ndef value2rgba(x:float, cmap:Callable=cm.RdYlGn, alpha_mult:float=1.0)->Tuple:\n    \"Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`.\"\n    c = cmap(x)\n    rgb = (np.array(c[:-1]) * 255).astype(int)\n    a = c[-1] * alpha_mult\n    return tuple(rgb.tolist() + [a])\n\ndef piece_attn_html(pieces:List[str], attns:List[float], sep:str=' ', **kwargs)->str:\n    html_code,spans = ['<span style=\"font-family: monospace;\">'], []\n    for p, a in zip(pieces, attns):\n        p = html.escape(p)\n        c = str(value2rgba(a, alpha_mult=0.5, **kwargs))\n        spans.append(f'<span title=\"{a:.3f}\" style=\"background-color: rgba{c};\">{p}</span>')\n    html_code.append(sep.join(spans))\n    html_code.append('</span>')\n    return ''.join(html_code)\n\ndef show_piece_attn(*args, **kwargs):\n    from IPython.display import display, HTML\n    display(HTML(piece_attn_html(*args, **kwargs)))\n\ndef _eval_dropouts(mod):\n        module_name =  mod.__class__.__name__\n        if 'Dropout' in module_name or 'BatchNorm' in module_name: mod.training = False\n        for module in mod.children(): _eval_dropouts(module)\n\nclass TextClassificationInterpretation(ClassificationInterpretation):\n    \"\"\"Provides an interpretation of classification based on input sensitivity.\n    This was designed for AWD-LSTM only for the moment, because Transformer already has its own attentional model.\n    \"\"\"\n\n    def __init__(self, learn: Learner, preds: Tensor, y_true: Tensor, losses: Tensor, ds_type: DatasetType = DatasetType.Valid):\n        super(TextClassificationInterpretation, self).__init__(learn,preds,y_true,losses,ds_type)\n        self.model = learn.model\n\n    @classmethod\n    def from_learner(cls, learn: Learner,  ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None):\n        \"Gets preds, y_true, losses to construct base class from a learner\"\n        preds_res = learn.get_preds(ds_type=ds_type, activ=activ, with_loss=True, ordered=True)\n        return cls(learn, *preds_res)\n\n    def intrinsic_attention(self, text:str, class_id:int=None):\n        \"\"\"Calculate the intrinsic attention of the input w.r.t to an output `class_id`, or the classification given by the model if `None`.\n        For reference, see the Sequential Jacobian session at https://www.cs.toronto.edu/~graves/preprint.pdf\n        \"\"\"\n        self.model.train()\n        _eval_dropouts(self.model)\n        self.model.zero_grad()\n        self.model.reset()\n        ids = self.data.one_item(text)[0]\n        emb = self.model[0].module.encoder(ids).detach().requires_grad_(True)\n        lstm_output = self.model[0].module(emb, from_embeddings=True)\n        self.model.eval()\n        cl = self.model[1](lstm_output + (torch.zeros_like(ids).byte(),))[0].softmax(dim=-1)\n        if class_id is None: class_id = cl.argmax()\n        cl[0][class_id].backward()\n        attn = emb.grad.squeeze().abs().sum(dim=-1)\n        attn /= attn.max()\n        tokens = self.data.single_ds.reconstruct(ids[0])\n        return tokens, attn\n\n    def html_intrinsic_attention(self, text:str, class_id:int=None, **kwargs)->str:\n        text, attn = self.intrinsic_attention(text, class_id)\n        return piece_attn_html(text.text.split(), to_np(attn), **kwargs)\n\n    def show_intrinsic_attention(self, text:str, class_id:int=None, **kwargs)->None:\n        text, attn = self.intrinsic_attention(text, class_id)\n        show_piece_attn(text.text.split(), to_np(attn), **kwargs)\n\n    def show_top_losses(self, k:int, max_len:int=70)->None:\n        \"\"\"\n        Create a tabulation showing the first `k` texts in top_losses along with their prediction, actual,loss, and probability of\n        actual class. `max_len` is the maximum number of tokens displayed.\n        \"\"\"\n        from IPython.display import display, HTML\n        items = []\n        tl_val,tl_idx = self.top_losses()\n        for i,idx in enumerate(tl_idx):\n            if k <= 0: break\n            k -= 1\n            tx,cl = self.data.dl(self.ds_type).dataset[idx]\n            cl = cl.data\n            classes = self.data.classes\n            txt = ' '.join(tx.text.split(' ')[:max_len]) if max_len is not None else tx.text\n            tmp = [txt, f'{classes[self.pred_class[idx]]}', f'{classes[cl]}', f'{self.losses[idx]:.2f}',\n                   f'{self.preds[idx][cl]:.2f}']\n            items.append(tmp)\n        items = np.array(items)\n        names = ['Text', 'Prediction', 'Actual', 'Loss', 'Probability']\n        df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)\n        with pd.option_context('display.max_colwidth', -1):\n            display(HTML(df.to_html(index=False)))\n"
  },
  {
    "path": "fastai/text/learner.py",
    "content": "'Model training for NLP'\nfrom ..torch_core import *\nfrom ..basic_train import *\nfrom ..callbacks import *\nfrom ..data_block import CategoryList\nfrom ..basic_data import *\nfrom ..datasets import *\nfrom ..metrics import accuracy\nfrom ..train import GradientClipping\nfrom ..layers import *\nfrom .models import *\nfrom .transform import *\nfrom .data import *\n\n__all__ = ['RNNLearner', 'LanguageLearner', 'convert_weights', 'decode_spec_tokens', 'get_language_model', 'language_model_learner',\n           'MultiBatchEncoder', 'get_text_classifier', 'text_classifier_learner', 'PoolingLinearClassifier']\n\n_model_meta = {AWD_LSTM: {'hid_name':'emb_sz', 'url':URLs.WT103_FWD, 'url_bwd':URLs.WT103_BWD,\n                          'config_lm':awd_lstm_lm_config, 'split_lm': awd_lstm_lm_split,\n                          'config_clas':awd_lstm_clas_config, 'split_clas': awd_lstm_clas_split},\n               Transformer: {'hid_name':'d_model', 'url':URLs.OPENAI_TRANSFORMER,\n                             'config_lm':tfmer_lm_config, 'split_lm': tfmer_lm_split,\n                             'config_clas':tfmer_clas_config, 'split_clas': tfmer_clas_split},\n               TransformerXL: {'hid_name':'d_model',\n                              'config_lm':tfmerXL_lm_config, 'split_lm': tfmerXL_lm_split,\n                              'config_clas':tfmerXL_clas_config, 'split_clas': tfmerXL_clas_split}}\n\ndef convert_weights(wgts:Weights, stoi_wgts:Dict[str,int], itos_new:Collection[str]) -> Weights:\n    \"Convert the model `wgts` to go with a new vocabulary.\"\n    dec_bias, enc_wgts = wgts.get('1.decoder.bias', None), wgts['0.encoder.weight']\n    wgts_m = enc_wgts.mean(0)\n    if dec_bias is not None: bias_m = dec_bias.mean(0)\n    new_w = enc_wgts.new_zeros((len(itos_new),enc_wgts.size(1))).zero_()\n    if dec_bias is not None: new_b = dec_bias.new_zeros((len(itos_new),)).zero_()\n    for i,w in enumerate(itos_new):\n        r = stoi_wgts[w] if w in stoi_wgts else -1\n        new_w[i] = enc_wgts[r] if r>=0 else wgts_m\n        if dec_bias is not None: new_b[i] = dec_bias[r] if r>=0 else bias_m\n    wgts['0.encoder.weight'] = new_w\n    if '0.encoder_dp.emb.weight' in wgts: wgts['0.encoder_dp.emb.weight'] = new_w.clone()\n    wgts['1.decoder.weight'] = new_w.clone()\n    if dec_bias is not None: wgts['1.decoder.bias'] = new_b\n    return wgts\n\nclass RNNLearner(Learner):\n    \"Basic class for a `Learner` in NLP.\"\n    def __init__(self, data:DataBunch, model:nn.Module, split_func:OptSplitFunc=None, clip:float=None,\n                 alpha:float=2., beta:float=1., metrics=None, **learn_kwargs):\n        is_class = (hasattr(data.train_ds, 'y') and (isinstance(data.train_ds.y, CategoryList) or\n                                                     isinstance(data.train_ds.y, LMLabelList)))\n        metrics = ifnone(metrics, ([accuracy] if is_class else []))\n        super().__init__(data, model, metrics=metrics, **learn_kwargs)\n        self.callbacks.append(RNNTrainer(self, alpha=alpha, beta=beta))\n        if clip: self.callback_fns.append(partial(GradientClipping, clip=clip))\n        if split_func: self.split(split_func)\n\n    def save_encoder(self, name:str):\n        \"Save the encoder to `name` inside the model directory.\"\n        if is_pathlike(name): self._test_writeable_path()\n        encoder = get_model(self.model)[0]\n        if hasattr(encoder, 'module'): encoder = encoder.module\n        torch.save(encoder.state_dict(), self.path/self.model_dir/f'{name}.pth')\n\n    def load_encoder(self, name:str, device:torch.device=None):\n        \"Load the encoder `name` from the model directory.\"\n        encoder = get_model(self.model)[0]\n        if device is None: device = self.data.device\n        if hasattr(encoder, 'module'): encoder = encoder.module\n        encoder.load_state_dict(torch.load(self.path/self.model_dir/f'{name}.pth', map_location=device))\n        self.freeze()\n\n    def load_pretrained(self, wgts_fname:str, itos_fname:str, strict:bool=True):\n        \"Load a pretrained model and adapts it to the data vocabulary.\"\n        old_itos = pickle.load(open(itos_fname, 'rb'))\n        old_stoi = {v:k for k,v in enumerate(old_itos)}\n        wgts = torch.load(wgts_fname, map_location=lambda storage, loc: storage)\n        if 'model' in wgts: wgts = wgts['model']\n        wgts = convert_weights(wgts, old_stoi, self.data.train_ds.vocab.itos)\n        self.model.load_state_dict(wgts, strict=strict)\n\n    def get_preds(self, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None, with_loss:bool=False, n_batch:Optional[int]=None,\n                  pbar:Optional[PBar]=None, ordered:bool=False) -> List[Tensor]:\n        \"Return predictions and targets on the valid, train, or test set, depending on `ds_type`.\"\n        self.model.reset()\n        if ordered: np.random.seed(42)\n        preds = super().get_preds(ds_type=ds_type, activ=activ, with_loss=with_loss, n_batch=n_batch, pbar=pbar)\n        if ordered and hasattr(self.dl(ds_type), 'sampler'):\n            np.random.seed(42)\n            sampler = [i for i in self.dl(ds_type).sampler]\n            reverse_sampler = np.argsort(sampler)\n            preds = [p[reverse_sampler] for p in preds]\n        return(preds)\n\ndef decode_spec_tokens(tokens):\n    new_toks,rule,arg = [],None,None\n    for t in tokens:\n        if t in [TK_MAJ, TK_UP, TK_REP, TK_WREP]: rule = t\n        elif rule is None: new_toks.append(t)\n        elif rule == TK_MAJ:\n            new_toks.append(t[:1].upper() + t[1:].lower())\n            rule = None\n        elif rule == TK_UP:\n            new_toks.append(t.upper())\n            rule = None\n        elif arg is None:\n            try:    arg = int(t)\n            except: rule = None\n        else:\n            if rule == TK_REP: new_toks.append(t * arg)\n            else:              new_toks += [t] * arg\n    return new_toks\n\nclass LanguageLearner(RNNLearner):\n    \"Subclass of RNNLearner for predictions.\"\n\n    def predict(self, text:str, n_words:int=1, no_unk:bool=True, temperature:float=1., min_p:float=None, sep:str=' ',\n                decoder=decode_spec_tokens):\n        \"Return the `n_words` that come after `text`.\"\n        ds = self.data.single_dl.dataset\n        self.model.reset()\n        xb,yb = self.data.one_item(text)\n        new_idx = []\n        for _ in range(n_words): #progress_bar(range(n_words), leave=False):\n            res = self.pred_batch(batch=(xb,yb))[0][-1]\n            #if len(new_idx) == 0: self.model[0].select_hidden([0])\n            if no_unk: res[self.data.vocab.stoi[UNK]] = 0.\n            if min_p is not None:\n                if (res >= min_p).float().sum() == 0:\n                    warn(f\"There is no item with probability >= {min_p}, try a lower value.\")\n                else: res[res < min_p] = 0.\n            if temperature != 1.: res.pow_(1 / temperature)\n            idx = torch.multinomial(res, 1).item()\n            new_idx.append(idx)\n            xb = xb.new_tensor([idx])[None]\n        return text + sep + sep.join(decoder(self.data.vocab.textify(new_idx, sep=None)))\n\n    def beam_search(self, text:str, n_words:int, no_unk:bool=True, top_k:int=10, beam_sz:int=1000, temperature:float=1.,\n                    sep:str=' ', decoder=decode_spec_tokens):\n        \"Return the `n_words` that come after `text` using beam search.\"\n        ds = self.data.single_dl.dataset\n        self.model.reset()\n        self.model.eval()\n        xb, yb = self.data.one_item(text)\n        nodes = None\n        nodes = xb.clone()\n        scores = xb.new_zeros(1).float()\n        with torch.no_grad():\n            for k in progress_bar(range(n_words), leave=False):\n                out = F.log_softmax(self.model(xb)[0][:,-1], dim=-1)\n                if no_unk: out[:,self.data.vocab.stoi[UNK]] = -float('Inf')\n                values, indices = out.topk(top_k, dim=-1)\n                scores = (-values + scores[:,None]).view(-1)\n                indices_idx = torch.arange(0,nodes.size(0))[:,None].expand(nodes.size(0), top_k).contiguous().view(-1)\n                sort_idx = scores.argsort()[:beam_sz]\n                scores = scores[sort_idx]\n                nodes = torch.cat([nodes[:,None].expand(nodes.size(0),top_k,nodes.size(1)),\n                                indices[:,:,None].expand(nodes.size(0),top_k,1),], dim=2)\n                nodes = nodes.view(-1, nodes.size(2))[sort_idx]\n                self.model[0].select_hidden(indices_idx[sort_idx])\n                xb = nodes[:,-1][:,None]\n        if temperature != 1.: scores.div_(temperature)\n        node_idx = torch.multinomial(torch.exp(-scores), 1).item()\n        return text + sep + sep.join(decoder(self.data.vocab.textify([i.item() for i in nodes[node_idx][1:] ], sep=None)))\n\n    def show_results(self, ds_type=DatasetType.Valid, rows:int=5, max_len:int=20):\n        from IPython.display import display, HTML\n        \"Show `rows` result of predictions on `ds_type` dataset.\"\n        ds = self.dl(ds_type).dataset\n        x,y = self.data.one_batch(ds_type, detach=False, denorm=False)\n        preds = self.pred_batch(batch=(x,y))\n        y = y.view(*x.size())\n        z = preds.view(*x.size(),-1).argmax(dim=2)\n        xs = [ds.x.reconstruct(grab_idx(x, i)) for i in range(rows)]\n        ys = [ds.x.reconstruct(grab_idx(y, i)) for i in range(rows)]\n        zs = [ds.x.reconstruct(grab_idx(z, i)) for i in range(rows)]\n        items,names = [],['text', 'target', 'pred']\n        for i, (x,y,z) in enumerate(zip(xs,ys,zs)):\n            txt_x = ' '.join(x.text.split(' ')[:max_len])\n            txt_y = ' '.join(y.text.split(' ')[max_len-1:2*max_len-1])\n            txt_z = ' '.join(z.text.split(' ')[max_len-1:2*max_len-1])\n            items.append([txt_x, txt_y, txt_z])\n        items = np.array(items)\n        df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)\n        with pd.option_context('display.max_colwidth', -1):\n            display(HTML(df.to_html(index=False)))\n\ndef get_language_model(arch:Callable, vocab_sz:int, config:dict=None, drop_mult:float=1.):\n    \"Create a language model from `arch` and its `config`, maybe `pretrained`.\"\n    meta = _model_meta[arch]\n    config = ifnone(config, meta['config_lm']).copy()\n    for k in config.keys():\n        if k.endswith('_p'): config[k] *= drop_mult\n    tie_weights,output_p,out_bias = map(config.pop, ['tie_weights', 'output_p', 'out_bias'])\n    init = config.pop('init') if 'init' in config else None\n    encoder = arch(vocab_sz, **config)\n    enc = encoder.encoder if tie_weights else None\n    decoder = LinearDecoder(vocab_sz, config[meta['hid_name']], output_p, tie_encoder=enc, bias=out_bias)\n    model = SequentialRNN(encoder, decoder)\n    return model if init is None else model.apply(init)\n\ndef language_model_learner(data:DataBunch, arch, config:dict=None, drop_mult:float=1., pretrained:bool=True,\n                           pretrained_fnames:OptStrTuple=None, **learn_kwargs) -> 'LanguageLearner':\n    \"Create a `Learner` with a language model from `data` and `arch`.\"\n    model = get_language_model(arch, len(data.vocab.itos), config=config, drop_mult=drop_mult)\n    meta = _model_meta[arch]\n    learn = LanguageLearner(data, model, split_func=meta['split_lm'], **learn_kwargs)\n    url = 'url_bwd' if data.backwards else 'url'\n    if pretrained or pretrained_fnames:\n        if pretrained_fnames is not None:\n            fnames = [learn.path/learn.model_dir/f'{fn}.{ext}' for fn,ext in zip(pretrained_fnames, ['pth', 'pkl'])]\n        else:\n            if url not in meta:\n                warn(\"There are no pretrained weights for that architecture yet!\")\n                return learn\n            model_path = untar_data(meta[url] , data=False)\n            fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]\n        learn.load_pretrained(*fnames)\n        learn.freeze()\n    return learn\n\ndef masked_concat_pool(outputs, mask):\n    \"Pool MultiBatchEncoder outputs into one vector [last_hidden, max_pool, avg_pool].\"\n    output = outputs[-1]\n    avg_pool = output.masked_fill(mask[:, :, None], 0).mean(dim=1)\n    avg_pool *= output.size(1) / (output.size(1)-mask.type(avg_pool.dtype).sum(dim=1))[:,None]\n    max_pool = output.masked_fill(mask[:,:,None], -float('inf')).max(dim=1)[0]\n    x = torch.cat([output[:,-1], max_pool, avg_pool], 1)\n    return x\n\nclass PoolingLinearClassifier(Module):\n    \"Create a linear classifier with pooling.\"\n    def __init__(self, layers:Collection[int], drops:Collection[float]):\n        mod_layers = []\n        if len(drops) != len(layers)-1: raise ValueError(\"Number of layers and dropout values do not match.\")\n        activs = [nn.ReLU(inplace=True)] * (len(layers) - 2) + [None]\n        for n_in, n_out, p, actn in zip(layers[:-1], layers[1:], drops, activs):\n            mod_layers += bn_drop_lin(n_in, n_out, p=p, actn=actn)\n        self.layers = nn.Sequential(*mod_layers)\n\n    def forward(self, input:Tuple[Tensor,Tensor, Tensor])->Tuple[Tensor,Tensor,Tensor]:\n        raw_outputs,outputs,mask = input\n        x = masked_concat_pool(outputs, mask)\n        x = self.layers(x)\n        return x, raw_outputs, outputs\n\nclass MultiBatchEncoder(Module):\n    \"Create an encoder over `module` that can process a full sentence.\"\n    def __init__(self, bptt:int, max_len:int, module:nn.Module, pad_idx:int=1):\n        self.max_len,self.bptt,self.module,self.pad_idx = max_len,bptt,module,pad_idx\n\n    def concat(self, arrs:Collection[Tensor])->Tensor:\n        \"Concatenate the `arrs` along the batch dimension.\"\n        return [torch.cat([l[si] for l in arrs], dim=1) for si in range_of(arrs[0])]\n\n    def reset(self):\n        if hasattr(self.module, 'reset'): self.module.reset()\n\n    def forward(self, input:LongTensor)->Tuple[Tensor,Tensor]:\n        bs,sl = input.size()\n        self.reset()\n        raw_outputs,outputs,masks = [],[],[]\n        for i in range(0, sl, self.bptt):\n            r, o = self.module(input[:,i: min(i+self.bptt, sl)])\n            if i>(sl-self.max_len):\n                masks.append(input[:,i: min(i+self.bptt, sl)] == self.pad_idx)\n                raw_outputs.append(r)\n                outputs.append(o)\n        return self.concat(raw_outputs),self.concat(outputs),torch.cat(masks,dim=1)\n\ndef get_text_classifier(arch:Callable, vocab_sz:int, n_class:int, bptt:int=70, max_len:int=20*70, config:dict=None,\n                        drop_mult:float=1., lin_ftrs:Collection[int]=None, ps:Collection[float]=None,\n                        pad_idx:int=1) -> nn.Module:\n    \"Create a text classifier from `arch` and its `config`, maybe `pretrained`.\"\n    meta = _model_meta[arch]\n    config = ifnone(config, meta['config_clas']).copy()\n    for k in config.keys():\n        if k.endswith('_p'): config[k] *= drop_mult\n    if lin_ftrs is None: lin_ftrs = [50]\n    if ps is None:  ps = [0.1]*len(lin_ftrs)\n    layers = [config[meta['hid_name']] * 3] + lin_ftrs + [n_class]\n    ps = [config.pop('output_p')] + ps\n    init = config.pop('init') if 'init' in config else None\n    encoder = MultiBatchEncoder(bptt, max_len, arch(vocab_sz, **config), pad_idx=pad_idx)\n    model = SequentialRNN(encoder, PoolingLinearClassifier(layers, ps))\n    return model if init is None else model.apply(init)\n\ndef text_classifier_learner(data:DataBunch, arch:Callable, bptt:int=70, max_len:int=70*20, config:dict=None,\n                            pretrained:bool=True, drop_mult:float=1., lin_ftrs:Collection[int]=None,\n                            ps:Collection[float]=None, **learn_kwargs) -> 'TextClassifierLearner':\n    \"Create a `Learner` with a text classifier from `data` and `arch`.\"\n    model = get_text_classifier(arch, len(data.vocab.itos), data.c, bptt=bptt, max_len=max_len,\n                                config=config, drop_mult=drop_mult, lin_ftrs=lin_ftrs, ps=ps)\n    meta = _model_meta[arch]\n    learn = RNNLearner(data, model, split_func=meta['split_clas'], **learn_kwargs)\n    if pretrained:\n        if 'url' not in meta:\n            warn(\"There are no pretrained weights for that architecture yet!\")\n            return learn\n        model_path = untar_data(meta['url'], data=False)\n        fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]\n        learn.load_pretrained(*fnames, strict=False)\n        learn.freeze()\n    return learn\n"
  },
  {
    "path": "fastai/text/models/__init__.py",
    "content": "from .awd_lstm import *\nfrom .transformer import *\n__all__ = [*awd_lstm.__all__, *transformer.__all__]\n"
  },
  {
    "path": "fastai/text/models/awd_lstm.py",
    "content": "from ...torch_core import *\nfrom ...layers import *\nfrom ...train import ClassificationInterpretation\nfrom ...basic_train import *\nfrom ...basic_data import *\nfrom ..data import TextClasDataBunch\nimport matplotlib.cm as cm\n\n__all__ = ['EmbeddingDropout', 'LinearDecoder', 'AWD_LSTM', 'RNNDropout',\n           'SequentialRNN', 'WeightDropout', 'dropout_mask', 'awd_lstm_lm_split', 'awd_lstm_clas_split',\n           'awd_lstm_lm_config', 'awd_lstm_clas_config', 'TextClassificationInterpretation']\n\ndef dropout_mask(x:Tensor, sz:Collection[int], p:float):\n    \"Return a dropout mask of the same type as `x`, size `sz`, with probability `p` to cancel an element.\"\n    return x.new(*sz).bernoulli_(1-p).div_(1-p)\n\nclass RNNDropout(Module):\n    \"Dropout with probability `p` that is consistent on the seq_len dimension.\"\n\n    def __init__(self, p:float=0.5): self.p=p\n\n    def forward(self, x:Tensor)->Tensor:\n        if not self.training or self.p == 0.: return x\n        m = dropout_mask(x.data, (x.size(0), 1, x.size(2)), self.p)\n        return x * m\n\nclass WeightDropout(Module):\n    \"A module that warps another layer in which some weights will be replaced by 0 during training.\"\n\n    def __init__(self, module:nn.Module, weight_p:float, layer_names:Collection[str]=['weight_hh_l0']):\n        self.module,self.weight_p,self.layer_names = module,weight_p,layer_names\n        for layer in self.layer_names:\n            #Makes a copy of the weights of the selected layers.\n            w = getattr(self.module, layer)\n            self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))\n            self.module._parameters[layer] = F.dropout(w, p=self.weight_p, training=False)\n\n    def _setweights(self):\n        \"Apply dropout to the raw weights.\"\n        for layer in self.layer_names:\n            raw_w = getattr(self, f'{layer}_raw')\n            self.module._parameters[layer] = F.dropout(raw_w, p=self.weight_p, training=self.training)\n\n    def forward(self, *args:ArgStar):\n        self._setweights()\n        with warnings.catch_warnings():\n            #To avoid the warning that comes because the weights aren't flattened.\n            warnings.simplefilter(\"ignore\")\n            return self.module.forward(*args)\n\n    def reset(self):\n        for layer in self.layer_names:\n            raw_w = getattr(self, f'{layer}_raw')\n            self.module._parameters[layer] = F.dropout(raw_w, p=self.weight_p, training=False)\n        if hasattr(self.module, 'reset'): self.module.reset()\n\nclass EmbeddingDropout(Module):\n    \"Apply dropout with probabily `embed_p` to an embedding layer `emb`.\"\n\n    def __init__(self, emb:nn.Module, embed_p:float):\n        self.emb,self.embed_p = emb,embed_p\n        self.pad_idx = self.emb.padding_idx\n        if self.pad_idx is None: self.pad_idx = -1\n\n    def forward(self, words:LongTensor, scale:Optional[float]=None)->Tensor:\n        if self.training and self.embed_p != 0:\n            size = (self.emb.weight.size(0),1)\n            mask = dropout_mask(self.emb.weight.data, size, self.embed_p)\n            masked_embed = self.emb.weight * mask\n        else: masked_embed = self.emb.weight\n        if scale: masked_embed.mul_(scale)\n        return F.embedding(words, masked_embed, self.pad_idx, self.emb.max_norm,\n                           self.emb.norm_type, self.emb.scale_grad_by_freq, self.emb.sparse)\n\nclass AWD_LSTM(Module):\n    \"AWD-LSTM/QRNN inspired by https://arxiv.org/abs/1708.02182.\"\n\n    initrange=0.1\n\n    def __init__(self, vocab_sz:int, emb_sz:int, n_hid:int, n_layers:int, pad_token:int=1, hidden_p:float=0.2,\n                 input_p:float=0.6, embed_p:float=0.1, weight_p:float=0.5, qrnn:bool=False, bidir:bool=False):\n        self.bs,self.qrnn,self.emb_sz,self.n_hid,self.n_layers = 1,qrnn,emb_sz,n_hid,n_layers\n        self.n_dir = 2 if bidir else 1\n        self.encoder = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_token)\n        self.encoder_dp = EmbeddingDropout(self.encoder, embed_p)\n        if self.qrnn:\n            #Using QRNN requires an installation of cuda\n            from .qrnn import QRNN\n            self.rnns = [QRNN(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz)//self.n_dir, 1,\n                              save_prev_x=True, zoneout=0, window=2 if l == 0 else 1, output_gate=True, bidirectional=bidir) \n                         for l in range(n_layers)]\n            for rnn in self.rnns: \n                rnn.layers[0].linear = WeightDropout(rnn.layers[0].linear, weight_p, layer_names=['weight'])\n        else:\n            self.rnns = [nn.LSTM(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz)//self.n_dir, 1,\n                                 batch_first=True, bidirectional=bidir) for l in range(n_layers)]\n            self.rnns = [WeightDropout(rnn, weight_p) for rnn in self.rnns]\n        self.rnns = nn.ModuleList(self.rnns)\n        self.encoder.weight.data.uniform_(-self.initrange, self.initrange)\n        self.input_dp = RNNDropout(input_p)\n        self.hidden_dps = nn.ModuleList([RNNDropout(hidden_p) for l in range(n_layers)])\n\n    def forward(self, input:Tensor, from_embeddings:bool=False)->Tuple[Tensor,Tensor]:\n        if from_embeddings: bs,sl,es = input.size()\n        else: bs,sl = input.size()\n        if bs!=self.bs:\n            self.bs=bs\n            self.reset()\n        raw_output = self.input_dp(input if from_embeddings else self.encoder_dp(input))\n        new_hidden,raw_outputs,outputs = [],[],[]\n        for l, (rnn,hid_dp) in enumerate(zip(self.rnns, self.hidden_dps)):\n            raw_output, new_h = rnn(raw_output, self.hidden[l])\n            new_hidden.append(new_h)\n            raw_outputs.append(raw_output)\n            if l != self.n_layers - 1: raw_output = hid_dp(raw_output)\n            outputs.append(raw_output)\n        self.hidden = to_detach(new_hidden, cpu=False)\n        return raw_outputs, outputs\n\n    def _one_hidden(self, l:int)->Tensor:\n        \"Return one hidden state.\"\n        nh = (self.n_hid if l != self.n_layers - 1 else self.emb_sz) // self.n_dir\n        return one_param(self).new(self.n_dir, self.bs, nh).zero_()\n\n    def select_hidden(self, idxs):\n        if self.qrnn: self.hidden = [h[:,idxs,:] for h in self.hidden]\n        else: self.hidden = [(h[0][:,idxs,:],h[1][:,idxs,:]) for h in self.hidden]\n        self.bs = len(idxs)\n\n    def reset(self):\n        \"Reset the hidden states.\"\n        [r.reset() for r in self.rnns if hasattr(r, 'reset')]\n        if self.qrnn: self.hidden = [self._one_hidden(l) for l in range(self.n_layers)]\n        else: self.hidden = [(self._one_hidden(l), self._one_hidden(l)) for l in range(self.n_layers)]\n\nclass LinearDecoder(Module):\n    \"To go on top of a RNNCore module and create a Language Model.\"\n    initrange=0.1\n\n    def __init__(self, n_out:int, n_hid:int, output_p:float, tie_encoder:nn.Module=None, bias:bool=True):\n        self.decoder = nn.Linear(n_hid, n_out, bias=bias)\n        self.decoder.weight.data.uniform_(-self.initrange, self.initrange)\n        self.output_dp = RNNDropout(output_p)\n        if bias: self.decoder.bias.data.zero_()\n        if tie_encoder: self.decoder.weight = tie_encoder.weight\n\n    def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]:\n        raw_outputs, outputs = input\n        output = self.output_dp(outputs[-1])\n        decoded = self.decoder(output)\n        return decoded, raw_outputs, outputs\n\nclass SequentialRNN(nn.Sequential):\n    \"A sequential module that passes the reset call to its children.\"\n    def reset(self):\n        for c in self.children():\n            if hasattr(c, 'reset'): c.reset()\n\ndef awd_lstm_lm_split(model:nn.Module) -> List[nn.Module]:\n    \"Split a RNN `model` in groups for differential learning rates.\"\n    groups = [[rnn, dp] for rnn, dp in zip(model[0].rnns, model[0].hidden_dps)]\n    return groups + [[model[0].encoder, model[0].encoder_dp, model[1]]]\n\ndef awd_lstm_clas_split(model:nn.Module) -> List[nn.Module]:\n    \"Split a RNN `model` in groups for differential learning rates.\"\n    groups = [[model[0].module.encoder, model[0].module.encoder_dp]]\n    groups += [[rnn, dp] for rnn, dp in zip(model[0].module.rnns, model[0].module.hidden_dps)]\n    return groups + [[model[1]]]\n\nawd_lstm_lm_config = dict(emb_sz=400, n_hid=1152, n_layers=3, pad_token=1, qrnn=False, bidir=False, output_p=0.1,\n                          hidden_p=0.15, input_p=0.25, embed_p=0.02, weight_p=0.2, tie_weights=True, out_bias=True)\n\nawd_lstm_clas_config = dict(emb_sz=400, n_hid=1152, n_layers=3, pad_token=1, qrnn=False, bidir=False, output_p=0.4,\n                       hidden_p=0.3, input_p=0.4, embed_p=0.05, weight_p=0.5)\n\ndef value2rgba(x:float, cmap:Callable=cm.RdYlGn, alpha_mult:float=1.0)->Tuple:\n    \"Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`.\"\n    c = cmap(x)\n    rgb = (np.array(c[:-1]) * 255).astype(int)\n    a = c[-1] * alpha_mult\n    return tuple(rgb.tolist() + [a])\n\ndef piece_attn_html(pieces:List[str], attns:List[float], sep:str=' ', **kwargs)->str:\n    html_code,spans = ['<span style=\"font-family: monospace;\">'], []\n    for p, a in zip(pieces, attns):\n        p = html.escape(p)\n        c = str(value2rgba(a, alpha_mult=0.5, **kwargs))\n        spans.append(f'<span title=\"{a:.3f}\" style=\"background-color: rgba{c};\">{p}</span>')\n    html_code.append(sep.join(spans))\n    html_code.append('</span>')\n    return ''.join(html_code)\n\ndef show_piece_attn(*args, **kwargs):\n    from IPython.display import display, HTML\n    display(HTML(piece_attn_html(*args, **kwargs)))\n\ndef _eval_dropouts(mod):\n        module_name =  mod.__class__.__name__\n        if 'Dropout' in module_name or 'BatchNorm' in module_name: mod.training = False\n        for module in mod.children(): _eval_dropouts(module)\n            \nclass TextClassificationInterpretation(ClassificationInterpretation):\n    \"\"\"Provides an interpretation of classification based on input sensitivity.\n    This was designed for AWD-LSTM only for the moment, because Transformer already has its own attentional model.\n    \"\"\"\n\n    def __init__(self, learn: Learner, preds: Tensor, y_true: Tensor, losses: Tensor, ds_type: DatasetType = DatasetType.Valid):\n        super().__init__(learn,preds,y_true,losses,ds_type)\n        self.model = learn.model\n\n    def intrinsic_attention(self, text:str, class_id:int=None):\n        \"\"\"Calculate the intrinsic attention of the input w.r.t to an output `class_id`, or the classification given by the model if `None`.\n        For reference, see the Sequential Jacobian session at https://www.cs.toronto.edu/~graves/preprint.pdf\n        \"\"\"\n        self.model.train()\n        _eval_dropouts(self.model)\n        self.model.zero_grad()\n        self.model.reset()\n        ids = self.data.one_item(text)[0]\n        emb = self.model[0].module.encoder(ids).detach().requires_grad_(True)\n        lstm_output = self.model[0].module(emb, from_embeddings=True)\n        self.model.eval()\n        cl = self.model[1](lstm_output + (torch.zeros_like(ids).byte(),))[0].softmax(dim=-1)\n        if class_id is None: class_id = cl.argmax()\n        cl[0][class_id].backward()\n        attn = emb.grad.squeeze().abs().sum(dim=-1)\n        attn /= attn.max()\n        tokens = self.data.single_ds.reconstruct(ids[0])\n        return tokens, attn\n\n    def html_intrinsic_attention(self, text:str, class_id:int=None, **kwargs)->str:\n        text, attn = self.intrinsic_attention(text, class_id)\n        return piece_attn_html(text.text.split(), to_np(attn), **kwargs)\n\n    def show_intrinsic_attention(self, text:str, class_id:int=None, **kwargs)->None:\n        text, attn = self.intrinsic_attention(text, class_id)\n        show_piece_attn(text.text.split(), to_np(attn), **kwargs)\n\n    def show_top_losses(self, k:int, max_len:int=70)->None:\n        \"\"\"\n        Create a tabulation showing the first `k` texts in top_losses along with their prediction, actual,loss, and probability of\n        actual class. `max_len` is the maximum number of tokens displayed.\n        \"\"\"\n        from IPython.display import display, HTML\n        items = []\n        tl_val,tl_idx = self.top_losses()\n        for i,idx in enumerate(tl_idx):\n            if k <= 0: break\n            k -= 1\n            tx,cl = self.data.dl(self.ds_type).dataset[idx]\n            cl = cl.data\n            classes = self.data.classes\n            txt = ' '.join(tx.text.split(' ')[:max_len]) if max_len is not None else tx.text\n            tmp = [txt, f'{classes[self.pred_class[idx]]}', f'{classes[cl]}', f'{self.losses[idx]:.2f}',\n                   f'{self.preds[idx][cl]:.2f}']\n            items.append(tmp)\n        items = np.array(items)\n        names = ['Text', 'Prediction', 'Actual', 'Loss', 'Probability']\n        df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)\n        with pd.option_context('display.max_colwidth', -1):\n            display(HTML(df.to_html(index=False)))\n"
  },
  {
    "path": "fastai/text/models/bwd_forget_mult_cuda.cpp",
    "content": "#include <torch/torch.h>\n\n#include <vector>\n\n// CUDA forward declarations\nat::Tensor bwd_forget_mult_cuda_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first);\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nat::Tensor bwd_forget_mult_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first) {\n  CHECK_INPUT(x); CHECK_INPUT(f); CHECK_INPUT(output);\n  return bwd_forget_mult_cuda_forward(x, f, output, batch_first);\n}\n\nstd::vector<at::Tensor> bwd_forget_mult_cuda_backward(at::Tensor x, at::Tensor f, at::Tensor output,\n                at::Tensor grad_output, bool batch_first);\n\nstd::vector<at::Tensor> bwd_forget_mult_backward(at::Tensor x, at::Tensor f, at::Tensor output,\n                at::Tensor grad_output, bool batch_first) {\n  CHECK_INPUT(x); CHECK_INPUT(f); CHECK_INPUT(output);\n  return bwd_forget_mult_cuda_backward(x, f, output, grad_output, batch_first);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &bwd_forget_mult_forward, \"BwdForgetMult forward (CUDA)\");\n  m.def(\"backward\", &bwd_forget_mult_backward, \"BwdForgetMult backward (CUDA)\");\n}\n"
  },
  {
    "path": "fastai/text/models/bwd_forget_mult_cuda_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <THC/THC.h>\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n\n#include <vector>\n\ntemplate <typename scalar_t>\n__global__ void bwd_forget_mult_cuda_forward_kernel(const scalar_t* __restrict__ x,\n                const scalar_t* __restrict__ f, scalar_t* __restrict__ output,\n                size_t batch_size, size_t seq_length, size_t n_hidden, bool batch_first) {\n  /*\n  Note: output is assumed to be one timestep longer than f or x where output[seq_length] = h_{+1}\n  This means output array has a size of seq_length+1 on the word dimension\n  */\n  const int hid = blockIdx.x * blockDim.x + threadIdx.x;\n  const int bid = blockIdx.y * blockDim.y + threadIdx.y;\n  if (hid < n_hidden && bid < batch_size){\n    for (int ts = seq_length-1; ts >= 0; ts--) {\n      int i          = 0;\n      int dst_i      = 0;\n      int dst_iplus1 = 0;\n      if (batch_first){\n        i          = bid * n_hidden * seq_length     + (ts+0) * n_hidden + hid;\n        dst_i      = bid * n_hidden * (seq_length+1) + (ts+0) * n_hidden + hid;\n        dst_iplus1 = bid * n_hidden * (seq_length+1) + (ts+1) * n_hidden + hid;\n      }\n      else {\n        i          = (ts+0) * n_hidden * batch_size  + bid * n_hidden + hid;\n        dst_i      = (ts+0) * n_hidden * batch_size  + bid * n_hidden + hid;\n        dst_iplus1 = (ts+1) * n_hidden * batch_size  + bid * n_hidden + hid;\n      }\n      output[dst_i]   = f[i] * x[i];\n      output[dst_i]  += (1 - f[i]) * output[dst_iplus1];\n    }\n  }\n}\n\ntemplate <typename scalar_t>\n__global__ void bwd_forget_mult_cuda_backward_kernel(const scalar_t* __restrict__ x,\n                const scalar_t* __restrict__ f, const scalar_t* __restrict__ output,\n                const scalar_t* __restrict__ grad_output, scalar_t* __restrict__ grad_x,\n                scalar_t* __restrict__ grad_f, scalar_t* __restrict__ grad_h,\n                size_t batch_size, size_t seq_length, size_t n_hidden, bool batch_first) {\n  const int hid = blockIdx.x * blockDim.x + threadIdx.x;\n  const int bid = blockIdx.y * blockDim.y + threadIdx.y;\n  double running_f = 0;\n  if(hid < n_hidden && bid < batch_size){\n    for (int ts = 0; ts < seq_length; ts++) {\n      int i          = 0;\n      int dst_iplus1 = 0;\n      if (batch_first){\n        i          = bid * n_hidden * seq_length     + (ts+0) * n_hidden + hid;\n        dst_iplus1 = bid * n_hidden * (seq_length+1) + (ts+1) * n_hidden + hid;\n      }\n      else {\n        i          = (ts+0) * n_hidden * batch_size  + bid * n_hidden + hid;\n        dst_iplus1 = (ts+1) * n_hidden * batch_size  + bid * n_hidden + hid;\n      }\n      running_f       += grad_output[i];\n      grad_x[i]       = f[i] * running_f;\n      grad_f[i]       = (x[i] - output[dst_iplus1]) * running_f;\n      // The line below is likely more numerically stable than (1 - f[i]) * running_f;\n      running_f       = running_f - f[i] * running_f;\n    }\n    grad_h[bid * n_hidden + hid] = running_f;\n  }\n}\n\nat::Tensor bwd_forget_mult_cuda_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first) {\n  const auto batch_size = (batch_first) ? x.size(0) : x.size(1);\n  const auto seq_length = (batch_first) ? x.size(1) : x.size(0);\n  const auto n_hidden   = x.size(2);\n  \n  const int threads = 1024;\n  const dim3 blocks((n_hidden + threads - 1) / threads, batch_size);\n  AT_DISPATCH_FLOATING_TYPES(x.type(), \"bwd_forget_mult_cuda_forward\", ([&] {\n    bwd_forget_mult_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(\n        x.data<scalar_t>(), f.data<scalar_t>(), output.data<scalar_t>(), batch_size,\n        seq_length, n_hidden, batch_first);\n  }));\n\n  THCudaCheck(cudaGetLastError());\n  return output;\n}\n\nstd::vector<at::Tensor> bwd_forget_mult_cuda_backward(at::Tensor x, at::Tensor f,\n                at::Tensor output, at::Tensor grad_output, bool batch_first) {\n  const auto batch_size = (batch_first) ? x.size(0) : x.size(1);\n  const auto seq_length = (batch_first) ? x.size(1) : x.size(0);\n  const auto n_hidden   = x.size(2);\n\n  auto grad_x = at::zeros_like(x);\n  auto grad_f = at::zeros_like(x);\n  auto grad_h = at::zeros({batch_size, n_hidden}, x.options());\n  \n  const int threads = 1024;\n  const dim3 blocks((n_hidden + threads - 1) / threads, batch_size);\n  AT_DISPATCH_FLOATING_TYPES(x.type(), \"bwd_forget_mult_cuda_forward\", ([&] {\n    bwd_forget_mult_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(\n        x.data<scalar_t>(), f.data<scalar_t>(), output.data<scalar_t>(), grad_output.data<scalar_t>(),\n        grad_x.data<scalar_t>(), grad_f.data<scalar_t>(), grad_h.data<scalar_t>(), batch_size,\n        seq_length, n_hidden, batch_first);\n  }));\n\n  THCudaCheck(cudaGetLastError());\n  return {grad_x, grad_f, grad_h};\n}\n\n"
  },
  {
    "path": "fastai/text/models/forget_mult_cuda.cpp",
    "content": "#include <torch/torch.h>\n\n#include <vector>\n\n// CUDA forward declarations\nat::Tensor forget_mult_cuda_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first);\n\n// C++ interface\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nat::Tensor forget_mult_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first) {\n  CHECK_INPUT(x); CHECK_INPUT(f); CHECK_INPUT(output);\n  return forget_mult_cuda_forward(x, f, output, batch_first);\n}\n\nstd::vector<at::Tensor> forget_mult_cuda_backward(at::Tensor x, at::Tensor f, at::Tensor output,\n                at::Tensor grad_output, bool batch_first);\n\nstd::vector<at::Tensor> forget_mult_backward(at::Tensor x, at::Tensor f, at::Tensor output,\n                at::Tensor grad_output, bool batch_first) {\n  CHECK_INPUT(x); CHECK_INPUT(f); CHECK_INPUT(output);\n  return forget_mult_cuda_backward(x, f, output, grad_output, batch_first);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &forget_mult_forward, \"ForgetMult forward (CUDA)\");\n  m.def(\"backward\", &forget_mult_backward, \"ForgetMult backward (CUDA)\");\n}\n"
  },
  {
    "path": "fastai/text/models/forget_mult_cuda_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <THC/THC.h>\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n\n#include <vector>\n\ntemplate <typename scalar_t>\n__global__ void forget_mult_cuda_forward_kernel(const scalar_t* __restrict__ x,\n                const scalar_t* __restrict__ f, scalar_t* __restrict__ output,\n                size_t batch_size, size_t seq_length, size_t n_hidden, bool batch_first) {\n  /*\n  Note: output is assumed to be one timestep longer than f or x where output[0] = h_{-1}\n  This means output array has a size of seq_length+1 on the word dimension\n  */\n  const int hid = blockIdx.x * blockDim.x + threadIdx.x;\n  const int bid = blockIdx.y * blockDim.y + threadIdx.y;\n  if (hid < n_hidden && bid < batch_size){\n    for (int ts = 1; ts < seq_length + 1; ts++) {\n      int i           = 0;\n      int dst_i       = 0;\n      int dst_iminus1 = 0;\n      if (batch_first){\n        i           = bid * n_hidden * seq_length     + (ts-1) * n_hidden + hid;\n        dst_i       = bid * n_hidden * (seq_length+1) + (ts-0) * n_hidden + hid;\n        dst_iminus1 = bid * n_hidden * (seq_length+1) + (ts-1) * n_hidden + hid;\n      }\n      else {\n        i           = (ts-1) * n_hidden * batch_size  + bid * n_hidden + hid;\n        dst_i       = (ts-0) * n_hidden * batch_size  + bid * n_hidden + hid;\n        dst_iminus1 = (ts-1) * n_hidden * batch_size  + bid * n_hidden + hid;\n      }\n      output[dst_i]   = f[i] * x[i];\n      output[dst_i]  += (1 - f[i]) * output[dst_iminus1];\n    }\n  }\n}\n\ntemplate <typename scalar_t>\n__global__ void forget_mult_cuda_backward_kernel(const scalar_t* __restrict__ x,\n                const scalar_t* __restrict__ f, const scalar_t* __restrict__ output,\n                const scalar_t* __restrict__ grad_output, scalar_t* __restrict__ grad_x,\n                scalar_t* __restrict__ grad_f, scalar_t* __restrict__ grad_h,\n                size_t batch_size, size_t seq_length, size_t n_hidden, bool batch_first) {\n  const int hid = blockIdx.x * blockDim.x + threadIdx.x;\n  const int bid = blockIdx.y * blockDim.y + threadIdx.y;\n  double running_f = 0;\n  if(hid < n_hidden && bid < batch_size){\n    for (int ts = seq_length; ts >= 0 + 1; ts--) {\n      int i           = 0;\n      int dst_i       = 0;\n      int dst_iminus1 = 0;\n      if (batch_first){\n        i           = bid * n_hidden * seq_length     + (ts-1) * n_hidden + hid;\n        dst_i       = bid * n_hidden * (seq_length+1) + (ts-0) * n_hidden + hid;\n        dst_iminus1 = bid * n_hidden * (seq_length+1) + (ts-1) * n_hidden + hid;\n      }\n      else {\n        i           = (ts-1) * n_hidden * batch_size  + bid * n_hidden + hid;\n        dst_i       = (ts-0) * n_hidden * batch_size  + bid * n_hidden + hid;\n        dst_iminus1 = (ts-1) * n_hidden * batch_size  + bid * n_hidden + hid;\n      }\n      running_f       += grad_output[i];\n      grad_x[i]       = f[i] * running_f;\n      grad_f[i]       = (x[i] - output[dst_iminus1]) * running_f;\n      // The line below is likely more numerically stable than (1 - f[i]) * running_f;\n      running_f       = running_f - f[i] * running_f;\n    }\n    grad_h[bid * n_hidden + hid] = running_f;\n  }\n}\n\nat::Tensor forget_mult_cuda_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first) {\n  const auto batch_size = (batch_first) ? x.size(0) : x.size(1);\n  const auto seq_length = (batch_first) ? x.size(1) : x.size(0);\n  const auto n_hidden   = x.size(2);\n  \n  const int threads = 1024;\n  const dim3 blocks((n_hidden + threads - 1) / threads, batch_size);\n  AT_DISPATCH_FLOATING_TYPES(x.type(), \"forget_mult_cuda_forward\", ([&] {\n    forget_mult_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(\n        x.data<scalar_t>(), f.data<scalar_t>(), output.data<scalar_t>(), batch_size,\n        seq_length, n_hidden, batch_first);\n  }));\n\n  THCudaCheck(cudaGetLastError());\n  return output;\n}\n\nstd::vector<at::Tensor> forget_mult_cuda_backward(at::Tensor x, at::Tensor f,\n                at::Tensor output, at::Tensor grad_output, bool batch_first) {\n  const auto batch_size = (batch_first) ? x.size(0) : x.size(1);\n  const auto seq_length = (batch_first) ? x.size(1) : x.size(0);\n  const auto n_hidden   = x.size(2);\n\n  auto grad_x = at::zeros_like(x);\n  auto grad_f = at::zeros_like(x);\n  auto grad_h = at::zeros({batch_size, n_hidden}, x.options());\n  \n  const int threads = 1024;\n  const dim3 blocks((n_hidden + threads - 1) / threads, batch_size);\n  AT_DISPATCH_FLOATING_TYPES(x.type(), \"forget_mult_cuda_forward\", ([&] {\n    forget_mult_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(\n        x.data<scalar_t>(), f.data<scalar_t>(), output.data<scalar_t>(), grad_output.data<scalar_t>(),\n        grad_x.data<scalar_t>(), grad_f.data<scalar_t>(), grad_h.data<scalar_t>(), batch_size,\n        seq_length, n_hidden, batch_first);\n  }));\n\n  THCudaCheck(cudaGetLastError());\n  return {grad_x, grad_f, grad_h};\n}\n\n"
  },
  {
    "path": "fastai/text/models/qrnn.py",
    "content": "from ...torch_core import *\nfrom torch.utils.cpp_extension import load\nfrom torch.autograd import Function\n\n__all__ = ['QRNNLayer', 'QRNN']\n\nimport fastai\nif torch.cuda.is_available():\n    fastai_path = Path(fastai.__path__[0])/'text'/'models'\n    files = ['forget_mult_cuda.cpp', 'forget_mult_cuda_kernel.cu']\n    forget_mult_cuda = load(name='forget_mult_cuda', sources=[fastai_path/f for f in files])\n    files = ['bwd_forget_mult_cuda.cpp', 'bwd_forget_mult_cuda_kernel.cu']\n    bwd_forget_mult_cuda = load(name='bwd_forget_mult_cuda', sources=[fastai_path/f for f in files])\n\ndef dispatch_cuda(cuda_class, cpu_func, x):\n    return cuda_class.apply if x.device.type == 'cuda' else cpu_func\n    \nclass ForgetMultGPU(Function):\n    \n    @staticmethod\n    def forward(ctx, x:Tensor, f:Tensor, hidden_init:Optional[Tensor]=None, batch_first:bool=True):\n        if batch_first:\n            batch_size, seq_size, hidden_size = f.size()\n            output = f.new_zeros(batch_size, seq_size + 1, hidden_size)\n            if hidden_init is not None: output[:, 0] = hidden_init\n            else: output.zero_()\n        else: \n            seq_size, batch_size, hidden_size = f.size()\n            output = f.new(seq_size + 1, batch_size, hidden_size)\n            if hidden_init is not None: output[0] = hidden_init\n            else: output.zero_()\n        output = forget_mult_cuda.forward(x, f, output, batch_first)\n        ctx.save_for_backward(x, f, hidden_init, output)\n        ctx.batch_first = batch_first\n        return output[:,1:] if batch_first else output[1:]\n    \n    @staticmethod\n    def backward(ctx, grad_output):\n        x, f, hidden_init, output = ctx.saved_tensors\n        grad_x, grad_f, grad_h = forget_mult_cuda.backward(x, f, output, grad_output, ctx.batch_first)\n        return (grad_x, grad_f, (None if hidden_init is None else grad_h), None)\n    \nclass BwdForgetMultGPU(Function):\n    \n    @staticmethod\n    def forward(ctx, x:Tensor, f:Tensor, hidden_init:Optional[Tensor]=None, batch_first:bool=True):\n        if batch_first:\n            batch_size, seq_size, hidden_size = f.size()\n            output = f.new(batch_size, seq_size + 1, hidden_size)\n            if hidden_init is not None: output[:, -1] = hidden_init\n            else: output.zero_()\n        else: \n            seq_size, batch_size, hidden_size = f.size()\n            output = f.new(seq_size + 1, batch_size, hidden_size)\n            if hidden_init is not None: output[-1] = hidden_init\n            else: output.zero_()\n        output = bwd_forget_mult_cuda.forward(x, f, output, batch_first)\n        ctx.save_for_backward(x, f, hidden_init, output)\n        ctx.batch_first = batch_first\n        return output[:,:-1] if batch_first else output[:-1]\n    \n    @staticmethod\n    def backward(ctx, grad_output:Tensor):\n        x, f, hidden_init, output = ctx.saved_tensors\n        grad_x, grad_f, grad_h = bwd_forget_mult_cuda.backward(x, f, output, grad_output, ctx.batch_first)\n        return (grad_x, grad_f, (None if hidden_init is None else grad_h), None)\n    \ndef forget_mult_CPU(x:Tensor, f:Tensor, hidden_init:Optional[Tensor]=None, batch_first:bool=True, backward:bool=False):\n    result = []\n    dim = (1 if batch_first else 0)\n    forgets = f.split(1, dim=dim)\n    inputs =  x.split(1, dim=dim)\n    prev_h = None if hidden_init is None else hidden_init.unsqueeze(1 if batch_first else 0)\n    idx_range = range(len(inputs)-1,-1,-1) if backward else range(len(inputs))\n    for i in idx_range:\n        prev_h = inputs[i] * forgets[i] if prev_h is None else inputs[i] * forgets[i] + (1-forgets[i]) * prev_h\n        if backward: result.insert(0, prev_h)\n        else:        result.append(prev_h)\n    return torch.cat(result, dim=dim)\n\nclass QRNNLayer(Module):\n    \"Apply a single layer Quasi-Recurrent Neural Network (QRNN) to an input sequence.\"\n\n    def __init__(self, input_size:int, hidden_size:int=None, save_prev_x:bool=False, zoneout:float=0, window:int=1, \n                 output_gate:bool=True, batch_first:bool=True, backward:bool=False):\n        super().__init__()\n        assert window in [1, 2], \"This QRNN implementation currently only handles convolutional window of size 1 or size 2\"\n        self.save_prev_x,self.zoneout,self.window = save_prev_x,zoneout,window\n        self.output_gate,self.batch_first,self.backward = output_gate,batch_first,backward\n        hidden_size = ifnone(hidden_size, input_size)\n        #One large matmul with concat is faster than N small matmuls and no concat\n        mult = (3 if output_gate else 2)\n        self.linear = nn.Linear(window * input_size, mult * hidden_size)\n        self.prevX = None\n\n    def reset(self):\n        # If you are saving the previous value of x, you should call this when starting with a new state\n        self.prevX = None\n        \n    def forward(self, inp, hid=None):\n        y = self.linear(self._get_source(inp))\n        if self.output_gate: z_gate,f_gate,o_gate = y.chunk(3, dim=2)\n        else:                z_gate,f_gate        = y.chunk(2, dim=2)\n        z_gate.tanh_()\n        f_gate.sigmoid_()\n        if self.zoneout and self.training:\n            mask = dropout_mask(f_gate, f_gate.size(), self.zoneout).requires_grad_(False)\n            f_gate = f_gate * mask\n        z_gate,f_gate = z_gate.contiguous(),f_gate.contiguous()\n        if self.backward: forget_mult = dispatch_cuda(BwdForgetMultGPU, partial(forget_mult_CPU, backward=True), inp)\n        else:             forget_mult = dispatch_cuda(ForgetMultGPU, forget_mult_CPU, inp)\n        c_gate = forget_mult(z_gate, f_gate, hid, self.batch_first)\n        output = torch.sigmoid(o_gate) * c_gate if self.output_gate else c_gate\n        if self.window > 1 and self.save_prev_x: \n            if self.backward: self.prevX = (inp[:, :1] if self.batch_first else inp[:1]).detach()\n            else:             self.prevX = (inp[:, -1:] if self.batch_first else inp[-1:]).detach()\n        idx = 0 if self.backward else -1\n        return output, (c_gate[:, idx] if self.batch_first else c_gate[idx])\n\n    def _get_source(self, inp):\n        if self.window == 1: return inp\n        dim = (1 if self.batch_first else 0)\n        inp_shift = [torch.zeros_like(inp[:,:1] if self.batch_first else inp[:1]) if self.prevX is None else self.prevX]\n        if self.backward: inp_shift.insert(0,inp[:,1:] if self.batch_first else inp[1:])\n        else:             inp_shift.append(inp[:,:-1] if self.batch_first else inp[:-1])\n        inp_shift = torch.cat(inp_shift, dim)\n        return torch.cat([inp, inp_shift], 2)\n    \nclass QRNN(Module):\n    \"Apply a multiple layer Quasi-Recurrent Neural Network (QRNN) to an input sequence.\"\n\n    def __init__(self, input_size:int, hidden_size:int, n_layers:int=1, bias:bool=True, batch_first:bool=True,\n                 dropout:float=0, bidirectional:bool=False, save_prev_x:bool=False, zoneout:float=0, window:int=None, \n                 output_gate:bool=True):\n        assert not (save_prev_x and bidirectional), \"Can't save the previous X with bidirectional.\"\n        assert bias == True, 'Removing underlying bias is not yet supported'\n        super().__init__()\n        kwargs = dict(batch_first=batch_first, zoneout=zoneout, output_gate=output_gate)\n        self.layers = nn.ModuleList([QRNNLayer(input_size if l == 0 else hidden_size, hidden_size, save_prev_x=save_prev_x, \n                                               window=((2 if l ==0 else 1) if window is None else window), **kwargs) \n                                     for l in range(n_layers)])\n        if bidirectional:\n            self.layers_bwd = nn.ModuleList([QRNNLayer(input_size if l == 0 else hidden_size, hidden_size, \n                                                       backward=True, window=((2 if l ==0 else 1) if window is None else window), \n                                                       **kwargs) for l in range(n_layers)])\n        self.n_layers,self.batch_first,self.dropout,self.bidirectional = n_layers,batch_first,dropout,bidirectional\n        \n    def reset(self):\n        \"If your convolutional window is greater than 1 and you save previous xs, you must reset at the beginning of each new sequence.\"\n        for layer in self.layers:     layer.reset()\n        if self.bidirectional:\n            for layer in self.layers_bwd: layer.reset()    \n\n    def forward(self, inp, hid=None):\n        new_hid = []\n        if self.bidirectional: inp_bwd = inp.clone()\n        for i, layer in enumerate(self.layers):\n            inp, h = layer(inp, None if hid is None else hid[2*i if self.bidirectional else i])\n            new_hid.append(h)\n            if self.bidirectional:\n                inp_bwd, h_bwd = self.layers_bwd[i](inp_bwd, None if hid is None else hid[2*i+1])\n                new_hid.append(h_bwd)\n            if self.dropout != 0 and i < len(self.layers) - 1:\n                for o in ([inp, inp_bwd] if self.bidirectional else [inp]):\n                    o = F.dropout(o, p=self.dropout, training=self.training, inplace=False)\n        if self.bidirectional: inp = torch.cat([inp, inp_bwd], dim=2)\n        return inp, torch.stack(new_hid, 0)"
  },
  {
    "path": "fastai/text/models/transformer.py",
    "content": "from ...torch_core import *\nfrom ...layers import *\nfrom .awd_lstm import RNNDropout, LinearDecoder, SequentialRNN\n\n__all__ = ['Activation', 'PositionalEncoding', 'GeLU', 'Swish', 'feed_forward', 'MultiHeadAttention', 'MultiHeadRelativeAttention',\n           'DecoderLayer', 'Transformer', 'TransformerXL', 'tfmer_lm_config', 'tfmer_clas_config', 'tfmer_lm_split', 'tfmer_clas_split',\n           'tfmerXL_lm_config', 'tfmerXL_clas_config', 'tfmerXL_lm_split', 'tfmerXL_clas_split']\n\nActivation = Enum('Activation', 'ReLU Swish GeLU')\n\nclass PositionalEncoding(Module):\n    \"Encode the position with a sinusoid.\"\n    def __init__(self, d:int): self.register_buffer('freq', 1 / (10000 ** (torch.arange(0., d, 2.)/d)))\n\n    def forward(self, pos:Tensor):\n        inp = torch.ger(pos, self.freq)\n        enc = torch.cat([inp.sin(), inp.cos()], dim=-1)\n        return enc\n\nclass GeLU(Module):\n    def forward(self, x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))\n\nclass Swish(Module):\n    def forward(self, x): return x * torch.sigmoid(x)\n\n_activ_func = {Activation.ReLU:nn.ReLU(inplace=True), Activation.GeLU:GeLU(), Activation.Swish: Swish()}\n\ndef feed_forward(d_model:int, d_ff:int, ff_p:float=0., act:Activation=Activation.ReLU, double_drop:bool=True):\n    layers = [nn.Linear(d_model, d_ff), _activ_func[act]]\n    if double_drop: layers.append(nn.Dropout(ff_p))\n    return SequentialEx(*layers, nn.Linear(d_ff, d_model), nn.Dropout(ff_p), MergeLayer(), nn.LayerNorm(d_model))\n\nclass MultiHeadAttention(Module):\n    \"MutiHeadAttention.\"\n    def __init__(self, n_heads:int, d_model:int, d_head:int=None, resid_p:float=0., attn_p:float=0., bias:bool=True,\n                 scale:bool=True):\n        d_head = ifnone(d_head, d_model//n_heads)\n        self.n_heads,self.d_head,self.scale = n_heads,d_head,scale\n        self.attention = nn.Linear(d_model, 3 * n_heads * d_head, bias=bias)\n        self.out = nn.Linear(n_heads * d_head, d_model, bias=bias)\n        self.drop_att,self.drop_res = nn.Dropout(attn_p),nn.Dropout(resid_p)\n        self.ln = nn.LayerNorm(d_model)\n\n    def forward(self, x:Tensor, mask:Tensor=None, **kwargs):\n        return self.ln(x + self.drop_res(self.out(self._apply_attention(x, mask=mask, **kwargs))))\n\n    def _apply_attention(self, x:Tensor, mask:Tensor=None):\n        bs,x_len = x.size(0),x.size(1)\n        wq,wk,wv = torch.chunk(self.attention(x), 3, dim=-1)\n        wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))\n        wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3)\n        attn_score = torch.matmul(wq, wk)\n        if self.scale: attn_score.div_(self.d_head ** 0.5)\n        if mask is not None:\n            attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)\n        attn_prob = self.drop_att(F.softmax(attn_score, dim=-1))\n        attn_vec = torch.matmul(attn_prob, wv)\n        return attn_vec.permute(0, 2, 1, 3).contiguous().contiguous().view(bs, x_len, -1)\n\n    def _attention_einsum(self, x, mask=None):\n        # Permute and matmul is a little bit faster but this implementation is more readable\n        bs,x_len = x.size(0),x.size(1)\n        wq,wk,wv = torch.chunk(self.attention(x), 3, dim=-1)\n        wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))\n        attn_score = torch.einsum('bind,bjnd->bijn', (wq, wk))\n        if self.scale: attn_score.mul_(1/(self.d_head ** 0.5))\n        if mask is not None:\n            attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)\n        attn_prob = self.drop_att(F.softmax(attn_score, dim=2))\n        attn_vec = torch.einsum('bijn,bjnd->bind', (attn_prob, wv))\n        return attn_vec.contiguous().view(bs, x_len, -1)\n\n#def _line_shift1(x:Tensor, mask:bool=False):\n#    \"Shift the line i of `x` by p-i elements to the left, is `mask` puts 0s on the diagonal.\"\n#    bs,n,p,nh = x.size()\n#    x_pad = torch.cat([x.new_zeros(bs,n,1,nh), x], dim=2)\n#    x_shift = x_pad.view(bs,p + 1,n,nh)[:,1:].view_as(x)\n#    if mask: x_shift.mul_(torch.tril(x.new_ones(n,p), p-n)[None,:,:,None])\n#    return x_shift\n\ndef _line_shift(x:Tensor, mask:bool=False):\n    \"Shift the line i of `x` by p-i elements to the left, is `mask` puts 0s on the diagonal.\"\n    bs,nh,n,p = x.size()\n    x_pad = torch.cat([x.new_zeros(bs,nh,n,1), x], dim=3)\n    x_shift = x_pad.view(bs,nh,p + 1,n)[:,:,1:].view_as(x)\n    if mask: x_shift.mul_(torch.tril(x.new_ones(n,p), p-n)[None,None,])\n    return x_shift\n\nclass MultiHeadRelativeAttention(MultiHeadAttention):\n    \"MutiHeadAttention with relative positional encoding.\"\n\n    def __init__(self, n_heads:int, d_model:int, d_head:int, resid_p:float=0., attn_p:float=0., bias:bool=True,\n                 scale:bool=True):\n        super().__init__(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale)\n        self.r_attn = nn.Linear(d_model, n_heads * d_head, bias=bias)\n\n    def _apply_attention(self, x:Tensor, r:Tensor=None, u:Tensor=None, v:Tensor=None, mask:Tensor=None, mem:Tensor=None):\n        #Notations from the paper: x input, r vector of relative distance between two elements, u et v learnable\n        #parameters of the model common between all layers, mask to avoid cheating and mem the previous hidden states.\n        bs,x_len,seq_len = x.size(0),x.size(1),r.size(0)\n        context = x if mem is None else torch.cat([mem, x], dim=1)\n        wq,wk,wv = torch.chunk(self.attention(context), 3, dim=-1)\n        wq = wq[:,-x_len:]\n        wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))\n        wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3)\n        wkr = self.r_attn(r)\n        wkr = wkr.view(seq_len, self.n_heads, self.d_head)\n        wkr = wkr.permute(1,2,0)\n        #### compute attention score (AC is (a) + (c) and BS is (b) + (d) in the paper)\n        AC = torch.matmul(wq+u,wk)\n        BD = _line_shift(torch.matmul(wq+v, wkr))\n        if self.scale: attn_score = (AC + BD).mul_(1/(self.d_head ** 0.5))\n        if mask is not None:\n            attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)\n        attn_prob = self.drop_att(F.softmax(attn_score, dim=-1))\n        attn_vec = torch.matmul(attn_prob, wv)\n        return attn_vec.permute(0, 2, 1, 3).contiguous().view(bs, x_len, -1)\n\n    def _attention_einsum(self, x:Tensor, r:Tensor=None, u:Tensor=None, v:Tensor=None, mask:Tensor=None, mem:Tensor=None):\n        # Permute and matmul is a little bit faster but this implementation is more readable\n        bs,x_len,seq_len = x.size(0),x.size(1),r.size(0)\n        context = x if mem is None else torch.cat([mem, x], dim=1)\n        wq,wk,wv = torch.chunk(self.attention(context), 3, dim=-1)\n        wq = wq[:,-x_len:]\n        wkr = self.r_attn(r)\n        wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))\n        wkr = wkr.view(seq_len, self.n_heads, self.d_head)\n        #### compute attention score (AC is (a) + (c) and BS is (b) + (d) in the paper)\n        AC = torch.einsum('bind,bjnd->bijn', (wq+u, wk))\n        BD = _line_shift1(torch.einsum('bind,jnd->bijn', (wq+v, wkr)))\n        attn_score = (AC + BD).mul_(1/(self.d_head ** 0.5))\n        if mask is not None:\n            attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)\n        attn_prob = self.drop_att(F.softmax(attn_score, dim=2))\n        attn_vec = torch.einsum('bijn,bjnd->bind', (attn_prob, wv))\n        return attn_vec.contiguous().view(bs, x_len, -1)\n\nclass DecoderLayer(Module):\n    \"Basic block of a Transformer model.\"\n    #Can't use Sequential directly cause more than one input...\n    def __init__(self, n_heads:int, d_model:int, d_head:int, d_inner:int, resid_p:float=0., attn_p:float=0., ff_p:float=0.,\n                 bias:bool=True, scale:bool=True, act:Activation=Activation.ReLU, double_drop:bool=True,\n                 attn_cls:Callable=MultiHeadAttention):\n        self.mhra = attn_cls(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale)\n        self.ff   = feed_forward(d_model, d_inner, ff_p=ff_p, act=act, double_drop=double_drop)\n\n    def forward(self, x:Tensor, mask:Tensor=None, **kwargs): return self.ff(self.mhra(x, mask=mask, **kwargs))\n\nclass Transformer(Module):\n    \"Transformer model: https://arxiv.org/abs/1706.03762.\"\n    def __init__(self, vocab_sz:int, ctx_len:int, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int,\n                 resid_p:float=0., attn_p:float=0., ff_p:float=0., embed_p:float=0., bias:bool=True, scale:bool=True,\n                 act:Activation=Activation.ReLU, double_drop:bool=True, attn_cls:Callable=MultiHeadAttention,\n                 learned_pos_enc:bool=True, mask:bool=True):\n        self.mask = mask\n        self.encoder = nn.Embedding(vocab_sz, d_model)\n        self.pos_enc = nn.Embedding(ctx_len, d_model) if learned_pos_enc else PositionalEncoding(d_model)\n        self.drop_emb = nn.Dropout(embed_p)\n        self.layers = nn.ModuleList([DecoderLayer(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p,\n                      ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop,\n                      attn_cls=attn_cls) for k in range(n_layers)])\n\n    def reset(self): pass\n\n    def forward(self, x):\n        bs, x_len = x.size()\n        pos = torch.arange(0, x_len, device=x.device, dtype=x.dtype)\n        inp = self.drop_emb(self.encoder(x) + self.pos_enc(pos)[None]) #.mul_(self.d_model ** 0.5)\n        mask = torch.triu(x.new_ones(x_len, x_len), diagonal=1).byte()[None,None] if self.mask else None\n        #[None,:,:None] for einsum implementation of attention\n        for layer in self.layers: inp = layer(inp, mask=mask)\n        return ([inp],[inp]) #For the LinearDecoder\n\nclass TransformerXL(Module):\n    \"TransformerXL model: https://arxiv.org/abs/1901.02860.\"\n    def __init__(self, vocab_sz:int, ctx_len:int, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int,\n                 resid_p:float=0., attn_p:float=0., ff_p:float=0., embed_p:float=0., bias:bool=False, scale:bool=True,\n                 act:Activation=Activation.ReLU, double_drop:bool=True, attn_cls:Callable=MultiHeadRelativeAttention,\n                 learned_pos_enc:bool=False, mask:bool=True, mem_len:int=0):\n        self.encoder = nn.Embedding(vocab_sz, d_model)\n        self.pos_enc = nn.Embedding(ctx_len, d_model) if learned_pos_enc else PositionalEncoding(d_model)\n        self.drop_emb = nn.Dropout(embed_p)\n        self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention\n        self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention\n        self.mem_len,self.n_layers,self.d_model,self.mask = mem_len,n_layers,d_model,mask\n        self.init = False\n        self.layers = nn.ModuleList([DecoderLayer(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p,\n                      ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop,\n                      attn_cls=attn_cls) for k in range(n_layers)])\n\n    def reset(self):\n        \"Reset the internal memory.\"\n        self.hidden = [next(self.parameters()).data.new(0) for i in range(self.n_layers+1)]\n\n    def _update_mems(self, hids):\n        if not getattr(self, 'hidden', False): return None\n        assert len(hids) == len(self.hidden), 'len(hids) != len(self.hidden)'\n        with torch.no_grad():\n            for i in range(len(hids)):\n                cat = torch.cat([self.hidden[i], hids[i]], dim=1)\n                self.hidden[i] = cat[:,-self.mem_len:].detach()\n\n    def select_hidden(self, idxs): self.hidden = [h[idxs] for h in self.hidden]\n\n    def forward(self, x):\n        #The hidden state has to be initiliazed in the forward pass for nn.DataParallel\n        if self.mem_len > 0 and not self.init:\n            self.reset()\n            self.init = True\n        bs,x_len = x.size()\n        inp = self.drop_emb(self.encoder(x)) #.mul_(self.d_model ** 0.5)\n        m_len = self.hidden[0].size(1) if hasattr(self, 'hidden') and len(self.hidden[0].size()) > 1 else 0\n        seq_len = m_len + x_len\n        mask = torch.triu(x.new_ones(x_len, seq_len), diagonal=1+m_len).byte()[None,None] if self.mask else None\n        #[None,:,:None] for einsum implementation of attention\n        hids = []\n        pos = torch.arange(seq_len-1, -1, -1, device=inp.device, dtype=inp.dtype)\n        pos_enc = self.pos_enc(pos)\n        hids.append(inp)\n        for i, layer in enumerate(self.layers):\n            mem = self.hidden[i] if self.mem_len > 0 else None\n            inp = layer(inp, r=pos_enc, u=self.u, v=self.v, mask=mask, mem=mem)\n            hids.append(inp)\n        core_out = inp[:,-x_len:]\n        if self.mem_len > 0 : self._update_mems(hids)\n        return (self.hidden if self.mem_len > 0 else [core_out]),[core_out]\n\ndef init_transformer(m):\n    classname = m.__class__.__name__\n    if classname.find('Linear') != -1:\n        if hasattr(m, 'weight') and m.weight is not None: nn.init.normal_(m.weight, 0., 0.02)\n        if hasattr(m, 'bias') and m.bias is not None:     nn.init.constant_(m.bias, 0.)\n    elif classname.find('LayerNorm') != -1:\n        if hasattr(m, 'weight') and m.weight is not None: nn.init.normal_(m.weight, 1., 0.02)\n        if hasattr(m, 'bias') and m.bias is not None:     nn.init.constant_(m.bias, 0.)\n    elif classname.find('TransformerXL') != -1:\n        if hasattr(m, 'u'): nn.init.normal_(m.u, 0., 0.02)\n        if hasattr(m, 'v'): nn.init.normal_(m.v, 0., 0.02)\n\ntfmer_lm_config = dict(ctx_len=512, n_layers=12, n_heads=12, d_model=768, d_head=64, d_inner=3072, resid_p=0.1, attn_p=0.1,\n                         ff_p=0.1, embed_p=0.1, output_p=0., bias=True, scale=True, act=Activation.GeLU, double_drop=False,\n                         tie_weights=True, out_bias=False, init=init_transformer, mask=True)\n\ntfmer_clas_config = dict(ctx_len=512, n_layers=12, n_heads=12, d_model=768, d_head=64, d_inner=3072, resid_p=0.1, attn_p=0.1,\n                         ff_p=0.1, embed_p=0.1, output_p=0., bias=True, scale=True, act=Activation.GeLU, double_drop=False,\n                         init=init_transformer, mask=False)\n\ndef tfmer_lm_split(model:nn.Module) -> List[nn.Module]:\n    \"Split a RNN `model` in groups for differential learning rates.\"\n    encoder = model[0]\n    n = len(encoder.layers)//3\n    groups = [list(encoder.layers[:n]), list(encoder.layers[n:2*n]), list(encoder.layers[2*n:])]\n    return groups + [[encoder.encoder, model[1]]]\n\ndef tfmer_clas_split(model:nn.Module) -> List[nn.Module]:\n    \"Split a RNN `model` in groups for differential learning rates.\"\n    encoder = model[0].module\n    n = len(encoder.layers)//3\n    groups = [[encoder.encoder], list(encoder.layers[:n]), list(encoder.layers[n:2*n]), list(encoder.layers[2*n:])]\n    return groups + [[model[1]]]\n\ntfmerXL_lm_config = dict(ctx_len=150, n_layers=12, n_heads=10, d_model=410, d_head=41, d_inner=2100, resid_p=0.1, attn_p=0.1,\n                         ff_p=0.1, embed_p=0.1, output_p=0.1, bias=False, scale=True, act=Activation.ReLU, double_drop=True,\n                         tie_weights=True, out_bias=True, init=init_transformer, mem_len=150, mask=True)\n\ntfmerXL_clas_config = dict(ctx_len=150, n_layers=12, n_heads=10, d_model=410, d_head=41, d_inner=2100, resid_p=0.1, attn_p=0.1,\n                         ff_p=0.1, embed_p=0.1, output_p=0.1, bias=False, scale=True, act=Activation.ReLU, double_drop=True,\n                         init=init_transformer, mem_len=150, mask=False)\n\ndef tfmerXL_lm_split(model:nn.Module) -> List[nn.Module]:\n    \"Split a RNN `model` in groups for differential learning rates.\"\n    encoder = model[0]\n    n = len(encoder.layers)//3\n    groups = [list(encoder.layers[:n]) + [ParameterModule(encoder.u), ParameterModule(encoder.v)]]\n    return groups + [list(encoder.layers[n:2*n]), list(encoder.layers[2*n:]), [encoder.encoder, model[1]]]\n\ndef tfmerXL_clas_split(model:nn.Module) -> List[nn.Module]:\n    \"Split a RNN `model` in groups for differential learning rates.\"\n    encoder = model[0].module\n    n = len(encoder.layers)//3\n    groups = [[encoder.encoder], list(encoder.layers[:n]) + [ParameterModule(encoder.u), ParameterModule(encoder.v)]]\n    return groups + [list(encoder.layers[n:2*n]), list(encoder.layers[2*n:]), [model[1]]]\n"
  },
  {
    "path": "fastai/text/transform.py",
    "content": "\"NLP data processing; tokenizes text and creates vocab indexes\"\nfrom ..torch_core import *\n\nimport spacy\nfrom spacy.symbols import ORTH\n\n__all__ = ['BaseTokenizer', 'SpacyTokenizer', 'Tokenizer', 'Vocab', 'fix_html', 'replace_all_caps', 'replace_rep', 'replace_wrep',\n           'rm_useless_spaces', 'spec_add_spaces', 'BOS', 'EOS', 'FLD', 'UNK', 'PAD', 'TK_MAJ', 'TK_UP', 'TK_REP', 'TK_REP', 'TK_WREP',\n           'deal_caps']\n\nBOS,EOS,FLD,UNK,PAD = 'xxbos','xxeos','xxfld','xxunk','xxpad'\nTK_MAJ,TK_UP,TK_REP,TK_WREP = 'xxmaj','xxup','xxrep','xxwrep'\ndefaults.text_spec_tok = [UNK,PAD,BOS,EOS,FLD,TK_MAJ,TK_UP,TK_REP,TK_WREP]\n\n\nclass BaseTokenizer():\n    \"Basic class for a tokenizer function.\"\n    def __init__(self, lang:str):                      self.lang = lang\n    def tokenizer(self, t:str) -> List[str]:           return t.split(' ')\n    def add_special_cases(self, toks:Collection[str]): pass\n\nclass SpacyTokenizer(BaseTokenizer):\n    \"Wrapper around a spacy tokenizer to make it a `BaseTokenizer`.\"\n    def __init__(self, lang:str):\n        self.tok = spacy.blank(lang, disable=[\"parser\",\"tagger\",\"ner\"])\n\n    def tokenizer(self, t:str) -> List[str]:\n        return [t.text for t in self.tok.tokenizer(t)]\n\n    def add_special_cases(self, toks:Collection[str]):\n        for w in toks:\n            self.tok.tokenizer.add_special_case(w, [{ORTH: w}])\n\ndef spec_add_spaces(t:str) -> str:\n    \"Add spaces around / and # in `t`. \\n\"\n    return re.sub(r'([/#\\n])', r' \\1 ', t)\n\ndef rm_useless_spaces(t:str) -> str:\n    \"Remove multiple spaces in `t`.\"\n    return re.sub(' {2,}', ' ', t)\n\ndef replace_rep(t:str) -> str:\n    \"Replace repetitions at the character level in `t`.\"\n    def _replace_rep(m:Collection[str]) -> str:\n        c,cc = m.groups()\n        return f' {TK_REP} {len(cc)+1} {c} '\n    re_rep = re.compile(r'(\\S)(\\1{3,})')\n    return re_rep.sub(_replace_rep, t)\n\ndef replace_wrep(t:str) -> str:\n    \"Replace word repetitions in `t`.\"\n    def _replace_wrep(m:Collection[str]) -> str:\n        c,cc = m.groups()\n        return f' {TK_WREP} {len(cc.split())+1} {c} '\n    re_wrep = re.compile(r'(\\b\\w+\\W+)(\\1{3,})')\n    return re_wrep.sub(_replace_wrep, t)\n\ndef fix_html(x:str) -> str:\n    \"List of replacements from html strings in `x`.\"\n    re1 = re.compile(r'  +')\n    x = x.replace('#39;', \"'\").replace('amp;', '&').replace('#146;', \"'\").replace(\n        'nbsp;', ' ').replace('#36;', '$').replace('\\\\n', \"\\n\").replace('quot;', \"'\").replace(\n        '<br />', \"\\n\").replace('\\\\\"', '\"').replace('<unk>',UNK).replace(' @.@ ','.').replace(\n        ' @-@ ','-').replace(' @,@ ',',').replace('\\\\', ' \\\\ ')\n    return re1.sub(' ', html.unescape(x))\n\ndef replace_all_caps(x:Collection[str]) -> Collection[str]:\n    \"Replace tokens in ALL CAPS in `x` by their lower version and add `TK_UP` before.\"\n    res = []\n    for t in x:\n        if t.isupper() and len(t) > 1: res.append(TK_UP); res.append(t.lower())\n        else: res.append(t)\n    return res\n\ndef deal_caps(x:Collection[str]) -> Collection[str]:\n    \"Replace all Capitalized tokens in `x` by their lower version and add `TK_MAJ` before.\"\n    res = []\n    for t in x:\n        if t == '': continue\n        if t[0].isupper() and len(t) > 1 and t[1:].islower(): res.append(TK_MAJ)\n        res.append(t.lower())\n    return res\n\ndefaults.text_pre_rules = [fix_html, replace_rep, replace_wrep, spec_add_spaces, rm_useless_spaces]\ndefaults.text_post_rules = [replace_all_caps, deal_caps]\n\nclass Tokenizer():\n    \"Put together rules and a tokenizer function to tokenize text with multiprocessing.\"\n    def __init__(self, tok_func:Callable=SpacyTokenizer, lang:str='en', pre_rules:ListRules=None,\n                 post_rules:ListRules=None, special_cases:Collection[str]=None, n_cpus:int=None):\n        self.tok_func,self.lang,self.special_cases = tok_func,lang,special_cases\n        self.pre_rules  = ifnone(pre_rules,  defaults.text_pre_rules )\n        self.post_rules = ifnone(post_rules, defaults.text_post_rules)\n        self.special_cases = special_cases if special_cases else defaults.text_spec_tok\n        self.n_cpus = ifnone(n_cpus, defaults.cpus)\n\n    def __repr__(self) -> str:\n        res = f'Tokenizer {self.tok_func.__name__} in {self.lang} with the following rules:\\n'\n        for rule in self.pre_rules: res += f' - {rule.__name__}\\n'\n        for rule in self.post_rules: res += f' - {rule.__name__}\\n'\n        return res\n\n    def process_text(self, t:str, tok:BaseTokenizer) -> List[str]:\n        \"Process one text `t` with tokenizer `tok`.\"\n        for rule in self.pre_rules: t = rule(t)\n        toks = tok.tokenizer(t)\n        for rule in self.post_rules: toks = rule(toks)\n        return toks\n\n    def _process_all_1(self, texts:Collection[str]) -> List[List[str]]:\n        \"Process a list of `texts` in one process.\"\n        tok = self.tok_func(self.lang)\n        if self.special_cases: tok.add_special_cases(self.special_cases)\n        return [self.process_text(str(t), tok) for t in texts]\n\n    def process_all(self, texts:Collection[str]) -> List[List[str]]:\n        \"Process a list of `texts`.\"\n        if self.n_cpus <= 1: return self._process_all_1(texts)\n        with ProcessPoolExecutor(self.n_cpus) as e:\n            return sum(e.map(self._process_all_1, partition_by_cores(texts, self.n_cpus)), [])\n\nclass Vocab():\n    \"Contain the correspondence between numbers and tokens and numericalize.\"\n    def __init__(self, itos:Collection[str]):\n        self.itos = itos\n        self.stoi = collections.defaultdict(int,{v:k for k,v in enumerate(self.itos)})\n\n    def numericalize(self, t:Collection[str]) -> List[int]:\n        \"Convert a list of tokens `t` to their ids.\"\n        return [self.stoi[w] for w in t]\n\n    def textify(self, nums:Collection[int], sep=' ') -> List[str]:\n        \"Convert a list of `nums` to their tokens.\"\n        return sep.join([self.itos[i] for i in nums]) if sep is not None else [self.itos[i] for i in nums]\n\n    def __getstate__(self):\n        return {'itos':self.itos}\n\n    def __setstate__(self, state:dict):\n        self.itos = state['itos']\n        self.stoi = collections.defaultdict(int,{v:k for k,v in enumerate(self.itos)})\n\n    def save(self, path):\n        \"Save `self.itos` in `path`\"\n        pickle.dump(self.itos, open(path, 'wb'))\n\n    @classmethod\n    def create(cls, tokens:Tokens, max_vocab:int, min_freq:int) -> 'Vocab':\n        \"Create a vocabulary from a set of `tokens`.\"\n        freq = Counter(p for o in tokens for p in o)\n        itos = [o for o,c in freq.most_common(max_vocab) if c >= min_freq]\n        for o in reversed(defaults.text_spec_tok):\n            if o in itos: itos.remove(o)\n            itos.insert(0, o)\n        itos = itos[:max_vocab]\n        if len(itos) < max_vocab: #Make sure vocab size is a multiple of 8 for fast mixed precision training\n            while len(itos)%8 !=0: itos.append('xxfake')\n        return cls(itos)\n    \n    @classmethod\n    def load(cls, path):\n        \"Load the `Vocab` contained in `path`\"\n        itos = pickle.load(open(path, 'rb'))\n        return cls(itos)\n"
  },
  {
    "path": "fastai/torch_core.py",
    "content": "\"Utility functions to help deal with tensors\"\nfrom .imports.torch import *\nfrom .core import *\nfrom collections import OrderedDict\nfrom torch.nn.parallel import DistributedDataParallel\n\nAffineMatrix = Tensor\nBoolOrTensor = Union[bool,Tensor]\nFloatOrTensor = Union[float,Tensor]\nIntOrTensor = Union[int,Tensor]\nItemsList = Collection[Union[Tensor,ItemBase,'ItemsList',float,int]]\nLambdaFunc = Callable[[Tensor],Tensor]\nLayerFunc = Callable[[nn.Module],None]\nModuleList = Collection[nn.Module]\nNPArray = np.ndarray\nOptOptimizer = Optional[optim.Optimizer]\nParamList = Collection[nn.Parameter]\nRank0Tensor = NewType('OneEltTensor', Tensor)\nSplitFunc = Callable[[nn.Module], List[nn.Module]]\nSplitFuncOrIdxList = Union[Callable, Collection[ModuleList]]\nTensorOrNumber = Union[Tensor,Number]\nTensorOrNumList = Collection[TensorOrNumber]\nTensorImage = Tensor\nTensorImageSize = Tuple[int,int,int]\nTensors = Union[Tensor, Collection['Tensors']]\nWeights = Dict[str,Tensor]\n\nAffineFunc = Callable[[KWArgs], AffineMatrix]\nHookFunc = Callable[[nn.Module, Tensors, Tensors], Any]\nLogitTensorImage = TensorImage\nLossFunction = Callable[[Tensor, Tensor], Rank0Tensor]\nMetricFunc = Callable[[Tensor,Tensor],TensorOrNumber]\nMetricFuncList = Collection[MetricFunc]\nMetricsList = Collection[TensorOrNumber]\nOptLossFunc = Optional[LossFunction]\nOptMetrics = Optional[MetricsList]\nOptSplitFunc = Optional[SplitFunc]\nPixelFunc = Callable[[TensorImage, ArgStar, KWArgs], TensorImage]\n\nLightingFunc = Callable[[LogitTensorImage, ArgStar, KWArgs], LogitTensorImage]\n\nfastai_types = {\n    AnnealFunc:'AnnealFunc', ArgStar:'ArgStar', BatchSamples:'BatchSamples',\n    FilePathList:'FilePathList', Floats:'Floats', ImgLabel:'ImgLabel', ImgLabels:'ImgLabels', KeyFunc:'KeyFunc',\n    KWArgs:'KWArgs', ListOrItem:'ListOrItem', ListRules:'ListRules', ListSizes:'ListSizes',\n    NPArrayableList:'NPArrayableList', NPArrayList:'NPArrayList', NPArrayMask:'NPArrayMask', NPImage:'NPImage',\n    OptDataFrame:'OptDataFrame', OptListOrItem:'OptListOrItem', OptRange:'OptRange', OptStrTuple:'OptStrTuple',\n    OptStats:'OptStats', PathOrStr:'PathOrStr', PBar:'PBar', Point:'Point', Points:'Points', Sizes:'Sizes',\n    SplitArrayList:'SplitArrayList', StartOptEnd:'StartOptEnd', StrList:'StrList', Tokens:'Tokens',\n    OptStrList:'OptStrList', AffineMatrix:'AffineMatrix', BoolOrTensor:'BoolOrTensor', FloatOrTensor:'FloatOrTensor',\n    IntOrTensor:'IntOrTensor', ItemsList:'ItemsList', LambdaFunc:'LambdaFunc',\n    LayerFunc:'LayerFunc', ModuleList:'ModuleList', OptOptimizer:'OptOptimizer', ParamList:'ParamList',\n    Rank0Tensor:'Rank0Tensor', SplitFunc:'SplitFunc', SplitFuncOrIdxList:'SplitFuncOrIdxList',\n    TensorOrNumber:'TensorOrNumber', TensorOrNumList:'TensorOrNumList', TensorImage:'TensorImage',\n    TensorImageSize:'TensorImageSize', Tensors:'Tensors', Weights:'Weights', AffineFunc:'AffineFunc',\n    HookFunc:'HookFunc', LogitTensorImage:'LogitTensorImage', LossFunction:'LossFunction', MetricFunc:'MetricFunc',\n    MetricFuncList:'MetricFuncList', MetricsList:'MetricsList', OptLossFunc:'OptLossFunc', OptMetrics:'OptMetrics',\n    OptSplitFunc:'OptSplitFunc', PixelFunc:'PixelFunc', LightingFunc:'LightingFunc', IntsOrStrs:'IntsOrStrs',\n    PathLikeOrBinaryStream:'PathLikeOrBinaryStream'\n}\n\nbn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)\nbias_types = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)\ndef is_pool_type(l:Callable): return re.search(r'Pool[123]d$', l.__class__.__name__)\nno_wd_types = bn_types + (nn.LayerNorm,)\ndefaults.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\nAdamW = partial(optim.Adam, betas=(0.9,0.99))\n\n#Monkey-patch `torch.cuda.set_device` so that it updates `defaults.device`\n_old_torch_cuda_set_device = torch.cuda.set_device\ndef _new_torch_cuda_set_device(device):\n    _old_torch_cuda_set_device(device)\n    defaults.device = torch.device('cuda', device) if isinstance(device, int) else device\ntorch.cuda.set_device = _new_torch_cuda_set_device\n\ndef tensor(x:Any, *rest)->Tensor:\n    \"Like `torch.as_tensor`, but handle lists too, and can pass multiple vector elements directly.\"\n    if len(rest): x = (x,)+rest\n    # XXX: Pytorch bug in dataloader using num_workers>0; TODO: create repro and report\n    if is_listy(x) and len(x)==0: return tensor(0)\n    res = torch.tensor(x) if is_listy(x) else as_tensor(x)\n    if res.dtype is torch.int32:\n        warn('Tensor is int32: upgrading to int64; for better performance use int64 input')\n        return res.long()\n    return res\n\nclass Module(nn.Module, metaclass=PrePostInitMeta):\n    \"Same as `nn.Module`, but no need for subclasses to call `super().__init__`\"\n    def __pre_init__(self): super().__init__()\n    def __init__(self): pass\n\ndef np_address(x:np.ndarray)->int:\n    \"Address of `x` in memory.\"\n    return x.__array_interface__['data'][0]\n\ndef to_detach(b:Tensors, cpu:bool=True):\n    \"Recursively detach lists of tensors in `b `; put them on the CPU if `cpu=True`.\"\n    def _inner(x, cpu=True):\n        if not isinstance(x,Tensor): return x\n        x = x.detach()\n        return x.cpu() if cpu else x\n    return recurse(_inner, b, cpu=cpu)\n\ndef to_data(b:ItemsList):\n    \"Recursively map lists of items in `b ` to their wrapped data.\"\n    return recurse(lambda x: x.data if isinstance(x,ItemBase) else x, b)\n\ndef to_cpu(b:ItemsList):\n    \"Recursively map lists of tensors in `b ` to the cpu.\"\n    return recurse(lambda x: x.cpu() if isinstance(x,Tensor) else x, b)\n\ndef to_half(b:Collection[Tensor])->Collection[Tensor]:\n    \"Recursively map lists of tensors in `b ` to FP16.\"\n    return recurse(lambda x: x.half() if x.dtype not in [torch.int64, torch.int32, torch.int16] else x, b)\n\ndef to_float(b:Collection[Tensor])->Collection[Tensor]:\n    \"Recursively map lists of tensors in `b ` to FP16.\"\n    return recurse(lambda x: x.float() if x.dtype not in [torch.int64, torch.int32, torch.int16] else x, b)\n\ndef to_device(b:Tensors, device:torch.device):\n    \"Recursively put `b` on `device`.\"\n    device = ifnone(device, defaults.device)\n    return recurse(lambda x: x.to(device, non_blocking=True), b)\n\ndef data_collate(batch:ItemsList)->Tensor:\n    \"Convert `batch` items to tensor data.\"\n    return torch.utils.data.dataloader.default_collate(to_data(batch))\n\ndef requires_grad(m:nn.Module, b:Optional[bool]=None)->Optional[bool]:\n    \"If `b` is not set return `requires_grad` of first param, else set `requires_grad` on all params as `b`\"\n    ps = list(m.parameters())\n    if not ps: return None\n    if b is None: return ps[0].requires_grad\n    for p in ps: p.requires_grad=b\n\ndef trainable_params(m:nn.Module)->ParamList:\n    \"Return list of trainable params in `m`.\"\n    res = filter(lambda p: p.requires_grad, m.parameters())\n    return res\n\ndef children(m:nn.Module)->ModuleList:\n    \"Get children of `m`.\"\n    return list(m.children())\n\ndef num_children(m:nn.Module)->int:\n    \"Get number of children modules in `m`.\"\n    return len(children(m))\n\ndef range_children(m:nn.Module)->Iterator[int]:\n    \"Return iterator of len of children of `m`.\"\n    return range(num_children(m))\n\nclass ParameterModule(Module):\n    \"Register a lone parameter `p` in a module.\"\n    def __init__(self, p:nn.Parameter): self.val = p\n    def forward(self, x): return x\n\ndef children_and_parameters(m:nn.Module):\n    \"Return the children of `m` and its direct parameters not registered in modules.\"\n    children = list(m.children())\n    children_p = sum([[id(p) for p in c.parameters()] for c in m.children()],[])\n    for p in m.parameters():\n        if id(p) not in children_p: children.append(ParameterModule(p))\n    return children\n\ndef flatten_model(m:nn.Module):\n    if num_children(m):\n        mapped = map(flatten_model,children_and_parameters(m))\n        return sum(mapped,[])\n    else:\n        return [m]\n\n#flatten_model = lambda m: sum(map(flatten_model,children_and_parameters(m)),[]) if num_children(m) else [m]\n\ndef first_layer(m:nn.Module)->nn.Module:\n    \"Retrieve first layer in a module `m`.\"\n    return flatten_model(m)[0]\n\ndef last_layer(m:nn.Module)->nn.Module:\n    \"Retrieve last layer in a module `m`.\"\n    return flatten_model(m)[-1]\n\ndef split_model_idx(model:nn.Module, idxs:Collection[int])->ModuleList:\n    \"Split `model` according to the indexes in `idxs`.\"\n    layers = flatten_model(model)\n    if idxs[0] != 0: idxs = [0] + idxs\n    if idxs[-1] != len(layers): idxs.append(len(layers))\n    return [nn.Sequential(*layers[i:j]) for i,j in zip(idxs[:-1],idxs[1:])]\n\ndef split_model(model:nn.Module=None, splits:Collection[Union[nn.Module,ModuleList]]=None):\n    \"Split `model` according to the layers in `splits`.\"\n    splits = listify(splits)\n    if isinstance(splits[0], nn.Module):\n        layers = flatten_model(model)\n        idxs = [layers.index(first_layer(s)) for s in splits]\n        return split_model_idx(model, idxs)\n    return [nn.Sequential(*s) for s in splits]\n\ndef get_param_groups(layer_groups:Collection[nn.Module])->List[List[nn.Parameter]]:\n    return [sum([list(trainable_params(c)) for c in l.children()], []) for l in layer_groups]\n\ndef split_no_wd_params(layer_groups:Collection[nn.Module])->List[List[nn.Parameter]]:\n    \"Separate the parameters in `layer_groups` between `no_wd_types` and  bias (`bias_types`) from the rest.\"\n    split_params = []\n    for l in layer_groups:\n        l1,l2 = [],[]\n        for c in l.children():\n            if isinstance(c, no_wd_types): l2 += list(trainable_params(c))\n            elif isinstance(c, bias_types):\n                bias = c.bias if hasattr(c, 'bias') else None\n                l1 += [p for p in trainable_params(c) if not (p is bias)]\n                if bias is not None: l2.append(bias)\n            else: l1 += list(trainable_params(c))\n        #Since we scan the children separately, we might get duplicates (tied weights). We need to preserve the order\n        #for the optimizer load of state_dict\n        l1,l2 = uniqueify(l1),uniqueify(l2)\n        split_params += [l1, l2]\n    return split_params\n\ndef set_bn_eval(m:nn.Module)->None:\n    \"Set bn layers in eval mode for all recursive children of `m`.\"\n    for l in m.children():\n        if isinstance(l, bn_types) and not next(l.parameters()).requires_grad:\n            l.eval()\n        set_bn_eval(l)\n\ndef batch_to_half(b:Collection[Tensor])->Collection[Tensor]:\n    \"Set the input of batch `b` to half precision.\"\n    return [to_half(b[0]), b[1]]\n\ndef bn2float(module:nn.Module)->nn.Module:\n    \"If `module` is batchnorm don't use half precision.\"\n    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): module.float()\n    for child in module.children(): bn2float(child)\n    return module\n\ndef model2half(model:nn.Module)->nn.Module:\n    \"Convert `model` to half precision except the batchnorm layers.\"\n    return bn2float(model.half())\n\ndef init_default(m:nn.Module, func:LayerFunc=nn.init.kaiming_normal_)->nn.Module:\n    \"Initialize `m` weights with `func` and set `bias` to 0.\"\n    if func:\n        if hasattr(m, 'weight'): func(m.weight)\n        if hasattr(m, 'bias') and hasattr(m.bias, 'data'): m.bias.data.fill_(0.)\n    return m\n\ndef cond_init(m:nn.Module, init_func:LayerFunc):\n    \"Initialize the non-batchnorm layers of `m` with `init_func`.\"\n    if (not isinstance(m, bn_types)) and requires_grad(m): init_default(m, init_func)\n\ndef apply_leaf(m:nn.Module, f:LayerFunc):\n    \"Apply `f` to children of `m`.\"\n    c = children(m)\n    if isinstance(m, nn.Module): f(m)\n    for l in c: apply_leaf(l,f)\n\ndef apply_init(m, init_func:LayerFunc):\n    \"Initialize all non-batchnorm layers of `m` with `init_func`.\"\n    apply_leaf(m, partial(cond_init, init_func=init_func))\n\ndef in_channels(m:nn.Module) -> List[int]:\n    \"Return the shape of the first weight layer in `m`.\"\n    for l in flatten_model(m):\n        if hasattr(l, 'weight'): return l.weight.shape[1]\n    raise Exception('No weight layer')\n\nclass ModelOnCPU():\n    \"A context manager to evaluate `model` on the CPU inside.\"\n    def __init__(self, model:nn.Module): self.model = model       \n    def __enter__(self):\n        self.device = one_param(self.model).device\n        return self.model.cpu()\n    def __exit__(self, type, value, traceback):\n        self.model = self.model.to(self.device)\n    \nclass NoneReduceOnCPU():\n    \"A context manager to evaluate `loss_func` with none reduce and weights on the CPU inside.\"\n    def __init__(self, loss_func:LossFunction): \n        self.loss_func,self.device,self.old_red = loss_func,None,None\n        \n    def __enter__(self):\n        if hasattr(self.loss_func, 'weight') and self.loss_func.weight is not None:\n            self.device = self.loss_func.weight.device\n            self.loss_func.weight = self.loss_func.weight.cpu()\n        if hasattr(self.loss_func, 'reduction'):\n            self.old_red = getattr(self.loss_func, 'reduction')\n            setattr(self.loss_func, 'reduction', 'none')\n            return self.loss_func\n        else: return partial(self.loss_func, reduction='none')\n        \n    def __exit__(self, type, value, traceback):\n        if self.device is not None:  self.loss_func.weight = self.loss_func.weight.to(self.device)\n        if self.old_red is not None: setattr(self.loss_func, 'reduction', self.old_red)    \n    \ndef model_type(dtype):\n    \"Return the torch type corresponding to `dtype`.\"\n    return (torch.float32 if np.issubdtype(dtype, np.floating) else\n            torch.int64 if np.issubdtype(dtype, np.integer)\n            else None)\n\ndef np2model_tensor(a):\n    \"Tranform numpy array `a` to a tensor of the same type.\"\n    dtype = model_type(a.dtype)\n    res = as_tensor(a)\n    if not dtype: return res\n    return res.type(dtype)\n\ndef _pca(x, k=2):\n    \"Compute PCA of `x` with `k` dimensions.\"\n    x = x-torch.mean(x,0)\n    U,S,V = torch.svd(x.t())\n    return torch.mm(x,U[:,:k])\ntorch.Tensor.pca = _pca\n\ndef trange_of(x): \n    \"Create a tensor from `range_of(x)`.\"\n    return torch.arange(len(x))\n\ndef to_np(x): \n    \"Convert a tensor to a numpy array.\"\n    return x.data.cpu().numpy()\n\n# monkey patching to allow matplotlib to plot tensors\ndef tensor__array__(self, dtype=None):\n    res = to_np(self)\n    if dtype is None: return res\n    else: return res.astype(dtype, copy=False)\nTensor.__array__ = tensor__array__\nTensor.ndim = property(lambda x: len(x.shape))\n\ndef grab_idx(x,i,batch_first:bool=True):\n    \"Grab the `i`-th batch in `x`, `batch_first` stating the batch dimension.\"\n    if batch_first: return ([o[i].cpu() for o in x]   if is_listy(x) else x[i].cpu())\n    else:           return ([o[:,i].cpu() for o in x] if is_listy(x) else x[:,i].cpu())\n\ndef logit(x:Tensor)->Tensor:\n    \"Logit of `x`, clamped to avoid inf.\"\n    x = x.clamp(1e-7, 1-1e-7)\n    return -(1/x-1).log()\n\ndef logit_(x:Tensor)->Tensor:\n    \"Inplace logit of `x`, clamped to avoid inf\"\n    x.clamp_(1e-7, 1-1e-7)\n    return (x.reciprocal_().sub_(1)).log_().neg_()\n\ndef set_all_seed(seed:int)->None:\n    \"Sets the seeds for all pseudo random generators in fastai lib\"\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    random.seed(seed)\n\ndef uniform(low:Number, high:Number=None, size:Optional[List[int]]=None)->FloatOrTensor:\n    \"Draw 1 or shape=`size` random floats from uniform dist: min=`low`, max=`high`.\"\n    if high is None: high=low\n    return random.uniform(low,high) if size is None else torch.FloatTensor(*listify(size)).uniform_(low,high)\n\ndef log_uniform(low, high, size:Optional[List[int]]=None)->FloatOrTensor:\n    \"Draw 1 or shape=`size` random floats from uniform dist: min=log(`low`), max=log(`high`).\"\n    res = uniform(log(low), log(high), size)\n    return exp(res) if size is None else res.exp_()\n\ndef rand_bool(p:float, size:Optional[List[int]]=None)->BoolOrTensor:\n    \"Draw 1 or shape=`size` random booleans (`True` occuring with probability `p`).\"\n    return uniform(0,1,size)<p\n\ndef uniform_int(low:int, high:int, size:Optional[List[int]]=None)->IntOrTensor:\n    \"Generate int or tensor `size` of ints between `low` and `high` (included).\"\n    return random.randint(low,high) if size is None else torch.randint(low,high+1,size)\n\ndef one_param(m: nn.Module)->Tensor: \n    \"Return the first parameter of `m`.\"\n    return next(m.parameters())\n\ndef try_int(o:Any)->Any:\n    \"Try to convert `o` to int, default to `o` if not possible.\"\n    # NB: single-item rank-1 array/tensor can be converted to int, but we don't want to do this\n    if isinstance(o, (np.ndarray,Tensor)): return o if o.ndim else int(o)\n    if isinstance(o, collections.abc.Sized) or getattr(o,'__array_interface__',False): return o\n    try: return int(o)\n    except: return o\n\ndef get_model(model:nn.Module):\n    \"Return the model maybe wrapped inside `model`.\"\n    return model.module if isinstance(model, (DistributedDataParallel, nn.DataParallel)) else model\n\ndef flatten_check(out:Tensor, targ:Tensor) -> Tensor:\n    \"Check that `out` and `targ` have the same number of elements and flatten them.\"\n    out,targ = out.contiguous().view(-1),targ.contiguous().view(-1)\n    assert len(out) == len(targ), f\"Expected output and target to have the same number of elements but got {len(out)} and {len(targ)}.\"\n    return out,targ\n\n#Monkey-patch nn.DataParallel.reset\ndef _data_parallel_reset(self): \n    if hasattr(self.module, 'reset'): self.module.reset()\nnn.DataParallel.reset = _data_parallel_reset\n\ndef remove_module_load(state_dict):\n    \"\"\"create new OrderedDict that does not contain `module.`\"\"\"\n    new_state_dict = OrderedDict()\n    for k, v in state_dict.items(): new_state_dict[k[7:]] = v\n    return new_state_dict\n\ndef num_distrib():\n    \"Return the number of processes in distributed training (if applicable).\"\n    return int(os.environ.get('WORLD_SIZE', 0))\n\ndef rank_distrib():\n    \"Return the distributed rank of this process (if applicable).\"\n    return int(os.environ.get('RANK', 0))\n\ndef add_metrics(last_metrics:Collection[Rank0Tensor], mets:Union[Rank0Tensor, Collection[Rank0Tensor]]):\n    \"Return a dictionary for updating `last_metrics` with `mets`.\"\n    last_metrics,mets = listify(last_metrics),listify(mets)\n    return {'last_metrics': last_metrics + mets}\n\ndef try_save(state:Dict, path:Path=None, file:PathLikeOrBinaryStream=None):\n    target = open(path/file, 'wb') if is_pathlike(file) else file\n    try: torch.save(state, target)\n    except OSError as e:\n        raise Exception(f\"{e}\\n Can't write {path/file}. Pass an absolute writable pathlib obj `fname`.\")\n\ndef np_func(f):\n    \"Convert a function taking and returning numpy arrays to one taking and returning tensors\"\n    def _inner(*args, **kwargs):\n        nargs = [to_np(arg) if isinstance(arg,Tensor) else arg for arg in args]\n        return tensor(f(*nargs, **kwargs))\n    functools.update_wrapper(_inner, f)\n    return _inner\n\n"
  },
  {
    "path": "fastai/train.py",
    "content": "\"Provides advanced training extensions to `fastai.basic_train`. Includes half-precision, learning rate finder, mixup, and one-cycle\"\nfrom .torch_core import *\nfrom .callbacks import *\nfrom .basic_data import *\nfrom .basic_train import *\n\n__all__ = ['BnFreeze', 'GradientClipping', 'ShowGraph', 'Interpretation', 'ClassificationInterpretation', 'MultiLabelClassificationInterpretation',\n 'fit_one_cycle', 'lr_find', 'one_cycle_scheduler', 'to_fp16', 'to_fp32', 'mixup', 'AccumulateScheduler']\n\ndef one_cycle_scheduler(lr_max:float, **kwargs:Any)->OneCycleScheduler:\n    \"Instantiate a `OneCycleScheduler` with `lr_max`.\"\n    return partial(OneCycleScheduler, lr_max=lr_max, **kwargs)\n\ndef fit_one_cycle(learn:Learner, cyc_len:int, max_lr:Union[Floats,slice]=defaults.lr,\n                  moms:Tuple[float,float]=(0.95,0.85), div_factor:float=25., pct_start:float=0.3, final_div:float=None,\n                  wd:float=None, callbacks:Optional[CallbackList]=None, tot_epochs:int=None, start_epoch:int=None,\n                  batch_multiplier:int=1)->None:\n    \"Fit a model following the 1cycle policy.\"\n    max_lr = learn.lr_range(max_lr)\n    callbacks = listify(callbacks)\n    callbacks.append(OneCycleScheduler(learn, max_lr, moms=moms, div_factor=div_factor, pct_start=pct_start,\n                                       final_div=final_div, tot_epochs=tot_epochs, start_epoch=start_epoch))\n    learn.fit(cyc_len, max_lr, wd=wd, callbacks=callbacks, batch_multiplier=batch_multiplier)\n\ndef lr_find(learn:Learner, start_lr:Floats=1e-7, end_lr:Floats=10, num_it:int=100, stop_div:bool=True, wd:float=None,\n        batch_multiplier:int=1):\n    \"Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss diverges.\"\n    start_lr = learn.lr_range(start_lr)\n    start_lr = np.array(start_lr) if is_listy(start_lr) else start_lr\n    end_lr = learn.lr_range(end_lr)\n    end_lr = np.array(end_lr) if is_listy(end_lr) else end_lr\n    cb = LRFinder(learn, start_lr, end_lr, num_it, stop_div)\n    epochs = int(np.ceil(num_it/len(learn.data.train_dl)))\n    learn.fit(epochs, start_lr, callbacks=[cb], wd=wd, batch_multiplier=batch_multiplier)\n\ndef to_fp16(learn:Learner, loss_scale:float=None, max_noskip:int=1000, dynamic:bool=True, clip:float=None,\n            flat_master:bool=False, max_scale:float=2**24)->Learner:\n    \"Put `learn` in FP16 precision mode.\"\n    learn.to_fp32()\n    learn.model = model2half(learn.model)\n    learn.data.add_tfm(batch_to_half)\n    learn.mp_cb = MixedPrecision(learn, loss_scale=loss_scale, max_noskip=max_noskip, dynamic=dynamic, clip=clip,\n                                 flat_master=flat_master, max_scale=max_scale)\n    learn.callbacks.append(learn.mp_cb)\n    return learn\n\ndef to_fp32(learn:Learner):\n    \"Put `learn` back to FP32 precision mode.\"\n    learn.data.remove_tfm(batch_to_half)\n    for cb in learn.callbacks:\n        if isinstance(cb, MixedPrecision): learn.callbacks.remove(cb)\n    learn.model = learn.model.float()\n    return learn\n\ndef mixup(learn:Learner, alpha:float=0.4, stack_x:bool=False, stack_y:bool=True) -> Learner:\n    \"Add mixup https://arxiv.org/abs/1710.09412 to `learn`.\"\n    learn.callback_fns.append(partial(MixUpCallback, alpha=alpha, stack_x=stack_x, stack_y=stack_y))\n    return learn\n\nLearner.fit_one_cycle = fit_one_cycle\nLearner.lr_find = lr_find\nLearner.to_fp16 = to_fp16\nLearner.to_fp32 = to_fp32\nLearner.mixup = mixup\n\nclass ShowGraph(LearnerCallback):\n    \"Update a graph of learner stats and metrics after each epoch.\"\n    def on_epoch_end(self, n_epochs:int, last_metrics:MetricsList, **kwargs)->bool:\n        \"If we have `last_metrics` plot them in our pbar graph\"\n        if last_metrics is not None and last_metrics[0] is not None:\n            rec = self.learn.recorder\n            iters = range_of(rec.losses)\n            val_iter = np.array(rec.nb_batches).cumsum()\n            x_bounds = (0, (n_epochs - len(rec.nb_batches)) * rec.nb_batches[-1] + len(rec.losses))\n            y_bounds = (0, max((max(Tensor(rec.losses)), max(Tensor(rec.val_losses)))))\n            rec.pbar.update_graph([(iters, rec.losses), (val_iter, rec.val_losses)], x_bounds, y_bounds)\n        return {}\n\nclass BnFreeze(LearnerCallback):\n    \"Freeze moving average statistics in all non-trainable batchnorm layers.\"\n    def on_epoch_begin(self, **kwargs:Any)->None:\n        \"Put bn layers in eval mode just after `model.train()`.\"\n        set_bn_eval(self.learn.model)\n\nclass GradientClipping(LearnerCallback):\n    \"Gradient clipping during training.\"\n    def __init__(self, learn:Learner, clip:float = 0.):\n        super().__init__(learn)\n        self.clip = clip\n\n    def on_backward_end(self, **kwargs):\n        \"Clip the gradient before the optimizer step.\"\n        if self.clip: nn.utils.clip_grad_norm_(self.learn.model.parameters(), self.clip)\n\ndef clip_grad(learn:Learner, clip:float=0.1)->Learner:\n    \"Add gradient clipping of `clip` during training.\"\n    learn.callback_fns.append(partial(GradientClipping, clip=clip))\n    return learn\nLearner.clip_grad = clip_grad\n\nclass AccumulateScheduler(LearnerCallback):\n    \"Does accumlated step every nth step by accumulating gradients\"\n\n    def __init__(self, learn:Learner, n_step:int = 1, drop_last:bool = False):\n        super().__init__(learn)\n        self.n_step,self.drop_last = n_step,drop_last\n\n    def on_train_begin(self, **kwargs):\n        \"check if loss is reduction\"\n        if hasattr(self.loss_func, \"reduction\") and (self.loss_func.reduction != \"sum\"):\n             warn(\"For better gradients consider 'reduction=sum'\")\n\n    def on_epoch_begin(self, **kwargs):\n        \"init samples and batches, change optimizer\"\n        self.acc_samples, self.acc_batches = 0., 0.\n\n    def on_batch_begin(self, last_input, last_target, **kwargs):\n        \"accumulate samples and batches\"\n        self.acc_samples += last_input.shape[0]\n        self.acc_batches += 1\n\n    def on_backward_end(self, **kwargs):\n        \"accumulated step and reset samples, True will result in no stepping\"\n        if (self.acc_batches % self.n_step) == 0:\n            for p in (self.learn.model.parameters()):\n                if p.requires_grad: p.grad.div_(self.acc_samples)\n            self.acc_samples = 0\n        else: return {'skip_step':True, 'skip_zero':True}\n\n    def on_epoch_end(self, **kwargs):\n        \"step the rest of the accumulated grads if not perfectly divisible\"\n        for p in (self.learn.model.parameters()):\n                if p.requires_grad: p.grad.div_(self.acc_samples)\n        if not self.drop_last: self.learn.opt.step()\n        self.learn.opt.zero_grad()\n\n\nclass Interpretation():\n    \"Interpretation base class, can be inherited for task specific Interpretation classes\"\n    def __init__(self, learn:Learner, preds:Tensor, y_true:Tensor, losses:Tensor, ds_type:DatasetType=DatasetType.Valid):\n        self.data,self.preds,self.y_true,self.losses,self.ds_type, self.learn = \\\n                                 learn.data,preds,y_true,losses,ds_type,learn\n        self.ds = (self.data.train_ds if ds_type == DatasetType.Train else\n                   self.data.test_ds if ds_type == DatasetType.Test else\n                   self.data.valid_ds if ds_type == DatasetType.Valid else\n                   self.data.single_ds if ds_type == DatasetType.Single else\n                   self.data.fix_ds)\n\n    @classmethod\n    def from_learner(cls, learn: Learner,  ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None):\n        \"Gets preds, y_true, losses to construct base class from a learner\"\n        preds_res = learn.get_preds(ds_type=ds_type, activ=activ, with_loss=True)\n        return cls(learn, *preds_res)\n\n    def top_losses(self, k:int=None, largest=True):\n        \"`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`).\"\n        return self.losses.topk(ifnone(k, len(self.losses)), largest=largest)\n\n    # def top_scores(self, metric:Callable=None, k:int=None, largest=True):\n    #     \"`k` largest(/smallest) metric scores and indexes, defaulting to all scores (sorted by `largest`).\"\n    #     self.scores = metric(self.preds, self.y_true)\n    #     return self.scores.topk(ifnone(k, len(self.scores)), largest=largest)\n\n\nclass ClassificationInterpretation(Interpretation):\n    \"Interpretation methods for classification models.\"\n    def __init__(self, learn:Learner, preds:Tensor, y_true:Tensor, losses:Tensor, ds_type:DatasetType=DatasetType.Valid):\n        super(ClassificationInterpretation, self).__init__(learn,preds,y_true,losses,ds_type)\n        self.pred_class = self.preds.argmax(dim=1)\n\n    def confusion_matrix(self, slice_size:int=1):\n        \"Confusion matrix as an `np.ndarray`.\"\n        x=torch.arange(0,self.data.c)\n        if slice_size is None: cm = ((self.pred_class==x[:,None]) & (self.y_true==x[:,None,None])).sum(2)\n        else:\n            cm = torch.zeros(self.data.c, self.data.c, dtype=x.dtype)\n            for i in range(0, self.y_true.shape[0], slice_size):\n                cm_slice = ((self.pred_class[i:i+slice_size]==x[:,None])\n                            & (self.y_true[i:i+slice_size]==x[:,None,None])).sum(2)\n                torch.add(cm, cm_slice, out=cm)\n        return to_np(cm)\n\n    def plot_confusion_matrix(self, normalize:bool=False, title:str='Confusion matrix', cmap:Any=\"Blues\", slice_size:int=1,\n                              norm_dec:int=2, plot_txt:bool=True, return_fig:bool=None, **kwargs)->Optional[plt.Figure]:\n        \"Plot the confusion matrix, with `title` and using `cmap`.\"\n        # This function is mainly copied from the sklearn docs\n        cm = self.confusion_matrix(slice_size=slice_size)\n        if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n        fig = plt.figure(**kwargs)\n        plt.imshow(cm, interpolation='nearest', cmap=cmap)\n        plt.title(title)\n        tick_marks = np.arange(self.data.c)\n        plt.xticks(tick_marks, self.data.y.classes, rotation=90)\n        plt.yticks(tick_marks, self.data.y.classes, rotation=0)\n\n        if plot_txt:\n            thresh = cm.max() / 2.\n            for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n                coeff = f'{cm[i, j]:.{norm_dec}f}' if normalize else f'{cm[i, j]}'\n                plt.text(j, i, coeff, horizontalalignment=\"center\", verticalalignment=\"center\", color=\"white\" if cm[i, j] > thresh else \"black\")\n\n        plt.tight_layout()\n        plt.ylabel('Actual')\n        plt.xlabel('Predicted')\n        plt.grid(False)\n        if ifnone(return_fig, defaults.return_fig): return fig\n\n    def most_confused(self, min_val:int=1, slice_size:int=1)->Collection[Tuple[str,str,int]]:\n        \"Sorted descending list of largest non-diagonal entries of confusion matrix, presented as actual, predicted, number of occurrences.\"\n        cm = self.confusion_matrix(slice_size=slice_size)\n        np.fill_diagonal(cm, 0)\n        res = [(self.data.classes[i],self.data.classes[j],cm[i,j])\n                for i,j in zip(*np.where(cm>=min_val))]\n        return sorted(res, key=itemgetter(2), reverse=True)\n\n\ndef _learner_interpret(learn:Learner, ds_type:DatasetType=DatasetType.Valid):\n    \"Create a `ClassificationInterpretation` object from `learner` on `ds_type` with `tta`.\"\n    return ClassificationInterpretation.from_learner(learn, ds_type=ds_type)\nLearner.interpret = _learner_interpret\n\nclass MultiLabelClassificationInterpretation(Interpretation):\n    \"Interpretation methods for classification models.\"\n    def __init__(self, learn:Learner, preds:Tensor, y_true:Tensor, losses:Tensor, ds_type:DatasetType=DatasetType.Valid,\n                     sigmoid:bool=True, thresh:float=0.3):\n        raise NotImplementedError\n        super(MultiLabelClassificationInterpretation, self).__init__(learn,preds,y_true,losses,ds_type)\n        self.pred_class = self.preds.sigmoid(dim=1)>thresh if sigmoid else self.preds>thresh\n"
  },
  {
    "path": "fastai/utils/__init__.py",
    "content": "from .collect_env import *\n\n__all__ = [*collect_env.__all__]\n"
  },
  {
    "path": "fastai/utils/check_perf.py",
    "content": "from ..script import *\nfrom .collect_env import *\n\n# Temporary POC for module-based script\ncall_parse(check_perf)\n\n"
  },
  {
    "path": "fastai/utils/collect_env.py",
    "content": "\"Utility functions to help deal with user environment\"\n\nfrom ..imports.torch import *\nfrom ..core import *\nfrom ..script import *\nfrom .pynvml_gate import *\nimport fastprogress, subprocess, platform\n\n__all__ = ['show_install', 'check_perf']\n\ndef get_env(name):\n    \"Return env var value if it's defined and not an empty string, or return Unknown\"\n    res = os.environ.get(name,'')\n    return res if len(res) else \"Unknown\"\n\ndef show_install(show_nvidia_smi:bool=False):\n    \"Print user's setup information\"\n\n    import platform, fastai.version\n\n    rep = []\n    opt_mods = []\n\n    rep.append([\"=== Software ===\", None])\n    rep.append([\"python\", platform.python_version()])\n    rep.append([\"fastai\", fastai.__version__])\n    rep.append([\"fastprogress\", fastprogress.__version__])\n    rep.append([\"torch\",  torch.__version__])\n\n    # nvidia-smi\n    cmd = \"nvidia-smi\"\n    have_nvidia_smi = False\n    try: result = subprocess.run(cmd.split(), shell=False, check=False, stdout=subprocess.PIPE)\n    except: pass\n    else:\n        if result.returncode == 0 and result.stdout: have_nvidia_smi = True\n\n    # XXX: if nvidia-smi is not available, another check could be:\n    # /proc/driver/nvidia/version on most systems, since it's the\n    # currently active version\n\n    if have_nvidia_smi:\n        smi = result.stdout.decode('utf-8')\n        # matching: \"Driver Version: 396.44\"\n        match = re.findall(r'Driver Version: +(\\d+\\.\\d+)', smi)\n        if match: rep.append([\"nvidia driver\", match[0]])\n\n    available = \"available\" if torch.cuda.is_available() else \"**Not available** \"\n    rep.append([\"torch cuda\", f\"{torch.version.cuda} / is {available}\"])\n\n    # no point reporting on cudnn if cuda is not available, as it\n    # seems to be enabled at times even on cpu-only setups\n    if torch.cuda.is_available():\n        enabled = \"enabled\" if torch.backends.cudnn.enabled else \"**Not enabled** \"\n        rep.append([\"torch cudnn\", f\"{torch.backends.cudnn.version()} / is {enabled}\"])\n\n    rep.append([\"\\n=== Hardware ===\", None])\n\n    # it's possible that torch might not see what nvidia-smi sees?\n    gpu_total_mem = []\n    nvidia_gpu_cnt = 0\n    if have_nvidia_smi:\n        try:\n            cmd = \"nvidia-smi --query-gpu=memory.total --format=csv,nounits,noheader\"\n            result = subprocess.run(cmd.split(), shell=False, check=False, stdout=subprocess.PIPE)\n        except:\n            print(\"have nvidia-smi, but failed to query it\")\n        else:\n            if result.returncode == 0 and result.stdout:\n                output = result.stdout.decode('utf-8')\n                gpu_total_mem = [int(x) for x in output.strip().split('\\n')]\n                nvidia_gpu_cnt = len(gpu_total_mem)\n\n\n    if nvidia_gpu_cnt: rep.append([\"nvidia gpus\", nvidia_gpu_cnt])\n\n    torch_gpu_cnt = torch.cuda.device_count()\n    if torch_gpu_cnt:\n        rep.append([\"torch devices\", torch_gpu_cnt])\n        # information for each gpu\n        for i in range(torch_gpu_cnt):\n            rep.append([f\"  - gpu{i}\", (f\"{gpu_total_mem[i]}MB | \" if gpu_total_mem else \"\") + torch.cuda.get_device_name(i)])\n    else:\n        if nvidia_gpu_cnt:\n            rep.append([f\"Have {nvidia_gpu_cnt} GPU(s), but torch can't use them (check nvidia driver)\", None])\n        else:\n            rep.append([f\"No GPUs available\", None])\n\n\n    rep.append([\"\\n=== Environment ===\", None])\n\n    rep.append([\"platform\", platform.platform()])\n\n    if platform.system() == 'Linux':\n        distro = try_import('distro')\n        if distro:\n            # full distro info\n            rep.append([\"distro\", ' '.join(distro.linux_distribution())])\n        else:\n            opt_mods.append('distro');\n            # partial distro info\n            rep.append([\"distro\", platform.uname().version])\n\n    rep.append([\"conda env\", get_env('CONDA_DEFAULT_ENV')])\n    rep.append([\"python\", sys.executable])\n    rep.append([\"sys.path\", \"\\n\".join(sys.path)])\n\n    print(\"\\n\\n```text\")\n\n    keylen = max([len(e[0]) for e in rep if e[1] is not None])\n    for e in rep:\n        print(f\"{e[0]:{keylen}}\", (f\": {e[1]}\" if e[1] is not None else \"\"))\n\n    if have_nvidia_smi:\n        if show_nvidia_smi: print(f\"\\n{smi}\")\n    else:\n        if torch_gpu_cnt: print(\"no nvidia-smi is found\")\n        else: print(\"no supported gpus found on this system\")\n\n    print(\"```\\n\")\n\n    print(\"Please make sure to include opening/closing ``` when you paste into forums/github to make the reports appear formatted as code sections.\\n\")\n\n    if opt_mods:\n        print(\"Optional package(s) to enhance the diagnostics can be installed with:\")\n        print(f\"pip install {' '.join(opt_mods)}\")\n        print(\"Once installed, re-run this utility to get the additional information\")\n\ndef pypi_module_version_is_available(module, version):\n    \"Check whether module==version is available on pypi\"\n    # returns True/False (or None if failed to execute the check)\n\n    # using a hack that when passing \"module==\" w/ no version number to pip\n    # it \"fails\" and returns all the available versions in stderr\n    try:\n        cmd = f\"pip install {module}==\"\n        result = subprocess.run(cmd.split(), shell=False, check=False,\n                                stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n    except Exception as e:\n        print(f\"Error: {e}\")\n        return None\n    else:\n        if result.returncode == 1 and result.stderr:\n            output = result.stderr.decode('utf-8')\n            return True if version in output else False\n        else:\n            print(f\"Some error in {cmd}\")\n            return None\n\ndef check_perf():\n    \"Suggest how to improve the setup to speed things up\"\n\n    from PIL import features, Image\n    from packaging import version\n\n    print(\"Running performance checks.\")\n\n    # libjpeg_turbo check\n    print(\"\\n*** libjpeg-turbo status\")\n    if version.parse(Image.PILLOW_VERSION) >= version.parse(\"5.3.9\"):\n        if features.check_feature('libjpeg_turbo'):\n            print(\"✔ libjpeg-turbo is on\")\n        else:\n            print(\"✘ libjpeg-turbo is not on. It's recommended you install libjpeg-turbo to speed up JPEG decoding. See https://docs.fast.ai/performance.html#libjpeg-turbo\")\n    else:\n        print(f\"❓ libjpeg-turbo's status can't be derived - need Pillow(-SIMD)? >= 5.4.0 to tell, current version {Image.PILLOW_VERSION}\")\n        # XXX: remove this check/note once Pillow and Pillow-SIMD 5.4.0 is available\n        pillow_ver_5_4_is_avail = pypi_module_version_is_available(\"Pillow\", \"5.4.0\")\n        if pillow_ver_5_4_is_avail == False:\n            print(\"5.4.0 is not yet available, other than the dev version on github, which can be installed via pip from git+https://github.com/python-pillow/Pillow. See https://docs.fast.ai/performance.html#libjpeg-turbo\")\n\n    # Pillow-SIMD check\n    print(\"\\n*** Pillow-SIMD status\")\n    if re.search(r'\\.post\\d+', Image.PILLOW_VERSION):\n        print(f\"✔ Running Pillow-SIMD {Image.PILLOW_VERSION}\")\n    else:\n        print(f\"✘ Running Pillow {Image.PILLOW_VERSION}; It's recommended you install Pillow-SIMD to speed up image resizing and other operations. See https://docs.fast.ai/performance.html#pillow-simd\")\n\n    # CUDA version check\n    # compatibility table: k: min nvidia ver is required for v: cuda ver\n    # note: windows nvidia driver version is slightly higher, see:\n    # https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html\n    # note: add new entries if pytorch starts supporting new cudaXX\n    nvidia2cuda = {\n        \"410.00\": \"10.0\",\n        \"384.81\":  \"9.0\",\n        \"367.48\":  \"8.0\",\n    }\n    print(\"\\n*** CUDA status\")\n    if torch.cuda.is_available():\n        pynvml = load_pynvml_env()\n        nvidia_ver = (pynvml.nvmlSystemGetDriverVersion().decode('utf-8') if platform.system() != \"Darwin\" else \"Cannot be determined on OSX yet\")\n        cuda_ver   = torch.version.cuda\n        max_cuda = \"8.0\"\n        for k in sorted(nvidia2cuda.keys()):\n            if version.parse(nvidia_ver) > version.parse(k): max_cuda = nvidia2cuda[k]\n        if version.parse(str(max_cuda)) <= version.parse(cuda_ver):\n            print(f\"✔ Running the latest CUDA {cuda_ver} with NVIDIA driver {nvidia_ver}\")\n        else:\n            print(f\"✘ You are running pytorch built against cuda {cuda_ver}, your NVIDIA driver {nvidia_ver} supports cuda10. See https://pytorch.org/get-started/locally/ to install pytorch built against the faster CUDA version.\")\n    else:\n        print(f\"❓ Running cpu-only torch version, CUDA check is not relevant\")\n\n    print(\"\\nRefer to https://docs.fast.ai/performance.html to make sense out of these checks and suggestions.\")\n"
  },
  {
    "path": "fastai/utils/ipython.py",
    "content": "\"ipython utils\"\n\nimport os, functools, traceback, gc\n\ndef is_in_ipython():\n    \"Is the code running in the ipython environment (jupyter including)\"\n\n    program_name = os.path.basename(os.getenv('_', ''))\n\n    if ('jupyter-notebook' in program_name or # jupyter-notebook\n        'ipython'          in program_name or # ipython\n        'JPY_PARENT_PID'   in os.environ):    # ipython-notebook\n        return True\n    else:\n        return False\n\nIS_IN_IPYTHON = is_in_ipython()\n\ndef is_in_colab():\n    \"Is the code running in Google Colaboratory?\"\n    if not IS_IN_IPYTHON: return False\n    try:\n        from google import colab\n        return True\n    except: return False\n\nIS_IN_COLAB = is_in_colab()\n\ndef get_ref_free_exc_info():\n    \"Free traceback from references to locals() in each frame to avoid circular reference leading to gc.collect() unable to reclaim memory\"\n    type, val, tb = sys.exc_info()\n    traceback.clear_frames(tb)\n    return (type, val, tb)\n\ndef gpu_mem_restore(func):\n    \"Reclaim GPU RAM if CUDA out of memory happened, or execution was interrupted\"\n    @functools.wraps(func)\n    def wrapper(*args, **kwargs):\n        tb_clear_frames = os.environ.get('FASTAI_TB_CLEAR_FRAMES', None)\n        if not IS_IN_IPYTHON or tb_clear_frames==\"0\":\n            return func(*args, **kwargs)\n\n        try:\n            return func(*args, **kwargs)\n        except Exception as e:\n            if (\"CUDA out of memory\" in str(e) or\n                \"device-side assert triggered\" in str(e) or\n                tb_clear_frames == \"1\"):\n                type, val, tb = get_ref_free_exc_info() # must!\n                gc.collect()\n                if \"device-side assert triggered\" in str(e):\n                    warn(\"\"\"When 'device-side assert triggered' error happens, it's not possible to recover and you must restart the kernel to continue. Use os.environ['CUDA_LAUNCH_BLOCKING']=\"1\" before restarting to debug\"\"\")\n                raise type(val).with_traceback(tb) from None\n            else: raise # re-raises the exact last exception\n    return wrapper\n\nclass gpu_mem_restore_ctx():\n    \"context manager to reclaim RAM if an exception happened under ipython\"\n    def __enter__(self): return self\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        if not exc_val: return True\n        traceback.clear_frames(exc_tb)\n        gc.collect()\n        raise exc_type(exc_val).with_traceback(exc_tb) from None\n"
  },
  {
    "path": "fastai/utils/mem.py",
    "content": "\"Utility functions for memory management\"\n\nfrom ..imports.torch import *\nfrom ..core import *\nfrom ..script import *\nimport functools, threading, time\nfrom .pynvml_gate import *\nfrom collections import namedtuple\n\n#is_osx = platform.system() == \"Darwin\"\nuse_gpu = torch.cuda.is_available()\n\nGPUMemory = namedtuple('GPUMemory', ['total', 'free', 'used'])\n\nif use_gpu:\n    pynvml = load_pynvml_env()\n\ndef preload_pytorch():\n    torch.ones((1, 1)).cuda()\n\ndef b2mb(num):\n    \"\"\" convert Bs to MBs and round down \"\"\"\n    return int(num/2**20)\n\ndef gpu_mem_get(id=None):\n    \"get total, used and free memory (in MBs) for gpu `id`. if `id` is not passed, currently selected torch device is used\"\n    if not use_gpu: return GPUMemory(0, 0, 0)\n    if id is None: id = torch.cuda.current_device()\n    try:\n        handle = pynvml.nvmlDeviceGetHandleByIndex(id)\n        info = pynvml.nvmlDeviceGetMemoryInfo(handle)\n        return GPUMemory(*(map(b2mb, [info.total, info.free, info.used])))\n    except:\n        return GPUMemory(0, 0, 0)\n\ndef gpu_mem_get_all():\n    \"get total, used and free memory (in MBs) for each available gpu\"\n    if not use_gpu: return []\n    return list(map(gpu_mem_get, range(pynvml.nvmlDeviceGetCount())))\n\ndef gpu_mem_get_free():\n    \"get free memory (in MBs) for the currently selected gpu id, w/o emptying the cache\"\n    return gpu_mem_get().free\n\ndef gpu_mem_get_free_no_cache():\n    \"get free memory (in MBs) for the currently selected gpu id, after emptying the cache\"\n    torch.cuda.empty_cache()\n    return gpu_mem_get().free\n\ndef gpu_mem_get_used():\n    \"get used memory (in MBs) for the currently selected gpu id, w/o emptying the cache\"\n    return gpu_mem_get().used\n\ndef gpu_mem_get_used_fast(gpu_handle):\n    \"get used memory (in MBs) for the currently selected gpu id, w/o emptying the cache, and needing the `gpu_handle` arg\"\n    info = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle)\n    return b2mb(info.used)\n\ndef gpu_mem_get_used_no_cache():\n    \"get used memory (in MBs) for the currently selected gpu id, after emptying the cache\"\n    torch.cuda.empty_cache()\n    return gpu_mem_get().used\n\ndef gpu_with_max_free_mem():\n    \"get [gpu_id, its_free_ram] for the first gpu with highest available RAM\"\n    mem_all = gpu_mem_get_all()\n    if not len(mem_all): return None, 0\n    free_all = np.array([x.free for x in mem_all])\n    id = np.argmax(free_all)\n    return id, free_all[id]\n\nclass GPUMemTrace():\n    \"Trace allocated and peaked GPU memory usage (deltas).\"\n    def __init__(self, silent=False, ctx=None, on_exit_report=True):\n        assert torch.cuda.is_available(), \"pytorch CUDA is required\"\n        self.silent = silent # shortcut to turn off all reports from constructor\n        self.ctx    = ctx    # default context note in report\n        self.on_exit_report = on_exit_report # auto-report on ctx manager exit (default: True)\n        self.start()\n\n    def reset(self):\n        self.used_start = gpu_mem_get_used_no_cache()\n        self.used_peak  = self.used_start\n\n    def data_set(self):\n        # delta_used is the difference between current used mem and used mem at the start\n        self.delta_used = gpu_mem_get_used_no_cache() - self.used_start\n\n        # delta_peaked is the overhead if any. It is calculated as follows:\n        #\n        # 1. The difference between the peak memory and the used memory at the\n        # start is measured:\n        # 2a. If it's negative, then delta_peaked is 0\n        # 2b. Otherwise, if used_delta is positive it gets subtracted from delta_peaked\n        # XXX: 2a shouldn't be needed once we have a reliable peak counter\n        self.delta_peaked = self.used_peak - self.used_start\n        if self.delta_peaked < 0: self.delta_peaked = 0\n        elif self.delta_used > 0: self.delta_peaked -= self.delta_used\n\n    def data(self):\n        if self.is_running: self.data_set()\n        return self.delta_used, self.delta_peaked\n\n    def start(self):\n        self.is_running = True\n        self.reset()\n        self.peak_monitor_start()\n\n    def stop(self):\n        self.peak_monitor_stop()\n        self.data_set()\n        self.is_running = False\n\n    def __enter__(self):\n        self.start()\n        return self\n\n    def __exit__(self, *exc):\n        self.stop()\n        if self.on_exit_report: self.report('exit')\n\n    def __del__(self):\n        self.stop()\n\n    def __repr__(self):\n        delta_used, delta_peaked = self.data()\n        return f\"△Used Peaked MB: {delta_used:6,.0f} {delta_peaked:6,.0f}\"\n\n    def _get_ctx(self, subctx=None):\n        \"Return ' (ctx: subctx)' or ' (ctx)' or ' (subctx)' or '' depending on this and constructor arguments\"\n        l = []\n        if self.ctx is not None:      l.append(self.ctx)\n        if subctx is not None:        l.append(subctx)\n        return '' if len(l) == 0 else f\" ({': '.join(l)})\"\n\n    def silent(self, silent=True):\n        self.silent = silent\n\n    def report(self, subctx=None):\n        \"Print delta used+peaked, and an optional context note, which can also be preset in constructor\"\n        if self.silent: return\n        print(f\"{ self.__repr__() }{ self._get_ctx(subctx) }\")\n\n    def report_n_reset(self, subctx=None):\n        \"Print delta used+peaked, and an optional context note. Then reset counters\"\n        self.report(subctx)\n        self.reset()\n\n    def peak_monitor_start(self):\n        self.peak_monitoring = True\n\n        # continually sample GPU RAM usage\n        peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)\n        peak_monitor_thread.daemon = True\n        peak_monitor_thread.start()\n\n    def peak_monitor_stop(self):\n        self.peak_monitoring = False\n\n    # XXX: this is an unreliable function, since there is no thread priority\n    # control and it may not run enough or not run at all\n    def peak_monitor_func(self):\n        gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(torch.cuda.current_device())\n        while True:\n            self.used_peak = max(gpu_mem_get_used_fast(gpu_handle), self.used_peak)\n            if not self.peak_monitoring: break\n            time.sleep(0.001) # 1msec\n\ndef gpu_mem_trace(func):\n    \"A decorator that runs `GPUMemTrace` w/ report on func\"\n    @functools.wraps(func)\n    def wrapper(*args, **kwargs):\n        with GPUMemTrace(ctx=func.__qualname__, on_exit_report=True):\n            return func(*args, **kwargs)\n    return wrapper\n\ndef reduce_mem_usage(df):\n    \"\"\" iterate through all the columns of a dataframe and modify the data type\n        to reduce memory usage.\n    \"\"\"\n    start_mem = df.memory_usage().sum() / 1024**2\n    print('Memory usage of dataframe is {:.2f} MB'.format(start_mem))\n\n    #Removed from debugging\n    columns = df.columns\n    #.drop('index')\n\n    for col in columns:\n        col_type = df[col].dtype\n        if str(col_type) != 'category' and col_type != 'datetime64[ns]' and col_type != bool:\n            if col_type != object:\n                c_min = df[col].min()\n                c_max = df[col].max()\n                if str(col_type)[:3] == 'int':\n                    if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:\n                        df[col] = df[col].astype(np.int8)\n                    elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:\n                        df[col] = df[col].astype(np.int16)\n                    elif c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:\n                        df[col] = df[col].astype(np.int32)\n                    elif c_min > np.iinfo(np.int64).min and c_max < np.iinfo(np.int64).max:\n                        df[col] = df[col].astype(np.int64)\n                else:\n                    #if c_min > np.finfo(np.float16).min and c_max < np.finfo(np.float16).max:\n                        #df[col] = df[col].astype(np.float16)\n                    #Sometimes causes and error and had to remove\n                    if c_min > np.finfo(np.float32).min and c_max < np.finfo(np.float32).max:\n                        df[col] = df[col].astype(np.float32)\n                    else:\n                        print('Error '+col+' Value would be a float64. Disregarding.')\n            else:\n                df[col] = df[col].astype('category')\n\n    end_mem = df.memory_usage().sum() / 1024**2\n    print('Memory usage after optimization is: {:.2f} MB'.format(end_mem))\n    print('Decreased by {:.1f}%'.format(100 * (start_mem - end_mem) / start_mem))\n\n    return df\n"
  },
  {
    "path": "fastai/utils/mod_display.py",
    "content": "\" Utils for modifying what is displayed in notebooks and command line\"\nimport fastai\nimport fastprogress\n\nfrom ..basic_train import *\nfrom ..core import *\n\n__all__ = ['progress_disabled_ctx']\n\nclass progress_disabled_ctx():\n    \"Context manager to disable the progress update bar and Recorder print.\"\n    def __init__(self,learn:Learner):\n        self.learn = learn\n\n    def __enter__(self):\n        #silence progress bar\n        fastprogress.fastprogress.NO_BAR = True\n        fastai.basic_train.master_bar,fastai.basic_train.progress_bar = fastprogress.force_console_behavior()\n        self.orig_callback_fns = copy(self.learn.callback_fns)\n        rec_name = [x for x in self.learn.callback_fns if hasattr(x, 'func') and x.func == Recorder]\n        if len(rec_name):\n            rec_idx = self.learn.callback_fns.index(rec_name[0])\n            self.learn.callback_fns[rec_idx] = partial(Recorder, add_time=True, silent=True) #silence recorder\n        return self.learn\n\n    def __exit__(self, *args):\n        fastai.basic_train.master_bar,fastai.basic_train.progress_bar = master_bar,progress_bar\n        self.learn.callback_fns = self.orig_callback_fns\n"
  },
  {
    "path": "fastai/utils/pynvml_gate.py",
    "content": "\"\"\"Get OS specific nvml wrapper. On OSX we use pynvx as drop in replacement for pynvml\"\"\"\n\nimport platform\nfrom ..script import *\n\n#\n# BEGIN: Temporary workaround for nvml.dll load issue in Win10\n#\n# Remove once nicolargo/nvidia-ml-py3#2 and a new version of the module is released \n# (OR fbcotter/py3nvml#10 but will require extra work to rename things)\n# Refer https://forums.fast.ai/t/nvml-dll-loading-issue-in-nvidia-ml-py3-7-352-0-py-0/39684/8\nimport threading\nfrom ctypes import *\n\nnvmlLib = None\nlibLoadLock = threading.Lock()\n\ndef _LoadNvmlLibrary():\n    '''\n    Load the library if it isn't loaded already\n    '''\n\n    global nvmlLib\n\n    if (nvmlLib == None):\n        libLoadLock.acquire()\n\n        try:\n            if (nvmlLib == None):\n                try:\n                    if (sys.platform[:3] == \"win\"):\n                        searchPaths = [\n                            os.path.join(os.getenv(\"ProgramFiles\", r\"C:\\Program Files\"), r\"NVIDIA Corporation\\NVSMI\\nvml.dll\"),\n                            os.path.join(os.getenv(\"WinDir\", r\"C:\\Windows\"), r\"System32\\nvml.dll\"),\n                        ]\n                        nvmlPath = next((x for x in searchPaths if os.path.isfile(x)), None)\n                        if (nvmlPath == None):\n                            nvmlLib = None\n                        else:\n                            nvmlLib = CDLL(nvmlPath)\n                    else:\n                        nvmlLib = None\n                except OSError as ose:\n                    nvmlLib = None\n        finally:\n            libLoadLock.release()\n#\n# END: Temporary workaround for nvml.dll load issue in Win10\n#\n\ndef load_pynvml_env():\n    import pynvml # nvidia-ml-py3\n\n    #\n    # BEGIN: Temporary workaround for nvml.dll load issue in Win10 (continued)\n    _LoadNvmlLibrary()\n    pynvml.nvmlLib = nvmlLib\n    #\n    # END: Temporary workaround for nvml.dll load issue in Win10\n    #\n\n    if platform.system() == \"Darwin\":\n        try:\n            from pynvx import pynvml\n        except:\n            print(\"please install pynvx on OSX: pip install pynvx\")\n            sys.exit(1)\n\n        pynvml.nvmlInit()\n        return pynvml\n\n    pynvml.nvmlInit()\n\n    return pynvml\n"
  },
  {
    "path": "fastai/utils/show_install.py",
    "content": "from ..script import *\nfrom .collect_env import *\n\n# Temporary POC for module-based script\n@call_parse\ndef main(show_nvidia_smi:Param(opt=False, nargs='?', type=bool)=False):\n    return show_install(show_nvidia_smi)\n\n"
  },
  {
    "path": "fastai/version.py",
    "content": "__all__ = ['__version__']\n__version__ = '1.0.56.dev0'\n"
  },
  {
    "path": "fastai/vision/__init__.py",
    "content": "from .. import basics\nfrom ..basics import *\nfrom .learner import *\nfrom .image import *\nfrom .data import *\nfrom .transform import *\nfrom .tta import *\nfrom . import models\n\nfrom .. import vision\n\n__all__ = [*basics.__all__, *learner.__all__, *data.__all__, *image.__all__, *transform.__all__, *tta.__all__, 'models', 'vision']\n\n"
  },
  {
    "path": "fastai/vision/cyclegan.py",
    "content": "from ..torch_core import *\nfrom ..layers import *\nfrom ..callback import *\nfrom ..basic_train import Learner, LearnerCallback\n\n__all__ = ['CycleGAN', 'CycleGanLoss', 'AdaptiveLoss', 'CycleGANTrainer']\n\ndef convT_norm_relu(ch_in:int, ch_out:int, norm_layer:nn.Module, ks:int=3, stride:int=2, bias:bool=True):\n    return [nn.ConvTranspose2d(ch_in, ch_out, kernel_size=ks, stride=stride, padding=1, output_padding=1, bias=bias),\n            norm_layer(ch_out), nn.ReLU(True)]\n\ndef pad_conv_norm_relu(ch_in:int, ch_out:int, pad_mode:str, norm_layer:nn.Module, ks:int=3, bias:bool=True,\n                       pad=1, stride:int=1, activ:bool=True, init:Callable=nn.init.kaiming_normal_)->List[nn.Module]:\n    layers = []\n    if pad_mode == 'reflection': layers.append(nn.ReflectionPad2d(pad))\n    elif pad_mode == 'border':   layers.append(nn.ReplicationPad2d(pad))\n    p = pad if pad_mode == 'zeros' else 0\n    conv = nn.Conv2d(ch_in, ch_out, kernel_size=ks, padding=p, stride=stride, bias=bias)\n    if init:\n        init(conv.weight)\n        if hasattr(conv, 'bias') and hasattr(conv.bias, 'data'): conv.bias.data.fill_(0.)\n    layers += [conv, norm_layer(ch_out)]\n    if activ: layers.append(nn.ReLU(inplace=True))\n    return layers\n\nclass ResnetBlock(Module):\n    def __init__(self, dim:int, pad_mode:str='reflection', norm_layer:nn.Module=None, dropout:float=0., bias:bool=True):\n        assert pad_mode in ['zeros', 'reflection', 'border'], f'padding {pad_mode} not implemented.'\n        norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)\n        layers = pad_conv_norm_relu(dim, dim, pad_mode, norm_layer, bias=bias)\n        if dropout != 0: layers.append(nn.Dropout(dropout))\n        layers += pad_conv_norm_relu(dim, dim, pad_mode, norm_layer, bias=bias, activ=False)\n        self.conv_block = nn.Sequential(*layers)\n\n    def forward(self, x): return x + self.conv_block(x)\n\ndef resnet_generator(ch_in:int, ch_out:int, n_ftrs:int=64, norm_layer:nn.Module=None,\n                     dropout:float=0., n_blocks:int=6, pad_mode:str='reflection')->nn.Module:\n    norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)\n    bias = (norm_layer == nn.InstanceNorm2d)\n    layers = pad_conv_norm_relu(ch_in, n_ftrs, 'reflection', norm_layer, pad=3, ks=7, bias=bias)\n    for i in range(2):\n        layers += pad_conv_norm_relu(n_ftrs, n_ftrs *2, 'zeros', norm_layer, stride=2, bias=bias)\n        n_ftrs *= 2\n    layers += [ResnetBlock(n_ftrs, pad_mode, norm_layer, dropout, bias) for _ in range(n_blocks)]\n    for i in range(2):\n        layers += convT_norm_relu(n_ftrs, n_ftrs//2, norm_layer, bias=bias)\n        n_ftrs //= 2\n    layers += [nn.ReflectionPad2d(3), nn.Conv2d(n_ftrs, ch_out, kernel_size=7, padding=0), nn.Tanh()]\n    return nn.Sequential(*layers)\n\ndef conv_norm_lr(ch_in:int, ch_out:int, norm_layer:nn.Module=None, ks:int=3, bias:bool=True, pad:int=1, stride:int=1,\n                 activ:bool=True, slope:float=0.2, init:Callable=nn.init.kaiming_normal_)->List[nn.Module]:\n    conv = nn.Conv2d(ch_in, ch_out, kernel_size=ks, padding=pad, stride=stride, bias=bias)\n    if init:\n        init(conv.weight)\n        if hasattr(conv, 'bias') and hasattr(conv.bias, 'data'): conv.bias.data.fill_(0.)\n    layers = [conv]\n    if norm_layer is not None: layers.append(norm_layer(ch_out))\n    if activ: layers.append(nn.LeakyReLU(slope, inplace=True))\n    return layers\n\ndef critic(ch_in:int, n_ftrs:int=64, n_layers:int=3, norm_layer:nn.Module=None, sigmoid:bool=False)->nn.Module:\n    norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)\n    bias = (norm_layer == nn.InstanceNorm2d)\n    layers = conv_norm_lr(ch_in, n_ftrs, ks=4, stride=2, pad=1)\n    for i in range(n_layers-1):\n        new_ftrs = 2*n_ftrs if i <= 3 else n_ftrs\n        layers += conv_norm_lr(n_ftrs, new_ftrs, norm_layer, ks=4, stride=2, pad=1, bias=bias)\n        n_ftrs = new_ftrs\n    new_ftrs = 2*n_ftrs if n_layers <=3 else n_ftrs\n    layers += conv_norm_lr(n_ftrs, new_ftrs, norm_layer, ks=4, stride=1, pad=1, bias=bias)\n    layers.append(nn.Conv2d(new_ftrs, 1, kernel_size=4, stride=1, padding=1))\n    if sigmoid: layers.append(nn.Sigmoid())\n    return nn.Sequential(*layers)\n\nclass CycleGAN(Module):\n\n    def __init__(self, ch_in:int, ch_out:int, n_features:int=64, disc_layers:int=3, gen_blocks:int=6, lsgan:bool=True,\n                 drop:float=0., norm_layer:nn.Module=None):\n        self.D_A = critic(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)\n        self.D_B = critic(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)\n        self.G_A = resnet_generator(ch_in, ch_out, n_features, norm_layer, drop, gen_blocks)\n        self.G_B = resnet_generator(ch_in, ch_out, n_features, norm_layer, drop, gen_blocks)\n        #G_A: takes real input B and generates fake input A\n        #G_B: takes real input A and generates fake input B\n        #D_A: trained to make the difference between real input A and fake input A\n        #D_B: trained to make the difference between real input B and fake input B\n\n    def forward(self, real_A, real_B):\n        fake_A, fake_B = self.G_A(real_B), self.G_B(real_A)\n        if not self.training: return torch.cat([fake_A[:,None],fake_B[:,None]], 1)\n        idt_A, idt_B = self.G_A(real_A), self.G_B(real_B)\n        return [fake_A, fake_B, idt_A, idt_B]\n\nclass AdaptiveLoss(Module):\n    def __init__(self, crit): self.crit = crit\n\n    def forward(self, output, target:bool):\n        targ = output.new_ones(*output.size()) if target else output.new_zeros(*output.size())\n        return self.crit(output, targ)\n\nclass CycleGanLoss(Module):\n    def __init__(self, cgan:nn.Module, lambda_A:float=10., lambda_B:float=10, lambda_idt:float=0.5, lsgan:bool=True):\n        self.cgan,self.l_A,self.l_B,self.l_idt = cgan,lambda_A,lambda_B,lambda_idt\n        #self.crit = F.mse_loss if lsgan else F.binary_cross_entropy\n        self.crit = AdaptiveLoss(F.mse_loss if lsgan else F.binary_cross_entropy)\n\n    def set_input(self, input):\n        self.real_A,self.real_B = input\n\n    def forward(self, output, target):\n        fake_A, fake_B, idt_A, idt_B = output\n        #Generators should return identity on the datasets they try to convert to\n        idt_loss = self.l_idt * (self.l_B * F.l1_loss(idt_A, self.real_B) + self.l_A * F.l1_loss(idt_B, self.real_A))\n        #Generators are trained to trick the critics so the following should be ones\n        gen_loss = self.crit(self.cgan.D_A(fake_A), True) + self.crit(self.cgan.D_B(fake_B), True)\n        #Cycle loss\n        cycle_loss = self.l_A * F.l1_loss(self.cgan.G_A(fake_B), self.real_A)\n        cycle_loss += self.l_B * F.l1_loss(self.cgan.G_B(fake_A), self.real_B)\n        self.metrics = [idt_loss, gen_loss, cycle_loss]\n        return idt_loss + gen_loss + cycle_loss\n\nclass CycleGANTrainer(LearnerCallback):\n    \"`LearnerCallback` that handles cycleGAN Training.\"\n    _order=-20\n    def _set_trainable(self, D_A=False, D_B=False):\n        gen = (not D_A) and (not D_B)\n        requires_grad(self.learn.model.G_A, gen)\n        requires_grad(self.learn.model.G_B, gen)\n        requires_grad(self.learn.model.D_A, D_A)\n        requires_grad(self.learn.model.D_B, D_B)\n        if not gen:\n            self.opt_D_A.lr, self.opt_D_A.mom = self.learn.opt.lr, self.learn.opt.mom\n            self.opt_D_A.wd, self.opt_D_A.beta = self.learn.opt.wd, self.learn.opt.beta\n            self.opt_D_B.lr, self.opt_D_B.mom = self.learn.opt.lr, self.learn.opt.mom\n            self.opt_D_B.wd, self.opt_D_B.beta = self.learn.opt.wd, self.learn.opt.beta\n\n    def on_train_begin(self, **kwargs):\n        \"Create the various optimizers.\"\n        self.G_A,self.G_B = self.learn.model.G_A,self.learn.model.G_B\n        self.D_A,self.D_B = self.learn.model.D_A,self.learn.model.D_B\n        self.crit = self.learn.loss_func.crit\n        self.opt_G = self.learn.opt.new([nn.Sequential(*flatten_model(self.G_A), *flatten_model(self.G_B))])\n        self.opt_D_A = self.learn.opt.new([nn.Sequential(*flatten_model(self.D_A))])\n        self.opt_D_B = self.learn.opt.new([nn.Sequential(*flatten_model(self.D_B))])\n        self.learn.opt.opt = self.opt_G.opt\n        self._set_trainable()\n        self.names = ['idt_loss', 'gen_loss', 'cyc_loss', 'da_loss', 'db_loss']\n        self.learn.recorder.no_val=True\n        self.learn.recorder.add_metric_names(self.names)\n        self.smootheners = {n:SmoothenValue(0.98) for n in self.names}\n\n    def on_batch_begin(self, last_input, **kwargs):\n        \"Register the `last_input` in the loss function.\"\n        self.learn.loss_func.set_input(last_input)\n\n    def on_batch_end(self, last_input, last_output, **kwargs):\n        \"Steps through the generators then each of the critics.\"\n        self.G_A.zero_grad(); self.G_B.zero_grad()\n        fake_A, fake_B = last_output[0].detach(), last_output[1].detach()\n        real_A, real_B = last_input\n        self._set_trainable(D_A=True)\n        self.D_A.zero_grad()\n        loss_D_A = 0.5 * (self.crit(self.D_A(real_A), True) + self.crit(self.D_A(fake_A), False))\n        loss_D_A.backward()\n        self.opt_D_A.step()\n        self._set_trainable(D_B=True)\n        self.D_B.zero_grad()\n        loss_D_B = 0.5 * (self.crit(self.D_B(real_B), True) + self.crit(self.D_B(fake_B), False))\n        loss_D_B.backward()\n        self.opt_D_B.step()\n        self._set_trainable()\n        metrics = self.learn.loss_func.metrics + [loss_D_A, loss_D_B]\n        for n,m in zip(self.names,metrics): self.smootheners[n].add_value(m)\n\n    def on_epoch_end(self, last_metrics, **kwargs):\n        \"Put the various losses in the recorder.\"\n        return add_metrics(last_metrics, [s.smooth for k,s in self.smootheners.items()])\n\n"
  },
  {
    "path": "fastai/vision/data.py",
    "content": "\"Manages data input pipeline - folderstransformbatch input. Includes support for classification, segmentation and bounding boxes\"\nfrom numbers import Integral\nfrom ..torch_core import *\nfrom .image import *\nfrom .transform import *\nfrom ..data_block import *\nfrom ..basic_data import *\nfrom ..layers import *\nfrom .learner import *\nfrom torchvision import transforms as tvt\n\n__all__ = ['get_image_files', 'denormalize', 'get_annotations', 'ImageDataBunch',\n           'ImageList', 'normalize', 'normalize_funcs', 'resize_to',\n           'channel_view', 'mnist_stats', 'cifar_stats', 'imagenet_stats', 'imagenet_stats_inception', 'download_images',\n           'verify_images', 'bb_pad_collate', 'ImageImageList', 'PointsLabelList',\n           'ObjectCategoryList', 'ObjectItemList', 'SegmentationLabelList', 'SegmentationItemList', 'PointsItemList']\n\nimage_extensions = set(k for k,v in mimetypes.types_map.items() if v.startswith('image/'))\n\ndef get_image_files(c:PathOrStr, check_ext:bool=True, recurse=False)->FilePathList:\n    \"Return list of files in `c` that are images. `check_ext` will filter to `image_extensions`.\"\n    return get_files(c, extensions=(image_extensions if check_ext else None), recurse=recurse)\n\ndef get_annotations(fname, prefix=None):\n    \"Open a COCO style json in `fname` and returns the lists of filenames (with maybe `prefix`) and labelled bboxes.\"\n    annot_dict = json.load(open(fname))\n    id2images, id2bboxes, id2cats = {}, collections.defaultdict(list), collections.defaultdict(list)\n    classes = {}\n    for o in annot_dict['categories']:\n        classes[o['id']] = o['name']\n    for o in annot_dict['annotations']:\n        bb = o['bbox']\n        id2bboxes[o['image_id']].append([bb[1],bb[0], bb[3]+bb[1], bb[2]+bb[0]])\n        id2cats[o['image_id']].append(classes[o['category_id']])\n    for o in annot_dict['images']:\n        if o['id'] in id2bboxes:\n            id2images[o['id']] = ifnone(prefix, '') + o['file_name']\n    ids = list(id2images.keys())\n    return [id2images[k] for k in ids], [[id2bboxes[k], id2cats[k]] for k in ids]\n\ndef bb_pad_collate(samples:BatchSamples, pad_idx:int=0) -> Tuple[FloatTensor, Tuple[LongTensor, LongTensor]]:\n    \"Function that collect `samples` of labelled bboxes and adds padding with `pad_idx`.\"\n    if isinstance(samples[0][1], int): return data_collate(samples)\n    max_len = max([len(s[1].data[1]) for s in samples])\n    bboxes = torch.zeros(len(samples), max_len, 4)\n    labels = torch.zeros(len(samples), max_len).long() + pad_idx\n    imgs = []\n    for i,s in enumerate(samples):\n        imgs.append(s[0].data[None])\n        bbs, lbls = s[1].data\n        if not (bbs.nelement() == 0):\n            bboxes[i,-len(lbls):] = bbs\n            labels[i,-len(lbls):] = tensor(lbls)\n    return torch.cat(imgs,0), (bboxes,labels)\n\ndef normalize(x:TensorImage, mean,std:Tensor)->TensorImage:\n    \"Normalize `x` with `mean` and `std`.\"\n    return (x-mean[...,None,None]) / std[...,None,None]\n\ndef denormalize(x:TensorImage, mean,std:Tensor, do_x:bool=True)->TensorImage:\n    \"Denormalize `x` with `mean` and `std`.\"\n    return x.cpu().float()*std[...,None,None] + mean[...,None,None] if do_x else x.cpu()\n\ndef _normalize_batch(b:Tuple[Tensor,Tensor], mean:Tensor, std:Tensor, do_x:bool=True, do_y:bool=False)->Tuple[Tensor,Tensor]:\n    \"`b` = `x`,`y` - normalize `x` array of imgs and `do_y` optionally `y`.\"\n    x,y = b\n    mean,std = mean.to(x.device),std.to(x.device)\n    if do_x: x = normalize(x,mean,std)\n    if do_y and len(y.shape) == 4: y = normalize(y,mean,std)\n    return x,y\n\ndef normalize_funcs(mean:Tensor, std:Tensor, do_x:bool=True, do_y:bool=False)->Tuple[Callable,Callable]:\n    \"Create normalize/denormalize func using `mean` and `std`, can specify `do_y` and `device`.\"\n    mean,std = tensor(mean),tensor(std)\n    return (partial(_normalize_batch, mean=mean, std=std, do_x=do_x, do_y=do_y),\n            partial(denormalize,      mean=mean, std=std, do_x=do_x))\n\ncifar_stats = ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261])\nimagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\nimagenet_stats_inception = ([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\nmnist_stats = ([0.15]*3, [0.15]*3)\n\ndef channel_view(x:Tensor)->Tensor:\n    \"Make channel the first axis of `x` and flatten remaining axes\"\n    return x.transpose(0,1).contiguous().view(x.shape[1],-1)\n\nclass ImageDataBunch(DataBunch):\n    \"DataBunch suitable for computer vision.\"\n    _square_show = True\n\n    @classmethod\n    def create_from_ll(cls, lls:LabelLists, bs:int=64, val_bs:int=None, ds_tfms:Optional[TfmList]=None,\n                num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None, device:torch.device=None,\n                test:Optional[PathOrStr]=None, collate_fn:Callable=data_collate, size:int=None, no_check:bool=False,\n                resize_method:ResizeMethod=None, mult:int=None, padding_mode:str='reflection',\n                mode:str='bilinear', tfm_y:bool=False)->'ImageDataBunch':\n        \"Create an `ImageDataBunch` from `LabelLists` `lls` with potential `ds_tfms`.\"\n        lls = lls.transform(tfms=ds_tfms, size=size, resize_method=resize_method, mult=mult, padding_mode=padding_mode,\n                            mode=mode, tfm_y=tfm_y)\n        if test is not None: lls.add_test_folder(test)\n        return lls.databunch(bs=bs, val_bs=val_bs, dl_tfms=dl_tfms, num_workers=num_workers, collate_fn=collate_fn,\n                             device=device, no_check=no_check)\n\n    @classmethod\n    def from_folder(cls, path:PathOrStr, train:PathOrStr='train', valid:PathOrStr='valid',\n                    valid_pct=None, seed:int=None, classes:Collection=None, **kwargs:Any)->'ImageDataBunch':\n        \"Create from imagenet style dataset in `path` with `train`,`valid`,`test` subfolders (or provide `valid_pct`).\"\n        path=Path(path)\n        il = ImageList.from_folder(path)\n        if valid_pct is None: src = il.split_by_folder(train=train, valid=valid)\n        else: src = il.split_by_rand_pct(valid_pct, seed)\n        src = src.label_from_folder(classes=classes)\n        return cls.create_from_ll(src, **kwargs)\n\n    @classmethod\n    def from_df(cls, path:PathOrStr, df:pd.DataFrame, folder:PathOrStr=None, label_delim:str=None, valid_pct:float=0.2,\n                seed:int=None, fn_col:IntsOrStrs=0, label_col:IntsOrStrs=1, suffix:str='', **kwargs:Any)->'ImageDataBunch':\n        \"Create from a `DataFrame` `df`.\"\n        src = (ImageList.from_df(df, path=path, folder=folder, suffix=suffix, cols=fn_col)\n                .split_by_rand_pct(valid_pct, seed)\n                .label_from_df(label_delim=label_delim, cols=label_col))\n        return cls.create_from_ll(src, **kwargs)\n\n    @classmethod\n    def from_csv(cls, path:PathOrStr, folder:PathOrStr=None, label_delim:str=None, csv_labels:PathOrStr='labels.csv',\n                 valid_pct:float=0.2, seed:int=None, fn_col:int=0, label_col:int=1, suffix:str='', delimiter:str=None,\n                 header:Optional[Union[int,str]]='infer', **kwargs:Any)->'ImageDataBunch':\n        \"Create from a csv file in `path/csv_labels`.\"\n        path = Path(path)\n        df = pd.read_csv(path/csv_labels, header=header, delimiter=delimiter)\n        return cls.from_df(path, df, folder=folder, label_delim=label_delim, valid_pct=valid_pct, seed=seed,\n                fn_col=fn_col, label_col=label_col, suffix=suffix, **kwargs)\n\n    @classmethod\n    def from_lists(cls, path:PathOrStr, fnames:FilePathList, labels:Collection[str], valid_pct:float=0.2, seed:int=None,\n                   item_cls:Callable=None, **kwargs):\n        \"Create from list of `fnames` in `path`.\"\n        item_cls = ifnone(item_cls, ImageList)\n        fname2label = {f:l for (f,l) in zip(fnames, labels)}\n        src = (item_cls(fnames, path=path).split_by_rand_pct(valid_pct, seed)\n                                .label_from_func(lambda x:fname2label[x]))\n        return cls.create_from_ll(src, **kwargs)\n\n    @classmethod\n    def from_name_func(cls, path:PathOrStr, fnames:FilePathList, label_func:Callable, valid_pct:float=0.2, seed:int=None,\n                       **kwargs):\n        \"Create from list of `fnames` in `path` with `label_func`.\"\n        src = ImageList(fnames, path=path).split_by_rand_pct(valid_pct, seed)\n        return cls.create_from_ll(src.label_from_func(label_func), **kwargs)\n\n    @classmethod\n    def from_name_re(cls, path:PathOrStr, fnames:FilePathList, pat:str, valid_pct:float=0.2, **kwargs):\n        \"Create from list of `fnames` in `path` with re expression `pat`.\"\n        pat = re.compile(pat)\n        def _get_label(fn):\n            if isinstance(fn, Path): fn = fn.as_posix()\n            res = pat.search(str(fn))\n            assert res,f'Failed to find \"{pat}\" in \"{fn}\"'\n            return res.group(1)\n        return cls.from_name_func(path, fnames, _get_label, valid_pct=valid_pct, **kwargs)\n\n    @staticmethod\n    def single_from_classes(path:Union[Path, str], classes:Collection[str], ds_tfms:TfmList=None, **kwargs):\n        \"Create an empty `ImageDataBunch` in `path` with `classes`. Typically used for inference.\"\n        warn(\"\"\"This method is deprecated and will be removed in a future version, use `load_learner` after\n             `Learner.export()`\"\"\", DeprecationWarning)\n        sd = ImageList([], path=path, ignore_empty=True).split_none()\n        return sd.label_const(0, label_cls=CategoryList, classes=classes).transform(ds_tfms, **kwargs).databunch()\n\n    def batch_stats(self, funcs:Collection[Callable]=None, ds_type:DatasetType=DatasetType.Train)->Tensor:\n        \"Grab a batch of data and call reduction function `func` per channel\"\n        funcs = ifnone(funcs, [torch.mean,torch.std])\n        x = self.one_batch(ds_type=ds_type, denorm=False)[0].cpu()\n        return [func(channel_view(x), 1) for func in funcs]\n\n    def normalize(self, stats:Collection[Tensor]=None, do_x:bool=True, do_y:bool=False)->None:\n        \"Add normalize transform using `stats` (defaults to `DataBunch.batch_stats`)\"\n        if getattr(self,'norm',False): raise Exception('Can not call normalize twice')\n        if stats is None: self.stats = self.batch_stats()\n        else:             self.stats = stats\n        self.norm,self.denorm = normalize_funcs(*self.stats, do_x=do_x, do_y=do_y)\n        self.add_tfm(self.norm)\n        return self\n\ndef download_image(url,dest, timeout=4):\n    try: r = download_url(url, dest, overwrite=True, show_progress=False, timeout=timeout)\n    except Exception as e: print(f\"Error {url} {e}\")\n\ndef _download_image_inner(dest, url, i, timeout=4):\n    suffix = re.findall(r'\\.\\w+?(?=(?:\\?|$))', url)\n    suffix = suffix[0] if len(suffix)>0  else '.jpg'\n    download_image(url, dest/f\"{i:08d}{suffix}\", timeout=timeout)\n\ndef download_images(urls:Collection[str], dest:PathOrStr, max_pics:int=1000, max_workers:int=8, timeout=4):\n    \"Download images listed in text file `urls` to path `dest`, at most `max_pics`\"\n    urls = open(urls).read().strip().split(\"\\n\")[:max_pics]\n    dest = Path(dest)\n    dest.mkdir(exist_ok=True)\n    parallel(partial(_download_image_inner, dest, timeout=timeout), urls, max_workers=max_workers)\n\ndef resize_to(img, targ_sz:int, use_min:bool=False):\n    \"Size to resize to, to hit `targ_sz` at same aspect ratio, in PIL coords (i.e w*h)\"\n    w,h = img.size\n    min_sz = (min if use_min else max)(w,h)\n    ratio = targ_sz/min_sz\n    return int(w*ratio),int(h*ratio)\n\ndef verify_image(file:Path, idx:int, delete:bool, max_size:Union[int,Tuple[int,int]]=None, dest:Path=None, n_channels:int=3,\n                 interp=PIL.Image.BILINEAR, ext:str=None, img_format:str=None, resume:bool=False, **kwargs):\n    \"Check if the image in `file` exists, maybe resize it and copy it in `dest`.\"\n    try:\n        # deal with partially broken images as indicated by PIL warnings\n        with warnings.catch_warnings():\n            warnings.filterwarnings('error')\n            try:\n                with open(file, 'rb') as img_file: PIL.Image.open(img_file)\n            except Warning as w:\n                if \"Possibly corrupt EXIF data\" in str(w):\n                    if delete: # green light to modify files\n                        print(f\"{file}: Removing corrupt EXIF data\")\n                        warnings.simplefilter(\"ignore\")\n                        # save EXIF-cleaned up image, which happens automatically\n                        PIL.Image.open(file).save(file)\n                    else: # keep user's files intact\n                        print(f\"{file}: Not removing corrupt EXIF data, pass `delete=True` to do that\")\n                else: warnings.warn(w)\n\n        img = PIL.Image.open(file)\n        imgarr = np.array(img)\n        img_channels = 1 if len(imgarr.shape) == 2 else imgarr.shape[2]\n        if (max_size is not None and (img.height > max_size or img.width > max_size)) or img_channels != n_channels:\n            assert isinstance(dest, Path), \"You should provide `dest` Path to save resized image\"\n            dest_fname = dest/file.name\n            if ext is not None: dest_fname=dest_fname.with_suffix(ext)\n            if resume and os.path.isfile(dest_fname): return\n            if max_size is not None:\n                new_sz = resize_to(img, max_size)\n                img = img.resize(new_sz, resample=interp)\n            if n_channels == 3: img = img.convert(\"RGB\")\n            img.save(dest_fname, img_format, **kwargs)\n    except Exception as e:\n        print(f'{e}')\n        if delete: file.unlink()\n\ndef verify_images(path:PathOrStr, delete:bool=True, max_workers:int=4, max_size:Union[int]=None, recurse:bool=False,\n                  dest:PathOrStr='.', n_channels:int=3, interp=PIL.Image.BILINEAR, ext:str=None, img_format:str=None,\n                  resume:bool=None, **kwargs):\n    \"Check if the images in `path` aren't broken, maybe resize them and copy it in `dest`.\"\n    path = Path(path)\n    if resume is None and dest == '.': resume=False\n    dest = path/Path(dest)\n    os.makedirs(dest, exist_ok=True)\n    files = get_image_files(path, recurse=recurse)\n    func = partial(verify_image, delete=delete, max_size=max_size, dest=dest, n_channels=n_channels, interp=interp,\n                   ext=ext, img_format=img_format, resume=resume, **kwargs)\n    parallel(func, files, max_workers=max_workers)\n\nclass ImageList(ItemList):\n    \"`ItemList` suitable for computer vision.\"\n    _bunch,_square_show,_square_show_res = ImageDataBunch,True,True\n    def __init__(self, *args, convert_mode='RGB', after_open:Callable=None, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.convert_mode,self.after_open = convert_mode,after_open\n        self.copy_new += ['convert_mode', 'after_open']\n        self.c,self.sizes = 3,{}\n\n    def open(self, fn):\n        \"Open image in `fn`, subclass and overwrite for custom behavior.\"\n        return open_image(fn, convert_mode=self.convert_mode, after_open=self.after_open)\n\n    def get(self, i):\n        fn = super().get(i)\n        res = self.open(fn)\n        self.sizes[i] = res.size\n        return res\n    \n    @classmethod\n    def from_folder(cls, path:PathOrStr='.', extensions:Collection[str]=None, **kwargs)->ItemList:\n        \"Get the list of files in `path` that have an image suffix. `recurse` determines if we search subfolders.\"\n        extensions = ifnone(extensions, image_extensions)\n        return super().from_folder(path=path, extensions=extensions, **kwargs)\n\n    @classmethod\n    def from_df(cls, df:DataFrame, path:PathOrStr, cols:IntsOrStrs=0, folder:PathOrStr=None, suffix:str='', **kwargs)->'ItemList':\n        \"Get the filenames in `cols` of `df` with `folder` in front of them, `suffix` at the end.\"\n        suffix = suffix or ''\n        res = super().from_df(df, path=path, cols=cols, **kwargs)\n        pref = f'{res.path}{os.path.sep}'\n        if folder is not None: pref += f'{folder}{os.path.sep}'\n        res.items = np.char.add(np.char.add(pref, res.items.astype(str)), suffix)\n        return res\n\n    @classmethod\n    def from_csv(cls, path:PathOrStr, csv_name:str, header:str='infer', delimiter:str=None, **kwargs)->'ItemList':\n        \"Get the filenames in `path/csv_name` opened with `header`.\"\n        path = Path(path)\n        df = pd.read_csv(path/csv_name, header=header, delimiter=delimiter)\n        return cls.from_df(df, path=path, **kwargs)\n\n    def reconstruct(self, t:Tensor): return Image(t.float().clamp(min=0,max=1))\n\n    def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):\n        \"Show the `xs` (inputs) and `ys` (targets) on a figure of `figsize`.\"\n        rows = int(np.ceil(math.sqrt(len(xs))))\n        axs = subplots(rows, rows, imgsize=imgsize, figsize=figsize)\n        for x,y,ax in zip(xs, ys, axs.flatten()): x.show(ax=ax, y=y, **kwargs)\n        for ax in axs.flatten()[len(xs):]: ax.axis('off')\n        plt.tight_layout()\n\n    def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):\n        \"Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`.\"\n        if self._square_show_res:\n            title = 'Ground truth\\nPredictions'\n            rows = int(np.ceil(math.sqrt(len(xs))))\n            axs = subplots(rows, rows, imgsize=imgsize, figsize=figsize, title=title, weight='bold', size=12)\n            for x,y,z,ax in zip(xs,ys,zs,axs.flatten()): x.show(ax=ax, title=f'{str(y)}\\n{str(z)}', **kwargs)\n            for ax in axs.flatten()[len(xs):]: ax.axis('off')\n        else:\n            title = 'Ground truth/Predictions'\n            axs = subplots(len(xs), 2, imgsize=imgsize, figsize=figsize, title=title, weight='bold', size=14)\n            for i,(x,y,z) in enumerate(zip(xs,ys,zs)):\n                x.show(ax=axs[i,0], y=y, **kwargs)\n                x.show(ax=axs[i,1], y=z, **kwargs)\n\nclass ObjectCategoryProcessor(MultiCategoryProcessor):\n    \"`PreProcessor` for labelled bounding boxes.\"\n    def __init__(self, ds:ItemList, pad_idx:int=0):\n        super().__init__(ds)\n        self.pad_idx = pad_idx\n        self.state_attrs.append('pad_idx')\n\n    def process(self, ds:ItemList):\n        ds.pad_idx = self.pad_idx\n        super().process(ds)\n\n    def process_one(self,item): return [item[0], [self.c2i.get(o,None) for o in item[1]]]\n\n    def generate_classes(self, items):\n        \"Generate classes from unique `items` and add `background`.\"\n        classes = super().generate_classes([o[1] for o in items])\n        classes = ['background'] + list(classes)\n        return classes\n\ndef _get_size(xs,i):\n    size = xs.sizes.get(i,None)\n    if size is None:\n        # Image hasn't been accessed yet, so we don't know its size\n        _ = xs[i]\n        size = xs.sizes[i]\n    return size\n\nclass ObjectCategoryList(MultiCategoryList):\n    \"`ItemList` for labelled bounding boxes.\"\n    _processor = ObjectCategoryProcessor\n\n    def get(self, i):\n        return ImageBBox.create(*_get_size(self.x,i), *self.items[i], classes=self.classes, pad_idx=self.pad_idx)\n\n    def analyze_pred(self, pred): return pred\n\n    def reconstruct(self, t, x):\n        (bboxes, labels) = t\n        if len((labels - self.pad_idx).nonzero()) == 0: return\n        i = (labels - self.pad_idx).nonzero().min()\n        bboxes,labels = bboxes[i:],labels[i:]\n        return ImageBBox.create(*x.size, bboxes, labels=labels, classes=self.classes, scale=False)\n\nclass ObjectItemList(ImageList):\n    \"`ItemList` suitable for object detection.\"\n    _label_cls,_square_show_res = ObjectCategoryList,False\n\nclass SegmentationProcessor(PreProcessor):\n    \"`PreProcessor` that stores the classes for segmentation.\"\n    def __init__(self, ds:ItemList): self.classes = ds.classes\n    def process(self, ds:ItemList):  ds.classes,ds.c = self.classes,len(self.classes)\n\nclass SegmentationLabelList(ImageList):\n    \"`ItemList` for segmentation masks.\"\n    _processor=SegmentationProcessor\n    def __init__(self, items:Iterator, classes:Collection=None, **kwargs):\n        super().__init__(items, **kwargs)\n        self.copy_new.append('classes')\n        self.classes,self.loss_func = classes,CrossEntropyFlat(axis=1)\n\n    def open(self, fn): return open_mask(fn)\n    def analyze_pred(self, pred, thresh:float=0.5): return pred.argmax(dim=0)[None]\n    def reconstruct(self, t:Tensor): return ImageSegment(t)\n\nclass SegmentationItemList(ImageList):\n    \"`ItemList` suitable for segmentation tasks.\"\n    _label_cls,_square_show_res = SegmentationLabelList,False\n\nclass PointsProcessor(PreProcessor):\n    \"`PreProcessor` that stores the number of targets for point regression.\"\n    def __init__(self, ds:ItemList): self.c = len(ds.items[0].reshape(-1))\n    def process(self, ds:ItemList):  ds.c = self.c\n\nclass PointsLabelList(ItemList):\n    \"`ItemList` for points.\"\n    _processor = PointsProcessor\n    def __init__(self, items:Iterator, **kwargs):\n        super().__init__(items, **kwargs)\n        self.loss_func = MSELossFlat()\n\n    def get(self, i):\n        o = super().get(i)\n        return ImagePoints(FlowField(_get_size(self.x,i), o), scale=True)\n\n    def analyze_pred(self, pred, thresh:float=0.5): return pred.view(-1,2)\n    def reconstruct(self, t, x): return ImagePoints(FlowField(x.size, t), scale=False)\n\nclass PointsItemList(ImageList):\n    \"`ItemList` for `Image` to `ImagePoints` tasks.\"\n    _label_cls,_square_show_res = PointsLabelList,False\n\nclass ImageImageList(ImageList):\n    \"`ItemList` suitable for `Image` to `Image` tasks.\"\n    _label_cls,_square_show,_square_show_res = ImageList,False,False\n\n    def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):\n        \"Show the `xs` (inputs) and `ys`(targets)  on a figure of `figsize`.\"\n        axs = subplots(len(xs), 2, imgsize=imgsize, figsize=figsize)\n        for i, (x,y) in enumerate(zip(xs,ys)):\n            x.show(ax=axs[i,0], **kwargs)\n            y.show(ax=axs[i,1], **kwargs)\n        plt.tight_layout()\n\n    def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):\n        \"Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`.\"\n        title = 'Input / Prediction / Target'\n        axs = subplots(len(xs), 3, imgsize=imgsize, figsize=figsize, title=title, weight='bold', size=14)\n        for i,(x,y,z) in enumerate(zip(xs,ys,zs)):\n            x.show(ax=axs[i,0], **kwargs)\n            y.show(ax=axs[i,2], **kwargs)\n            z.show(ax=axs[i,1], **kwargs)\n\n\ndef _ll_pre_transform(self, train_tfm:List[Callable], valid_tfm:List[Callable]):\n    \"Call `train_tfm` and `valid_tfm` after opening image, before converting from `PIL.Image`\"\n    self.train.x.after_open = compose(train_tfm)\n    self.valid.x.after_open = compose(valid_tfm)\n    return self\n\ndef _db_pre_transform(self, train_tfm:List[Callable], valid_tfm:List[Callable]):\n    \"Call `train_tfm` and `valid_tfm` after opening image, before converting from `PIL.Image`\"\n    self.train_ds.x.after_open = compose(train_tfm)\n    self.valid_ds.x.after_open = compose(valid_tfm)\n    return self\n\ndef _presize(self, size:int, val_xtra_size:int=32, scale:Tuple[float]=(0.08, 1.0), ratio:Tuple[float]=(0.75, 4./3.),\n             interpolation:int=2):\n    \"Resize images to `size` using `RandomResizedCrop`, passing along `kwargs` to train transform\"\n    return self.pre_transform(\n        tvt.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation), \n        [tvt.Resize(size+val_xtra_size), tvt.CenterCrop(size)])\n\nLabelLists.pre_transform = _ll_pre_transform\nDataBunch.pre_transform = _db_pre_transform\nLabelLists.presize = _presize\nDataBunch.presize = _presize\n\n"
  },
  {
    "path": "fastai/vision/gan.py",
    "content": "from ..torch_core import *\nfrom ..layers import *\nfrom ..callback import *\nfrom ..basic_data import *\nfrom ..basic_train import Learner, LearnerCallback\nfrom .image import Image\nfrom .data import ImageList\n\n__all__ = ['basic_critic', 'basic_generator', 'GANModule', 'GANLoss', 'GANTrainer', 'FixedGANSwitcher', 'AdaptiveGANSwitcher',\n           'GANLearner', 'NoisyItem', 'GANItemList', 'gan_critic', 'AdaptiveLoss', 'accuracy_thresh_expand',\n           'GANDiscriminativeLR']\n\ndef AvgFlatten():\n    \"Takes the average of the input.\"\n    return Lambda(lambda x: x.mean(0).view(1))\n\ndef basic_critic(in_size:int, n_channels:int, n_features:int=64, n_extra_layers:int=0, **conv_kwargs):\n    \"A basic critic for images `n_channels` x `in_size` x `in_size`.\"\n    layers = [conv_layer(n_channels, n_features, 4, 2, 1, leaky=0.2, norm_type=None, **conv_kwargs)]#norm_type=None?\n    cur_size, cur_ftrs = in_size//2, n_features\n    layers.append(nn.Sequential(*[conv_layer(cur_ftrs, cur_ftrs, 3, 1, leaky=0.2, **conv_kwargs) for _ in range(n_extra_layers)]))\n    while cur_size > 4:\n        layers.append(conv_layer(cur_ftrs, cur_ftrs*2, 4, 2, 1, leaky=0.2, **conv_kwargs))\n        cur_ftrs *= 2 ; cur_size //= 2\n    layers += [conv2d(cur_ftrs, 1, 4, padding=0), AvgFlatten()]\n    return nn.Sequential(*layers)\n\ndef basic_generator(in_size:int, n_channels:int, noise_sz:int=100, n_features:int=64, n_extra_layers=0, **conv_kwargs):\n    \"A basic generator from `noise_sz` to images `n_channels` x `in_size` x `in_size`.\"\n    cur_size, cur_ftrs = 4, n_features//2\n    while cur_size < in_size:  cur_size *= 2; cur_ftrs *= 2\n    layers = [conv_layer(noise_sz, cur_ftrs, 4, 1, transpose=True, **conv_kwargs)]\n    cur_size = 4\n    while cur_size < in_size // 2:\n        layers.append(conv_layer(cur_ftrs, cur_ftrs//2, 4, 2, 1, transpose=True, **conv_kwargs))\n        cur_ftrs //= 2; cur_size *= 2\n    layers += [conv_layer(cur_ftrs, cur_ftrs, 3, 1, 1, transpose=True, **conv_kwargs) for _ in range(n_extra_layers)]\n    layers += [conv2d_trans(cur_ftrs, n_channels, 4, 2, 1, bias=False), nn.Tanh()]\n    return nn.Sequential(*layers)\n\nclass GANModule(Module):\n    \"Wrapper around a `generator` and a `critic` to create a GAN.\"\n    def __init__(self, generator:nn.Module=None, critic:nn.Module=None, gen_mode:bool=False):\n        self.gen_mode = gen_mode\n        if generator: self.generator,self.critic = generator,critic\n\n    def forward(self, *args):\n        return self.generator(*args) if self.gen_mode else self.critic(*args)\n\n    def switch(self, gen_mode:bool=None):\n        \"Put the model in generator mode if `gen_mode`, in critic mode otherwise.\"\n        self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode\n\nclass GANLoss(GANModule):\n    \"Wrapper around `loss_funcC` (for the critic) and `loss_funcG` (for the generator).\"\n    def __init__(self, loss_funcG:Callable, loss_funcC:Callable, gan_model:GANModule):\n        super().__init__()\n        self.loss_funcG,self.loss_funcC,self.gan_model = loss_funcG,loss_funcC,gan_model\n\n    def generator(self, output, target):\n        \"Evaluate the `output` with the critic then uses `self.loss_funcG` to combine it with `target`.\"\n        fake_pred = self.gan_model.critic(output)\n        return self.loss_funcG(fake_pred, target, output)\n\n    def critic(self, real_pred, input):\n        \"Create some `fake_pred` with the generator from `input` and compare them to `real_pred` in `self.loss_funcD`.\"\n        fake = self.gan_model.generator(input.requires_grad_(False)).requires_grad_(True)\n        fake_pred = self.gan_model.critic(fake)\n        return self.loss_funcC(real_pred, fake_pred)\n\nclass GANTrainer(LearnerCallback):\n    \"Handles GAN Training.\"\n    _order=-20\n    def __init__(self, learn:Learner, switch_eval:bool=False, clip:float=None, beta:float=0.98, gen_first:bool=False,\n                 show_img:bool=True):\n        super().__init__(learn)\n        self.switch_eval,self.clip,self.beta,self.gen_first,self.show_img = switch_eval,clip,beta,gen_first,show_img\n        self.generator,self.critic = self.model.generator,self.model.critic\n\n    def _set_trainable(self):\n        train_model = self.generator if     self.gen_mode else self.critic\n        loss_model  = self.generator if not self.gen_mode else self.critic\n        requires_grad(train_model, True)\n        requires_grad(loss_model, False)\n        if self.switch_eval:\n            train_model.train()\n            loss_model.eval()\n\n    def on_train_begin(self, **kwargs):\n        \"Create the optimizers for the generator and critic if necessary, initialize smootheners.\"\n        if not getattr(self,'opt_gen',None):\n            self.opt_gen = self.opt.new([nn.Sequential(*flatten_model(self.generator))])\n        else: self.opt_gen.lr,self.opt_gen.wd = self.opt.lr,self.opt.wd\n        if not getattr(self,'opt_critic',None):\n            self.opt_critic = self.opt.new([nn.Sequential(*flatten_model(self.critic))])\n        else: self.opt_critic.lr,self.opt_critic.wd = self.opt.lr,self.opt.wd\n        self.gen_mode = self.gen_first\n        self.switch(self.gen_mode)\n        self.closses,self.glosses = [],[]\n        self.smoothenerG,self.smoothenerC = SmoothenValue(self.beta),SmoothenValue(self.beta)\n        #self.recorder.no_val=True\n        self.recorder.add_metric_names(['gen_loss', 'disc_loss'])\n        self.imgs,self.titles = [],[]\n\n    def on_train_end(self, **kwargs):\n        \"Switch in generator mode for showing results.\"\n        self.switch(gen_mode=True)\n\n    def on_batch_begin(self, last_input, last_target, **kwargs):\n        \"Clamp the weights with `self.clip` if it's not None, return the correct input.\"\n        if self.clip is not None:\n            for p in self.critic.parameters(): p.data.clamp_(-self.clip, self.clip)\n        return {'last_input':last_input,'last_target':last_target} if self.gen_mode else {'last_input':last_target,'last_target':last_input}\n\n    def on_backward_begin(self, last_loss, last_output, **kwargs):\n        \"Record `last_loss` in the proper list.\"\n        last_loss = last_loss.detach().cpu()\n        if self.gen_mode:\n            self.smoothenerG.add_value(last_loss)\n            self.glosses.append(self.smoothenerG.smooth)\n            self.last_gen = last_output.detach().cpu()\n        else:\n            self.smoothenerC.add_value(last_loss)\n            self.closses.append(self.smoothenerC.smooth)\n\n    def on_epoch_begin(self, epoch, **kwargs):\n        \"Put the critic or the generator back to eval if necessary.\"\n        self.switch(self.gen_mode)\n\n    def on_epoch_end(self, pbar, epoch, last_metrics, **kwargs):\n        \"Put the various losses in the recorder and show a sample image.\"\n        if not hasattr(self, 'last_gen') or not self.show_img: return\n        data = self.learn.data\n        img = self.last_gen[0]\n        norm = getattr(data,'norm',False)\n        if norm and norm.keywords.get('do_y',False): img = data.denorm(img)\n        img = data.train_ds.y.reconstruct(img)\n        self.imgs.append(img)\n        self.titles.append(f'Epoch {epoch}')\n        pbar.show_imgs(self.imgs, self.titles)\n        return add_metrics(last_metrics, [getattr(self.smoothenerG,'smooth',None),getattr(self.smoothenerC,'smooth',None)])\n\n    def switch(self, gen_mode:bool=None):\n        \"Switch the model, if `gen_mode` is provided, in the desired mode.\"\n        self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode\n        self.opt.opt = self.opt_gen.opt if self.gen_mode else self.opt_critic.opt\n        self._set_trainable()\n        self.model.switch(gen_mode)\n        self.loss_func.switch(gen_mode)\n\nclass FixedGANSwitcher(LearnerCallback):\n    \"Switcher to do `n_crit` iterations of the critic then `n_gen` iterations of the generator.\"\n    def __init__(self, learn:Learner, n_crit:Union[int,Callable]=1, n_gen:Union[int,Callable]=1):\n        super().__init__(learn)\n        self.n_crit,self.n_gen = n_crit,n_gen\n\n    def on_train_begin(self, **kwargs):\n        \"Initiate the iteration counts.\"\n        self.n_c,self.n_g = 0,0\n\n    def on_batch_end(self, iteration, **kwargs):\n        \"Switch the model if necessary.\"\n        if self.learn.gan_trainer.gen_mode:\n            self.n_g += 1\n            n_iter,n_in,n_out = self.n_gen,self.n_c,self.n_g\n        else:\n            self.n_c += 1\n            n_iter,n_in,n_out = self.n_crit,self.n_g,self.n_c\n        target = n_iter if isinstance(n_iter, int) else n_iter(n_in)\n        if target == n_out:\n            self.learn.gan_trainer.switch()\n            self.n_c,self.n_g = 0,0\n\n@dataclass\nclass AdaptiveGANSwitcher(LearnerCallback):\n    \"Switcher that goes back to generator/critic when the loss goes below `gen_thresh`/`crit_thresh`.\"\n    def __init__(self, learn:Learner, gen_thresh:float=None, critic_thresh:float=None):\n        super().__init__(learn)\n        self.gen_thresh,self.critic_thresh = gen_thresh,critic_thresh\n\n    def on_batch_end(self, last_loss, **kwargs):\n        \"Switch the model if necessary.\"\n        if self.gan_trainer.gen_mode:\n            if self.gen_thresh  is None:      self.gan_trainer.switch()\n            elif last_loss < self.gen_thresh: self.gan_trainer.switch()\n        else:\n            if self.critic_thresh is None:       self.gan_trainer.switch()\n            elif last_loss < self.critic_thresh: self.gan_trainer.switch()\n\ndef gan_loss_from_func(loss_gen, loss_crit, weights_gen:Tuple[float,float]=None):\n    \"Define loss functions for a GAN from `loss_gen` and `loss_crit`.\"\n    def _loss_G(fake_pred, output, target, weights_gen=weights_gen):\n        ones = fake_pred.new_ones(fake_pred.shape[0])\n        weights_gen = ifnone(weights_gen, (1.,1.))\n        return weights_gen[0] * loss_crit(fake_pred, ones) + weights_gen[1] * loss_gen(output, target)\n\n    def _loss_C(real_pred, fake_pred):\n        ones  = real_pred.new_ones (real_pred.shape[0])\n        zeros = fake_pred.new_zeros(fake_pred.shape[0])\n        return (loss_crit(real_pred, ones) + loss_crit(fake_pred, zeros)) / 2\n\n    return _loss_G, _loss_C\n\nclass GANLearner(Learner):\n    \"A `Learner` suitable for GANs.\"\n    def __init__(self, data:DataBunch, generator:nn.Module, critic:nn.Module, gen_loss_func:LossFunction,\n                 crit_loss_func:LossFunction, switcher:Callback=None, gen_first:bool=False, switch_eval:bool=True,\n                 show_img:bool=True, clip:float=None, **learn_kwargs):\n        gan = GANModule(generator, critic)\n        loss_func = GANLoss(gen_loss_func, crit_loss_func, gan)\n        switcher = ifnone(switcher, partial(FixedGANSwitcher, n_crit=5, n_gen=1))\n        super().__init__(data, gan, loss_func=loss_func, callback_fns=[switcher], **learn_kwargs)\n        trainer = GANTrainer(self, clip=clip, switch_eval=switch_eval, show_img=show_img)\n        self.gan_trainer = trainer\n        self.callbacks.append(trainer)\n\n    @classmethod\n    def from_learners(cls, learn_gen:Learner, learn_crit:Learner, switcher:Callback=None,\n                      weights_gen:Tuple[float,float]=None, **learn_kwargs):\n        \"Create a GAN from `learn_gen` and `learn_crit`.\"\n        losses = gan_loss_from_func(learn_gen.loss_func, learn_crit.loss_func, weights_gen=weights_gen)\n        return cls(learn_gen.data, learn_gen.model, learn_crit.model, *losses, switcher=switcher, **learn_kwargs)\n\n    @classmethod\n    def wgan(cls, data:DataBunch, generator:nn.Module, critic:nn.Module, switcher:Callback=None, clip:float=0.01, **learn_kwargs):\n        \"Create a WGAN from `data`, `generator` and `critic`.\"\n        return cls(data, generator, critic, NoopLoss(), WassersteinLoss(), switcher=switcher, clip=clip, **learn_kwargs)\n\nclass NoisyItem(ItemBase):\n    \"An random `ItemBase` of size `noise_sz`.\"\n    def __init__(self, noise_sz): self.obj,self.data = noise_sz,torch.randn(noise_sz, 1, 1)\n    def __str__(self):  return ''\n    def apply_tfms(self, tfms, **kwargs): return self\n\nclass GANItemList(ImageList):\n    \"`ItemList` suitable for GANs.\"\n    _label_cls = ImageList\n\n    def __init__(self, items, noise_sz:int=100, **kwargs):\n        super().__init__(items, **kwargs)\n        self.noise_sz = noise_sz\n        self.copy_new.append('noise_sz')\n\n    def get(self, i): return NoisyItem(self.noise_sz)\n    def reconstruct(self, t): return NoisyItem(t.size(0))\n\n    def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):\n        \"Shows `ys` (target images) on a figure of `figsize`.\"\n        super().show_xys(ys, xs, imgsize=imgsize, figsize=figsize, **kwargs)\n\n    def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):\n        \"Shows `zs` (generated images) on a figure of `figsize`.\"\n        super().show_xys(zs, xs, imgsize=imgsize, figsize=figsize, **kwargs)\n\n_conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)\n\ndef _conv(ni:int, nf:int, ks:int=3, stride:int=1, **kwargs):\n    return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)\n\ndef gan_critic(n_channels:int=3, nf:int=128, n_blocks:int=3, p:int=0.15):\n    \"Critic to train a `GAN`.\"\n    layers = [\n        _conv(n_channels, nf, ks=4, stride=2),\n        nn.Dropout2d(p/2),\n        res_block(nf, dense=True,**_conv_args)]\n    nf *= 2 # after dense block\n    for i in range(n_blocks):\n        layers += [\n            nn.Dropout2d(p),\n            _conv(nf, nf*2, ks=4, stride=2, self_attention=(i==0))]\n        nf *= 2\n    layers += [\n        _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),\n        Flatten()]\n    return nn.Sequential(*layers)\n\nclass GANDiscriminativeLR(LearnerCallback):\n    \"`Callback` that handles multiplying the learning rate by `mult_lr` for the critic.\"\n    def __init__(self, learn:Learner, mult_lr:float = 5.):\n        super().__init__(learn)\n        self.mult_lr = mult_lr\n\n    def on_batch_begin(self, train, **kwargs):\n        \"Multiply the current lr if necessary.\"\n        if not self.learn.gan_trainer.gen_mode and train: self.learn.opt.lr *= self.mult_lr\n\n    def on_step_end(self, **kwargs):\n        \"Put the LR back to its value if necessary.\"\n        if not self.learn.gan_trainer.gen_mode: self.learn.opt.lr /= self.mult_lr\n\nclass AdaptiveLoss(Module):\n    \"Expand the `target` to match the `output` size before applying `crit`.\"\n    def __init__(self, crit):\n        self.crit = crit\n\n    def forward(self, output, target):\n        return self.crit(output, target[:,None].expand_as(output).float())\n\ndef accuracy_thresh_expand(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True)->Rank0Tensor:\n    \"Compute accuracy after expanding `y_true` to the size of `y_pred`.\"\n    if sigmoid: y_pred = y_pred.sigmoid()\n    return ((y_pred>thresh)==y_true[:,None].expand_as(y_pred).byte()).float().mean()\n"
  },
  {
    "path": "fastai/vision/image.py",
    "content": "\"`Image` provides support to convert, transform and show images\"\nfrom ..torch_core import *\nfrom ..basic_data import *\nfrom ..layers import MSELossFlat\nfrom io import BytesIO\nimport PIL\n\n__all__ = ['PIL', 'Image', 'ImageBBox', 'ImageSegment', 'ImagePoints', 'FlowField', 'RandTransform', 'TfmAffine', 'TfmCoord',\n           'TfmCrop', 'TfmLighting', 'TfmPixel', 'Transform', 'bb2hw', 'image2np', 'open_image', 'open_mask', 'tis2hw',\n           'pil2tensor', 'scale_flow', 'show_image', 'CoordFunc', 'TfmList', 'open_mask_rle', 'rle_encode',\n           'rle_decode', 'ResizeMethod', 'plot_flat', 'plot_multi', 'show_multi', 'show_all']\n\nResizeMethod = IntEnum('ResizeMethod', 'CROP PAD SQUISH NO')\ndef pil2tensor(image:Union[NPImage,NPArray],dtype:np.dtype)->TensorImage:\n    \"Convert PIL style `image` array to torch style image tensor.\"\n    a = np.asarray(image)\n    if a.ndim==2 : a = np.expand_dims(a,2)\n    a = np.transpose(a, (1, 0, 2))\n    a = np.transpose(a, (2, 1, 0))\n    return torch.from_numpy(a.astype(dtype, copy=False) )\n\ndef image2np(image:Tensor)->np.ndarray:\n    \"Convert from torch style `image` to numpy/matplotlib style.\"\n    res = image.cpu().permute(1,2,0).numpy()\n    return res[...,0] if res.shape[2]==1 else res\n\ndef bb2hw(a:Collection[int])->np.ndarray:\n    \"Convert bounding box points from (width,height,center) to (height,width,top,left).\"\n    return np.array([a[1],a[0],a[3]-a[1],a[2]-a[0]])\n\ndef tis2hw(size:Union[int,TensorImageSize]) -> Tuple[int,int]:\n    \"Convert `int` or `TensorImageSize` to (height,width) of an image.\"\n    if type(size) is str: raise RuntimeError(\"Expected size to be an int or a tuple, got a string.\")\n    return listify(size, 2) if isinstance(size, int) else listify(size[-2:],2)\n\ndef _draw_outline(o:Patch, lw:int):\n    \"Outline bounding box onto image `Patch`.\"\n    o.set_path_effects([patheffects.Stroke(\n        linewidth=lw, foreground='black'), patheffects.Normal()])\n\ndef _draw_rect(ax:plt.Axes, b:Collection[int], color:str='white', text=None, text_size=14):\n    \"Draw bounding box on `ax`.\"\n    patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor=color, lw=2))\n    _draw_outline(patch, 4)\n    if text is not None:\n        patch = ax.text(*b[:2], text, verticalalignment='top', color=color, fontsize=text_size, weight='bold')\n        _draw_outline(patch,1)\n\ndef _get_default_args(func:Callable):\n    return {k: v.default\n            for k, v in inspect.signature(func).parameters.items()\n            if v.default is not inspect.Parameter.empty}\n\n@dataclass\nclass FlowField():\n    \"Wrap together some coords `flow` with a `size`.\"\n    size:Tuple[int,int]\n    flow:Tensor\n\nCoordFunc = Callable[[FlowField, ArgStar, KWArgs], LogitTensorImage]\n\nclass Image(ItemBase):\n    \"Support applying transforms to image data in `px`.\"\n    def __init__(self, px:Tensor):\n        self._px = px\n        self._logit_px=None\n        self._flow=None\n        self._affine_mat=None\n        self.sample_kwargs = {}\n\n    def set_sample(self, **kwargs)->'ImageBase':\n        \"Set parameters that control how we `grid_sample` the image after transforms are applied.\"\n        self.sample_kwargs = kwargs\n        return self\n\n    def clone(self):\n        \"Mimic the behavior of torch.clone for `Image` objects.\"\n        return self.__class__(self.px.clone())\n\n    @property\n    def shape(self)->Tuple[int,int,int]: return self._px.shape\n    @property\n    def size(self)->Tuple[int,int]: return self.shape[-2:]\n    @property\n    def device(self)->torch.device: return self._px.device\n\n    def __repr__(self): return f'{self.__class__.__name__} {tuple(self.shape)}'\n    def _repr_png_(self): return self._repr_image_format('png')\n    def _repr_jpeg_(self): return self._repr_image_format('jpeg')\n\n    def _repr_image_format(self, format_str):\n        with BytesIO() as str_buffer:\n            plt.imsave(str_buffer, image2np(self.px), format=format_str)\n            return str_buffer.getvalue()\n\n    def apply_tfms(self, tfms:TfmList, do_resolve:bool=True, xtra:Optional[Dict[Callable,dict]]=None,\n                   size:Optional[Union[int,TensorImageSize]]=None, resize_method:ResizeMethod=None,\n                   mult:int=None, padding_mode:str='reflection', mode:str='bilinear', remove_out:bool=True,\n                   is_x:bool=True, x_frames:int=1, y_frames:int=1)->TensorImage:\n        \"Apply all `tfms` to the `Image`, if `do_resolve` picks value for random args.\"\n        if not (tfms or xtra or size): return self\n\n        if size is not None and isinstance(size, int):\n            num_frames = x_frames if is_x else y_frames\n            if num_frames > 1:\n                size = (size, size*num_frames)\n\n        tfms = listify(tfms)\n        xtra = ifnone(xtra, {})\n        default_rsz = ResizeMethod.SQUISH if (size is not None and is_listy(size)) else ResizeMethod.CROP\n        resize_method = ifnone(resize_method, default_rsz)\n        if resize_method <= 2 and size is not None: tfms = self._maybe_add_crop_pad(tfms)\n        tfms = sorted(tfms, key=lambda o: o.tfm.order)\n        if do_resolve: _resolve_tfms(tfms)\n        x = self.clone()\n        x.set_sample(padding_mode=padding_mode, mode=mode, remove_out=remove_out)\n        if size is not None:\n            crop_target = _get_crop_target(size, mult=mult)\n            if resize_method in (ResizeMethod.CROP,ResizeMethod.PAD):\n                target = _get_resize_target(x, crop_target, do_crop=(resize_method==ResizeMethod.CROP))\n                x.resize(target)\n            elif resize_method==ResizeMethod.SQUISH: x.resize((x.shape[0],) + crop_target)\n        else: size = x.size\n        size_tfms = [o for o in tfms if isinstance(o.tfm,TfmCrop)]\n        for tfm in tfms:\n            if tfm.tfm in xtra: x = tfm(x, **xtra[tfm.tfm])\n            elif tfm in size_tfms:\n                if resize_method in (ResizeMethod.CROP,ResizeMethod.PAD):\n                    x = tfm(x, size=_get_crop_target(size,mult=mult), padding_mode=padding_mode)\n            else: x = tfm(x)\n        return x.refresh()\n\n    def refresh(self)->None:\n        \"Apply any logit, flow, or affine transfers that have been sent to the `Image`.\"\n        if self._logit_px is not None:\n            self._px = self._logit_px.sigmoid_()\n            self._logit_px = None\n        if self._affine_mat is not None or self._flow is not None:\n            self._px = _grid_sample(self._px, self.flow, **self.sample_kwargs)\n            self.sample_kwargs = {}\n            self._flow = None\n        return self\n\n    def save(self, fn:PathOrStr):\n        \"Save the image to `fn`.\"\n        x = image2np(self.data*255).astype(np.uint8)\n        PIL.Image.fromarray(x).save(fn)\n\n    @property\n    def px(self)->TensorImage:\n        \"Get the tensor pixel buffer.\"\n        self.refresh()\n        return self._px\n    @px.setter\n    def px(self,v:TensorImage)->None:\n        \"Set the pixel buffer to `v`.\"\n        self._px=v\n\n    @property\n    def flow(self)->FlowField:\n        \"Access the flow-field grid after applying queued affine transforms.\"\n        if self._flow is None:\n            self._flow = _affine_grid(self.shape)\n        if self._affine_mat is not None:\n            self._flow = _affine_mult(self._flow,self._affine_mat)\n            self._affine_mat = None\n        return self._flow\n\n    @flow.setter\n    def flow(self,v:FlowField): self._flow=v\n\n    def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any):\n        \"Equivalent to `image = sigmoid(func(logit(image)))`.\"\n        self.logit_px = func(self.logit_px, *args, **kwargs)\n        return self\n\n    def pixel(self, func:PixelFunc, *args, **kwargs)->'Image':\n        \"Equivalent to `image.px = func(image.px)`.\"\n        self.px = func(self.px, *args, **kwargs)\n        return self\n\n    def coord(self, func:CoordFunc, *args, **kwargs)->'Image':\n        \"Equivalent to `image.flow = func(image.flow, image.size)`.\"\n        self.flow = func(self.flow, *args, **kwargs)\n        return self\n\n    def affine(self, func:AffineFunc, *args, **kwargs)->'Image':\n        \"Equivalent to `image.affine_mat = image.affine_mat @ func()`.\"\n        m = tensor(func(*args, **kwargs)).to(self.device)\n        self.affine_mat = self.affine_mat @ m\n        return self\n\n    def resize(self, size:Union[int,TensorImageSize])->'Image':\n        \"Resize the image to `size`, size can be a single int.\"\n        assert self._flow is None\n        if isinstance(size, int): size=(self.shape[0], size, size)\n        if tuple(size)==tuple(self.shape): return self\n        self.flow = _affine_grid(size)\n        return self\n\n    @property\n    def affine_mat(self)->AffineMatrix:\n        \"Get the affine matrix that will be applied by `refresh`.\"\n        if self._affine_mat is None:\n            self._affine_mat = torch.eye(3).to(self.device)\n        return self._affine_mat\n    @affine_mat.setter\n    def affine_mat(self,v)->None: self._affine_mat=v\n\n    @property\n    def logit_px(self)->LogitTensorImage:\n        \"Get logit(image.px).\"\n        if self._logit_px is None: self._logit_px = logit_(self.px)\n        return self._logit_px\n    @logit_px.setter\n    def logit_px(self,v:LogitTensorImage)->None: self._logit_px=v\n\n    @property\n    def data(self)->TensorImage:\n        \"Return this images pixels as a tensor.\"\n        return self.px\n\n    def show(self, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True,\n              cmap:str=None, y:Any=None, **kwargs):\n        \"Show image on `ax` with `title`, using `cmap` if single-channel, overlaid with optional `y`\"\n        cmap = ifnone(cmap, defaults.cmap)\n        ax = show_image(self, ax=ax, hide_axis=hide_axis, cmap=cmap, figsize=figsize)\n        if y is not None: y.show(ax=ax, **kwargs)\n        if title is not None: ax.set_title(title)\n\nclass ImageSegment(Image):\n    \"Support applying transforms to segmentation masks data in `px`.\"\n    def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any)->'Image': return self\n\n    def refresh(self):\n        self.sample_kwargs['mode'] = 'nearest'\n        return super().refresh()\n\n    @property\n    def data(self)->TensorImage:\n        \"Return this image pixels as a `LongTensor`.\"\n        return self.px.long()\n\n    def show(self, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True,\n        cmap:str='tab20', alpha:float=0.5, **kwargs):\n        \"Show the `ImageSegment` on `ax`.\"\n        ax = show_image(self, ax=ax, hide_axis=hide_axis, cmap=cmap, figsize=figsize,\n                        interpolation='nearest', alpha=alpha, vmin=0, **kwargs)\n        if title: ax.set_title(title)\n\n    def reconstruct(self, t:Tensor): return ImageSegment(t)\n\nclass ImagePoints(Image):\n    \"Support applying transforms to a `flow` of points.\"\n    def __init__(self, flow:FlowField, scale:bool=True, y_first:bool=True):\n        if scale: flow = scale_flow(flow)\n        if y_first: flow.flow = flow.flow.flip(1)\n        self._flow = flow\n        self._affine_mat = None\n        self.flow_func = []\n        self.sample_kwargs = {}\n        self.transformed = False\n        self.loss_func = MSELossFlat()\n\n    def clone(self):\n        \"Mimic the behavior of torch.clone for `ImagePoints` objects.\"\n        return self.__class__(FlowField(self.size, self.flow.flow.clone()), scale=False, y_first=False)\n\n    @property\n    def shape(self)->Tuple[int,int,int]: return (1, *self._flow.size)\n    @property\n    def size(self)->Tuple[int,int]: return self._flow.size\n    @size.setter\n    def size(self, sz:int): self._flow.size=sz\n    @property\n    def device(self)->torch.device: return self._flow.flow.device\n\n    def __repr__(self): return f'{self.__class__.__name__} {tuple(self.size)}'\n    def _repr_image_format(self, format_str): return None\n\n    @property\n    def flow(self)->FlowField:\n        \"Access the flow-field grid after applying queued affine and coord transforms.\"\n        if self._affine_mat is not None:\n            self._flow = _affine_inv_mult(self._flow, self._affine_mat)\n            self._affine_mat = None\n            self.transformed = True\n        if len(self.flow_func) != 0:\n            for f in self.flow_func[::-1]: self._flow = f(self._flow)\n            self.transformed = True\n            self.flow_func = []\n        return self._flow\n\n    @flow.setter\n    def flow(self,v:FlowField):  self._flow=v\n\n    def coord(self, func:CoordFunc, *args, **kwargs)->'ImagePoints':\n        \"Put `func` with `args` and `kwargs` in `self.flow_func` for later.\"\n        if 'invert' in kwargs: kwargs['invert'] = True\n        else: warn(f\"{func.__name__} isn't implemented for {self.__class__}.\")\n        self.flow_func.append(partial(func, *args, **kwargs))\n        return self\n\n    def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any)->'ImagePoints': return self\n\n    def pixel(self, func:PixelFunc, *args, **kwargs)->'ImagePoints':\n        \"Equivalent to `self = func_flow(self)`.\"\n        self = func(self, *args, **kwargs)\n        self.transformed=True\n        return self\n\n    def refresh(self) -> 'ImagePoints':\n        return self\n\n    def resize(self, size:Union[int,TensorImageSize]) -> 'ImagePoints':\n        \"Resize the image to `size`, size can be a single int.\"\n        if isinstance(size, int): size=(1, size, size)\n        self._flow.size = size[1:]\n        return self\n\n    @property\n    def data(self)->Tensor:\n        \"Return the points associated to this object.\"\n        flow = self.flow #This updates flow before we test if some transforms happened\n        if self.transformed:\n            if 'remove_out' not in self.sample_kwargs or self.sample_kwargs['remove_out']:\n                flow = _remove_points_out(flow)\n            self.transformed=False\n        return flow.flow.flip(1)\n\n    def show(self, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True, **kwargs):\n        \"Show the `ImagePoints` on `ax`.\"\n        if ax is None: _,ax = plt.subplots(figsize=figsize)\n        pnt = scale_flow(FlowField(self.size, self.data), to_unit=False).flow.flip(1)\n        params = {'s': 10, 'marker': '.', 'c': 'r', **kwargs}\n        ax.scatter(pnt[:, 0], pnt[:, 1], **params)\n        if hide_axis: ax.axis('off')\n        if title: ax.set_title(title)\n\nclass ImageBBox(ImagePoints):\n    \"Support applying transforms to a `flow` of bounding boxes.\"\n    def __init__(self, flow:FlowField, scale:bool=True, y_first:bool=True, labels:Collection=None,\n                 classes:dict=None, pad_idx:int=0):\n        super().__init__(flow, scale, y_first)\n        self.pad_idx = pad_idx\n        if labels is not None and len(labels)>0 and not isinstance(labels[0],Category):\n            labels = array([Category(l,classes[l]) for l in labels])\n        self.labels = labels\n\n    def clone(self) -> 'ImageBBox':\n        \"Mimic the behavior of torch.clone for `Image` objects.\"\n        flow = FlowField(self.size, self.flow.flow.clone())\n        return self.__class__(flow, scale=False, y_first=False, labels=self.labels, pad_idx=self.pad_idx)\n\n    @classmethod\n    def create(cls, h:int, w:int, bboxes:Collection[Collection[int]], labels:Collection=None, classes:dict=None,\n               pad_idx:int=0, scale:bool=True)->'ImageBBox':\n        \"Create an ImageBBox object from `bboxes`.\"\n        if isinstance(bboxes, np.ndarray) and bboxes.dtype == np.object: bboxes = np.array([bb for bb in bboxes])\n        bboxes = tensor(bboxes).float()\n        tr_corners = torch.cat([bboxes[:,0][:,None], bboxes[:,3][:,None]], 1)\n        bl_corners = bboxes[:,1:3].flip(1)\n        bboxes = torch.cat([bboxes[:,:2], tr_corners, bl_corners, bboxes[:,2:]], 1)\n        flow = FlowField((h,w), bboxes.view(-1,2))\n        return cls(flow, labels=labels, classes=classes, pad_idx=pad_idx, y_first=True, scale=scale)\n\n    def _compute_boxes(self) -> Tuple[LongTensor, LongTensor]:\n        bboxes = self.flow.flow.flip(1).view(-1, 4, 2).contiguous().clamp(min=-1, max=1)\n        mins, maxes = bboxes.min(dim=1)[0], bboxes.max(dim=1)[0]\n        bboxes = torch.cat([mins, maxes], 1)\n        mask = (bboxes[:,2]-bboxes[:,0] > 0) * (bboxes[:,3]-bboxes[:,1] > 0)\n        if len(mask) == 0: return tensor([self.pad_idx] * 4), tensor([self.pad_idx])\n        res = bboxes[mask]\n        if self.labels is None: return res,None\n        return res, self.labels[to_np(mask).astype(bool)]\n\n    @property\n    def data(self)->Union[FloatTensor, Tuple[FloatTensor,LongTensor]]:\n        bboxes,lbls = self._compute_boxes()\n        lbls = np.array([o.data for o in lbls]) if lbls is not None else None\n        return bboxes if lbls is None else (bboxes, lbls)\n\n    def show(self, y:Image=None, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True,\n        color:str='white', **kwargs):\n        \"Show the `ImageBBox` on `ax`.\"\n        if ax is None: _,ax = plt.subplots(figsize=figsize)\n        bboxes, lbls = self._compute_boxes()\n        h,w = self.flow.size\n        bboxes.add_(1).mul_(torch.tensor([h/2, w/2, h/2, w/2])).long()\n        for i, bbox in enumerate(bboxes):\n            if lbls is not None: text = str(lbls[i])\n            else: text=None\n            _draw_rect(ax, bb2hw(bbox), text=text, color=color)\n\ndef open_image(fn:PathOrStr, div:bool=True, convert_mode:str='RGB', cls:type=Image,\n        after_open:Callable=None)->Image:\n    \"Return `Image` object created from image in file `fn`.\"\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\", UserWarning) # EXIF warning from TiffPlugin\n        x = PIL.Image.open(fn).convert(convert_mode)\n    if after_open: x = after_open(x)\n    x = pil2tensor(x,np.float32)\n    if div: x.div_(255)\n    return cls(x)\n\ndef open_mask(fn:PathOrStr, div=False, convert_mode='L', after_open:Callable=None)->ImageSegment:\n    \"Return `ImageSegment` object create from mask in file `fn`. If `div`, divides pixel values by 255.\"\n    return open_image(fn, div=div, convert_mode=convert_mode, cls=ImageSegment, after_open=after_open)\n\ndef open_mask_rle(mask_rle:str, shape:Tuple[int, int])->ImageSegment:\n    \"Return `ImageSegment` object create from run-length encoded string in `mask_lre` with size in `shape`.\"\n    x = FloatTensor(rle_decode(str(mask_rle), shape).astype(np.uint8))\n    x = x.view(shape[1], shape[0], -1)\n    return ImageSegment(x.permute(2,0,1))\n\ndef rle_encode(img:NPArrayMask)->str:\n    \"Return run-length encoding string from `img`.\"\n    pixels = np.concatenate([[0], img.flatten() , [0]])\n    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1\n    runs[1::2] -= runs[::2]\n    return ' '.join(str(x) for x in runs)\n\ndef rle_decode(mask_rle:str, shape:Tuple[int,int])->NPArrayMask:\n    \"Return an image array from run-length encoded string `mask_rle` with `shape`.\"\n    s = mask_rle.split()\n    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]\n    starts -= 1\n    ends = starts + lengths\n    img = np.zeros(shape[0]*shape[1], dtype=np.uint)\n    for low, up in zip(starts, ends): img[low:up] = 1\n    return img.reshape(shape)\n\ndef show_image(img:Image, ax:plt.Axes=None, figsize:tuple=(3,3), hide_axis:bool=True, cmap:str='binary',\n                alpha:float=None, **kwargs)->plt.Axes:\n    \"Display `Image` in notebook.\"\n    if ax is None: fig,ax = plt.subplots(figsize=figsize)\n    ax.imshow(image2np(img.data), cmap=cmap, alpha=alpha, **kwargs)\n    if hide_axis: ax.axis('off')\n    return ax\n\ndef scale_flow(flow, to_unit=True):\n    \"Scale the coords in `flow` to -1/1 or the image size depending on `to_unit`.\"\n    s = tensor([flow.size[0]/2,flow.size[1]/2])[None]\n    if to_unit: flow.flow = flow.flow/s-1\n    else:       flow.flow = (flow.flow+1)*s\n    return flow\n\ndef _remove_points_out(flow:FlowField):\n    pad_mask = (flow.flow[:,0] >= -1) * (flow.flow[:,0] <= 1) * (flow.flow[:,1] >= -1) * (flow.flow[:,1] <= 1)\n    flow.flow = flow.flow[pad_mask]\n    return flow\n\nclass Transform():\n    \"Utility class for adding probability and wrapping support to transform `func`.\"\n    _wrap=None\n    order=0\n    def __init__(self, func:Callable, order:Optional[int]=None):\n        \"Create a transform for `func` and assign it an priority `order`, attach to `Image` class.\"\n        if order is not None: self.order=order\n        self.func=func\n        self.func.__name__ = func.__name__[1:] #To remove the _ that begins every transform function.\n        functools.update_wrapper(self, self.func)\n        self.func.__annotations__['return'] = Image\n        self.params = copy(func.__annotations__)\n        self.def_args = _get_default_args(func)\n        setattr(Image, func.__name__,\n                lambda x, *args, **kwargs: self.calc(x, *args, **kwargs))\n\n    def __call__(self, *args:Any, p:float=1., is_random:bool=True, use_on_y:bool=True, **kwargs:Any)->Image:\n        \"Calc now if `args` passed; else create a transform called prob `p` if `random`.\"\n        if args: return self.calc(*args, **kwargs)\n        else: return RandTransform(self, kwargs=kwargs, is_random=is_random, use_on_y=use_on_y, p=p)\n\n    def calc(self, x:Image, *args:Any, **kwargs:Any)->Image:\n        \"Apply to image `x`, wrapping it if necessary.\"\n        if self._wrap: return getattr(x, self._wrap)(self.func, *args, **kwargs)\n        else:          return self.func(x, *args, **kwargs)\n\n    @property\n    def name(self)->str: return self.__class__.__name__\n\n    def __repr__(self)->str: return f'{self.name} ({self.func.__name__})'\n\n@dataclass\nclass RandTransform():\n    \"Wrap `Transform` to add randomized execution.\"\n    tfm:Transform\n    kwargs:dict\n    p:float=1.0\n    resolved:dict = field(default_factory=dict)\n    do_run:bool = True\n    is_random:bool = True\n    use_on_y:bool = True\n    def __post_init__(self): functools.update_wrapper(self, self.tfm)\n\n    def resolve(self)->None:\n        \"Bind any random variables in the transform.\"\n        if not self.is_random:\n            self.resolved = {**self.tfm.def_args, **self.kwargs}\n            return\n\n        self.resolved = {}\n        # for each param passed to tfm...\n        for k,v in self.kwargs.items():\n            # ...if it's annotated, call that fn...\n            if k in self.tfm.params:\n                rand_func = self.tfm.params[k]\n                self.resolved[k] = rand_func(*listify(v))\n            # ...otherwise use the value directly\n            else: self.resolved[k] = v\n        # use defaults for any args not filled in yet\n        for k,v in self.tfm.def_args.items():\n            if k not in self.resolved: self.resolved[k]=v\n        # anything left over must be callable without params\n        for k,v in self.tfm.params.items():\n            if k not in self.resolved and k!='return': self.resolved[k]=v()\n\n        self.do_run = rand_bool(self.p)\n\n    @property\n    def order(self)->int: return self.tfm.order\n\n    def __call__(self, x:Image, *args, **kwargs)->Image:\n        \"Randomly execute our tfm on `x`.\"\n        return self.tfm(x, *args, **{**self.resolved, **kwargs}) if self.do_run else x\n\ndef _resolve_tfms(tfms:TfmList):\n    \"Resolve every tfm in `tfms`.\"\n    for f in listify(tfms): f.resolve()\n\ndef _grid_sample(x:TensorImage, coords:FlowField, mode:str='bilinear', padding_mode:str='reflection', remove_out:bool=True)->TensorImage:\n    \"Resample pixels in `coords` from `x` by `mode`, with `padding_mode` in ('reflection','border','zeros').\"\n    coords = coords.flow.permute(0, 3, 1, 2).contiguous().permute(0, 2, 3, 1) # optimize layout for grid_sample\n    if mode=='bilinear': # hack to get smoother downwards resampling\n        mn,mx = coords.min(),coords.max()\n        # max amount we're affine zooming by (>1 means zooming in)\n        z = 1/(mx-mn).item()*2\n        # amount we're resizing by, with 100% extra margin\n        d = min(x.shape[1]/coords.shape[1], x.shape[2]/coords.shape[2])/2\n        # If we're resizing up by >200%, and we're zooming less than that, interpolate first\n        if d>1 and d>z: x = F.interpolate(x[None], scale_factor=1/d, mode='area')[0]\n    return F.grid_sample(x[None], coords, mode=mode, padding_mode=padding_mode)[0]\n\ndef _affine_grid(size:TensorImageSize)->FlowField:\n    size = ((1,)+size)\n    N, C, H, W = size\n    grid = FloatTensor(N, H, W, 2)\n    linear_points = torch.linspace(-1, 1, W) if W > 1 else tensor([-1])\n    grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(grid[:, :, :, 0])\n    linear_points = torch.linspace(-1, 1, H) if H > 1 else tensor([-1])\n    grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(grid[:, :, :, 1])\n    return FlowField(size[2:], grid)\n\ndef _affine_mult(c:FlowField,m:AffineMatrix)->FlowField:\n    \"Multiply `c` by `m` - can adjust for rectangular shaped `c`.\"\n    if m is None: return c\n    size = c.flow.size()\n    h,w = c.size\n    m[0,1] *= h/w\n    m[1,0] *= w/h\n    c.flow = c.flow.view(-1,2)\n    c.flow = torch.addmm(m[:2,2], c.flow,  m[:2,:2].t()).view(size)\n    return c\n\ndef _affine_inv_mult(c, m):\n    \"Applies the inverse affine transform described in `m` to `c`.\"\n    size = c.flow.size()\n    h,w = c.size\n    m[0,1] *= h/w\n    m[1,0] *= w/h\n    c.flow = c.flow.view(-1,2)\n    a = torch.inverse(m[:2,:2].t())\n    c.flow = torch.mm(c.flow - m[:2,2], a).view(size)\n    return c\n\nclass TfmAffine(Transform):\n    \"Decorator for affine tfm funcs.\"\n    order,_wrap = 5,'affine'\nclass TfmPixel(Transform):\n    \"Decorator for pixel tfm funcs.\"\n    order,_wrap = 10,'pixel'\nclass TfmCoord(Transform):\n    \"Decorator for coord tfm funcs.\"\n    order,_wrap = 4,'coord'\nclass TfmCrop(TfmPixel):\n    \"Decorator for crop tfm funcs.\"\n    order=99\nclass TfmLighting(Transform):\n    \"Decorator for lighting tfm funcs.\"\n    order,_wrap = 8,'lighting'\n\ndef _round_multiple(x:int, mult:int=None)->int:\n    \"Calc `x` to nearest multiple of `mult`.\"\n    return (int(x/mult+0.5)*mult) if mult is not None else x\n\ndef _get_crop_target(target_px:Union[int,TensorImageSize], mult:int=None)->Tuple[int,int]:\n    \"Calc crop shape of `target_px` to nearest multiple of `mult`.\"\n    target_r,target_c = tis2hw(target_px)\n    return _round_multiple(target_r,mult),_round_multiple(target_c,mult)\n\ndef _get_resize_target(img, crop_target, do_crop=False)->TensorImageSize:\n    \"Calc size of `img` to fit in `crop_target` - adjust based on `do_crop`.\"\n    if crop_target is None: return None\n    ch,r,c = img.shape\n    target_r,target_c = crop_target\n    ratio = (min if do_crop else max)(r/target_r, c/target_c)\n    return ch,int(round(r/ratio)),int(round(c/ratio)) #Sometimes those are numpy numbers and round doesn't return an int.\n\ndef plot_flat(r, c, figsize):\n    \"Shortcut for `enumerate(subplots.flatten())`\"\n    return enumerate(plt.subplots(r, c, figsize=figsize)[1].flatten())\n\ndef plot_multi(func:Callable[[int,int,plt.Axes],None], r:int=1, c:int=1, figsize:Tuple=(12,6)):\n    \"Call `func` for every combination of `r,c` on a subplot\"\n    axes = plt.subplots(r, c, figsize=figsize)[1]\n    for i in range(r):\n        for j in range(c): func(i,j,axes[i,j])\n\ndef show_multi(func:Callable[[int,int],Image], r:int=1, c:int=1, figsize:Tuple=(9,9)):\n    \"Call `func(i,j).show(ax)` for every combination of `r,c`\"\n    plot_multi(lambda i,j,ax: func(i,j).show(ax), r, c, figsize=figsize)\n\ndef show_all(imgs:Collection[Image], r:int=1, c:Optional[int]=None, figsize=(12,6)):\n    \"Show all `imgs` using `r` rows\"\n    imgs = listify(imgs)\n    if c is None: c = len(imgs)//r\n    for i,ax in plot_flat(r,c,figsize): imgs[i].show(ax)\n"
  },
  {
    "path": "fastai/vision/interpret.py",
    "content": "from ..torch_core import *\nfrom ..basic_data import *\nfrom ..basic_train import *\nfrom .image import *\nfrom ..train import Interpretation\nfrom textwrap import wrap\n\n__all__ = ['SegmentationInterpretation', 'ObjectDetectionInterpretation']\n\nclass SegmentationInterpretation(Interpretation):\n    \"Interpretation methods for segmenatation models.\"\n    def __init__(self, learn:Learner, preds:Tensor, y_true:Tensor, losses:Tensor,\n                 ds_type:DatasetType=DatasetType.Valid):\n        super(SegmentationInterpretation, self).__init__(learn,preds,y_true,losses,ds_type)\n        self.pred_class = self.preds.argmax(dim=1)\n        self.c2i = {c:i for i,c in enumerate(self.data.classes)}\n        self.i2c = {i:c for c,i in self.c2i.items()}\n    \n    def top_losses(self, sizes:Tuple, k:int=None, largest=True):\n        \"Reduce flatten loss to give a single loss value for each image\"\n        losses = self.losses.view(-1, np.prod(sizes)).mean(-1)\n        return losses.topk(ifnone(k, len(losses)), largest=largest)\n    \n    def _interp_show(self, ims:ImageSegment, classes:Collection=None, sz:int=20, cmap='tab20',\n                    title_suffix:str=None):\n        \"Show ImageSegment with color mapping labels\"\n        fig,axes=plt.subplots(1,2,figsize=(sz,sz))\n        np_im = to_np(ims.data).copy()\n        # tab20 - qualitative colormaps support max of 20 distinc colors\n        # if len(classes) > 20 close idxs map to same color\n        # image\n        if classes is not None:\n            class_idxs = [self.c2i[c] for c in classes]\n            mask = np.max(np.stack([np_im==i for i in class_idxs]),axis=0)\n            np_im = (np_im*mask).astype(np.float)\n            np_im[np.where(mask==0)] = np.nan\n        im=axes[0].imshow(np_im[0], cmap=cmap)\n\n        # labels\n        np_im_labels = list(np.unique(np_im[~np.isnan(np_im)]))\n        c = len(np_im_labels); n = math.ceil(np.sqrt(c))\n        label_im = np.array(np_im_labels + [np.nan]*(n**2-c)).reshape(n,n)\n        axes[1].imshow(label_im, cmap=cmap)\n        for i,l in enumerate([self.i2c[l] for l in np_im_labels]):\n            div,mod=divmod(i,n)\n            l = \"\\n\".join(wrap(l,10)) if len(l) > 10 else l\n            axes[1].text(mod, div, f\"{l}\", ha='center', color='white', fontdict={'size':sz})\n\n        if title_suffix:\n            axes[0].set_title(f\"{title_suffix}_imsegment\")\n            axes[1].set_title(f\"{title_suffix}_labels\")\n\n    def show_xyz(self, i, classes:list=None, sz=10):\n        'show (image, true and pred) from self.ds with color mappings, optionally only plot'\n        x,y = self.ds[i]\n        self.ds.show_xys([x],[y], figsize=(sz/2,sz/2))\n        self._interp_show(ImageSegment(self.y_true[i]), classes, sz=sz, title_suffix='true')\n        self._interp_show(ImageSegment(self.pred_class[i][None,:]), classes, sz=sz, title_suffix='pred')\n\n    def _generate_confusion(self):\n        \"Average and Per Image Confusion: intersection of pixels given a true label, true label sums to 1\"\n        single_img_confusion = []\n        mean_confusion = []\n        n =  self.pred_class.shape[0]\n        for c_j in range(self.data.c):\n            true_binary = self.y_true.squeeze(1) == c_j\n            total_true = true_binary.view(n,-1).sum(dim=1).float()\n            for c_i in range(self.data.c):\n                pred_binary = self.pred_class == c_i\n                total_intersect = (true_binary*pred_binary).view(n,-1).sum(dim=1).float()\n                p_given_t = (total_intersect / (total_true))\n                p_given_t_mean = p_given_t[~torch.isnan(p_given_t)].mean()\n                single_img_confusion.append(p_given_t)\n                mean_confusion.append(p_given_t_mean)\n        self.single_img_cm = to_np(torch.stack(single_img_confusion).permute(1,0).view(-1, self.data.c, self.data.c))\n        self.mean_cm = to_np(torch.tensor(mean_confusion).view(self.data.c, self.data.c))\n        return self.mean_cm, self.single_img_cm\n\n    def _plot_intersect_cm(self, cm, title=\"Intersection with Predict given True\"):\n        \"Plot confusion matrices: self.mean_cm or self.single_img_cm generated by `_generate_confusion`\"\n        from IPython.display import display, HTML\n        fig,ax=plt.subplots(1,1,figsize=(10,10))\n        im=ax.imshow(cm, cmap=\"Blues\")\n        ax.set_xlabel(\"Predicted\")\n        ax.set_ylabel(\"True\")\n        ax.set_title(f\"{title}\")\n        ax.set_xticks(range(self.data.c))\n        ax.set_yticks(range(self.data.c))\n        ax.set_xticklabels(self.data.classes, rotation='vertical')\n        ax.set_yticklabels(self.data.classes)\n        fig.colorbar(im)\n        \n        df = (pd.DataFrame([self.data.classes, cm.diagonal()], index=['label', 'score'])\n            .T.sort_values('score', ascending=False))\n        with pd.option_context('display.max_colwidth', -1):\n            display(HTML(df.to_html(index=False)))\n        return df\n\n\n\nclass ObjectDetectionInterpretation(Interpretation):\n    \"Interpretation methods for classification models.\"\n    def __init__(self, learn:Learner, preds:Tensor, y_true:Tensor, losses:Tensor, ds_type:DatasetType=DatasetType.Valid):\n        raise NotImplementedError\n        super(ObjectDetectionInterpretation, self).__init__(learn,preds,y_true,losses,ds_type)\n        "
  },
  {
    "path": "fastai/vision/learner.py",
    "content": "\"`Learner` support for computer vision\"\nfrom ..torch_core import *\nfrom ..basic_train import *\nfrom ..basic_data import *\nfrom .image import *\nfrom . import models\nfrom ..callback import *\nfrom ..layers import *\nfrom ..callbacks.hooks import *\nfrom ..train import ClassificationInterpretation\n\n__all__ = ['cnn_learner', 'create_cnn', 'create_cnn_model', 'create_body', 'create_head', 'unet_learner']\n# By default split models between first and second layer\ndef _default_split(m:nn.Module): return (m[1],)\n# Split a resnet style model\ndef _resnet_split(m:nn.Module): return (m[0][6],m[1])\n# Split squeezenet model on maxpool layers\ndef _squeezenet_split(m:nn.Module): return (m[0][0][5], m[0][0][8], m[1])\ndef _densenet_split(m:nn.Module): return (m[0][0][7],m[1])\ndef _vgg_split(m:nn.Module): return (m[0][0][22],m[1])\ndef _alexnet_split(m:nn.Module): return (m[0][0][6],m[1])\n\n_default_meta    = {'cut':None, 'split':_default_split}\n_resnet_meta     = {'cut':-2, 'split':_resnet_split }\n_squeezenet_meta = {'cut':-1, 'split': _squeezenet_split}\n_densenet_meta   = {'cut':-1, 'split':_densenet_split}\n_vgg_meta        = {'cut':-1, 'split':_vgg_split}\n_alexnet_meta    = {'cut':-1, 'split':_alexnet_split}\n\nmodel_meta = {\n    models.resnet18 :{**_resnet_meta}, models.resnet34: {**_resnet_meta},\n    models.resnet50 :{**_resnet_meta}, models.resnet101:{**_resnet_meta},\n    models.resnet152:{**_resnet_meta},\n\n    models.squeezenet1_0:{**_squeezenet_meta},\n    models.squeezenet1_1:{**_squeezenet_meta},\n\n    models.densenet121:{**_densenet_meta}, models.densenet169:{**_densenet_meta},\n    models.densenet201:{**_densenet_meta}, models.densenet161:{**_densenet_meta},\n    models.vgg16_bn:{**_vgg_meta}, models.vgg19_bn:{**_vgg_meta},\n    models.alexnet:{**_alexnet_meta}}\n\ndef cnn_config(arch):\n    \"Get the metadata associated with `arch`.\"\n    #torch.backends.cudnn.benchmark = True\n    return model_meta.get(arch, _default_meta)\n\ndef has_pool_type(m):\n    if is_pool_type(m): return True\n    for l in m.children():\n        if has_pool_type(l): return True\n    return False\n\ndef create_body(arch:Callable, pretrained:bool=True, cut:Optional[Union[int, Callable]]=None):\n    \"Cut off the body of a typically pretrained `model` at `cut` (int) or cut the model as specified by `cut(model)` (function).\"\n    model = arch(pretrained=pretrained)\n    cut = ifnone(cut, cnn_config(arch)['cut'])\n    if cut is None:\n        ll = list(enumerate(model.children()))\n        cut = next(i for i,o in reversed(ll) if has_pool_type(o))\n    if   isinstance(cut, int):      return nn.Sequential(*list(model.children())[:cut])\n    elif isinstance(cut, Callable): return cut(model)\n    else:                           raise NamedError(\"cut must be either integer or a function\")\n\n\ndef create_head(nf:int, nc:int, lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5,\n                concat_pool:bool=True, bn_final:bool=False):\n    \"Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes.\"\n    lin_ftrs = [nf, 512, nc] if lin_ftrs is None else [nf] + lin_ftrs + [nc]\n    ps = listify(ps)\n    if len(ps) == 1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps\n    actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]\n    pool = AdaptiveConcatPool2d() if concat_pool else nn.AdaptiveAvgPool2d(1)\n    layers = [pool, Flatten()]\n    for ni,no,p,actn in zip(lin_ftrs[:-1], lin_ftrs[1:], ps, actns):\n        layers += bn_drop_lin(ni, no, True, p, actn)\n    if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))\n    return nn.Sequential(*layers)\n\ndef create_cnn_model(base_arch:Callable, nc:int, cut:Union[int,Callable]=None, pretrained:bool=True,\n                     lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5, custom_head:Optional[nn.Module]=None,\n                     bn_final:bool=False, concat_pool:bool=True):\n    \"Create custom convnet architecture\"\n    body = create_body(base_arch, pretrained, cut)\n    if custom_head is None:\n        nf = num_features_model(nn.Sequential(*body.children())) * (2 if concat_pool else 1)\n        head = create_head(nf, nc, lin_ftrs, ps=ps, concat_pool=concat_pool, bn_final=bn_final)\n    else: head = custom_head\n    return nn.Sequential(body, head)\n\ndef cnn_learner(data:DataBunch, base_arch:Callable, cut:Union[int,Callable]=None, pretrained:bool=True,\n                lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5, custom_head:Optional[nn.Module]=None,\n                split_on:Optional[SplitFuncOrIdxList]=None, bn_final:bool=False, init=nn.init.kaiming_normal_,\n                concat_pool:bool=True, **kwargs:Any)->Learner:\n    \"Build convnet style learner.\"\n    meta = cnn_config(base_arch)\n    model = create_cnn_model(base_arch, data.c, cut, pretrained, lin_ftrs, ps=ps, custom_head=custom_head,\n        bn_final=bn_final, concat_pool=concat_pool)\n    learn = Learner(data, model, **kwargs)\n    learn.split(split_on or meta['split'])\n    if pretrained: learn.freeze()\n    if init: apply_init(model[1], init)\n    return learn\n\ndef create_cnn(data, base_arch, **kwargs):\n    warn(\"`create_cnn` is deprecated and is now named `cnn_learner`.\")\n    return cnn_learner(data, base_arch, **kwargs)\n\ndef unet_learner(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,\n                 norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None, blur:bool=False,\n                 self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,\n                 bottle:bool=False, cut:Union[int,Callable]=None, **learn_kwargs:Any)->Learner:\n    \"Build Unet learner from `data` and `arch`.\"\n    meta = cnn_config(arch)\n    body = create_body(arch, pretrained, cut)\n    try:    size = data.train_ds[0][0].size\n    except: size = next(iter(data.train_dl))[0].shape[-2:]\n    model = to_device(models.unet.DynamicUnet(body, n_classes=data.c, img_size=size, blur=blur, blur_final=blur_final,\n          self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,\n          bottle=bottle), data.device)\n    learn = Learner(data, model, **learn_kwargs)\n    learn.split(ifnone(split_on, meta['split']))\n    if pretrained: learn.freeze()\n    apply_init(model[2], nn.init.kaiming_normal_)\n    return learn\n\n@classmethod\ndef _cl_int_from_learner(cls, learn:Learner, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None, tta=False):\n    \"Create an instance of `ClassificationInterpretation`. `tta` indicates if we want to use Test Time Augmentation.\"\n    preds = learn.TTA(ds_type=ds_type, with_loss=True) if tta else learn.get_preds(ds_type=ds_type, activ=activ, with_loss=True)\n\n    return cls(learn, *preds, ds_type=ds_type)\n\ndef _test_cnn(m):\n    if not isinstance(m, nn.Sequential) or not len(m) == 2: return False\n    return isinstance(m[1][0], (AdaptiveConcatPool2d, nn.AdaptiveAvgPool2d))\n\ndef _cl_int_gradcam(self, idx, heatmap_thresh:int=16, image:bool=True):\n    m = self.learn.model.eval()\n    im,cl = self.learn.data.dl(DatasetType.Valid).dataset[idx]\n    cl = int(cl)\n    xb,_ = self.data.one_item(im, detach=False, denorm=False) #put into a minibatch of batch size = 1\n    with hook_output(m[0]) as hook_a: \n        with hook_output(m[0], grad=True) as hook_g:\n            preds = m(xb)\n            preds[0,int(cl)].backward() \n    acts  = hook_a.stored[0].cpu() #activation maps\n    if (acts.shape[-1]*acts.shape[-2]) >= heatmap_thresh:\n        grad = hook_g.stored[0][0].cpu()\n        grad_chan = grad.mean(1).mean(1)\n        mult = F.relu(((acts*grad_chan[...,None,None])).sum(0))\n        if image:\n            xb_im = Image(xb[0])\n            _,ax = plt.subplots()\n            sz = list(xb_im.shape[-2:])\n            xb_im.show(ax,title=f\"pred. class: {self.pred_class[idx]}, actual class: {self.learn.data.classes[cl]}\")\n            ax.imshow(mult, alpha=0.4, extent=(0,*sz[::-1],0),\n              interpolation='bilinear', cmap='magma')\n        return mult\n\nClassificationInterpretation.GradCAM =_cl_int_gradcam\n\ndef _cl_int_plot_top_losses(self, k, largest=True, figsize=(12,12), heatmap:bool=False, heatmap_thresh:int=16,\n                            return_fig:bool=None)->Optional[plt.Figure]:\n    \"Show images in `top_losses` along with their prediction, actual, loss, and probability of actual class.\"\n    assert not heatmap or _test_cnn(self.learn.model), \"`heatmap=True` requires a model like `cnn_learner` produces.\"\n    if heatmap is None: heatmap = _test_cnn(self.learn.model)\n    tl_val,tl_idx = self.top_losses(k, largest)\n    classes = self.data.classes\n    cols = math.ceil(math.sqrt(k))\n    rows = math.ceil(k/cols)\n    fig,axes = plt.subplots(rows, cols, figsize=figsize)\n    fig.suptitle('prediction/actual/loss/probability', weight='bold', size=14)\n    for i,idx in enumerate(tl_idx):\n        im,cl = self.data.dl(self.ds_type).dataset[idx]\n        cl = int(cl)\n        im.show(ax=axes.flat[i], title=\n            f'{classes[self.pred_class[idx]]}/{classes[cl]} / {self.losses[idx]:.2f} / {self.preds[idx][cl]:.2f}')\n        if heatmap:\n            mult = self.GradCAM(idx,heatmap_thresh,image=False)\n            if mult is not None:\n                sz = list(im.shape[-2:])\n                axes.flat[i].imshow(mult, alpha=0.6, extent=(0,*sz[::-1],0), interpolation='bilinear', cmap='magma')                \n    if ifnone(return_fig, defaults.return_fig): return fig\n\ndef _cl_int_plot_multi_top_losses(self, samples:int=3, figsize:Tuple[int,int]=(8,8), save_misclassified:bool=False):\n    \"Show images in `top_losses` along with their prediction, actual, loss, and probability of predicted class in a multilabeled dataset.\"\n    if samples >20:\n        print(\"Max 20 samples\")\n        return\n    losses, idxs = self.top_losses(self.data.c)\n    l_dim = len(losses.size())\n    if l_dim == 1: losses, idxs = self.top_losses()\n    infolist, ordlosses_idxs, mismatches_idxs, mismatches, losses_mismatches, mismatchescontainer = [],[],[],[],[],[]\n    truthlabels = np.asarray(self.y_true, dtype=int)\n    classes_ids = [k for k in enumerate(self.data.classes)]\n    predclass = np.asarray(self.pred_class)\n    for i,pred in enumerate(predclass):\n        where_truth = np.nonzero((truthlabels[i]>0))[0]\n        mismatch = np.all(pred!=where_truth)\n        if mismatch:\n            mismatches_idxs.append(i)\n            if l_dim > 1 : losses_mismatches.append((losses[i][pred], i))\n            else: losses_mismatches.append((losses[i], i))\n        if l_dim > 1: infotup = (i, pred, where_truth, losses[i][pred], np.round(self.preds[i], decimals=3)[pred], mismatch)\n        else: infotup = (i, pred, where_truth, losses[i], np.round(self.preds[i], decimals=3)[pred], mismatch)\n        infolist.append(infotup)\n    ds = self.data.dl(self.ds_type).dataset\n    mismatches = ds[mismatches_idxs]\n    ordlosses = sorted(losses_mismatches, key = lambda x: x[0], reverse=True)\n    for w in ordlosses: ordlosses_idxs.append(w[1])\n    mismatches_ordered_byloss = ds[ordlosses_idxs]\n    print(f'{str(len(mismatches))} misclassified samples over {str(len(self.data.valid_ds))} samples in the validation set.')\n    samples = min(samples, len(mismatches))\n    for ima in range(len(mismatches_ordered_byloss)):\n        mismatchescontainer.append(mismatches_ordered_byloss[ima][0])\n    for sampleN in range(samples):\n        actualclasses = ''\n        for clas in infoList[ordlosses_idxs[sampleN]][2]:\n            actualclasses = f'{actualclasses} -- {str(classes_ids[clas][1])}'\n        imag = mismatches_ordered_byloss[sampleN][0]\n        imag = show_image(imag, figsize=figsize)\n        imag.set_title(f\"\"\"Predicted: {classes_ids[infoList[ordlosses_idxs[sampleN]][1]][1]} \\nActual: {actualclasses}\\nLoss: {infoList[ordlosses_idxs[sampleN]][3]}\\nProbability: {infoList[ordlosses_idxs[sampleN]][4]}\"\"\",\n                        loc='left')\n        plt.show()\n        if save_misclassified: return mismatchescontainer\n\nClassificationInterpretation.from_learner          = _cl_int_from_learner\nClassificationInterpretation.plot_top_losses       = _cl_int_plot_top_losses\nClassificationInterpretation.plot_multi_top_losses = _cl_int_plot_multi_top_losses\n \n\ndef _learner_interpret(learn:Learner, ds_type:DatasetType=DatasetType.Valid, tta=False):\n    \"Create a `ClassificationInterpretation` object from `learner` on `ds_type` with `tta`.\"\n    return ClassificationInterpretation.from_learner(learn, ds_type=ds_type, tta=tta)\nLearner.interpret = _learner_interpret\n"
  },
  {
    "path": "fastai/vision/models/__init__.py",
    "content": "from .xresnet import *\nfrom torchvision.models import ResNet,resnet18,resnet34,resnet50,resnet101,resnet152\nfrom torchvision.models import SqueezeNet,squeezenet1_0,squeezenet1_1\nfrom torchvision.models import densenet121,densenet169,densenet201,densenet161\nfrom torchvision.models import vgg16_bn,vgg19_bn,alexnet\nfrom .darknet import *\nfrom .unet import *\nfrom .wrn import *\nfrom .xception import *\n"
  },
  {
    "path": "fastai/vision/models/cadene_models.py",
    "content": "#These models are dowloaded via the repo https://github.com/Cadene/pretrained-models.pytorch\n#See licence here: https://github.com/Cadene/pretrained-models.pytorch/blob/master/LICENSE.txt\nfrom torch import nn\nfrom ..learner import model_meta\nfrom ...core import *\n\npretrainedmodels = try_import('pretrainedmodels')\nif not pretrainedmodels:\n    raise Exception('Error: `pretrainedmodels` is needed. `pip install pretrainedmodels`')\n\n__all__ = ['inceptionv4', 'inceptionresnetv2', 'nasnetamobile', 'dpn92', 'xception_cadene', 'se_resnet50',\n           'se_resnet101', 'se_resnext50_32x4d', 'senet154', 'pnasnet5large', 'se_resnext101_32x4d']\n\ndef get_model(model_name:str, pretrained:bool, seq:bool=False, pname:str='imagenet', **kwargs):\n    pretrained = pname if pretrained else None\n    model = getattr(pretrainedmodels, model_name)(pretrained=pretrained, **kwargs)\n    return nn.Sequential(*model.children()) if seq else model\n\ndef inceptionv4(pretrained:bool=False):\n    model = get_model('inceptionv4', pretrained)\n    all_layers = list(model.children())\n    return nn.Sequential(*all_layers[0], *all_layers[1:])\nmodel_meta[inceptionv4] = {'cut': -2, 'split': lambda m: (m[0][11], m[1])}\n\ndef nasnetamobile(pretrained:bool=False):\n    model = get_model('nasnetamobile', pretrained, num_classes=1000)\n    model.logits = noop\n    return nn.Sequential(model)\nmodel_meta[nasnetamobile] = {'cut': noop, 'split': lambda m: (list(m[0][0].children())[8], m[1])}\n\ndef pnasnet5large(pretrained:bool=False):\n    model = get_model('pnasnet5large', pretrained, num_classes=1000)\n    model.logits = noop\n    return nn.Sequential(model)\nmodel_meta[pnasnet5large] = {'cut': noop, 'split': lambda m: (list(m[0][0].children())[8], m[1])}\n\ndef inceptionresnetv2(pretrained:bool=False):   return get_model('inceptionresnetv2', pretrained, seq=True)\ndef dpn92(pretrained:bool=False):               return get_model('dpn92', pretrained, pname='imagenet+5k', seq=True)\ndef xception_cadene(pretrained=False):          return get_model('xception', pretrained, seq=True)\ndef se_resnet50(pretrained:bool=False):         return get_model('se_resnet50', pretrained)\ndef se_resnet101(pretrained:bool=False):        return get_model('se_resnet101', pretrained)\ndef se_resnext50_32x4d(pretrained:bool=False):  return get_model('se_resnext50_32x4d', pretrained)\ndef se_resnext101_32x4d(pretrained:bool=False): return get_model('se_resnext101_32x4d', pretrained)\ndef senet154(pretrained:bool=False):            return get_model('senet154', pretrained)\n\nmodel_meta[inceptionresnetv2] = {'cut': -2, 'split': lambda m: (m[0][9],     m[1])}\nmodel_meta[dpn92]             = {'cut': -1, 'split': lambda m: (m[0][0][16], m[1])}\nmodel_meta[xception_cadene]   = {'cut': -1, 'split': lambda m: (m[0][11],    m[1])}\nmodel_meta[senet154]          = {'cut': -3, 'split': lambda m: (m[0][3],     m[1])}\n_se_resnet_meta               = {'cut': -2, 'split': lambda m: (m[0][3],     m[1])}\nmodel_meta[se_resnet50]         = _se_resnet_meta\nmodel_meta[se_resnet101]        = _se_resnet_meta\nmodel_meta[se_resnext50_32x4d]  = _se_resnet_meta\nmodel_meta[se_resnext101_32x4d] = _se_resnet_meta\n\n# TODO: add \"resnext101_32x4d\" \"resnext101_64x4d\" after serialization issue is fixed:\n# https://github.com/Cadene/pretrained-models.pytorch/pull/128\n"
  },
  {
    "path": "fastai/vision/models/darknet.py",
    "content": "from ...torch_core import *\nfrom ...layers import *\n\n__all__ = ['Darknet', 'ResLayer']\n\ndef conv_bn_lrelu(ni:int, nf:int, ks:int=3, stride:int=1)->nn.Sequential:\n    \"Create a seuence Conv2d->BatchNorm2d->LeakyReLu layer.\"\n    return nn.Sequential(\n        nn.Conv2d(ni, nf, kernel_size=ks, bias=False, stride=stride, padding=ks//2),\n        nn.BatchNorm2d(nf),\n        nn.LeakyReLU(negative_slope=0.1, inplace=True))\n\nclass ResLayer(Module):\n    \"Resnet style layer with `ni` inputs.\"\n    def __init__(self, ni:int):\n        self.conv1 = conv_bn_lrelu(ni, ni//2, ks=1)\n        self.conv2 = conv_bn_lrelu(ni//2, ni, ks=3)\n\n    def forward(self, x): return x + self.conv2(self.conv1(x))\n\nclass Darknet(Module):\n    \"https://github.com/pjreddie/darknet\"\n    def make_group_layer(self, ch_in:int, num_blocks:int, stride:int=1):\n        \"starts with conv layer - `ch_in` channels in - then has `num_blocks` `ResLayer`\"\n        return [conv_bn_lrelu(ch_in, ch_in*2,stride=stride)\n               ] + [(ResLayer(ch_in*2)) for i in range(num_blocks)]\n\n    def __init__(self, num_blocks:Collection[int], num_classes:int, nf=32):\n        \"create darknet with `nf` and `num_blocks` layers\"\n        layers = [conv_bn_lrelu(3, nf, ks=3, stride=1)]\n        for i,nb in enumerate(num_blocks):\n            layers += self.make_group_layer(nf, nb, stride=2-(i==1))\n            nf *= 2\n        layers += [nn.AdaptiveAvgPool2d(1), Flatten(), nn.Linear(nf, num_classes)]\n        self.layers = nn.Sequential(*layers)\n\n    def forward(self, x): return self.layers(x)\n"
  },
  {
    "path": "fastai/vision/models/presnet.py",
    "content": "from pdb import set_trace\nimport torch.nn.functional as F\nimport torch.nn as nn\nimport torch\nimport math\nimport torch.utils.model_zoo as model_zoo\n\n__all__ = ['PResNet', 'presnet18', 'presnet34', 'presnet50', 'presnet101', 'presnet152']\n\nact_fn = nn.ReLU\n\ndef init_cnn(m):\n    if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)\n    if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight)\n    elif isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01)\n    for l in m.children(): init_cnn(l)\n\ndef conv(ni, nf, ks=3, stride=1, bias=False):\n    return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)\n\ndef conv_layer(conv_1st, ni, nf, ks=3, stride=1, zero_bn=False, bias=False):\n    bn = nn.BatchNorm2d(nf if conv_1st else ni)\n    nn.init.constant_(bn.weight, 0. if zero_bn else 1.)\n    res = [act_fn(), bn]\n    cn = conv(ni, nf, ks, stride=stride, bias=bias)\n    res.insert(0 if conv_1st else 2, cn)\n    return nn.Sequential(*res)\n\ndef conv_act(*args, **kwargs): return conv_layer(True , *args, **kwargs)\ndef act_conv(*args, **kwargs): return conv_layer(False, *args, **kwargs)\n\nclass BasicBlock(Module):\n    expansion = 1\n\n    def __init__(self, ni, nf, stride=1, downsample=None):\n        super(BasicBlock, self).__init__()\n        self.conv1 = act_conv(ni, nf, stride=stride)\n        self.conv2 = act_conv(nf, nf, zero_bn=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        identity = x if self.downsample is None else self.downsample(x)\n        x = self.conv1(x)\n        x = self.conv2(x)\n        x += identity\n        return x\n\nclass Bottleneck(Module):\n    expansion = 4\n\n    def __init__(self, ni, nf, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = act_conv(ni, nf, 1)\n        self.conv2 = act_conv(nf, nf, stride=stride)\n        self.conv3 = act_conv(nf, nf*self.expansion, 1)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        identity = x if self.downsample is None else self.downsample(x)\n        x = self.conv1(x)\n        x = self.conv2(x)\n        x = self.conv3(x)\n        x += identity\n        return x\n\nclass PResNet(Module):\n\n    def __init__(self, block, layers, num_classes=1000):\n        self.ni = 64\n        super().__init__()\n        self.conv1 = conv_act(3, 16, stride=2)\n        self.conv2 = conv_act(16, 32)\n        self.conv3 = conv_act(32, 64)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n        ni = 512*block.expansion\n        self.avgpool = nn.Sequential(\n            act_fn(), nn.BatchNorm2d(ni), nn.AdaptiveAvgPool2d(1))\n        self.fc = nn.Linear(ni, num_classes)\n\n        init_cnn(self)\n\n    def _make_layer(self, block, nf, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.ni != nf*block.expansion:\n            layers = [act_fn(), nn.BatchNorm2d(self.ni),\n                      nn.AvgPool2d(kernel_size=2)] if stride==2 else []\n            layers.append(conv(self.ni, nf*block.expansion))\n            downsample = nn.Sequential(*layers)\n\n        layers = [block(self.ni, nf, stride, downsample)]\n        self.ni = nf*block.expansion\n        for i in range(1, blocks): layers.append(block(self.ni, nf))\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.conv2(x)\n        x = self.conv3(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = x.view(x.size(0), -1)\n        x = self.fc(x)\n\n        return x\n\nmodel_urls = dict(presnet34='presnet34', presnet50='presnet50')\n\ndef presnet(block, n_layers, name, pre=False, **kwargs):\n    model = PResNet(block, n_layers, **kwargs)\n    #if pre: model.load_state_dict(model_zoo.load_url(model_urls[name]))\n    if pre: model.load_state_dict(torch.load(model_urls[name]))\n    return model\n\ndef presnet18(pretrained=False, **kwargs):\n    return presnet(BasicBlock, [2, 2, 2, 2], 'presnet18', pre=pretrained, **kwargs)\n\ndef presnet34(pretrained=False, **kwargs):\n    return presnet(BasicBlock, [3, 4, 6, 3], 'presnet34', pre=pretrained, **kwargs)\n\ndef presnet50(pretrained=False, **kwargs):\n    return presnet(Bottleneck, [3, 4, 6, 3], 'presnet50', pre=pretrained, **kwargs)\n\ndef presnet101(pretrained=False, **kwargs):\n    return presnet(Bottleneck, [3, 4, 23, 3], 'presnet101', pre=pretrained, **kwargs)\n\ndef presnet152(pretrained=False, **kwargs):\n    return presnet(Bottleneck, [3, 8, 36, 3], 'presnet152', pre=pretrained, **kwargs)\n\n"
  },
  {
    "path": "fastai/vision/models/unet.py",
    "content": "from ...torch_core import *\nfrom ...layers import *\nfrom ...callbacks.hooks import *\n\n__all__ = ['DynamicUnet', 'UnetBlock']\n\ndef _get_sfs_idxs(sizes:Sizes) -> List[int]:\n    \"Get the indexes of the layers where the size of the activation changes.\"\n    feature_szs = [size[-1] for size in sizes]\n    sfs_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])\n    if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs\n    return sfs_idxs\n\nclass UnetBlock(Module):\n    \"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`.\"\n    def __init__(self, up_in_c:int, x_in_c:int, hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,\n                 self_attention:bool=False, **kwargs):\n        self.hook = hook\n        self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, leaky=leaky, **kwargs)\n        self.bn = batchnorm_2d(x_in_c)\n        ni = up_in_c//2 + x_in_c\n        nf = ni if final_div else ni//2\n        self.conv1 = conv_layer(ni, nf, leaky=leaky, **kwargs)\n        self.conv2 = conv_layer(nf, nf, leaky=leaky, self_attention=self_attention, **kwargs)\n        self.relu = relu(leaky=leaky)\n\n    def forward(self, up_in:Tensor) -> Tensor:\n        s = self.hook.stored\n        up_out = self.shuf(up_in)\n        ssh = s.shape[-2:]\n        if ssh != up_out.shape[-2:]:\n            up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')\n        cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))\n        return self.conv2(self.conv1(cat_x))\n\n\nclass DynamicUnet(SequentialEx):\n    \"Create a U-Net from a given architecture.\"\n    def __init__(self, encoder:nn.Module, n_classes:int, img_size:Tuple[int,int]=(256,256), blur:bool=False, blur_final=True, self_attention:bool=False,\n                 y_range:Optional[Tuple[float,float]]=None,\n                 last_cross:bool=True, bottle:bool=False, **kwargs):\n        imsize = img_size\n        sfs_szs = model_sizes(encoder, size=imsize)\n        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))\n        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])\n        x = dummy_eval(encoder, imsize).detach()\n\n        ni = sfs_szs[-1][1]\n        middle_conv = nn.Sequential(conv_layer(ni, ni*2, **kwargs),\n                                    conv_layer(ni*2, ni, **kwargs)).eval()\n        x = middle_conv(x)\n        layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]\n\n        for i,idx in enumerate(sfs_idxs):\n            not_final = i!=len(sfs_idxs)-1\n            up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])\n            do_blur = blur and (not_final or blur_final)\n            sa = self_attention and (i==len(sfs_idxs)-3)\n            unet_block = UnetBlock(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=do_blur, self_attention=sa,\n                                   **kwargs).eval()\n            layers.append(unet_block)\n            x = unet_block(x)\n\n        ni = x.shape[1]\n        if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))\n        x = PixelShuffle_ICNR(ni)(x)\n        if imsize != x.shape[-2:]: layers.append(Lambda(lambda x: F.interpolate(x, imsize, mode='nearest')))\n        if last_cross:\n            layers.append(MergeLayer(dense=True))\n            ni += in_channels(encoder)\n            layers.append(res_block(ni, bottle=bottle, **kwargs))\n        layers += [conv_layer(ni, n_classes, ks=1, use_activ=False, **kwargs)]\n        if y_range is not None: layers.append(SigmoidRange(*y_range))\n        super().__init__(*layers)\n\n    def __del__(self):\n        if hasattr(self, \"sfs\"): self.sfs.remove()\n\n"
  },
  {
    "path": "fastai/vision/models/wrn.py",
    "content": "from ...layers import *\nfrom ...torch_core import *\n\n__all__ = ['BasicBlock', 'WideResNet', 'wrn_22']\n\ndef _bn(ni, init_zero=False):\n    \"Batchnorm layer with 0 initialization\"\n    m = nn.BatchNorm2d(ni)\n    m.weight.data.fill_(0 if init_zero else 1)\n    m.bias.data.zero_()\n    return m\n\ndef bn_relu_conv(ni, nf, ks, stride, init_zero=False):\n    bn_initzero = _bn(ni, init_zero=init_zero)\n    return nn.Sequential(bn_initzero, nn.ReLU(inplace=True), conv2d(ni, nf, ks, stride))\n\nclass BasicBlock(Module):\n    \"Block to from a wide ResNet.\"\n    def __init__(self, ni, nf, stride, drop_p=0.0):\n        self.bn = nn.BatchNorm2d(ni)\n        self.conv1 = conv2d(ni, nf, 3, stride)\n        self.conv2 = bn_relu_conv(nf, nf, 3, 1)\n        self.drop = nn.Dropout(drop_p, inplace=True) if drop_p else None\n        self.shortcut = conv2d(ni, nf, 1, stride) if ni != nf else noop\n\n    def forward(self, x):\n        x2 = F.relu(self.bn(x), inplace=True)\n        r = self.shortcut(x2)\n        x = self.conv1(x2)\n        if self.drop: x = self.drop(x)\n        x = self.conv2(x) * 0.2\n        return x.add_(r)\n\ndef _make_group(N, ni, nf, block, stride, drop_p):\n    return [block(ni if i == 0 else nf, nf, stride if i == 0 else 1, drop_p) for i in range(N)]\n\nclass WideResNet(Module):\n    \"Wide ResNet with `num_groups` and a width of `k`.\"\n    def __init__(self, num_groups:int, N:int, num_classes:int, k:int=1, drop_p:float=0.0, start_nf:int=16, n_in_channels:int=3):\n        n_channels = [start_nf]\n        for i in range(num_groups): n_channels.append(start_nf*(2**i)*k)\n\n        layers = [conv2d(n_in_channels, n_channels[0], 3, 1)]  # conv1\n        for i in range(num_groups):\n            layers += _make_group(N, n_channels[i], n_channels[i+1], BasicBlock, (1 if i==0 else 2), drop_p)\n\n        layers += [nn.BatchNorm2d(n_channels[num_groups]), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(1),\n                   Flatten(), nn.Linear(n_channels[num_groups], num_classes)]\n        self.features = nn.Sequential(*layers)\n\n    def forward(self, x): return self.features(x)\n\n\ndef wrn_22(): \n    \"Wide ResNet with 22 layers.\"\n    return WideResNet(num_groups=3, N=3, num_classes=10, k=6, drop_p=0.)\n"
  },
  {
    "path": "fastai/vision/models/xception.py",
    "content": "from ...vision import *\n\n__all__ = ['xception']\n\ndef sep_conv(ni,nf,pad=None,pool=False,act=True):\n    layers =  [nn.ReLU()] if act else []\n    layers += [\n        nn.Conv2d(ni,ni,3,1,1,groups=ni,bias=False),\n        nn.Conv2d(ni,nf,1,bias=False),\n        nn.BatchNorm2d(nf)\n    ]\n    if pool: layers.append(nn.MaxPool2d(2))\n    return nn.Sequential(*layers)\n\ndef conv(ni,nf,ks=1,stride=1, pad=None, act=True):\n    if pad is None: pad=ks//2\n    layers = [\n        nn.Conv2d(ni,nf,ks,stride,pad,bias=False),\n        nn.BatchNorm2d(nf),\n    ]\n    if act: layers.append(nn.ReLU())\n    return nn.Sequential(*layers)\n\nclass ConvSkip(Module):\n    def __init__(self,ni,nf=None,act=True):\n        self.nf,self.ni = nf,ni\n        if self.nf is None: self.nf = ni\n        self.conv = conv(ni,nf,stride=2, act=False)\n        self.m = nn.Sequential(\n            sep_conv(ni,ni,act=act),\n            sep_conv(ni,nf,pool=True)\n        )\n\n    def forward(self,x): return self.conv(x) + self.m(x)\n\ndef middle_flow(nf):\n    layers = [sep_conv(nf,nf) for i in range(3)]\n    return SequentialEx(*layers, MergeLayer())\n\ndef xception(c, k=8, n_middle=8):\n    \"Preview version of Xception network. Not tested yet - use at own risk. No pretrained model yet.\"\n    layers = [\n        conv(3, k*4, 3, 2),\n        conv(k*4, k*8, 3),\n        ConvSkip(k*8, k*16, act=False),\n        ConvSkip(k*16, k*32),\n        ConvSkip(k*32, k*91),\n    ]\n    for i in range(n_middle): layers.append(middle_flow(k*91))\n    layers += [\n        ConvSkip(k*91,k*128),\n        sep_conv(k*128,k*192,act=False),\n        sep_conv(k*192,k*256),\n        nn.ReLU(),\n        nn.AdaptiveAvgPool2d(1),\n        Flatten(),\n        nn.Linear(k*256,c)\n    ]\n    return nn.Sequential(*layers)\n\n"
  },
  {
    "path": "fastai/vision/models/xresnet.py",
    "content": "import torch.nn as nn\nimport torch,math,sys\nimport torch.utils.model_zoo as model_zoo\nfrom functools import partial\nfrom ...torch_core import Module\n\n__all__ = ['XResNet', 'xresnet18', 'xresnet34', 'xresnet50', 'xresnet101', 'xresnet152']\n\n# or: ELU+init (a=0.54; gain=1.55)\nact_fn = nn.ReLU(inplace=True)\n\nclass Flatten(Module):\n    def forward(self, x): return x.view(x.size(0), -1)\n\ndef init_cnn(m):\n    if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)\n    if isinstance(m, (nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight)\n    for l in m.children(): init_cnn(l)\n\ndef conv(ni, nf, ks=3, stride=1, bias=False):\n    return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)\n\ndef noop(x): return x\n\ndef conv_layer(ni, nf, ks=3, stride=1, zero_bn=False, act=True):\n    bn = nn.BatchNorm2d(nf)\n    nn.init.constant_(bn.weight, 0. if zero_bn else 1.)\n    layers = [conv(ni, nf, ks, stride=stride), bn]\n    if act: layers.append(act_fn)\n    return nn.Sequential(*layers)\n\nclass ResBlock(Module):\n    def __init__(self, expansion, ni, nh, stride=1):\n        nf,ni = nh*expansion,ni*expansion\n        layers  = [conv_layer(ni, nh, 3, stride=stride),\n                   conv_layer(nh, nf, 3, zero_bn=True, act=False)\n        ] if expansion == 1 else [\n                   conv_layer(ni, nh, 1),\n                   conv_layer(nh, nh, 3, stride=stride),\n                   conv_layer(nh, nf, 1, zero_bn=True, act=False)\n        ]\n        self.convs = nn.Sequential(*layers)\n        # TODO: check whether act=True works better\n        self.idconv = noop if ni==nf else conv_layer(ni, nf, 1, act=False)\n        self.pool = noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)\n\n    def forward(self, x): return act_fn(self.convs(x) + self.idconv(self.pool(x)))\n\ndef filt_sz(recep): return min(64, 2**math.floor(math.log2(recep*0.75)))\n\nclass XResNet(nn.Sequential):\n    def __init__(self, expansion, layers, c_in=3, c_out=1000):\n        stem = []\n        sizes = [c_in,32,32,64]\n        for i in range(3):\n            stem.append(conv_layer(sizes[i], sizes[i+1], stride=2 if i==0 else 1))\n            #nf = filt_sz(c_in*9)\n            #stem.append(conv_layer(c_in, nf, stride=2 if i==1 else 1))\n            #c_in = nf\n\n        block_szs = [64//expansion,64,128,256,512]\n        blocks = [self._make_layer(expansion, block_szs[i], block_szs[i+1], l, 1 if i==0 else 2)\n                  for i,l in enumerate(layers)]\n        super().__init__(\n            *stem,\n            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),\n            *blocks,\n            nn.AdaptiveAvgPool2d(1), Flatten(),\n            nn.Linear(block_szs[-1]*expansion, c_out),\n        )\n        init_cnn(self)\n\n    def _make_layer(self, expansion, ni, nf, blocks, stride):\n        return nn.Sequential(\n            *[ResBlock(expansion, ni if i==0 else nf, nf, stride if i==0 else 1)\n              for i in range(blocks)])\n\ndef xresnet(expansion, n_layers, name, pretrained=False, **kwargs):\n    model = XResNet(expansion, n_layers, **kwargs)\n    if pretrained: model.load_state_dict(model_zoo.load_url(model_urls[name]))\n    return model\n\nme = sys.modules[__name__]\nfor n,e,l in [\n    [ 18 , 1, [2,2,2 ,2] ],\n    [ 34 , 1, [3,4,6 ,3] ],\n    [ 50 , 4, [3,4,6 ,3] ],\n    [ 101, 4, [3,4,23,3] ],\n    [ 152, 4, [3,8,36,3] ],\n]:\n    name = f'xresnet{n}'\n    setattr(me, name, partial(xresnet, expansion=e, n_layers=l, name=name))\n\n"
  },
  {
    "path": "fastai/vision/models/xresnet2.py",
    "content": "import torch.nn as nn\nimport torch\nimport math\nimport torch.utils.model_zoo as model_zoo\nfrom ...torch_core import Module\n\n\n__all__ = ['XResNet', 'xresnet18', 'xresnet34_2', 'xresnet50_2', 'xresnet101', 'xresnet152']\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n\n\nclass BasicBlock(Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(BasicBlock, self).__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None: residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n                               padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None: residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\ndef conv2d(ni, nf, stride):\n    return nn.Sequential(nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False),\n                         nn.BatchNorm2d(nf), nn.ReLU(inplace=True))\n\nclass XResNet(Module):\n\n    def __init__(self, block, layers, c_out=1000):\n        self.inplanes = 64\n        super(XResNet, self).__init__()\n        self.conv1 = conv2d(3, 32, 2)\n        self.conv2 = conv2d(32, 32, 1)\n        self.conv3 = conv2d(32, 64, 1)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n        self.avgpool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Linear(512 * block.expansion, c_out)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        for m in self.modules():\n            if isinstance(m, BasicBlock): m.bn2.weight = nn.Parameter(torch.zeros_like(m.bn2.weight))\n            if isinstance(m, Bottleneck): m.bn3.weight = nn.Parameter(torch.zeros_like(m.bn3.weight))\n            if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01)\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            layers = []\n            if stride==2: layers.append(nn.AvgPool2d(kernel_size=2, stride=2))\n            layers += [\n                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, bias=False),\n                nn.BatchNorm2d(planes * block.expansion) ]\n            downsample = nn.Sequential(*layers)\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample))\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks): layers.append(block(self.inplanes, planes))\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.conv2(x)\n        x = self.conv3(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = x.view(x.size(0), -1)\n        x = self.fc(x)\n\n        return x\n\n\ndef xresnet18(pretrained=False, **kwargs):\n    \"\"\"Constructs a XResNet-18 model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = XResNet(BasicBlock, [2, 2, 2, 2], **kwargs)\n    if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet18']))\n    return model\n\n\ndef xresnet34_2(pretrained=False, **kwargs):\n    \"\"\"Constructs a XResNet-34 model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = XResNet(BasicBlock, [3, 4, 6, 3], **kwargs)\n    if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet34']))\n    return model\n\n\ndef xresnet50_2(pretrained=False, **kwargs):\n    \"\"\"Constructs a XResNet-50 model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = XResNet(Bottleneck, [3, 4, 6, 3], **kwargs)\n    if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet50']))\n    return model\n\n\ndef xresnet101(pretrained=False, **kwargs):\n    \"\"\"Constructs a XResNet-101 model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = XResNet(Bottleneck, [3, 4, 23, 3], **kwargs)\n    if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet101']))\n    return model\n\n\ndef xresnet152(pretrained=False, **kwargs):\n    \"\"\"Constructs a XResNet-152 model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = XResNet(Bottleneck, [3, 8, 36, 3], **kwargs)\n    if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet152']))\n    return model\n\n"
  },
  {
    "path": "fastai/vision/transform.py",
    "content": "\"Image transformations for data augmentation. All transforms are done on the tensor level\"\nfrom ..torch_core import *\nfrom .image import *\nfrom .image import _affine_mult\n\n__all__ = ['brightness', 'contrast', 'crop', 'crop_pad', 'cutout', 'dihedral', 'dihedral_affine', 'flip_affine', 'flip_lr',\n           'get_transforms', 'jitter', 'pad', 'perspective_warp', 'rand_pad', 'rand_crop', 'rand_zoom', 'rgb_randomize', 'rotate', 'skew', 'squish',\n           'rand_resize_crop', 'symmetric_warp', 'tilt', 'zoom', 'zoom_crop']\n\n_pad_mode_convert = {'reflection':'reflect', 'zeros':'constant', 'border':'replicate'}\n\n#NB: Although TfmLighting etc can be used as decorators, that doesn't work in Windows,\n#    so we do it manually for now.\n\ndef _brightness(x, change:uniform):\n    \"Apply `change` in brightness of image `x`.\"\n    return x.add_(scipy.special.logit(change))\nbrightness = TfmLighting(_brightness)\n\ndef _contrast(x, scale:log_uniform):\n    \"Apply `scale` to contrast of image `x`.\"\n    return x.mul_(scale)\ncontrast = TfmLighting(_contrast)\n\ndef _rotate(degrees:uniform):\n    \"Rotate image by `degrees`.\"\n    angle = degrees * math.pi / 180\n    return [[float(cos(angle)), float(-sin(angle)), 0.],\n            [float(sin(angle)),  float(cos(angle)), 0.],\n            [0.        ,  0.        , 1.]]\nrotate = TfmAffine(_rotate)\n\ndef _get_zoom_mat(sw:float, sh:float, c:float, r:float)->AffineMatrix:\n    \"`sw`,`sh` scale width,height - `c`,`r` focus col,row.\"\n    return [[sw, 0,  c],\n            [0, sh,  r],\n            [0,  0, 1.]]\n\ndef _zoom(scale:uniform=1.0, row_pct:uniform=0.5, col_pct:uniform=0.5):\n    \"Zoom image by `scale`. `row_pct`,`col_pct` select focal point of zoom.\"\n    s = 1-1/scale\n    col_c = s * (2*col_pct - 1)\n    row_c = s * (2*row_pct - 1)\n    return _get_zoom_mat(1/scale, 1/scale, col_c, row_c)\nzoom = TfmAffine(_zoom)\n\ndef _squish(scale:uniform=1.0, row_pct:uniform=0.5, col_pct:uniform=0.5):\n    \"Squish image by `scale`. `row_pct`,`col_pct` select focal point of zoom.\"\n    if scale <= 1:\n        col_c = (1-scale) * (2*col_pct - 1)\n        return _get_zoom_mat(scale, 1, col_c, 0.)\n    else:\n        row_c = (1-1/scale) * (2*row_pct - 1)\n        return _get_zoom_mat(1, 1/scale, 0., row_c)\nsquish = TfmAffine(_squish)\n\ndef _jitter(c, magnitude:uniform):\n    \"Replace pixels by random neighbors at `magnitude`.\"\n    c.flow.add_((torch.rand_like(c.flow)-0.5)*magnitude*2)\n    return c\njitter = TfmCoord(_jitter)\n\ndef _flip_lr(x):\n    \"Flip `x` horizontally.\"\n    #return x.flip(2)\n    if isinstance(x, ImagePoints):\n        x.flow.flow[...,0] *= -1\n        return x\n    return tensor(np.ascontiguousarray(np.array(x)[...,::-1]))\nflip_lr = TfmPixel(_flip_lr)\n\ndef _flip_affine() -> TfmAffine:\n    \"Flip `x` horizontally.\"\n    return [[-1, 0, 0.],\n            [0,  1, 0],\n            [0,  0, 1.]]\nflip_affine = TfmAffine(_flip_affine)\n\ndef _dihedral(x, k:partial(uniform_int,0,7)):\n    \"Randomly flip `x` image based on `k`.\"\n    flips=[]\n    if k&1: flips.append(1)\n    if k&2: flips.append(2)\n    if flips: x = torch.flip(x,flips)\n    if k&4: x = x.transpose(1,2)\n    return x.contiguous()\ndihedral = TfmPixel(_dihedral)\n\ndef _dihedral_affine(k:partial(uniform_int,0,7)):\n    \"Randomly flip `x` image based on `k`.\"\n    x = -1 if k&1 else 1\n    y = -1 if k&2 else 1\n    if k&4: return [[0, x, 0.],\n                    [y, 0, 0],\n                    [0, 0, 1.]]\n    return [[x, 0, 0.],\n            [0, y, 0],\n            [0, 0, 1.]]\ndihedral_affine = TfmAffine(_dihedral_affine)\n\ndef _pad_coord(x, row_pad:int, col_pad:int, mode='zeros'):\n    #TODO: implement other padding modes than zeros?\n    h,w = x.size\n    pad = torch.Tensor([w/(w + 2*col_pad), h/(h + 2*row_pad)])\n    x.flow = FlowField((h+2*row_pad, w+2*col_pad) , x.flow.flow * pad[None])\n    return x\n\ndef _pad_default(x, padding:int, mode='reflection'):\n    \"Pad `x` with `padding` pixels. `mode` fills in space ('zeros','reflection','border').\"\n    mode = _pad_mode_convert[mode]\n    return F.pad(x[None], (padding,)*4, mode=mode)[0]\n\ndef _pad_image_points(x, padding:int, mode='reflection'):\n    return _pad_coord(x, padding, padding, mode)\n\ndef _pad(x, padding:int, mode='reflection'):\n    f_pad = _pad_image_points if isinstance(x, ImagePoints) else  _pad_default\n    return f_pad(x, padding, mode)\n\npad = TfmPixel(_pad, order=-10)\n\ndef _cutout(x, n_holes:uniform_int=1, length:uniform_int=40):\n    \"Cut out `n_holes` number of square holes of size `length` in image at random locations.\"\n    h,w = x.shape[1:]\n    for n in range(n_holes):\n        h_y = np.random.randint(0, h)\n        h_x = np.random.randint(0, w)\n        y1 = int(np.clip(h_y - length / 2, 0, h))\n        y2 = int(np.clip(h_y + length / 2, 0, h))\n        x1 = int(np.clip(h_x - length / 2, 0, w))\n        x2 = int(np.clip(h_x + length / 2, 0, w))\n        x[:, y1:y2, x1:x2] = 0\n    return x\n\ncutout = TfmPixel(_cutout, order=20)\n\ndef _rgb_randomize(x, channel:int=None, thresh:float=0.3):\n    \"Randomize one of the channels of the input image\"\n    if channel is None: channel = np.random.randint(0, x.shape[0] - 1)\n    x[channel] = torch.rand(x.shape[1:]) * np.random.uniform(0, thresh)\n    return x\n\nrgb_randomize = TfmPixel(_rgb_randomize)\n\ndef _minus_epsilon(row_pct:float, col_pct:float, eps:float=1e-7):\n    if row_pct==1.: row_pct -= 1e-7\n    if col_pct==1.: col_pct -= 1e-7\n    return row_pct,col_pct\n\ndef _crop_default(x, size, row_pct:uniform=0.5, col_pct:uniform=0.5):\n    \"Crop `x` to `size` pixels. `row_pct`,`col_pct` select focal point of crop.\"\n    rows,cols = tis2hw(size)\n    row_pct,col_pct = _minus_epsilon(row_pct,col_pct)\n    row = int((x.size(1)-rows+1) * row_pct)\n    col = int((x.size(2)-cols+1) * col_pct)\n    return x[:, row:row+rows, col:col+cols].contiguous()\n\ndef _crop_image_points(x, size, row_pct=0.5, col_pct=0.5):\n    h,w = x.size\n    rows,cols = tis2hw(size)\n    row_pct,col_pct = _minus_epsilon(row_pct,col_pct)\n    x.flow.flow.mul_(torch.Tensor([w/cols, h/rows])[None])\n    row = int((h-rows+1) * row_pct)\n    col = int((w-cols+1) * col_pct)\n    x.flow.flow.add_(-1 + torch.Tensor([w/cols-2*col/cols, h/rows-2*row/rows])[None])\n    x.size = (rows, cols)\n    return x\n\ndef _crop(x, size, row_pct:uniform=0.5, col_pct:uniform=0.5):\n    f_crop = _crop_image_points if isinstance(x, ImagePoints) else _crop_default\n    return f_crop(x, size, row_pct, col_pct)\n\ncrop = TfmPixel(_crop)\n\ndef _crop_pad_default(x, size, padding_mode='reflection', row_pct:uniform = 0.5, col_pct:uniform = 0.5):\n    \"Crop and pad tfm - `row_pct`,`col_pct` sets focal point.\"\n    padding_mode = _pad_mode_convert[padding_mode]\n    size = tis2hw(size)\n    if x.shape[1:] == torch.Size(size): return x\n    rows,cols = size\n    row_pct,col_pct = _minus_epsilon(row_pct,col_pct)\n    if x.size(1)<rows or x.size(2)<cols:\n        row_pad = max((rows-x.size(1)+1)//2, 0)\n        col_pad = max((cols-x.size(2)+1)//2, 0)\n        x = F.pad(x[None], (col_pad,col_pad,row_pad,row_pad), mode=padding_mode)[0]\n    row = int((x.size(1)-rows+1)*row_pct)\n    col = int((x.size(2)-cols+1)*col_pct)\n    x = x[:, row:row+rows, col:col+cols]\n    return x.contiguous() # without this, get NaN later - don't know why\n\ndef _crop_pad_image_points(x, size, padding_mode='reflection', row_pct = 0.5, col_pct = 0.5):\n    size = tis2hw(size)\n    rows,cols = size\n    if x.size[0]<rows or x.size[1]<cols:\n        row_pad = max((rows-x.size[0]+1)//2, 0)\n        col_pad = max((cols-x.size[1]+1)//2, 0)\n        x = _pad_coord(x, row_pad, col_pad)\n    return crop(x,(rows,cols), row_pct, col_pct)\n\ndef _crop_pad(x, size, padding_mode='reflection', row_pct:uniform = 0.5, col_pct:uniform = 0.5):\n    f_crop_pad = _crop_pad_image_points if isinstance(x, ImagePoints) else _crop_pad_default\n    return f_crop_pad(x, size, padding_mode, row_pct, col_pct)\n\ncrop_pad = TfmCrop(_crop_pad)\n\ndef _image_maybe_add_crop_pad(img, tfms):\n    tfm_names = [tfm.__name__ for tfm in tfms]\n    return [crop_pad()] + tfms if 'crop_pad' not in tfm_names else tfms\nImage._maybe_add_crop_pad = _image_maybe_add_crop_pad\n\nrand_pos = {'row_pct':(0,1), 'col_pct':(0,1)}\n\ndef rand_pad(padding:int, size:int, mode:str='reflection'):\n    \"Fixed `mode` `padding` and random crop of `size`\"\n    return [pad(padding=padding,mode=mode),\n            crop(size=size, **rand_pos)]\n\ndef rand_zoom(scale:uniform=1.0, p:float=1.):\n    \"Randomized version of `zoom`.\"\n    return zoom(scale=scale, **rand_pos, p=p)\n\ndef rand_crop(*args, padding_mode='reflection', p:float=1.):\n    \"Randomized version of `crop_pad`.\"\n    return crop_pad(*args, **rand_pos, padding_mode=padding_mode, p=p)\n\ndef zoom_crop(scale:float, do_rand:bool=False, p:float=1.0):\n    \"Randomly zoom and/or crop.\"\n    zoom_fn = rand_zoom if do_rand else zoom\n    crop_fn = rand_crop if do_rand else crop_pad\n    return [zoom_fn(scale=scale, p=p), crop_fn()]\n\ndef _find_coeffs(orig_pts:Points, targ_pts:Points)->Tensor:\n    \"Find 8 coeff mentioned [here](https://web.archive.org/web/20150222120106/xenia.media.mit.edu/~cwren/interpolator/).\"\n    matrix = []\n    #The equations we'll need to solve.\n    for p1, p2 in zip(targ_pts, orig_pts):\n        matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0]*p1[0], -p2[0]*p1[1]])\n        matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1]*p1[0], -p2[1]*p1[1]])\n\n    A = FloatTensor(matrix)\n    B = FloatTensor(orig_pts).view(8, 1)\n    #The 8 scalars we seek are solution of AX = B\n    return torch.linalg.solve(A,B)[:,0]\n\ndef _apply_perspective(coords:FlowField, coeffs:Points)->FlowField:\n    \"Transform `coords` with `coeffs`.\"\n    size = coords.flow.size()\n    #compress all the dims expect the last one ang adds ones, coords become N * 3\n    coords.flow = coords.flow.view(-1,2)\n    #Transform the coeffs in a 3*3 matrix with a 1 at the bottom left\n    coeffs = torch.cat([coeffs, FloatTensor([1])]).view(3,3)\n    coords.flow = torch.addmm(coeffs[:,2], coords.flow, coeffs[:,:2].t())\n    coords.flow.mul_(1/coords.flow[:,2].unsqueeze(1))\n    coords.flow = coords.flow[:,:2].view(size)\n    return coords\n\n_orig_pts = [[-1,-1], [-1,1], [1,-1], [1,1]]\n\ndef _do_perspective_warp(c:FlowField, targ_pts:Points, invert=False):\n    \"Apply warp to `targ_pts` from `_orig_pts` to `c` `FlowField`.\"\n    if invert: return _apply_perspective(c, _find_coeffs(targ_pts, _orig_pts))\n    return _apply_perspective(c, _find_coeffs(_orig_pts, targ_pts))\n\ndef _perspective_warp(c, magnitude:partial(uniform,size=8)=0, invert=False):\n    \"Apply warp of `magnitude` to `c`.\"\n    magnitude = magnitude.view(4,2)\n    targ_pts = [[x+m for x,m in zip(xs, ms)] for xs, ms in zip(_orig_pts, magnitude)]\n    return _do_perspective_warp(c, targ_pts, invert)\nperspective_warp = TfmCoord(_perspective_warp)\n\ndef _symmetric_warp(c, magnitude:partial(uniform,size=4)=0, invert=False):\n    \"Apply symmetric warp of `magnitude` to `c`.\"\n    m = listify(magnitude, 4)\n    targ_pts = [[-1-m[3],-1-m[1]], [-1-m[2],1+m[1]], [1+m[3],-1-m[0]], [1+m[2],1+m[0]]]\n    return _do_perspective_warp(c, targ_pts, invert)\nsymmetric_warp = TfmCoord(_symmetric_warp)\n\ndef _tilt(c, direction:uniform_int, magnitude:uniform=0, invert=False):\n    \"Tilt `c` field with random `direction` and `magnitude`.\"\n    orig_pts = [[-1,-1], [-1,1], [1,-1], [1,1]]\n    if direction == 0:   targ_pts = [[-1,-1], [-1,1], [1,-1-magnitude], [1,1+magnitude]]\n    elif direction == 1: targ_pts = [[-1,-1-magnitude], [-1,1+magnitude], [1,-1], [1,1]]\n    elif direction == 2: targ_pts = [[-1,-1], [-1-magnitude,1], [1,-1], [1+magnitude,1]]\n    elif direction == 3: targ_pts = [[-1-magnitude,-1], [-1,1], [1+magnitude,-1], [1,1]]\n    coeffs = _find_coeffs(targ_pts, _orig_pts) if invert else _find_coeffs(_orig_pts, targ_pts)\n    return _apply_perspective(c, coeffs)\ntilt = TfmCoord(_tilt)\n\ndef _skew(c, direction:uniform_int, magnitude:uniform=0, invert=False):\n    \"Skew `c` field with random `direction` and `magnitude`.\"\n    orig_pts = [[-1,-1], [-1,1], [1,-1], [1,1]]\n    if direction == 0:   targ_pts = [[-1-magnitude,-1], [-1,1], [1,-1], [1,1]]\n    elif direction == 1: targ_pts = [[-1,-1-magnitude], [-1,1], [1,-1], [1,1]]\n    elif direction == 2: targ_pts = [[-1,-1], [-1-magnitude,1], [1,-1], [1,1]]\n    elif direction == 3: targ_pts = [[-1,-1], [-1,1+magnitude], [1,-1], [1,1]]\n    elif direction == 4: targ_pts = [[-1,-1], [-1,1], [1+magnitude,-1], [1,1]]\n    elif direction == 5: targ_pts = [[-1,-1], [-1,1], [1,-1-magnitude], [1,1]]\n    elif direction == 6: targ_pts = [[-1,-1], [-1,1], [1,-1], [1+magnitude,1]]\n    elif direction == 7: targ_pts = [[-1,-1], [-1,1], [1,-1], [1,1+magnitude]]\n    coeffs = _find_coeffs(targ_pts, _orig_pts) if invert else _find_coeffs(_orig_pts, targ_pts)\n    return _apply_perspective(c, coeffs)\nskew = TfmCoord(_skew)\n\ndef get_transforms(do_flip:bool=True, flip_vert:bool=False, max_rotate:float=10., max_zoom:float=1.1,\n                   max_lighting:float=0.2, max_warp:float=0.2, p_affine:float=0.75,\n                   p_lighting:float=0.75, xtra_tfms:Optional[Collection[Transform]]=None)->Collection[Transform]:\n    \"Utility func to easily create a list of flip, rotate, `zoom`, warp, lighting transforms.\"\n    res = [rand_crop()]\n    if do_flip:    res.append(dihedral_affine() if flip_vert else flip_lr(p=0.5))\n    if max_warp:   res.append(symmetric_warp(magnitude=(-max_warp,max_warp), p=p_affine))\n    if max_rotate: res.append(rotate(degrees=(-max_rotate,max_rotate), p=p_affine))\n    if max_zoom>1: res.append(rand_zoom(scale=(1.,max_zoom), p=p_affine))\n    if max_lighting:\n        res.append(brightness(change=(0.5*(1-max_lighting), 0.5*(1+max_lighting)), p=p_lighting))\n        res.append(contrast(scale=(1-max_lighting, 1/(1-max_lighting)), p=p_lighting))\n    #       train                   , valid\n    return (res + listify(xtra_tfms), [crop_pad()])\n\ndef _compute_zs_mat(sz:TensorImageSize, scale:float, squish:float,\n                   invert:bool, row_pct:float, col_pct:float)->AffineMatrix:\n    \"Utility routine to compute zoom/squish matrix.\"\n    orig_ratio = math.sqrt(sz[1]/sz[0])\n    for s,r,i in zip(scale,squish, invert):\n        s,r = 1/math.sqrt(s),math.sqrt(r)\n        if s * r <= 1 and s / r <= 1: #Test if we are completely inside the picture\n            w,h = (s/r, s*r) if i else (s*r,s/r)\n            col_c = (1-w) * (2*col_pct - 1)\n            row_c = (1-h) * (2*row_pct - 1)\n            return _get_zoom_mat(w, h, col_c, row_c)\n\n    #Fallback, hack to emulate a center crop without cropping anything yet.\n    if orig_ratio > 1: return _get_zoom_mat(1/orig_ratio**2, 1, 0, 0.)\n    else:              return _get_zoom_mat(1, orig_ratio**2, 0, 0.)\n\ndef _zoom_squish(c, scale:uniform=1.0, squish:uniform=1.0, invert:rand_bool=False,\n                row_pct:uniform=0.5, col_pct:uniform=0.5):\n    #This is intended for scale, squish and invert to be of size 10 (or whatever) so that the transform\n    #can try a few zoom/squishes before falling back to center crop (like torchvision.RandomResizedCrop)\n    m = _compute_zs_mat(c.size, scale, squish, invert, row_pct, col_pct)\n    return _affine_mult(c, FloatTensor(m))\nzoom_squish = TfmCoord(_zoom_squish)\n\ndef rand_resize_crop(size:int, max_scale:float=2., ratios:Tuple[float,float]=(0.75,1.33)):\n    \"Randomly resize and crop the image to a ratio in `ratios` after a zoom of `max_scale`.\"\n    return [zoom_squish(scale=(1.,max_scale,8), squish=(*ratios,8), invert=(0.5,8), row_pct=(0.,1.), col_pct=(0.,1.)),\n            crop(size=size)]\n"
  },
  {
    "path": "fastai/vision/tta.py",
    "content": "\"Brings TTA (Test Time Functionality) to the `Learner` class. Use `learner.TTA()` instead\"\nfrom ..torch_core import *\nfrom ..basic_train import *\nfrom ..basic_train import _loss_func2activ\nfrom ..basic_data import DatasetType\nfrom .transform import *\n\n__all__ = []\n\ndef _tta_only(learn:Learner, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None, scale:float=1.35) -> Iterator[List[Tensor]]:\n    \"Computes the outputs for several augmented inputs for TTA\"\n    dl = learn.dl(ds_type)\n    ds = dl.dataset\n    old = ds.tfms\n    activ = ifnone(activ, _loss_func2activ(learn.loss_func))\n    augm_tfm = [o for o in learn.data.train_ds.tfms if o.tfm not in\n               (crop_pad, flip_lr, dihedral, zoom)]\n    try:\n        pbar = master_bar(range(8))\n        for i in pbar:\n            row = 1 if i&1 else 0\n            col = 1 if i&2 else 0\n            flip = i&4\n            d = {'row_pct':row, 'col_pct':col, 'is_random':False}\n            tfm = [*augm_tfm, zoom(scale=scale, **d), crop_pad(**d)]\n            if flip: tfm.append(flip_lr(p=1.))\n            ds.tfms = tfm\n            yield get_preds(learn.model, dl, pbar=pbar, activ=activ)[0]\n    finally: ds.tfms = old\n\nLearner.tta_only = _tta_only\n\ndef _TTA(learn:Learner, beta:float=0.4, scale:float=1.35, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None, with_loss:bool=False) -> Tensors:\n    \"Applies TTA to predict on `ds_type` dataset.\"\n    preds,y = learn.get_preds(ds_type, activ=activ)\n    all_preds = list(learn.tta_only(ds_type=ds_type, activ=activ, scale=scale))\n    avg_preds = torch.stack(all_preds).mean(0)\n    if beta is None: return preds,avg_preds,y\n    else:\n        final_preds = preds*beta + avg_preds*(1-beta)\n        if with_loss:\n            with NoneReduceOnCPU(learn.loss_func) as lf: loss = lf(final_preds, y)\n            return final_preds, y, loss\n        return final_preds, y\n\nLearner.TTA = _TTA\n"
  },
  {
    "path": "fastai/widgets/__init__.py",
    "content": "from .class_confusion import *\nfrom .image_cleaner import *\nfrom .image_downloader import *\n"
  },
  {
    "path": "fastai/widgets/class_confusion.py",
    "content": "import math\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom tqdm import tqdm\nfrom itertools import permutations\nfrom ..tabular import TabularDataBunch\nfrom ..train import ClassificationInterpretation\nimport ipywidgets as widgets\n\nclass ClassConfusion():\n    \"Plot the most confused datapoints and statistics for the models misses.\" \n    def __init__(self, interp:ClassificationInterpretation, classlist:list, \n               is_ordered:bool=False, cut_off:int=100, varlist:list=None,\n               figsize:tuple=(8,8)):\n        self.interp = interp\n        self._is_tab = isinstance(interp.learn.data, TabularDataBunch)\n        if self._is_tab:\n            if interp.learn.data.train_ds.x.cont_names != []: \n                for x in range(len(interp.learn.data.procs)):\n                      if \"Normalize\" in str(interp.learn.data.procs[x]):\n                            self.means = interp.learn.data.train_ds.x.processor[0].procs[x].means\n                            self.stds = interp.learn.data.train_ds.x.processor[0].procs[x].stds\n        self.is_ordered = is_ordered\n        self.cut_off = cut_off\n        self.figsize = figsize\n        self.varlist = varlist\n        self.classl = classlist\n        self._show_losses(classlist)\n        \n    def _show_losses(self, classl:list, **kwargs):\n        \"Checks if the model is for Tabular or Images and gathers top losses\"\n        _, self.tl_idx = self.interp.top_losses(len(self.interp.losses))\n        self._tab_losses() if self._is_tab else self._create_tabs()\n        \n    def _create_tabs(self):\n        \"Creates a tab for each variable\"\n        self.lis = self.classl if self.is_ordered else list(permutations(self.classl, 2))\n        if self._is_tab:\n            self._boxes = len(self.df_list)\n            self._cols = math.ceil(math.sqrt(self._boxes))\n            self._rows = math.ceil(self._boxes/self._cols)\n            self.tbnames = list(self.df_list[0].columns)[:-1] if self.varlist is None else self.varlist\n        else:\n            vals = self.interp.most_confused()\n            self._ranges = []\n            self.tbnames = []\n            self._boxes = int(input('Please enter a value for `k`, or the top images you will see: '))\n            for x in iter(vals):\n                for y in range(len(self.lis)):\n                    if x[0:2] == self.lis[y]:\n                        self._ranges.append(x[2])\n                        self.tbnames.append(str(x[0] + ' | ' + x[1]))\n        items = [widgets.Output() for i, tab in enumerate(self.tbnames)]\n        self.tabs = widgets.Tab()\n        self.tabs.children = items\n        for i in range(len(items)):\n            self.tabs.set_title(i, self.tbnames[i])\n        self._populate_tabs()\n        \n    def _populate_tabs(self):\n        \"Adds relevant graphs to each tab\"\n        with tqdm(total=len(self.tbnames)) as pbar:\n            for i, tab in enumerate(self.tbnames):\n                with self.tabs.children[i]:\n                    self._plot_tab(tab) if self._is_tab else self._plot_imgs(tab, i)\n                pbar.update(1)\n        display(self.tabs)\n        \n    def _plot_tab(self, tab:str):\n        \"Generates graphs\"\n        if self._boxes is not None:\n            fig, ax = plt.subplots(self._boxes, figsize=self.figsize)\n        else:\n            fig, ax = plt.subplots(self._cols, self._rows, figsize=self.figsize)\n        fig.subplots_adjust(hspace=.5)\n        for j, x in enumerate(self.df_list):\n            title = f'{\"\".join(x.columns[-1])} {tab} distribution'\n            \n            if self._boxes is None:\n                row = int(j / self._cols)\n                col = j % row\n            if tab in self.cat_names:\n                vals = pd.value_counts(x[tab].values)\n                if self._boxes is not None:\n                    if vals.nunique() < 10:\n                        fig = vals.plot(kind='bar', title=title,  ax=ax[j], rot=0, width=.75)\n                    elif vals.nunique() > self.cut_off:\n                        print(f'Number of values is above {self.cut_off}')\n                    else:\n                        fig = vals.plot(kind='barh', title=title,  ax=ax[j], width=.75)   \n                else:\n                    fig = vals.plot(kind='barh', title=title,  ax=ax[row, col], width=.75)\n            else:\n                vals = x[tab]\n                if self._boxes is not None:\n                    axs = vals.plot(kind='hist', ax=ax[j], title=title, y='Frequency')\n                else:\n                    axs = vals.plot(kind='hist', ax=ax[row, col], title=title, y='Frequency')\n                axs.set_ylabel('Frequency')\n                if len(set(vals)) > 1:\n                    vals.plot(kind='kde', ax=axs, title=title, secondary_y=True)\n                else:\n                    print('Less than two unique values, cannot graph the KDE')\n        plt.show(fig)\n        plt.tight_layout()\n\n    def _plot_imgs(self, tab:str, i:int ,**kwargs):\n        \"Plots the most confused images\"\n        classes_gnd = self.interp.data.classes\n        x = 0\n        if self._ranges[i] < self._boxes:\n            cols = math.ceil(math.sqrt(self._ranges[i]))\n            rows = math.ceil(self._ranges[i]/cols)\n        if self._ranges[i] < 4 or self._boxes < 4:\n            cols = 2\n            rows = 2\n        else:\n            cols = math.ceil(math.sqrt(self._boxes))\n            rows = math.ceil(self._boxes/cols)\n        fig, ax = plt.subplots(rows, cols, figsize=self.figsize)\n        [axi.set_axis_off() for axi in ax.ravel()]\n        for j, idx in enumerate(self.tl_idx):\n            if self._boxes < x+1 or x > self._ranges[i]:\n                break\n            da, cl = self.interp.data.dl(self.interp.ds_type).dataset[idx]\n            row = (int)(x / cols)\n            col = x % cols\n            if str(cl) == tab.split(' ')[0] and str(classes_gnd[self.interp.pred_class[idx]]) == tab.split(' ')[2]:\n                img, lbl = self.interp.data.valid_ds[idx]\n                fn = self.interp.data.valid_ds.x.items[idx]\n                fn = re.search('([^/*]+)_\\d+.*$', str(fn)).group(0)\n                img.show(ax=ax[row, col])\n                ax[row,col].set_title(fn)\n                x += 1\n        plt.show(fig)\n        plt.tight_layout()\n\n    def _tab_losses(self, **kwargs):\n        \"Gathers dataframes of the combinations data\"\n        classes = self.interp.data.classes\n        cat_names = self.interp.data.x.cat_names\n        cont_names = self.interp.data.x.cont_names\n        comb = self.classl if self.is_ordered else list(permutations(self.classl,2))\n        self.df_list = []\n        arr = []\n        for i, idx in enumerate(self.tl_idx):\n            da, _ = self.interp.data.dl(self.interp.ds_type).dataset[idx]\n            res = ''\n            for c, n in zip(da.cats, da.names[:len(da.cats)]):\n                string = f'{da.classes[n][c]}'\n                if string == 'True' or string == 'False':\n                    string += ';'\n                    res += string\n                else:\n                    string = string[1:]\n                    res += string + ';'\n            for c, n in zip(da.conts, da.names[len(da.cats):]):\n                res += f'{c:.4f};'\n            arr.append(res)\n        f = pd.DataFrame([ x.split(';')[:-1] for x in arr], columns=da.names)\n        for i, var in enumerate(self.interp.data.cont_names):\n            f[var] = f[var].apply(lambda x: float(x) * self.stds[var] + self.means[var])\n        f['Original'] = 'Original'\n        self.df_list.append(f)\n        for j, x in enumerate(comb):\n            arr = []\n            for i, idx in enumerate(self.tl_idx):\n                da, cl = self.interp.data.dl(self.interp.ds_type).dataset[idx]\n                cl = int(cl)\n                if classes[self.interp.pred_class[idx]] == comb[j][0] and classes[cl] == comb[j][1]:\n                    res = ''\n                    for c, n in zip(da.cats, da.names[:len(da.cats)]):\n                        string = f'{da.classes[n][c]}'\n                        if string == 'True' or string == 'False':\n                            string += ';'\n                            res += string\n                        else:\n                            string = string[1:]\n                            res += string + ';'\n                    for c, n in zip(da.conts, da.names[len(da.cats):]):\n                        res += f'{c:.4f};'\n                    arr.append(res)      \n            f = pd.DataFrame([ x.split(';')[:-1] for x in arr], columns=da.names)\n            for i, var in enumerate(self.interp.data.cont_names):\n                f[var] = f[var].apply(lambda x: float(x) * self.stds[var] + self.means[var])\n            f[str(x)] = str(x)\n            self.df_list.append(f)\n        self.cat_names = cat_names\n        self._create_tabs()\n"
  },
  {
    "path": "fastai/widgets/image_cleaner.py",
    "content": "from ..torch_core import *\nfrom ..basic_train import *\nfrom ..basic_data import *\nfrom ..vision.data import *\nfrom ..vision.transform import *\nfrom ..vision.image import *\nfrom ..callbacks.hooks import *\nfrom ..layers import *\nfrom ipywidgets import widgets, Layout\nfrom IPython.display import clear_output, display\n\n__all__ = ['DatasetFormatter', 'ImageCleaner']\n\nclass DatasetFormatter():\n    \"Returns a dataset with the appropriate format and file indices to be displayed.\"\n    @classmethod\n    def from_toplosses(cls, learn, n_imgs=None, **kwargs):\n        \"Gets indices with top losses.\"\n        train_ds, train_idxs = cls.get_toplosses_idxs(learn, n_imgs, **kwargs)\n        return train_ds, train_idxs\n\n    @classmethod\n    def get_toplosses_idxs(cls, learn, n_imgs, **kwargs):\n        \"Sorts `ds_type` dataset by top losses and returns dataset and sorted indices.\"\n        dl = learn.data.fix_dl\n        if not n_imgs: n_imgs = len(dl.dataset)\n        _,_,top_losses = learn.get_preds(ds_type=DatasetType.Fix, with_loss=True)\n        idxs = torch.topk(top_losses, n_imgs)[1]\n        return cls.padded_ds(dl.dataset, **kwargs), idxs\n\n    def padded_ds(ll_input, size=(250, 300), resize_method=ResizeMethod.CROP, padding_mode='zeros', **kwargs):\n        \"For a LabelList `ll_input`, resize each image to `size` using `resize_method` and `padding_mode`.\"\n        return ll_input.transform(tfms=crop_pad(), size=size, resize_method=resize_method, padding_mode=padding_mode)\n    \n    @classmethod\n    def from_similars(cls, learn, layer_ls:list=[0, 7, 2], **kwargs):\n        \"Gets the indices for the most similar images.\"\n        train_ds, train_idxs = cls.get_similars_idxs(learn, layer_ls, **kwargs)\n        return train_ds, train_idxs\n\n    @classmethod\n    def get_similars_idxs(cls, learn, layer_ls, **kwargs):\n        \"Gets the indices for the most similar images in `ds_type` dataset\"\n        hook = hook_output(learn.model[layer_ls[0]][layer_ls[1]][layer_ls[2]])\n        dl = learn.data.fix_dl\n\n        ds_actns = cls.get_actns(learn, hook=hook, dl=dl, **kwargs)\n        similarities = cls.comb_similarity(ds_actns, ds_actns, **kwargs)\n        idxs = cls.sort_idxs(similarities)\n        return cls.padded_ds(dl, **kwargs), idxs\n\n    @staticmethod\n    def get_actns(learn, hook:Hook, dl:DataLoader, pool=AdaptiveConcatPool2d, pool_dim:int=4, **kwargs):\n        \"Gets activations at the layer specified by `hook`, applies `pool` of dim `pool_dim` and concatenates\"\n        print('Getting activations...')\n\n        actns = []\n        learn.model.eval()\n        with torch.no_grad():\n            for (xb,yb) in progress_bar(dl):\n                learn.model(xb)\n                actns.append((hook.stored).cpu())\n\n        if pool:\n            pool = pool(pool_dim)\n            return pool(torch.cat(actns)).view(len(dl.x),-1)\n        else: return torch.cat(actns).view(len(dl.x),-1)\n\n\n    @staticmethod\n    def comb_similarity(t1: torch.Tensor, t2: torch.Tensor, **kwargs):\n        # https://github.com/pytorch/pytorch/issues/11202\n        \"Computes the similarity function between each embedding of `t1` and `t2` matrices.\"\n        print('Computing similarities...')\n\n        w1 = t1.norm(p=2, dim=1, keepdim=True)\n        w2 = w1 if t2 is t1 else t2.norm(p=2, dim=1, keepdim=True)\n\n        t = torch.mm(t1, t2.t()) / (w1 * w2.t()).clamp(min=1e-8)\n        return torch.tril(t, diagonal=-1) \n\n    def largest_indices(arr, n):\n        \"Returns the `n` largest indices from a numpy array `arr`.\"\n        #https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array\n        flat = arr.flatten()\n        indices = np.argpartition(flat, -n)[-n:]\n        indices = indices[np.argsort(-flat[indices])]\n        return np.unravel_index(indices, arr.shape)\n\n    @classmethod\n    def sort_idxs(cls, similarities):\n        \"Sorts `similarities` and return the indexes in pairs ordered by highest similarity.\"\n        idxs = cls.largest_indices(similarities, len(similarities))\n        idxs = [(idxs[0][i], idxs[1][i]) for i in range(len(idxs[0]))]\n        return [e for l in idxs for e in l]\n\nclass ImageCleaner():\n    \"Displays images for relabeling or deletion and saves changes in `path` as 'cleaned.csv'.\"\n    def __init__(self, dataset, fns_idxs, path, batch_size:int=5, duplicates=False):\n        self._all_images,self._batch = [],[]\n        self._path = Path(path)\n        self._batch_size = batch_size\n        if duplicates: self._batch_size = 2\n        self._duplicates = duplicates\n        self._labels = dataset.classes\n        self._all_images = self.create_image_list(dataset, fns_idxs)\n        self._csv_dict = {dataset.x.items[i]: dataset.y[i] for i in range(len(dataset))}\n        self._deleted_fns = []\n        self._skipped = 0\n        self.render()\n\n    @classmethod\n    def make_img_widget(cls, img, layout=Layout(), format='jpg'):\n        \"Returns an image widget for specified file name `img`.\"\n        return widgets.Image(value=img, format=format, layout=layout)\n\n    @classmethod\n    def make_button_widget(cls, label, file_path=None, handler=None, style=None, layout=Layout(width='auto')):\n        \"Return a Button widget with specified `handler`.\"\n        btn = widgets.Button(description=label, layout=layout)\n        if handler is not None: btn.on_click(handler)\n        if style is not None: btn.button_style = style\n        btn.file_path = file_path\n        btn.flagged_for_delete = False\n        return btn\n\n    @classmethod\n    def make_dropdown_widget(cls, description='Description', options=['Label 1', 'Label 2'], value='Label 1',\n                            file_path=None, layout=Layout(), handler=None):\n        \"Return a Dropdown widget with specified `handler`.\"\n        dd = widgets.Dropdown(description=description, options=options, value=value, layout=layout)\n        if file_path is not None: dd.file_path = file_path\n        if handler is not None: dd.observe(handler, names=['value'])\n        return dd\n\n    @classmethod\n    def make_horizontal_box(cls, children, layout=Layout()):\n        \"Make a horizontal box with `children` and `layout`.\"\n        return widgets.HBox(children, layout=layout)\n\n    @classmethod\n    def make_vertical_box(cls, children, layout=Layout(), duplicates=False):\n        \"Make a vertical box with `children` and `layout`.\"\n        if not duplicates: return widgets.VBox(children, layout=layout)\n        else: return widgets.VBox([children[0], children[2]], layout=layout)\n\n    def create_image_list(self, dataset, fns_idxs):\n        \"Create a list of images, filenames and labels but first removing files that are not supposed to be displayed.\"\n        items = dataset.x.items\n        if self._duplicates:\n            chunked_idxs = chunks(fns_idxs, 2)\n            chunked_idxs = [chunk for chunk in chunked_idxs if Path(items[chunk[0]]).is_file() and Path(items[chunk[1]]).is_file()]\n            return  [(dataset.x[i]._repr_jpeg_(), items[i], self._labels[dataset.y[i].data]) for chunk in chunked_idxs for i in chunk]\n        else:\n            return [(dataset.x[i]._repr_jpeg_(), items[i], self._labels[dataset.y[i].data]) for i in fns_idxs if\n                    Path(items[i]).is_file()]\n\n    def relabel(self, change):\n        \"Relabel images by moving from parent dir with old label `class_old` to parent dir with new label `class_new`.\"\n        class_new,class_old,file_path = change.new,change.old,change.owner.file_path\n        fp = Path(file_path)\n        parent = fp.parents[1]\n        self._csv_dict[fp] = class_new\n\n    def next_batch(self, _):\n        \"Handler for 'Next Batch' button click. Delete all flagged images and renders next batch.\"\n        for img_widget, delete_btn, fp, in self._batch:\n            fp = delete_btn.file_path\n            if (delete_btn.flagged_for_delete == True):\n                self.delete_image(fp)\n                self._deleted_fns.append(fp)\n        self._all_images = self._all_images[self._batch_size:]\n        self.empty_batch()\n        self.render()\n\n    def on_delete(self, btn):\n        \"Flag this image as delete or keep.\"\n        btn.button_style = \"\" if btn.flagged_for_delete else \"danger\"\n        btn.flagged_for_delete = not btn.flagged_for_delete\n\n    def empty_batch(self): self._batch[:] = []\n\n    def delete_image(self, file_path):\n        del self._csv_dict[file_path]\n\n    def empty(self):\n        return len(self._all_images) == 0\n\n    def get_widgets(self, duplicates):\n        \"Create and format widget set.\"\n        widgets = []\n        for (img,fp,human_readable_label) in self._all_images[:self._batch_size]:\n            img_widget = self.make_img_widget(img, layout=Layout(height='250px', width='300px'))\n            dropdown = self.make_dropdown_widget(description='', options=self._labels, value=human_readable_label,\n                                                 file_path=fp, handler=self.relabel, layout=Layout(width='auto'))\n            delete_btn = self.make_button_widget('Delete', file_path=fp, handler=self.on_delete)\n            widgets.append(self.make_vertical_box([img_widget, dropdown, delete_btn],\n                                                  layout=Layout(width='auto', height='300px',\n                                                      overflow_x=\"hidden\"), duplicates=duplicates))\n            self._batch.append((img_widget, delete_btn, fp))\n        return widgets\n\n    def batch_contains_deleted(self):\n        \"Check if current batch contains already deleted images.\"\n        if not self._duplicates: return False\n        imgs = [self._all_images[:self._batch_size][0][1], self._all_images[:self._batch_size][1][1]]\n        return any(img in self._deleted_fns for img in imgs)\n\n    def write_csv(self):\n        # Get first element's file path so we write CSV to same directory as our data\n        csv_path = self._path/'cleaned.csv'\n        with open(csv_path, 'w') as f:\n            csv_writer = csv.writer(f)\n            csv_writer.writerow(['name','label'])\n            for pair in self._csv_dict.items():\n                pair = [os.path.relpath(pair[0], self._path), pair[1]]\n                csv_writer.writerow(pair)\n        return csv_path\n\n    def render(self):\n        \"Re-render Jupyter cell for batch of images.\"\n        clear_output()\n        self.write_csv()\n        if self.empty() and self._skipped>0:\n            return display(f'No images to show :). {self._skipped} pairs were '\n                    f'skipped since at least one of the images was deleted by the user.')\n        elif self.empty():\n            return display('No images to show :)')\n        if self.batch_contains_deleted():\n            self.next_batch(None)\n            self._skipped += 1\n        else:\n            display(self.make_horizontal_box(self.get_widgets(self._duplicates)))\n            display(self.make_button_widget('Next Batch', handler=self.next_batch, style=\"primary\"))\n"
  },
  {
    "path": "fastai/widgets/image_downloader.py",
    "content": "from ..core import *\nfrom ..vision.data import *\nfrom ipywidgets import widgets, Layout, Output, HBox, VBox, Text, BoundedIntText, Button, Dropdown, Box\nfrom IPython.display import clear_output, display\nfrom urllib.parse import quote\nfrom bs4 import BeautifulSoup\nimport time\n\n__all__ = ['ImageDownloader', 'download_google_images']\n\n_img_sizes = {'>400*300':'isz:lt,islt:qsvga','>640*480':'isz:lt,islt:vga','>800*600':'isz:lt,islt:svga',\n              '>1024*768':'visz:lt,islt:xga', '>2MP':'isz:lt,islt:2mp','>4MP':'isz:lt,islt:4mp','>6MP':'isz:lt,islt:6mp',\n              '>8MP':'isz:lt,islt:8mp', '>10MP':'isz:lt,islt:10mp','>12MP':'isz:lt,islt:12mp','>15MP':'isz:lt,islt:15mp',\n              '>20MP':'isz:lt,islt:20mp','>40MP':'isz:lt,islt:40mp','>70MP':'isz:lt,islt:70mp'}\n\nclass ImageDownloader():\n    \"\"\"\n    Displays a widget that allows searching and downloading images from google images search\n    in a Jupyter Notebook or Lab.\n    \"\"\"\n    def __init__(self, path:Union[Path,str]='data'):\n        \"Setup path to save images to, init the UI, and render the widgets.\"\n        self._path = Path(path)\n        self._ui = self._init_ui()\n        self.render()\n\n    def _init_ui(self) -> VBox:\n        \"Initialize the widget UI and return the UI.\"\n        self._search_input = Text(placeholder=\"What images to search for?\")\n        self._count_input = BoundedIntText(placeholder=\"How many pics?\", value=10, min=1, max=5000, step=1,\n                                           layout=Layout(width='60px'))\n        self._size_input = Dropdown(options= _img_sizes.keys(), value='>400*300', layout=Layout(width='120px'))\n        self._download_button = Button(description=\"Search & Download\", icon=\"download\", layout=Layout(width='200px'))\n        self._download_button.on_click(self.on_download_button_click)\n        self._output = Output()\n        self._controls_pane  = HBox([self._search_input, self._count_input, self._size_input, self._download_button],\n                                    layout=Layout(width='auto', height='40px'))\n        self._heading = \"\"\n        self._download_complete_heading = \"<h3>Download complete. Here are a few images</h3>\"\n        self._preview_header = widgets.HTML(self._heading, layout=Layout(height='60px'))\n        self._img_pane = Box(layout=Layout(display='inline'))\n        return VBox([self._controls_pane, self._preview_header, self._img_pane])\n\n    def render(self) -> None:\n        clear_output()\n        display(self._ui)\n\n    def clear_imgs(self) -> None:\n        \"Clear the widget's images preview pane.\"\n        self._preview_header.value = self._heading\n        self._img_pane.children = tuple()\n\n    def validate_search_input(self) -> bool:\n        \"Check if input value is empty.\"\n        input = self._search_input\n        if input.value == str(): input.layout = Layout(border=\"solid 2px red\", height='auto')\n        else:                    self._search_input.layout = Layout()\n        return input.value != str()\n\n    def on_download_button_click(self, btn) -> None:\n        \"Download button click handler: validate search term and download images.\"\n        term = self._search_input.value\n        limit = int(self._count_input.value)\n        size = self._size_input.value\n        if not self.validate_search_input(): return\n        self.clear_imgs()\n        downloaded_images = download_google_images(self._path, term, n_images=limit, size=size)\n        self.display_images_widgets(downloaded_images[:min(limit, 12)])\n        self._preview_header.value = self._download_complete_heading\n        self.render()\n\n    def display_images_widgets(self, fnames:list) -> None:\n        \"Display a few preview images in the notebook\"\n        imgs = [widgets.Image(value=open(f, 'rb').read(), width='200px') for f in fnames]\n        self._img_pane.children = tuple(imgs)\n\n\ndef download_google_images(path:PathOrStr, search_term:str, size:str='>400*300', n_images:int=10, format:str='jpg',\n                            max_workers:int=defaults.cpus, timeout:int=4) -> FilePathList:\n    \"\"\"\n    Search for `n_images` images on Google, matching `search_term` and `size` requirements,\n    download them into `path`/`search_term` and verify them, using `max_workers` threads.\n    \"\"\"\n    label_path = Path(path)/search_term\n    search_url = _search_url(search_term, size=size, format=format)\n    if n_images <= 100: img_tuples = _fetch_img_tuples(search_url, format=format, n_images=n_images)\n    else:               img_tuples = _fetch_img_tuples_webdriver(search_url, format=format, n_images=n_images)\n    downloaded_images = _download_images(label_path, img_tuples, max_workers=max_workers, timeout=timeout)\n    if len(downloaded_images) == 0: raise RuntimeError(f\"Couldn't download any images.\")\n    verify_images(label_path, max_workers=max_workers)\n    return get_image_files(label_path)\n    \ndef _url_params(size:str='>400*300', format:str='jpg') -> str:\n    \"Build Google Images Search Url params and return them as a string.\"\n    _fmts = {'jpg':'ift:jpg','gif':'ift:gif','png':'ift:png','bmp':'ift:bmp', 'svg':'ift:svg','webp':'webp','ico':'ift:ico'}\n    if size not in _img_sizes: \n        raise RuntimeError(f\"\"\"Unexpected size argument value: {size}.\n                    See `widgets.image_downloader._img_sizes` for supported sizes.\"\"\") \n    if format not in _fmts: \n        raise RuntimeError(f\"Unexpected image file format: {format}. Use jpg, gif, png, bmp, svg, webp, or ico.\")\n    return \"&tbs=\" + _img_sizes[size] + \",\" + _fmts[format]\n\ndef _search_url(search_term:str, size:str='>400*300', format:str='jpg') -> str:\n    \"Return a Google Images Search URL for a given search term.\"\n    return ('https://www.google.com/search?q=' + quote(search_term) +\n            '&espv=2&biw=1366&bih=667&site=webhp&source=lnms&tbm=isch' +\n            _url_params(size, format) + '&sa=X&ei=XosDVaCXD8TasATItgE&ved=0CAcQ_AUoAg')\n\ndef _img_fname(img_url:str) -> str:\n    \"Return image file name including the extension given its url.\"\n    return img_url.split('/')[-1]\n\ndef _fetch_img_tuples(url:str, format:str='jpg', n_images:int=10) -> list:\n    \"Parse the Google Images Search for urls and return the image metadata as tuples (fname, url).\"\n    headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36'}\n    html = requests.get(url, headers=headers).text\n    return _html_to_img_tuples(html, format=format, n_images=n_images)\n\ndef _html_to_img_tuples(html:str, format:str='jpg', n_images:int=10) -> list:    \n    \"Parse the google images html to img tuples containining `(fname, url)`\"\n    bs = BeautifulSoup(html, 'html.parser')\n    img_tags = bs.find_all('div', {'class': 'rg_meta'})\n    metadata_dicts = (json.loads(e.text) for e in img_tags)\n    img_tuples = ((_img_fname(d['ou']), d['ou']) for d in metadata_dicts if d['ity'] == format)\n    return list(itertools.islice(img_tuples, n_images))\n\ndef _fetch_img_tuples_webdriver(url:str, format:str='jpg', n_images:int=150) -> list:\n    \"\"\"\n    Parse the Google Images Search for urls and return the image metadata as tuples (fname, url).\n    Use this for downloads of >100 images. Requires `selenium`.\n    \"\"\"\n    try:\n        from selenium import webdriver\n        from selenium.webdriver.common.keys import Keys\n    except:\n        print(\"\"\"Looks like you're trying to download > 100 images and `selenium`\n                is not installed. Try running `pip install selenium` to fix this. \n                You'll also need chrome and `chromedriver` installed.\"\"\")\n    options = webdriver.ChromeOptions()\n    options.add_argument(\"--headless\")\n    try: driver = webdriver.Chrome(chrome_options=options)\n    except: print(\"\"\"Error initializing chromedriver. \n                    Check if it's in your path by running `which chromedriver`\"\"\")\n    driver.set_window_size(1440, 900)\n    driver.get(url)\n\n    for i in range(n_images // 100 + 1):\n        driver.execute_script(\"window.scrollTo(0, document.body.scrollHeight)\")\n        time.sleep(0.5 + random.random()/2.0)\n\n    n_available = len(driver.find_elements_by_css_selector(\"div.rg_meta\"))\n    if n_available < n_images:\n        raise ValueError(f\"Requested {n_images} images, but only found {n_available}.\")\n\n    html = driver.page_source\n    driver.close()\n    return _html_to_img_tuples(html, format=format, n_images=n_images)\n\ndef _download_images(label_path:PathOrStr, img_tuples:list, max_workers:int=defaults.cpus, timeout:int=4) -> FilePathList:\n    \"\"\"\n    Downloads images in `img_tuples` to `label_path`. \n    If the directory doesn't exist, it'll be created automatically.\n    Uses `parallel` to speed things up in `max_workers` when the system has enough CPU cores.\n    If something doesn't work, try setting up `max_workers=0` to debug.\n    \"\"\"\n    os.makedirs(Path(label_path), exist_ok=True)\n    parallel( partial(_download_single_image, label_path, timeout=timeout), img_tuples, max_workers=max_workers)\n    return get_image_files(label_path)\n\ndef _download_single_image(label_path:Path, img_tuple:tuple, i:int, timeout:int=4) -> None:\n    \"\"\"\n    Downloads a single image from Google Search results to `label_path`\n    given an `img_tuple` that contains `(fname, url)` of an image to download.\n    `i` is just an iteration number `int`. \n    \"\"\"\n    suffix = re.findall(r'\\.\\w+?(?=(?:\\?|$))', img_Tuple[1])\n    suffix = suffix[0].lower() if len(suffix)>0  else '.jpg'\n    fname = f\"{i:08d}{suffix}\"\n    download_url(img_Tuple[1], label_path/fname, timeout=timeout)\n"
  },
  {
    "path": "fid/LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "fid/fid_score.py",
    "content": "#!/usr/bin/env python3\n\n# Code adapted and modified from https://github.com/mseitzer/pytorch-fid.  Licensing\n# and description duplicated below.\n\n\"\"\"Calculates the Frechet Inception Distance (FID) to evalulate GANs\n\nThe FID metric calculates the distance between two distributions of images.\nTypically, we have summary statistics (mean & covariance matrix) of one\nof these distributions, while the 2nd distribution is given by a GAN.\n\nWhen run as a stand-alone program, it compares the distribution of\nimages that are stored as PNG/JPEG at a specified location with a\ndistribution given by summary statistics (in pickle format).\n\nThe FID is calculated by assuming that X_1 and X_2 are the activations of\nthe pool_3 layer of the inception net for generated samples and real world\nsamples respectively.\n\nSee --help to see further details.\n\nCode apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead\nof Tensorflow\n\nCopyright 2018 Institute of Bioinformatics, JKU Linz\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n   http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport pathlib\nfrom argparse import ArgumentParser, ArgumentDefaultsHelpFormatter\n\nimport numpy as np\nimport torch\nfrom scipy import linalg\nfrom torch.nn.functional import adaptive_avg_pool2d\nimport cv2\nimport imageio\n\ntry:\n    from tqdm import tqdm\nexcept ImportError:\n    # If not tqdm is not available, provide a mock version of it\n    def tqdm(x):\n        return x\n\n\nfrom .inception import InceptionV3\n\nparser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)\nparser.add_argument(\n    'path',\n    type=str,\n    nargs=2,\n    help=('Path to the generated images or ' 'to .npz statistic files'),\n)\nparser.add_argument('--batch-size', type=int, default=50, help='Batch size to use')\nparser.add_argument(\n    '--dims',\n    type=int,\n    default=2048,\n    choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),\n    help=(\n        'Dimensionality of Inception features to use. '\n        'By default, uses pool3 features'\n    ),\n)\nparser.add_argument(\n    '-c', '--gpu', default='', type=str, help='GPU to use (leave blank for CPU only)'\n)\n\n\ndef load_image_resized(fn, sz):\n    return cv2.resize(\n        imageio.imread(str(fn)), dsize=(sz, sz), interpolation=cv2.INTER_CUBIC\n    ).astype(np.float32)\n\n\ndef get_activations(\n    files,\n    model,\n    batch_size=50,\n    dims=2048,\n    cuda=False,\n    verbose=False,\n    eval_size: int = 299,\n):\n    \"\"\"Calculates the activations of the pool_3 layer for all images.\n\n    Params:\n    -- files       : List of image files paths\n    -- model       : Instance of inception model\n    -- batch_size  : Batch size of images for the model to process at once.\n                     Make sure that the number of samples is a multiple of\n                     the batch size, otherwise some samples are ignored. This\n                     behavior is retained to match the original FID score\n                     implementation.\n    -- dims        : Dimensionality of features returned by Inception\n    -- cuda        : If set to True, use GPU\n    -- verbose     : If set to True and parameter out_step is given, the number\n                     of calculated batches is reported.\n    Returns:\n    -- A numpy array of dimension (num images, dims) that contains the\n       activations of the given tensor when feeding inception with the\n       query tensor.\n    \"\"\"\n    model.eval()\n\n    if len(files) % batch_size != 0:\n        print(\n            (\n                'Warning: number of images is not a multiple of the '\n                'batch size. Some samples are going to be ignored.'\n            )\n        )\n    if batch_size > len(files):\n        print(\n            (\n                'Warning: batch size is bigger than the data size. '\n                'Setting batch size to data size'\n            )\n        )\n        batch_size = len(files)\n\n    n_batches = len(files) // batch_size\n    n_used_imgs = n_batches * batch_size\n\n    pred_arr = np.empty((n_used_imgs, dims))\n\n    for i in tqdm(range(n_batches)):\n        if verbose:\n            print('\\rPropagating batch %d/%d' % (i + 1, n_batches), end='', flush=True)\n        start = i * batch_size\n        end = start + batch_size\n\n        images = np.array(\n            [load_image_resized(fn, eval_size) for fn in files[start:end]]\n        )\n        # images = np.array([imageio.imread(str(f)).astype(np.float32)\n        # for f in files[start:end]])\n\n        # Reshape to (n_images, 3, height, width)\n        images = images.transpose((0, 3, 1, 2))\n        images /= 255\n\n        batch = torch.from_numpy(images).type(torch.FloatTensor)\n        if cuda:\n            batch = batch.cuda()\n\n        pred = model(batch)[0]\n\n        # If model output is not scalar, apply global spatial average pooling.\n        # This happens if you choose a dimensionality not equal 2048.\n        if pred.shape[2] != 1 or pred.shape[3] != 1:\n            pred = adaptive_avg_pool2d(pred, output_size=(1, 1))\n\n        pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)\n\n    if verbose:\n        print(' done')\n\n    return pred_arr\n\n\ndef calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):\n    \"\"\"Numpy implementation of the Frechet Distance.\n    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)\n    and X_2 ~ N(mu_2, C_2) is\n            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).\n\n    Stable version by Dougal J. Sutherland.\n\n    Params:\n    -- mu1   : Numpy array containing the activations of a layer of the\n               inception net (like returned by the function 'get_predictions')\n               for generated samples.\n    -- mu2   : The sample mean over activations, precalculated on an\n               representative data set.\n    -- sigma1: The covariance matrix over activations for generated samples.\n    -- sigma2: The covariance matrix over activations, precalculated on an\n               representative data set.\n\n    Returns:\n    --   : The Frechet Distance.\n    \"\"\"\n\n    mu1 = np.atleast_1d(mu1)\n    mu2 = np.atleast_1d(mu2)\n\n    sigma1 = np.atleast_2d(sigma1)\n    sigma2 = np.atleast_2d(sigma2)\n\n    assert (\n        mu1.shape == mu2.shape\n    ), 'Training and test mean vectors have different lengths'\n    assert (\n        sigma1.shape == sigma2.shape\n    ), 'Training and test covariances have different dimensions'\n\n    diff = mu1 - mu2\n\n    # Product might be almost singular\n    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)\n    if not np.isfinite(covmean).all():\n        msg = (\n            'fid calculation produces singular product; '\n            'adding %s to diagonal of cov estimates'\n        ) % eps\n        print(msg)\n        offset = np.eye(sigma1.shape[0]) * eps\n        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))\n\n    # Numerical error might give slight imaginary component\n    if np.iscomplexobj(covmean):\n        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):\n            m = np.max(np.abs(covmean.imag))\n            raise ValueError('Imaginary component {}'.format(m))\n        covmean = covmean.real\n\n    tr_covmean = np.trace(covmean)\n\n    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean\n\n\ndef calculate_activation_statistics(\n    files, model, batch_size=50, dims=2048, cuda=False, verbose=False\n):\n    \"\"\"Calculation of the statistics used by the FID.\n    Params:\n    -- files       : List of image files paths\n    -- model       : Instance of inception model\n    -- batch_size  : The images numpy array is split into batches with\n                     batch size batch_size. A reasonable batch size\n                     depends on the hardware.\n    -- dims        : Dimensionality of features returned by Inception\n    -- cuda        : If set to True, use GPU\n    -- verbose     : If set to True and parameter out_step is given, the\n                     number of calculated batches is reported.\n    Returns:\n    -- mu    : The mean over samples of the activations of the pool_3 layer of\n               the inception model.\n    -- sigma : The covariance matrix of the activations of the pool_3 layer of\n               the inception model.\n    \"\"\"\n    act = get_activations(files, model, batch_size, dims, cuda, verbose)\n    mu = np.mean(act, axis=0)\n    sigma = np.cov(act, rowvar=False)\n    return mu, sigma\n\n\ndef _compute_statistics_of_path(path, model, batch_size, dims, cuda):\n    if path.endswith('.npz'):\n        f = np.load(path)\n        m, s = f['mu'][:], f['sigma'][:]\n        f.close()\n    else:\n        path = pathlib.Path(path)\n        files = list(path.glob('*.jpg')) + list(path.glob('*.png'))\n        m, s = calculate_activation_statistics(files, model, batch_size, dims, cuda)\n\n    return m, s\n\n\ndef calculate_fid_given_paths(paths, batch_size, cuda, dims):\n    \"\"\"Calculates the FID of two paths\"\"\"\n    for p in paths:\n        if not os.path.exists(p):\n            raise RuntimeError('Invalid path: %s' % p)\n\n    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]\n\n    model = InceptionV3([block_idx])\n    if cuda:\n        model.cuda()\n\n    m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims, cuda)\n    m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims, cuda)\n    fid_value = calculate_frechet_distance(m1, s1, m2, s2)\n\n    return fid_value\n\n\nif __name__ == '__main__':\n    args = parser.parse_args()\n    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu\n\n    fid_value = calculate_fid_given_paths(\n        args.path, args.batch_size, args.gpu != '', args.dims\n    )\n    print('FID: ', fid_value)\n"
  },
  {
    "path": "fid/inception.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torchvision import models\n\ntry:\n    from torchvision.models.utils import load_state_dict_from_url\nexcept ImportError:\n    from torch.utils.model_zoo import load_url as load_state_dict_from_url\n\n# Inception weights ported to Pytorch from\n# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz\nFID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'\n\n\nclass InceptionV3(nn.Module):\n    \"\"\"Pretrained InceptionV3 network returning feature maps\"\"\"\n\n    # Index of default block of inception to return,\n    # corresponds to output of final average pooling\n    DEFAULT_BLOCK_INDEX = 3\n\n    # Maps feature dimensionality to their output blocks indices\n    BLOCK_INDEX_BY_DIM = {\n        64: 0,  # First max pooling features\n        192: 1,  # Second max pooling featurs\n        768: 2,  # Pre-aux classifier features\n        2048: 3,  # Final average pooling features\n    }\n\n    def __init__(\n        self,\n        output_blocks=[DEFAULT_BLOCK_INDEX],\n        resize_input=True,\n        normalize_input=True,\n        requires_grad=False,\n        use_fid_inception=True,\n    ):\n        \"\"\"Build pretrained InceptionV3\n\n        Parameters\n        ----------\n        output_blocks : list of int\n            Indices of blocks to return features of. Possible values are:\n                - 0: corresponds to output of first max pooling\n                - 1: corresponds to output of second max pooling\n                - 2: corresponds to output which is fed to aux classifier\n                - 3: corresponds to output of final average pooling\n        resize_input : bool\n            If true, bilinearly resizes input to width and height 299 before\n            feeding input to model. As the network without fully connected\n            layers is fully convolutional, it should be able to handle inputs\n            of arbitrary size, so resizing might not be strictly needed\n        normalize_input : bool\n            If true, scales the input from range (0, 1) to the range the\n            pretrained Inception network expects, namely (-1, 1)\n        requires_grad : bool\n            If true, parameters of the model require gradients. Possibly useful\n            for finetuning the network\n        use_fid_inception : bool\n            If true, uses the pretrained Inception model used in Tensorflow's\n            FID implementation. If false, uses the pretrained Inception model\n            available in torchvision. The FID Inception model has different\n            weights and a slightly different structure from torchvision's\n            Inception model. If you want to compute FID scores, you are\n            strongly advised to set this parameter to true to get comparable\n            results.\n        \"\"\"\n        super(InceptionV3, self).__init__()\n\n        self.resize_input = resize_input\n        self.normalize_input = normalize_input\n        self.output_blocks = sorted(output_blocks)\n        self.last_needed_block = max(output_blocks)\n\n        assert self.last_needed_block <= 3, 'Last possible output block index is 3'\n\n        self.blocks = nn.ModuleList()\n\n        if use_fid_inception:\n            inception = fid_inception_v3()\n        else:\n            inception = models.inception_v3(pretrained=True)\n\n        # Block 0: input to maxpool1\n        block0 = [\n            inception.Conv2d_1a_3x3,\n            inception.Conv2d_2a_3x3,\n            inception.Conv2d_2b_3x3,\n            nn.MaxPool2d(kernel_size=3, stride=2),\n        ]\n        self.blocks.append(nn.Sequential(*block0))\n\n        # Block 1: maxpool1 to maxpool2\n        if self.last_needed_block >= 1:\n            block1 = [\n                inception.Conv2d_3b_1x1,\n                inception.Conv2d_4a_3x3,\n                nn.MaxPool2d(kernel_size=3, stride=2),\n            ]\n            self.blocks.append(nn.Sequential(*block1))\n\n        # Block 2: maxpool2 to aux classifier\n        if self.last_needed_block >= 2:\n            block2 = [\n                inception.Mixed_5b,\n                inception.Mixed_5c,\n                inception.Mixed_5d,\n                inception.Mixed_6a,\n                inception.Mixed_6b,\n                inception.Mixed_6c,\n                inception.Mixed_6d,\n                inception.Mixed_6e,\n            ]\n            self.blocks.append(nn.Sequential(*block2))\n\n        # Block 3: aux classifier to final avgpool\n        if self.last_needed_block >= 3:\n            block3 = [\n                inception.Mixed_7a,\n                inception.Mixed_7b,\n                inception.Mixed_7c,\n                nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n            ]\n            self.blocks.append(nn.Sequential(*block3))\n\n        for param in self.parameters():\n            param.requires_grad = requires_grad\n\n    def forward(self, inp):\n        \"\"\"Get Inception feature maps\n\n        Parameters\n        ----------\n        inp : torch.autograd.Variable\n            Input tensor of shape Bx3xHxW. Values are expected to be in\n            range (0, 1)\n\n        Returns\n        -------\n        List of torch.autograd.Variable, corresponding to the selected output\n        block, sorted ascending by index\n        \"\"\"\n        outp = []\n        x = inp\n\n        if self.resize_input:\n            x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)\n\n        if self.normalize_input:\n            x = 2 * x - 1  # Scale from range (0, 1) to range (-1, 1)\n\n        for idx, block in enumerate(self.blocks):\n            x = block(x)\n            if idx in self.output_blocks:\n                outp.append(x)\n\n            if idx == self.last_needed_block:\n                break\n\n        return outp\n\n\ndef fid_inception_v3():\n    \"\"\"Build pretrained Inception model for FID computation\n\n    The Inception model for FID computation uses a different set of weights\n    and has a slightly different structure than torchvision's Inception.\n\n    This method first constructs torchvision's Inception and then patches the\n    necessary parts that are different in the FID Inception model.\n    \"\"\"\n    inception = models.inception_v3(\n        num_classes=1008, aux_logits=False, pretrained=False\n    )\n    inception.Mixed_5b = FIDInceptionA(192, pool_features=32)\n    inception.Mixed_5c = FIDInceptionA(256, pool_features=64)\n    inception.Mixed_5d = FIDInceptionA(288, pool_features=64)\n    inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)\n    inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)\n    inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)\n    inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)\n    inception.Mixed_7b = FIDInceptionE_1(1280)\n    inception.Mixed_7c = FIDInceptionE_2(2048)\n\n    state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)\n    inception.load_state_dict(state_dict)\n    return inception\n\n\nclass FIDInceptionA(models.inception.InceptionA):\n    \"\"\"InceptionA block patched for FID computation\"\"\"\n\n    def __init__(self, in_channels, pool_features):\n        super(FIDInceptionA, self).__init__(in_channels, pool_features)\n\n    def forward(self, x):\n        branch1x1 = self.branch1x1(x)\n\n        branch5x5 = self.branch5x5_1(x)\n        branch5x5 = self.branch5x5_2(branch5x5)\n\n        branch3x3dbl = self.branch3x3dbl_1(x)\n        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)\n        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)\n\n        # Patch: Tensorflow's average pool does not use the padded zero's in\n        # its average calculation\n        branch_pool = F.avg_pool2d(\n            x, kernel_size=3, stride=1, padding=1, count_include_pad=False\n        )\n        branch_pool = self.branch_pool(branch_pool)\n\n        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]\n        return torch.cat(outputs, 1)\n\n\nclass FIDInceptionC(models.inception.InceptionC):\n    \"\"\"InceptionC block patched for FID computation\"\"\"\n\n    def __init__(self, in_channels, channels_7x7):\n        super(FIDInceptionC, self).__init__(in_channels, channels_7x7)\n\n    def forward(self, x):\n        branch1x1 = self.branch1x1(x)\n\n        branch7x7 = self.branch7x7_1(x)\n        branch7x7 = self.branch7x7_2(branch7x7)\n        branch7x7 = self.branch7x7_3(branch7x7)\n\n        branch7x7dbl = self.branch7x7dbl_1(x)\n        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)\n        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)\n        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)\n        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)\n\n        # Patch: Tensorflow's average pool does not use the padded zero's in\n        # its average calculation\n        branch_pool = F.avg_pool2d(\n            x, kernel_size=3, stride=1, padding=1, count_include_pad=False\n        )\n        branch_pool = self.branch_pool(branch_pool)\n\n        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]\n        return torch.cat(outputs, 1)\n\n\nclass FIDInceptionE_1(models.inception.InceptionE):\n    \"\"\"First InceptionE block patched for FID computation\"\"\"\n\n    def __init__(self, in_channels):\n        super(FIDInceptionE_1, self).__init__(in_channels)\n\n    def forward(self, x):\n        branch1x1 = self.branch1x1(x)\n\n        branch3x3 = self.branch3x3_1(x)\n        branch3x3 = [\n            self.branch3x3_2a(branch3x3),\n            self.branch3x3_2b(branch3x3),\n        ]\n        branch3x3 = torch.cat(branch3x3, 1)\n\n        branch3x3dbl = self.branch3x3dbl_1(x)\n        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)\n        branch3x3dbl = [\n            self.branch3x3dbl_3a(branch3x3dbl),\n            self.branch3x3dbl_3b(branch3x3dbl),\n        ]\n        branch3x3dbl = torch.cat(branch3x3dbl, 1)\n\n        # Patch: Tensorflow's average pool does not use the padded zero's in\n        # its average calculation\n        branch_pool = F.avg_pool2d(\n            x, kernel_size=3, stride=1, padding=1, count_include_pad=False\n        )\n        branch_pool = self.branch_pool(branch_pool)\n\n        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]\n        return torch.cat(outputs, 1)\n\n\nclass FIDInceptionE_2(models.inception.InceptionE):\n    \"\"\"Second InceptionE block patched for FID computation\"\"\"\n\n    def __init__(self, in_channels):\n        super(FIDInceptionE_2, self).__init__(in_channels)\n\n    def forward(self, x):\n        branch1x1 = self.branch1x1(x)\n\n        branch3x3 = self.branch3x3_1(x)\n        branch3x3 = [\n            self.branch3x3_2a(branch3x3),\n            self.branch3x3_2b(branch3x3),\n        ]\n        branch3x3 = torch.cat(branch3x3, 1)\n\n        branch3x3dbl = self.branch3x3dbl_1(x)\n        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)\n        branch3x3dbl = [\n            self.branch3x3dbl_3a(branch3x3dbl),\n            self.branch3x3dbl_3b(branch3x3dbl),\n        ]\n        branch3x3dbl = torch.cat(branch3x3dbl, 1)\n\n        # Patch: The FID Inception model uses max pooling instead of average\n        # pooling. This is likely an error in this specific Inception\n        # implementation, as other Inception models use average pooling here\n        # (which matches the description in the paper).\n        branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)\n        branch_pool = self.branch_pool(branch_pool)\n\n        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]\n        return torch.cat(outputs, 1)\n"
  },
  {
    "path": "models/.gitkeep",
    "content": "\n"
  },
  {
    "path": "requirements-colab.txt",
    "content": "fastai==1.0.60\ntensorboardX>=1.6\nffmpeg-python\nyt-dlp\nopencv-python>=4.2.0.32\nPillow\ntornado\nimgaug==0.2.6\n"
  },
  {
    "path": "requirements-dev.txt",
    "content": "black\npre-commit\n"
  },
  {
    "path": "requirements.txt",
    "content": "wandb\nfastai==1.0.60\ntensorboardX>=1.6\nffmpeg\nffmpeg-python\nyt-dlp\njupyterlab\nopencv-python>=4.2.0.32\nPillow==9.3.0\n--extra-index-url https://download.pytorch.org/whl/cu113\ntorch==1.11.0\ntorchvision==0.12.0\nipywidgets\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\n\n\ndef get_description():\n    return \"Deep Learning library for colorizing and restoring old images and video\"\n\n\n# def get_long_description():\n#     with open(\"README.md\") as f:\n#         return f.read()\n\n\ndef get_requirements():\n    with open(\"requirements.txt\") as f:\n        return f.read().splitlines()\n\n\nsetup(\n    name=\"DeOldify\",\n    version=\"0.0.1\",\n    packages=find_packages(exclude=[\"tests\"]),\n    url=\"https://github.com/jantic/DeOldify\",\n    license=\"MIT License\",\n    description=get_description(),\n    # long_description=get_long_description(),\n    # long_description_content_type=\"text/markdown\",\n    classifiers=[\n        \"Development Status :: 4 - Beta\",\n        \"Framework :: Jupyter\",\n        \"Intended Audience :: Developers\",\n        \"Intended Audience :: Science/Research\",\n        \"License :: OSI Approved :: MIT License\",\n        \"Programming Language :: Python :: 3.6\",\n        \"Programming Language :: Python :: 3.7\",\n        \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n        \"Topic :: Software Development :: Libraries :: Python Modules\",\n    ],\n    install_requires=get_requirements(),\n    python_requires=\">=3.6\",\n)\n"
  },
  {
    "path": "test_images/.gitkeep",
    "content": ""
  },
  {
    "path": "tox.ini",
    "content": "[tox]\nenvlist=static,format\nskipsdist=True\n\n[testenv]\nwhitelist_externals=\n\t/usr/bin/sh\n\t/usr/bin/test\n\n[testenv:format]\ndeps=\n\tblack\ncommands=\n\tblack -S --check deoldify\n\n[testenv:static]\ndeps=\n\t-rrequirements.txt\n\tpylint\ncommands=\n\tsh -c 'pylint --disable=W deoldify; test $(( $? & (1|2|4|32) )) = 0'\n"
  }
]